feat(02-07): gorilla/csrf integration — mount middleware, wire all forms, env-driven key

- auth.Mount(env, key) wraps csrf.Protect with locked D-14/D-24 options
- auth.LoadKeyFromEnv() reads SESSION_SECRET, hex-decodes, validates 32 bytes; fails fast on error
- ui.CSRFField(token) templ component renders hidden _csrf input
- Layout, LoginPage/Fragment, SignupPage/Fragment, Index all embed @ui.CSRFField(csrfToken)
- Handlers thread csrf.Token(r) into every page/fragment render call
- NewRouter mounts auth.Mount after ResolveSession, before all route groups (D-24)
- main.go calls auth.LoadKeyFromEnv(); logs.Fatalf on missing/invalid SESSION_SECRET
- SESSION_SECRET documented in .env.example with openssl rand -hex 32 instruction
- go.mod: gorilla/csrf v1.7.3 (direct); prior tests updated with getCSRFToken helper
- All Plan 04/05/06 tests updated to acquire and submit valid _csrf tokens
This commit is contained in:
Arthur Belleville 2026-05-14 22:59:06 +02:00
parent ae2d356f87
commit 389e1bc8b4
No known key found for this signature in database
18 changed files with 421 additions and 75 deletions

View file

@ -43,6 +43,16 @@ func main() {
os.Exit(1) os.Exit(1)
} }
// Load the CSRF authentication key from SESSION_SECRET env var (D-15).
// Fails fast with a clear message if missing or wrong length — the server
// cannot operate without a valid CSRF key (AUTH-06).
csrfKey, err := auth.LoadKeyFromEnv()
if err != nil {
slog.Error("invalid SESSION_SECRET", "err", err,
"hint", "generate with: openssl rand -hex 32")
os.Exit(1)
}
// signal.NotifyContext (Go 1.21+) is the canonical idiom — equivalent // signal.NotifyContext (Go 1.21+) is the canonical idiom — equivalent
// to signal.Notify + a channel but the resulting ctx propagates the // to signal.Notify + a channel but the resulting ctx propagates the
// cancellation through to handlers, pgxpool dialing, etc. // cancellation through to handlers, pgxpool dialing, etc.
@ -67,7 +77,7 @@ func main() {
deps := web.AuthDeps{Queries: q, Store: store, Secure: secure, Limiter: rl} deps := web.AuthDeps{Queries: q, Store: store, Secure: secure, Limiter: rl}
router := web.NewRouter(pool, "./static", deps) router := web.NewRouter(pool, "./static", deps, csrfKey, env)
srv := &http.Server{ srv := &http.Server{
Addr: ":" + port, Addr: ":" + port,

View file

@ -6,13 +6,14 @@ require (
github.com/a-h/templ v0.3.1020 github.com/a-h/templ v0.3.1020
github.com/go-chi/chi/v5 v5.2.5 github.com/go-chi/chi/v5 v5.2.5
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/gorilla/csrf v1.7.3
github.com/jackc/pgx/v5 v5.9.2 github.com/jackc/pgx/v5 v5.9.2
github.com/pressly/goose/v3 v3.27.1 github.com/pressly/goose/v3 v3.27.1
golang.org/x/crypto v0.51.0 golang.org/x/crypto v0.51.0
golang.org/x/time v0.15.0
) )
require ( require (
github.com/gorilla/csrf v1.7.3 // indirect
github.com/gorilla/securecookie v1.1.2 // indirect github.com/gorilla/securecookie v1.1.2 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
@ -23,5 +24,4 @@ require (
golang.org/x/sync v0.20.0 // indirect golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.44.0 // indirect golang.org/x/sys v0.44.0 // indirect
golang.org/x/text v0.37.0 // indirect golang.org/x/text v0.37.0 // indirect
golang.org/x/time v0.15.0 // indirect
) )

View file

@ -9,6 +9,8 @@ github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/csrf v1.7.3 h1:BHWt6FTLZAb2HtWT5KDBf6qgpZzvtbp9QWDRKZMXJC0= github.com/gorilla/csrf v1.7.3 h1:BHWt6FTLZAb2HtWT5KDBf6qgpZzvtbp9QWDRKZMXJC0=

View file

@ -0,0 +1,83 @@
package auth
import (
"encoding/hex"
"errors"
"net/http"
"os"
"github.com/gorilla/csrf"
)
// ErrCSRFKeyInvalid is returned by LoadKeyFromEnv when SESSION_SECRET is
// missing or decodes to a length other than 32 bytes.
var ErrCSRFKeyInvalid = errors.New("SESSION_SECRET must be a 64-char hex string encoding 32 bytes; generate with: openssl rand -hex 32")
// LoadKeyFromEnv reads SESSION_SECRET from the environment, hex-decodes it,
// and validates that the result is exactly 32 bytes. Returns ErrCSRFKeyInvalid
// for a missing, non-hex, or wrong-length value. The caller (cmd/web/main.go)
// should log.Fatalf on error.
func LoadKeyFromEnv() ([]byte, error) {
raw := os.Getenv("SESSION_SECRET")
if raw == "" {
return nil, ErrCSRFKeyInvalid
}
key, err := hex.DecodeString(raw)
if err != nil {
return nil, ErrCSRFKeyInvalid
}
if len(key) != 32 {
return nil, ErrCSRFKeyInvalid
}
return key, nil
}
// Mount returns a gorilla/csrf middleware configured with the locked options
// from CONTEXT D-14 / D-24:
//
// - csrf.Secure(env != "dev"): sets the Secure flag on the _gorilla_csrf
// cookie in all environments except "dev" (plain-HTTP local development).
// - csrf.SameSite(csrf.SameSiteLaxMode): SameSite=Lax interim defense.
// - csrf.Path("/"): the CSRF cookie is scoped to the entire site.
// - csrf.FieldName("_csrf"): the hidden form field name (matches @ui.CSRFField).
// - csrf.RequestHeader("X-CSRF-Token"): accepted header for HTMX hx-headers usage.
//
// The middleware is mounted AFTER auth.ResolveSession and BEFORE any route
// group (D-24, Pitfall 7).
//
// When env == "dev", requests are additionally marked as plaintext HTTP via
// csrf.PlaintextHTTPRequest so gorilla/csrf skips the Referer-based origin
// check (which only applies to TLS). This allows local development and
// integration tests running over plain HTTP to function correctly.
//
// trustedOrigins is an optional list of additional trusted origins (used in
// tests to allow localhost requests without a Referer header).
// In production, pass nil — SameSite=Lax and the CSRF cookie handle the defense.
func Mount(env string, key []byte, trustedOrigins ...string) func(http.Handler) http.Handler {
opts := []csrf.Option{
csrf.Secure(env != "dev"),
csrf.SameSite(csrf.SameSiteLaxMode),
csrf.Path("/"),
csrf.FieldName("_csrf"),
csrf.RequestHeader("X-CSRF-Token"),
}
if len(trustedOrigins) > 0 {
opts = append(opts, csrf.TrustedOrigins(trustedOrigins))
}
csrfMiddleware := csrf.Protect(key, opts...)
// In dev mode, mark every request as plaintext HTTP so gorilla/csrf skips
// the Referer-based TLS origin check. This is safe: dev mode already has
// Secure=false on the cookie, and SameSite=Lax provides interim protection.
if env == "dev" {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Tag the request as plaintext HTTP before csrf.Protect sees it.
r = csrf.PlaintextHTTPRequest(r)
csrfMiddleware(next).ServeHTTP(w, r)
})
}
}
return csrfMiddleware
}

View file

@ -0,0 +1,66 @@
package web
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"backend/internal/auth"
"backend/internal/db/sqlc"
)
func TestCSRF_Debug(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
// GET /login and collect cookies
getReq := httptest.NewRequest(http.MethodGet, "/login", nil)
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
t.Logf("GET /login status: %d", getRec.Code)
t.Logf("GET /login cookies:")
for _, c := range getRec.Result().Cookies() {
t.Logf(" %s=%s (httponly=%v, secure=%v)", c.Name, c.Value[:min(len(c.Value), 20)], c.HttpOnly, c.Secure)
}
body := getRec.Body.String()
const needle = `name="_csrf" value="`
idx := strings.Index(body, needle)
if idx == -1 {
t.Fatal("no _csrf hidden input found")
}
rest := body[idx+len(needle):]
end := strings.Index(rest, `"`)
token := rest[:end]
t.Logf("Extracted CSRF token: %s...", token[:min(len(token), 20)])
form := url.Values{"email": {"x@y.com"}, "password": {"correct-horse-12"}, "_csrf": {token}}
postReq := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range getRec.Result().Cookies() {
t.Logf("Adding cookie to POST: %s", c.Name)
postReq.AddCookie(c)
}
t.Logf("POST cookies count: %d", len(getRec.Result().Cookies()))
postRec := httptest.NewRecorder()
router.ServeHTTP(postRec, postReq)
t.Logf("POST /login status: %d", postRec.Code)
if postRec.Code == 403 {
t.Logf("403 body: %s", postRec.Body.String()[:min(len(postRec.Body.String()), 200)])
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}

View file

@ -14,13 +14,15 @@ import (
) )
// newTestRouterWithCSRF builds a router with CSRF enabled using a test key. // newTestRouterWithCSRF builds a router with CSRF enabled using a test key.
// "localhost" is added as a trusted origin so httptest requests without a
// Referer header are accepted.
func newTestRouterWithCSRF(q *sqlc.Queries, store *auth.Store) http.Handler { func newTestRouterWithCSRF(q *sqlc.Queries, store *auth.Store) http.Handler {
csrfKey := make([]byte, 32) csrfKey := make([]byte, 32)
for i := range csrfKey { for i := range csrfKey {
csrfKey[i] = byte(i + 1) csrfKey[i] = byte(i + 1)
} }
deps := AuthDeps{Queries: q, Store: store, Secure: false} deps := AuthDeps{Queries: q, Store: store, Secure: false}
return NewRouter(stubPinger{}, "./static", deps, csrfKey, "dev") return NewRouter(stubPinger{}, "./static", deps, csrfKey, "dev", "localhost")
} }
// extractCSRFToken performs a GET request and extracts the _csrf token from the // extractCSRFToken performs a GET request and extracts the _csrf token from the

View file

@ -8,6 +8,8 @@ import (
"backend/internal/auth" "backend/internal/auth"
"backend/templates" "backend/templates"
"github.com/gorilla/csrf"
) )
// HealthzHandler returns an HTTP handler that probes the supplied Pinger // HealthzHandler returns an HTTP handler that probes the supplied Pinger
@ -38,12 +40,14 @@ func HealthzHandler(pinger Pinger) http.HandlerFunc {
// IndexHandler renders the root page (templates.Index) as text/html. // IndexHandler renders the root page (templates.Index) as text/html.
// The authenticated user is pulled from the request context (set by // The authenticated user is pulled from the request context (set by
// auth.ResolveSession) and passed to the template so the layout header can // auth.ResolveSession) and passed to the template so the layout header can
// render the logout button and the page can show the user's email. // render the logout button and email.
// csrf.Token(r) is threaded into the template so the logout form includes the
// hidden _csrf field (AUTH-06, D-14).
func IndexHandler() http.HandlerFunc { func IndexHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
_, user, _ := auth.Authed(r.Context()) _, user, _ := auth.Authed(r.Context())
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Content-Type", "text/html; charset=utf-8")
_ = templates.Index(user).Render(r.Context(), w) _ = templates.Index(user, csrf.Token(r)).Render(r.Context(), w)
} }
} }

View file

@ -12,6 +12,7 @@ import (
"backend/internal/db/sqlc" "backend/internal/db/sqlc"
"backend/templates" "backend/templates"
"github.com/gorilla/csrf"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn"
) )
@ -49,7 +50,7 @@ func clientIP(r *http.Request) string {
func SignupPageHandler() http.HandlerFunc { func SignupPageHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Content-Type", "text/html; charset=utf-8")
_ = templates.SignupPage(templates.SignupForm{}, templates.SignupErrors{}).Render(r.Context(), w) _ = templates.SignupPage(templates.SignupForm{}, templates.SignupErrors{}, csrf.Token(r)).Render(r.Context(), w)
} }
} }
@ -61,11 +62,13 @@ func SignupPageHandler() http.HandlerFunc {
// - Email is not logged on validation errors (T-2-18). // - Email is not logged on validation errors (T-2-18).
// - Duplicate email is detected via pgconn error code 23505 (T-2-19). // - Duplicate email is detected via pgconn error code 23505 (T-2-19).
// - A fresh session token is created on every signup (T-2-04). // - A fresh session token is created on every signup (T-2-04).
// - Form values are read via r.PostFormValue ONLY (never r.Body) so gorilla/csrf
// body consumption does not interfere (Pitfall 1, T-2-08c).
func SignupPostHandler(deps AuthDeps) http.HandlerFunc { func SignupPostHandler(deps AuthDeps) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
// 1. Read form values. // 1. Read form values via r.PostFormValue (Pitfall 1: never read r.Body directly).
email := strings.TrimSpace(r.PostFormValue("email")) email := strings.TrimSpace(r.PostFormValue("email"))
password := r.PostFormValue("password") password := r.PostFormValue("password")
@ -147,9 +150,9 @@ func renderSignupError(w http.ResponseWriter, r *http.Request, form templates.Si
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(status) w.WriteHeader(status)
if r.Header.Get("HX-Request") == "true" { if r.Header.Get("HX-Request") == "true" {
_ = templates.SignupFormFragment(form, errs).Render(r.Context(), w) _ = templates.SignupFormFragment(form, errs, csrf.Token(r)).Render(r.Context(), w)
} else { } else {
_ = templates.SignupPage(form, errs).Render(r.Context(), w) _ = templates.SignupPage(form, errs, csrf.Token(r)).Render(r.Context(), w)
} }
} }
@ -157,7 +160,7 @@ func renderSignupError(w http.ResponseWriter, r *http.Request, form templates.Si
func LoginPageHandler() http.HandlerFunc { func LoginPageHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Content-Type", "text/html; charset=utf-8")
_ = templates.LoginPage(templates.LoginForm{}, templates.LoginErrors{}).Render(r.Context(), w) _ = templates.LoginPage(templates.LoginForm{}, templates.LoginErrors{}, csrf.Token(r)).Render(r.Context(), w)
} }
} }
@ -176,11 +179,13 @@ func LoginPageHandler() http.HandlerFunc {
// wrong password to prevent user enumeration (D-20, T-2-03). // wrong password to prevent user enumeration (D-20, T-2-03).
// - Password is never logged (T-2-21). // - Password is never logged (T-2-21).
// - Session rotated on every successful login (T-2-04, D-10). // - Session rotated on every successful login (T-2-04, D-10).
// - Form values are read via r.PostFormValue ONLY (never r.Body) so gorilla/csrf
// body consumption does not interfere (Pitfall 1, T-2-08c).
func LoginPostHandler(deps AuthDeps) http.HandlerFunc { func LoginPostHandler(deps AuthDeps) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
// 1. Read form values. // 1. Read form values via r.PostFormValue (Pitfall 1: never read r.Body directly).
email := strings.TrimSpace(r.PostFormValue("email")) email := strings.TrimSpace(r.PostFormValue("email"))
password := r.PostFormValue("password") password := r.PostFormValue("password")
@ -271,9 +276,9 @@ func renderLoginError(w http.ResponseWriter, r *http.Request, form templates.Log
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(status) w.WriteHeader(status)
if r.Header.Get("HX-Request") == "true" { if r.Header.Get("HX-Request") == "true" {
_ = templates.LoginFormFragment(form, errs).Render(r.Context(), w) _ = templates.LoginFormFragment(form, errs, csrf.Token(r)).Render(r.Context(), w)
} else { } else {
_ = templates.LoginPage(form, errs).Render(r.Context(), w) _ = templates.LoginPage(form, errs, csrf.Token(r)).Render(r.Context(), w)
} }
} }
@ -286,6 +291,7 @@ func renderLoginError(w http.ResponseWriter, r *http.Request, form templates.Log
// - Store.Delete hard-deletes the session row (D-06, T-2-07). // - Store.Delete hard-deletes the session row (D-06, T-2-07).
// - ClearSessionCookie sets Max-Age=-1 to expire the browser cookie (D-06). // - ClearSessionCookie sets Max-Age=-1 to expire the browser cookie (D-06).
// - HTMX requests receive 200 + HX-Redirect; plain requests receive 303 (D-22). // - HTMX requests receive 200 + HX-Redirect; plain requests receive 303 (D-22).
// - CSRF token validated by gorilla/csrf middleware before this handler runs (AUTH-06).
func LogoutHandler(deps AuthDeps) http.HandlerFunc { func LogoutHandler(deps AuthDeps) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
// Defense-in-depth: RequireAuth already gates this route, but guard here // Defense-in-depth: RequireAuth already gates this route, but guard here

View file

@ -17,17 +17,62 @@ import (
"backend/internal/db/sqlc" "backend/internal/db/sqlc"
) )
// testCSRFKey is a fixed 32-byte key used by all test routers. It is NOT a
// real key — safe to have here because it is test-only and dev-env scoped.
var testCSRFKey = func() []byte {
key := make([]byte, 32)
for i := range key {
key[i] = byte(i + 1)
}
return key
}()
// newTestRouter builds a router backed by a real DB for integration tests. // newTestRouter builds a router backed by a real DB for integration tests.
// CSRF is enabled with a fixed test key and env="dev" (Secure=false on cookie).
// "localhost" is added as a trusted origin so httptest requests without a
// Referer header are accepted.
func newTestRouter(q *sqlc.Queries, store *auth.Store) http.Handler { func newTestRouter(q *sqlc.Queries, store *auth.Store) http.Handler {
deps := AuthDeps{Queries: q, Store: store, Secure: false} deps := AuthDeps{Queries: q, Store: store, Secure: false}
return NewRouter(stubPinger{}, "./static", deps) return NewRouter(stubPinger{}, "./static", deps, testCSRFKey, "dev", "localhost")
} }
// newTestRouterWithLimiter builds a router with an injected LimiterStore, // newTestRouterWithLimiter builds a router with an injected LimiterStore,
// enabling rate-limit tests to use a fake clock. // enabling rate-limit tests to use a fake clock.
func newTestRouterWithLimiter(q *sqlc.Queries, store *auth.Store, rl *auth.LimiterStore) http.Handler { func newTestRouterWithLimiter(q *sqlc.Queries, store *auth.Store, rl *auth.LimiterStore) http.Handler {
deps := AuthDeps{Queries: q, Store: store, Secure: false, Limiter: rl} deps := AuthDeps{Queries: q, Store: store, Secure: false, Limiter: rl}
return NewRouter(stubPinger{}, "./static", deps) return NewRouter(stubPinger{}, "./static", deps, testCSRFKey, "dev", "localhost")
}
// getCSRFToken performs a GET request to path and extracts the CSRF token
// from the rendered form HTML. Returns the token string and any Set-Cookie
// headers (including the gorilla_csrf cookie) from the response.
func getCSRFToken(t *testing.T, router http.Handler, path string, cookies []*http.Cookie) (token string, respCookies []*http.Cookie) {
t.Helper()
req := httptest.NewRequest(http.MethodGet, path, nil)
for _, c := range cookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
body := rec.Body.String()
const needle = `name="_csrf" value="`
idx := strings.Index(body, needle)
if idx == -1 {
t.Fatalf("getCSRFToken: _csrf hidden input not found in GET %s; body snippet: %.300s", path, body)
}
rest := body[idx+len(needle):]
end := strings.Index(rest, `"`)
if end == -1 {
t.Fatalf("getCSRFToken: closing quote not found for _csrf in GET %s", path)
}
token = rest[:end]
respCookies = append(respCookies, cookies...)
for _, c := range rec.Result().Cookies() {
respCookies = append(respCookies, c)
}
return token, respCookies
} }
// preInsertUser inserts a user with TestParams-hashed password directly via sqlc // preInsertUser inserts a user with TestParams-hashed password directly via sqlc
@ -81,9 +126,13 @@ func TestSignup_Success(t *testing.T) {
store := auth.NewStore(q) store := auth.NewStore(q)
router := newTestRouter(q, store) router := newTestRouter(q, store)
form := url.Values{"email": {"alice@example.com"}, "password": {"correct-horse-12"}} csrfToken, csrfCookies := getCSRFToken(t, router, "/signup", nil)
form := url.Values{"email": {"alice@example.com"}, "password": {"correct-horse-12"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
@ -138,10 +187,14 @@ func TestSignup_Success_HTMX(t *testing.T) {
store := auth.NewStore(q) store := auth.NewStore(q)
router := newTestRouter(q, store) router := newTestRouter(q, store)
form := url.Values{"email": {"bob@example.com"}, "password": {"correct-horse-12"}} csrfToken, csrfCookies := getCSRFToken(t, router, "/signup", nil)
form := url.Values{"email": {"bob@example.com"}, "password": {"correct-horse-12"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("HX-Request", "true") req.Header.Set("HX-Request", "true")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
@ -162,9 +215,13 @@ func TestSignup_InvalidEmail(t *testing.T) {
store := auth.NewStore(q) store := auth.NewStore(q)
router := newTestRouter(q, store) router := newTestRouter(q, store)
form := url.Values{"email": {"not-an-email"}, "password": {"correct-horse-12"}} csrfToken, csrfCookies := getCSRFToken(t, router, "/signup", nil)
form := url.Values{"email": {"not-an-email"}, "password": {"correct-horse-12"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
@ -195,9 +252,13 @@ func TestSignup_PasswordTooShort(t *testing.T) {
router := newTestRouter(q, store) router := newTestRouter(q, store)
// 11 chars — below the 12-char minimum. // 11 chars — below the 12-char minimum.
form := url.Values{"email": {"carol@example.com"}, "password": {"short12345!"}} csrfToken, csrfCookies := getCSRFToken(t, router, "/signup", nil)
form := url.Values{"email": {"carol@example.com"}, "password": {"short12345!"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
@ -227,9 +288,13 @@ func TestSignup_PasswordTooLong(t *testing.T) {
router := newTestRouter(q, store) router := newTestRouter(q, store)
longPw := strings.Repeat("a", 129) // 129 chars — above the 128-char maximum. longPw := strings.Repeat("a", 129) // 129 chars — above the 128-char maximum.
form := url.Values{"email": {"dave@example.com"}, "password": {longPw}} csrfToken, csrfCookies := getCSRFToken(t, router, "/signup", nil)
form := url.Values{"email": {"dave@example.com"}, "password": {longPw}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
@ -262,9 +327,13 @@ func TestSignup_DuplicateEmail(t *testing.T) {
// Pre-insert a user using TestParams to avoid the slow DefaultParams hash. // Pre-insert a user using TestParams to avoid the slow DefaultParams hash.
preInsertUser(t, ctx, q, "eve@example.com", "correct-horse-12") preInsertUser(t, ctx, q, "eve@example.com", "correct-horse-12")
form := url.Values{"email": {"eve@example.com"}, "password": {"correct-horse-12"}} csrfToken, csrfCookies := getCSRFToken(t, router, "/signup", nil)
form := url.Values{"email": {"eve@example.com"}, "password": {"correct-horse-12"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
@ -298,9 +367,13 @@ func TestSignup_EmailNormalized(t *testing.T) {
router := newTestRouter(q, store) router := newTestRouter(q, store)
// Uppercase + whitespace email — must be stored trimmed and lowercased. // Uppercase + whitespace email — must be stored trimmed and lowercased.
form := url.Values{"email": {" Frank@Example.COM "}, "password": {"correct-horse-12"}} csrfToken, csrfCookies := getCSRFToken(t, router, "/signup", nil)
form := url.Values{"email": {" Frank@Example.COM "}, "password": {"correct-horse-12"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
@ -367,9 +440,13 @@ func TestLogin_Success(t *testing.T) {
user := preInsertUser(t, ctx, q, "test@example.com", "correct-horse-12chars") user := preInsertUser(t, ctx, q, "test@example.com", "correct-horse-12chars")
form := url.Values{"email": {"test@example.com"}, "password": {"correct-horse-12chars"}} csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"test@example.com"}, "password": {"correct-horse-12chars"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
@ -406,10 +483,14 @@ func TestLogin_Success_HTMX(t *testing.T) {
preInsertUser(t, ctx, q, "test2@example.com", "correct-horse-12chars") preInsertUser(t, ctx, q, "test2@example.com", "correct-horse-12chars")
form := url.Values{"email": {"test2@example.com"}, "password": {"correct-horse-12chars"}} csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"test2@example.com"}, "password": {"correct-horse-12chars"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("HX-Request", "true") req.Header.Set("HX-Request", "true")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
@ -433,9 +514,13 @@ func TestLogin_WrongPassword(t *testing.T) {
preInsertUser(t, ctx, q, "testpw@example.com", "correct-horse-12chars") preInsertUser(t, ctx, q, "testpw@example.com", "correct-horse-12chars")
form := url.Values{"email": {"testpw@example.com"}, "password": {"wrong-password-12chars"}} csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"testpw@example.com"}, "password": {"wrong-password-12chars"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
@ -464,9 +549,13 @@ func TestLogin_UnknownEmail(t *testing.T) {
store := auth.NewStore(q) store := auth.NewStore(q)
router := newTestRouter(q, store) router := newTestRouter(q, store)
form := url.Values{"email": {"nouser@example.com"}, "password": {"correct-horse-12chars"}} csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"nouser@example.com"}, "password": {"correct-horse-12chars"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
@ -488,9 +577,13 @@ func TestLogin_ValidationError_BadEmail(t *testing.T) {
store := auth.NewStore(q) store := auth.NewStore(q)
router := newTestRouter(q, store) router := newTestRouter(q, store)
form := url.Values{"email": {"not-an-email"}, "password": {"correct-horse-12chars"}} csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"not-an-email"}, "password": {"correct-horse-12chars"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
@ -511,9 +604,13 @@ func TestLogin_ValidationError_ShortPassword(t *testing.T) {
store := auth.NewStore(q) store := auth.NewStore(q)
router := newTestRouter(q, store) router := newTestRouter(q, store)
form := url.Values{"email": {"testval@example.com"}, "password": {"shortpw12"}} csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"testval@example.com"}, "password": {"shortpw12"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
@ -542,10 +639,14 @@ func TestLogin_RotatesExistingSession(t *testing.T) {
t.Fatalf("store.Create: %v", err) t.Fatalf("store.Create: %v", err)
} }
form := url.Values{"email": {"rotatetest@example.com"}, "password": {"correct-horse-12chars"}} csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"rotatetest@example.com"}, "password": {"correct-horse-12chars"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: oldCookieValue}) req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: oldCookieValue})
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
@ -625,11 +726,15 @@ func TestLogin_RateLimit_6thAttemptReturns429(t *testing.T) {
preInsertUser(t, ctx, q, "ratelimit@example.com", "correct-horse-12chars") preInsertUser(t, ctx, q, "ratelimit@example.com", "correct-horse-12chars")
for i := 1; i <= 6; i++ { for i := 1; i <= 6; i++ {
form := url.Values{"email": {"ratelimit@example.com"}, "password": {"wrong-password-12chars"}} csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"ratelimit@example.com"}, "password": {"wrong-password-12chars"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// Set RemoteAddr to a known IP so chimw.RealIP won't change it. // Set RemoteAddr to a known IP so chimw.RealIP won't change it.
req.RemoteAddr = "192.168.1.1:12345" req.RemoteAddr = "192.168.1.1:12345"
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
@ -666,11 +771,15 @@ func TestLogin_RateLimit_6thAttemptHTMXNoFullPage(t *testing.T) {
preInsertUser(t, ctx, q, "ratelimithtmx@example.com", "correct-horse-12chars") preInsertUser(t, ctx, q, "ratelimithtmx@example.com", "correct-horse-12chars")
for i := 1; i <= 6; i++ { for i := 1; i <= 6; i++ {
form := url.Values{"email": {"ratelimithtmx@example.com"}, "password": {"wrong-password-12chars"}} csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"ratelimithtmx@example.com"}, "password": {"wrong-password-12chars"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("HX-Request", "true") req.Header.Set("HX-Request", "true")
req.RemoteAddr = "192.168.1.2:12345" req.RemoteAddr = "192.168.1.2:12345"
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
@ -704,19 +813,27 @@ func TestLogin_RateLimit_KeyedByEmailPlusIP(t *testing.T) {
// Exhaust emailA from IP1. // Exhaust emailA from IP1.
for i := 0; i < 6; i++ { for i := 0; i < 6; i++ {
form := url.Values{"email": {"emailA@example.com"}, "password": {"wrong-password-12chars"}} csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"emailA@example.com"}, "password": {"wrong-password-12chars"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.RemoteAddr = "10.0.0.1:1234" req.RemoteAddr = "10.0.0.1:1234"
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
} }
// emailB from same IP1 should still be allowed (separate key). // emailB from same IP1 should still be allowed (separate key).
form := url.Values{"email": {"emailB@example.com"}, "password": {"wrong-password-12chars"}} csrfTokenB, csrfCookiesB := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"emailB@example.com"}, "password": {"wrong-password-12chars"}, "_csrf": {csrfTokenB}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.RemoteAddr = "10.0.0.1:1234" req.RemoteAddr = "10.0.0.1:1234"
for _, c := range csrfCookiesB {
req.AddCookie(c)
}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
@ -738,10 +855,14 @@ func TestLogin_RateLimit_AppliesBeforeUserLookup(t *testing.T) {
// Use an email that does NOT exist in the DB. // Use an email that does NOT exist in the DB.
for i := 0; i < 6; i++ { for i := 0; i < 6; i++ {
form := url.Values{"email": {"nonexistent@example.com"}, "password": {"wrong-password-12chars"}} csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"nonexistent@example.com"}, "password": {"wrong-password-12chars"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.RemoteAddr = "10.0.0.2:1234" req.RemoteAddr = "10.0.0.2:1234"
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
@ -772,8 +893,13 @@ func TestLogout_Success(t *testing.T) {
} }
sessionID := hashCookieValue(t, cookieValue) sessionID := hashCookieValue(t, cookieValue)
req := httptest.NewRequest(http.MethodPost, "/logout", nil) sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: cookieValue}
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: cookieValue}) csrfToken, csrfCookies := getCSRFToken(t, router, "/", []*http.Cookie{sessionCookie})
req := httptest.NewRequest(http.MethodPost, "/logout", strings.NewReader(url.Values{"_csrf": {csrfToken}}.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
@ -819,18 +945,19 @@ func TestLogout_UnauthRedirectsToLogin(t *testing.T) {
store := auth.NewStore(q) store := auth.NewStore(q)
router := newTestRouter(q, store) router := newTestRouter(q, store)
// POST /logout with NO cookie — RequireAuth must block it. // POST /logout with NO cookie and NO CSRF token.
// gorilla/csrf runs before RequireAuth, so a missing CSRF token yields 403
// before RequireAuth can redirect to /login. Both 403 and 303 are acceptable
// here — the important invariant is that the request is rejected (not 500)
// and no logout side-effect occurs.
req := httptest.NewRequest(http.MethodPost, "/logout", nil) req := httptest.NewRequest(http.MethodPost, "/logout", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
// Must redirect to /login (from RequireAuth), NOT a 500. // 403 (csrf rejected) or 303 (RequireAuth redirected) — both are correct.
if rec.Code != http.StatusSeeOther { if rec.Code != http.StatusForbidden && rec.Code != http.StatusSeeOther {
t.Fatalf("status = %d; want 303 (RequireAuth redirect)", rec.Code) t.Fatalf("status = %d; want 403 (CSRF) or 303 (RequireAuth)", rec.Code)
}
if loc := rec.Header().Get("Location"); loc != "/login" {
t.Errorf("Location = %q; want /login", loc)
} }
} }
@ -849,9 +976,14 @@ func TestLogout_HXRedirect(t *testing.T) {
t.Fatalf("store.Create: %v", err) t.Fatalf("store.Create: %v", err)
} }
req := httptest.NewRequest(http.MethodPost, "/logout", nil) sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: cookieValue}
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: cookieValue}) csrfToken, csrfCookies := getCSRFToken(t, router, "/", []*http.Cookie{sessionCookie})
req := httptest.NewRequest(http.MethodPost, "/logout", strings.NewReader(url.Values{"_csrf": {csrfToken}}.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("HX-Request", "true") req.Header.Set("HX-Request", "true")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
@ -879,9 +1011,14 @@ func TestLogout_AfterLogoutSubsequentRequestUnauth(t *testing.T) {
t.Fatalf("store.Create: %v", err) t.Fatalf("store.Create: %v", err)
} }
// Logout first. // Logout first — need to get a CSRF token from the protected page.
logoutReq := httptest.NewRequest(http.MethodPost, "/logout", nil) sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: cookieValue}
logoutReq.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: cookieValue}) csrfToken, csrfCookies := getCSRFToken(t, router, "/", []*http.Cookie{sessionCookie})
logoutReq := httptest.NewRequest(http.MethodPost, "/logout", strings.NewReader(url.Values{"_csrf": {csrfToken}}.Encode()))
logoutReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
logoutReq.AddCookie(c)
}
logoutRec := httptest.NewRecorder() logoutRec := httptest.NewRecorder()
router.ServeHTTP(logoutRec, logoutReq) router.ServeHTTP(logoutRec, logoutReq)

View file

@ -66,7 +66,7 @@ func TestHealthz_Down(t *testing.T) {
// was public. The HTMX demo content is tested by // was public. The HTMX demo content is tested by
// TestProtected_HomeAuthRendersUserEmail in handlers_auth_test.go. // TestProtected_HomeAuthRendersUserEmail in handlers_auth_test.go.
func TestIndex_UnauthRedirects(t *testing.T) { func TestIndex_UnauthRedirects(t *testing.T) {
router := NewRouter(stubPinger{}, "./static", AuthDeps{}) router := NewRouter(stubPinger{}, "./static", AuthDeps{}, testCSRFKey, "dev")
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
@ -81,7 +81,7 @@ func TestIndex_UnauthRedirects(t *testing.T) {
} }
func TestDemoTime_Fragment(t *testing.T) { func TestDemoTime_Fragment(t *testing.T) {
router := NewRouter(stubPinger{}, "./static", AuthDeps{}) router := NewRouter(stubPinger{}, "./static", AuthDeps{}, testCSRFKey, "dev")
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/demo/time", nil) req := httptest.NewRequest(http.MethodGet, "/demo/time", nil)
@ -104,7 +104,7 @@ func TestDemoTime_Fragment(t *testing.T) {
} }
func TestRequestID_HeaderSet(t *testing.T) { func TestRequestID_HeaderSet(t *testing.T) {
router := NewRouter(stubPinger{}, "./static", AuthDeps{}) router := NewRouter(stubPinger{}, "./static", AuthDeps{}, testCSRFKey, "dev")
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/healthz", nil) req := httptest.NewRequest(http.MethodGet, "/healthz", nil)

View file

@ -26,8 +26,8 @@ type Pinger interface {
// 2. chi RealIP // 2. chi RealIP
// 3. SlogLoggerMiddleware (REPLACES chi's middleware.Logger — Pitfall 6) // 3. SlogLoggerMiddleware (REPLACES chi's middleware.Logger — Pitfall 6)
// 4. chi Recoverer (after Logger so panics carry request_id) // 4. chi Recoverer (after Logger so panics carry request_id)
// 5. auth.ResolveSession (reads session cookie, attaches user to context) // 5. auth.ResolveSession (reads session cookie, attaches user to context) — D-24
// NOTE: csrf.Protect is added in Plan 07. // 6. auth.Mount (gorilla/csrf — MUST come after ResolveSession, before routes) — D-24, Pitfall 7
// //
// Routes: GET / · GET /healthz · GET /demo/time · GET /static/* // Routes: GET / · GET /healthz · GET /demo/time · GET /static/*
// GET /signup (auth pages, behind RedirectIfAuthed) · POST /signup. // GET /signup (auth pages, behind RedirectIfAuthed) · POST /signup.
@ -36,13 +36,24 @@ type Pinger interface {
// //
// deps.Store may be nil during unit tests for Phase 1 routes (those routes // deps.Store may be nil during unit tests for Phase 1 routes (those routes
// never exercise session resolution). ResolveSession guards against nil Store. // never exercise session resolution). ResolveSession guards against nil Store.
func NewRouter(pinger Pinger, staticDir string, deps AuthDeps) http.Handler { //
// csrfKey is the 32-byte CSRF authentication key loaded from SESSION_SECRET.
// env is the runtime environment string (e.g. "dev", "development", "production").
// When env == "dev", the CSRF cookie Secure flag is disabled for plain-HTTP
// local development (D-15, D-24).
// trustedOrigins is an optional list of additional origins for the CSRF
// referer check (used in integration tests to allow localhost requests without
// a Referer header). In production, pass no extra args — leave empty.
func NewRouter(pinger Pinger, staticDir string, deps AuthDeps, csrfKey []byte, env string, trustedOrigins ...string) http.Handler {
r := chi.NewRouter() r := chi.NewRouter()
r.Use(RequestIDMiddleware) r.Use(RequestIDMiddleware)
r.Use(chimw.RealIP) r.Use(chimw.RealIP)
r.Use(SlogLoggerMiddleware(slog.Default())) r.Use(SlogLoggerMiddleware(slog.Default()))
r.Use(chimw.Recoverer) r.Use(chimw.Recoverer)
// D-24 locked order: ResolveSession BEFORE csrf.Protect (auth.Mount).
r.Use(auth.ResolveSession(deps.Store)) r.Use(auth.ResolveSession(deps.Store))
// D-24: gorilla/csrf runs after ResolveSession and before all route groups (Pitfall 7).
r.Use(auth.Mount(env, csrfKey, trustedOrigins...))
// Auth pages — redirect to / if already authenticated. // Auth pages — redirect to / if already authenticated.
r.Group(func(r chi.Router) { r.Group(func(r chi.Router) {

View file

@ -0,0 +1,9 @@
package ui
// CSRFField renders the hidden CSRF token input required by gorilla/csrf.
// Every <form method="POST"> must include @ui.CSRFField(csrfToken) as its
// first child so that the middleware can validate the double-submit cookie
// (D-15, AUTH-06).
templ CSRFField(token string) {
<input type="hidden" name="_csrf" value={ token }/>
}

View file

@ -5,13 +5,13 @@ import "backend/internal/web/ui"
// LoginPage renders the full /login page wrapped in the base Layout. // LoginPage renders the full /login page wrapped in the base Layout.
// It delegates the form section to LoginFormFragment so HTMX can swap just the // It delegates the form section to LoginFormFragment so HTMX can swap just the
// form on validation errors without re-rendering the surrounding shell. // form on validation errors without re-rendering the surrounding shell.
templ LoginPage(form LoginForm, errs LoginErrors) { templ LoginPage(form LoginForm, errs LoginErrors, csrfToken string) {
@Layout("Sign in", nil) { @Layout("Sign in", nil, csrfToken) {
<div class="flex min-h-[60vh] items-start justify-center pt-16"> <div class="flex min-h-[60vh] items-start justify-center pt-16">
@ui.Card(nil) { @ui.Card(nil) {
<div class="w-full max-w-sm px-6 py-8"> <div class="w-full max-w-sm px-6 py-8">
<h1 class="mb-6 text-2xl font-semibold">Sign in to your account</h1> <h1 class="mb-6 text-2xl font-semibold">Sign in to your account</h1>
@LoginFormFragment(form, errs) @LoginFormFragment(form, errs, csrfToken)
</div> </div>
} }
</div> </div>
@ -22,7 +22,7 @@ templ LoginPage(form LoginForm, errs LoginErrors) {
// hx-post targets this component itself so the form can be replaced inline // hx-post targets this component itself so the form can be replaced inline
// on validation failure (D-19, D-20). // on validation failure (D-19, D-20).
// The outer id="login-form" must match the hx-target on this element. // The outer id="login-form" must match the hx-target on this element.
templ LoginFormFragment(form LoginForm, errs LoginErrors) { templ LoginFormFragment(form LoginForm, errs LoginErrors, csrfToken string) {
<form <form
id="login-form" id="login-form"
method="POST" method="POST"
@ -32,7 +32,7 @@ templ LoginFormFragment(form LoginForm, errs LoginErrors) {
hx-swap="outerHTML" hx-swap="outerHTML"
class="space-y-5" class="space-y-5"
> >
<!-- CSRF field added in Plan 07 --> @ui.CSRFField(csrfToken)
@GeneralError(errs.General) @GeneralError(errs.General)
<div> <div>
<label for="email" class="block text-sm font-medium text-slate-700">Email address</label> <label for="email" class="block text-sm font-medium text-slate-700">Email address</label>

View file

@ -5,13 +5,13 @@ import "backend/internal/web/ui"
// SignupPage renders the full /signup page wrapped in the base Layout. // SignupPage renders the full /signup page wrapped in the base Layout.
// It delegates the form section to SignupFormFragment so HTMX can swap just the // It delegates the form section to SignupFormFragment so HTMX can swap just the
// form on validation errors without re-rendering the surrounding shell. // form on validation errors without re-rendering the surrounding shell.
templ SignupPage(form SignupForm, errs SignupErrors) { templ SignupPage(form SignupForm, errs SignupErrors, csrfToken string) {
@Layout("Sign up", nil) { @Layout("Sign up", nil, csrfToken) {
<div class="flex min-h-[60vh] items-start justify-center pt-16"> <div class="flex min-h-[60vh] items-start justify-center pt-16">
@ui.Card(nil) { @ui.Card(nil) {
<div class="w-full max-w-sm px-6 py-8"> <div class="w-full max-w-sm px-6 py-8">
<h1 class="mb-6 text-2xl font-semibold">Create your account</h1> <h1 class="mb-6 text-2xl font-semibold">Create your account</h1>
@SignupFormFragment(form, errs) @SignupFormFragment(form, errs, csrfToken)
</div> </div>
} }
</div> </div>
@ -22,7 +22,7 @@ templ SignupPage(form SignupForm, errs SignupErrors) {
// hx-post targets this component itself so the form can be replaced inline // hx-post targets this component itself so the form can be replaced inline
// on validation failure (D-19, D-25). // on validation failure (D-19, D-25).
// The outer id="signup-form" must match the hx-target on this element. // The outer id="signup-form" must match the hx-target on this element.
templ SignupFormFragment(form SignupForm, errs SignupErrors) { templ SignupFormFragment(form SignupForm, errs SignupErrors, csrfToken string) {
<form <form
id="signup-form" id="signup-form"
method="POST" method="POST"
@ -32,7 +32,7 @@ templ SignupFormFragment(form SignupForm, errs SignupErrors) {
hx-swap="outerHTML" hx-swap="outerHTML"
class="space-y-5" class="space-y-5"
> >
<!-- CSRF field added in Plan 07 --> @ui.CSRFField(csrfToken)
@GeneralError(errs.General) @GeneralError(errs.General)
<div> <div>
<label for="email" class="block text-sm font-medium text-slate-700">Email address</label> <label for="email" class="block text-sm font-medium text-slate-700">Email address</label>

View file

@ -11,7 +11,7 @@ import (
// expected form attributes and that email value round-trips correctly. // expected form attributes and that email value round-trips correctly.
func TestSignupPage_RendersForm(t *testing.T) { func TestSignupPage_RendersForm(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
err := SignupPage(SignupForm{Email: "x@y.z"}, SignupErrors{}).Render(context.Background(), &buf) err := SignupPage(SignupForm{Email: "x@y.z"}, SignupErrors{}, "testtoken").Render(context.Background(), &buf)
if err != nil { if err != nil {
t.Fatalf("SignupPage.Render: %v", err) t.Fatalf("SignupPage.Render: %v", err)
} }
@ -23,6 +23,7 @@ func TestSignupPage_RendersForm(t *testing.T) {
`action="/signup"`, `action="/signup"`,
`hx-post="/signup"`, `hx-post="/signup"`,
`value="x@y.z"`, `value="x@y.z"`,
`name="_csrf"`,
} { } {
if !strings.Contains(body, want) { if !strings.Contains(body, want) {
t.Errorf("SignupPage body missing %q", want) t.Errorf("SignupPage body missing %q", want)
@ -36,7 +37,7 @@ func TestSignupPage_RendersForm(t *testing.T) {
func TestSignupFormFragment_RendersErrors(t *testing.T) { func TestSignupFormFragment_RendersErrors(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
errs := SignupErrors{Password: "Password must be 12-128 characters"} errs := SignupErrors{Password: "Password must be 12-128 characters"}
err := SignupFormFragment(SignupForm{}, errs).Render(context.Background(), &buf) err := SignupFormFragment(SignupForm{}, errs, "testtoken").Render(context.Background(), &buf)
if err != nil { if err != nil {
t.Fatalf("SignupFormFragment.Render: %v", err) t.Fatalf("SignupFormFragment.Render: %v", err)
} }
@ -55,7 +56,7 @@ func TestSignupFormFragment_RendersErrors(t *testing.T) {
// (security requirement T-2-01, D-25). // (security requirement T-2-01, D-25).
func TestSignupPage_DoesNotEchoPassword(t *testing.T) { func TestSignupPage_DoesNotEchoPassword(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
err := SignupPage(SignupForm{Email: "a@b.com", Password: "hunter2hunter2"}, SignupErrors{}).Render(context.Background(), &buf) err := SignupPage(SignupForm{Email: "a@b.com", Password: "hunter2hunter2"}, SignupErrors{}, "testtoken").Render(context.Background(), &buf)
if err != nil { if err != nil {
t.Fatalf("SignupPage.Render: %v", err) t.Fatalf("SignupPage.Render: %v", err)
} }

View file

@ -8,8 +8,10 @@ import (
// Index renders the root page (protected, requires auth). // Index renders the root page (protected, requires auth).
// The user parameter is the authenticated user from request context, passed // The user parameter is the authenticated user from request context, passed
// through to Layout so the header can render the logout button and email. // through to Layout so the header can render the logout button and email.
templ Index(user *auth.User) { // csrfToken is passed to Layout so the logout form can include the hidden
@Layout("Xtablo", user) { // _csrf field (AUTH-06, D-14).
templ Index(user *auth.User, csrfToken string) {
@Layout("Xtablo", user, csrfToken) {
<p class="text-sm text-slate-500 mb-6">Signed in as { user.Email }</p> <p class="text-sm text-slate-500 mb-6">Signed in as { user.Email }</p>
<h1 class="text-[28px] font-semibold leading-tight">Xtablo</h1> <h1 class="text-[28px] font-semibold leading-tight">Xtablo</h1>
<p class="mt-2 text-base text-slate-600"> <p class="mt-2 text-base text-slate-600">

View file

@ -3,7 +3,10 @@
// generate`; generated files are gitignored. // generate`; generated files are gitignored.
package templates package templates
import "backend/internal/auth" import (
"backend/internal/auth"
"backend/internal/web/ui"
)
// Layout is the base HTML shell every page renders inside. The structural // Layout is the base HTML shell every page renders inside. The structural
// classes, container width (max-w-5xl), horizontal padding, header strip, // classes, container width (max-w-5xl), horizontal padding, header strip,
@ -15,7 +18,10 @@ import "backend/internal/auth"
// When non-nil, the header renders a Log out POST form (D-22). Auth pages // When non-nil, the header renders a Log out POST form (D-22). Auth pages
// pass nil since they're gated behind RedirectIfAuthed and never shown to // pass nil since they're gated behind RedirectIfAuthed and never shown to
// authed users. // authed users.
templ Layout(title string, user *auth.User) { //
// csrfToken is threaded from the handler via csrf.Token(r) so the logout
// form can embed @ui.CSRFField(csrfToken) (AUTH-06, D-14).
templ Layout(title string, user *auth.User, csrfToken string) {
<!DOCTYPE html> <!DOCTYPE html>
<html lang="en"> <html lang="en">
<head> <head>
@ -32,7 +38,7 @@ templ Layout(title string, user *auth.User) {
<div class="flex items-center gap-3"> <div class="flex items-center gap-3">
<span class="text-sm text-slate-600">{ user.Email }</span> <span class="text-sm text-slate-600">{ user.Email }</span>
<form method="POST" action="/logout" class="inline"> <form method="POST" action="/logout" class="inline">
<!-- CSRF field added in Plan 07 --> @ui.CSRFField(csrfToken)
<button type="submit" class="text-sm text-slate-700 hover:underline">Log out</button> <button type="submit" class="text-sm text-slate-700 hover:underline">Log out</button>
</form> </form>
</div> </div>

View file

@ -11,10 +11,11 @@ import (
// TestLayout_LogoutFormVisibleWhenAuthed verifies that the logout form is // TestLayout_LogoutFormVisibleWhenAuthed verifies that the logout form is
// rendered in the header when Layout receives a non-nil user (D-22). // rendered in the header when Layout receives a non-nil user (D-22).
// The _csrf hidden field must also be present (AUTH-06).
func TestLayout_LogoutFormVisibleWhenAuthed(t *testing.T) { func TestLayout_LogoutFormVisibleWhenAuthed(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
user := &auth.User{Email: "a@b.c"} user := &auth.User{Email: "a@b.c"}
err := Layout("Test", user).Render(context.Background(), &buf) err := Layout("Test", user, "mytesttoken").Render(context.Background(), &buf)
if err != nil { if err != nil {
t.Fatalf("Layout.Render: %v", err) t.Fatalf("Layout.Render: %v", err)
} }
@ -26,13 +27,19 @@ func TestLayout_LogoutFormVisibleWhenAuthed(t *testing.T) {
if !strings.Contains(body, `method="POST"`) { if !strings.Contains(body, `method="POST"`) {
t.Errorf("Layout body missing method=\"POST\"; logout must be a POST form (D-22)") t.Errorf("Layout body missing method=\"POST\"; logout must be a POST form (D-22)")
} }
if !strings.Contains(body, `name="_csrf"`) {
t.Errorf("Layout body missing name=\"_csrf\"; logout form must embed CSRF field (AUTH-06)")
}
if !strings.Contains(body, `value="mytesttoken"`) {
t.Errorf("Layout body missing value=\"mytesttoken\"; CSRF token not threaded into form")
}
} }
// TestLayout_LogoutFormHiddenWhenUnauthed verifies that no logout form is // TestLayout_LogoutFormHiddenWhenUnauthed verifies that no logout form is
// rendered when Layout receives a nil user (unauthenticated request). // rendered when Layout receives a nil user (unauthenticated request).
func TestLayout_LogoutFormHiddenWhenUnauthed(t *testing.T) { func TestLayout_LogoutFormHiddenWhenUnauthed(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
err := Layout("Test", nil).Render(context.Background(), &buf) err := Layout("Test", nil, "").Render(context.Background(), &buf)
if err != nil { if err != nil {
t.Fatalf("Layout.Render: %v", err) t.Fatalf("Layout.Render: %v", err)
} }