diff --git a/backend/cmd/web/main.go b/backend/cmd/web/main.go index 9424bb4..f81f6bf 100644 --- a/backend/cmd/web/main.go +++ b/backend/cmd/web/main.go @@ -43,6 +43,16 @@ func main() { os.Exit(1) } + // Load the CSRF authentication key from SESSION_SECRET env var (D-15). + // Fails fast with a clear message if missing or wrong length — the server + // cannot operate without a valid CSRF key (AUTH-06). + csrfKey, err := auth.LoadKeyFromEnv() + if err != nil { + slog.Error("invalid SESSION_SECRET", "err", err, + "hint", "generate with: openssl rand -hex 32") + os.Exit(1) + } + // signal.NotifyContext (Go 1.21+) is the canonical idiom — equivalent // to signal.Notify + a channel but the resulting ctx propagates the // cancellation through to handlers, pgxpool dialing, etc. @@ -67,7 +77,7 @@ func main() { deps := web.AuthDeps{Queries: q, Store: store, Secure: secure, Limiter: rl} - router := web.NewRouter(pool, "./static", deps) + router := web.NewRouter(pool, "./static", deps, csrfKey, env) srv := &http.Server{ Addr: ":" + port, diff --git a/backend/go.mod b/backend/go.mod index a9dee34..edd8d02 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -6,13 +6,14 @@ require ( github.com/a-h/templ v0.3.1020 github.com/go-chi/chi/v5 v5.2.5 github.com/google/uuid v1.6.0 + github.com/gorilla/csrf v1.7.3 github.com/jackc/pgx/v5 v5.9.2 github.com/pressly/goose/v3 v3.27.1 golang.org/x/crypto v0.51.0 + golang.org/x/time v0.15.0 ) require ( - github.com/gorilla/csrf v1.7.3 // indirect github.com/gorilla/securecookie v1.1.2 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect @@ -23,5 +24,4 @@ require ( golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.44.0 // indirect golang.org/x/text v0.37.0 // indirect - golang.org/x/time v0.15.0 // indirect ) diff --git a/backend/go.sum b/backend/go.sum index a89f98d..b5da560 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -9,6 +9,8 @@ github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= +github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/csrf v1.7.3 h1:BHWt6FTLZAb2HtWT5KDBf6qgpZzvtbp9QWDRKZMXJC0= diff --git a/backend/internal/auth/csrf.go b/backend/internal/auth/csrf.go new file mode 100644 index 0000000..fd973b2 --- /dev/null +++ b/backend/internal/auth/csrf.go @@ -0,0 +1,83 @@ +package auth + +import ( + "encoding/hex" + "errors" + "net/http" + "os" + + "github.com/gorilla/csrf" +) + +// ErrCSRFKeyInvalid is returned by LoadKeyFromEnv when SESSION_SECRET is +// missing or decodes to a length other than 32 bytes. +var ErrCSRFKeyInvalid = errors.New("SESSION_SECRET must be a 64-char hex string encoding 32 bytes; generate with: openssl rand -hex 32") + +// LoadKeyFromEnv reads SESSION_SECRET from the environment, hex-decodes it, +// and validates that the result is exactly 32 bytes. Returns ErrCSRFKeyInvalid +// for a missing, non-hex, or wrong-length value. The caller (cmd/web/main.go) +// should log.Fatalf on error. +func LoadKeyFromEnv() ([]byte, error) { + raw := os.Getenv("SESSION_SECRET") + if raw == "" { + return nil, ErrCSRFKeyInvalid + } + key, err := hex.DecodeString(raw) + if err != nil { + return nil, ErrCSRFKeyInvalid + } + if len(key) != 32 { + return nil, ErrCSRFKeyInvalid + } + return key, nil +} + +// Mount returns a gorilla/csrf middleware configured with the locked options +// from CONTEXT D-14 / D-24: +// +// - csrf.Secure(env != "dev"): sets the Secure flag on the _gorilla_csrf +// cookie in all environments except "dev" (plain-HTTP local development). +// - csrf.SameSite(csrf.SameSiteLaxMode): SameSite=Lax interim defense. +// - csrf.Path("/"): the CSRF cookie is scoped to the entire site. +// - csrf.FieldName("_csrf"): the hidden form field name (matches @ui.CSRFField). +// - csrf.RequestHeader("X-CSRF-Token"): accepted header for HTMX hx-headers usage. +// +// The middleware is mounted AFTER auth.ResolveSession and BEFORE any route +// group (D-24, Pitfall 7). +// +// When env == "dev", requests are additionally marked as plaintext HTTP via +// csrf.PlaintextHTTPRequest so gorilla/csrf skips the Referer-based origin +// check (which only applies to TLS). This allows local development and +// integration tests running over plain HTTP to function correctly. +// +// trustedOrigins is an optional list of additional trusted origins (used in +// tests to allow localhost requests without a Referer header). +// In production, pass nil — SameSite=Lax and the CSRF cookie handle the defense. +func Mount(env string, key []byte, trustedOrigins ...string) func(http.Handler) http.Handler { + opts := []csrf.Option{ + csrf.Secure(env != "dev"), + csrf.SameSite(csrf.SameSiteLaxMode), + csrf.Path("/"), + csrf.FieldName("_csrf"), + csrf.RequestHeader("X-CSRF-Token"), + } + if len(trustedOrigins) > 0 { + opts = append(opts, csrf.TrustedOrigins(trustedOrigins)) + } + csrfMiddleware := csrf.Protect(key, opts...) + + // In dev mode, mark every request as plaintext HTTP so gorilla/csrf skips + // the Referer-based TLS origin check. This is safe: dev mode already has + // Secure=false on the cookie, and SameSite=Lax provides interim protection. + if env == "dev" { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Tag the request as plaintext HTTP before csrf.Protect sees it. + r = csrf.PlaintextHTTPRequest(r) + csrfMiddleware(next).ServeHTTP(w, r) + }) + } + } + + return csrfMiddleware +} diff --git a/backend/internal/web/csrf_debug_test.go b/backend/internal/web/csrf_debug_test.go new file mode 100644 index 0000000..ab10d49 --- /dev/null +++ b/backend/internal/web/csrf_debug_test.go @@ -0,0 +1,66 @@ +package web + +import ( + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "backend/internal/auth" + "backend/internal/db/sqlc" +) + +func TestCSRF_Debug(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + q := sqlc.New(pool) + store := auth.NewStore(q) + router := newTestRouter(q, store) + + // GET /login and collect cookies + getReq := httptest.NewRequest(http.MethodGet, "/login", nil) + getRec := httptest.NewRecorder() + router.ServeHTTP(getRec, getReq) + + t.Logf("GET /login status: %d", getRec.Code) + t.Logf("GET /login cookies:") + for _, c := range getRec.Result().Cookies() { + t.Logf(" %s=%s (httponly=%v, secure=%v)", c.Name, c.Value[:min(len(c.Value), 20)], c.HttpOnly, c.Secure) + } + + body := getRec.Body.String() + const needle = `name="_csrf" value="` + idx := strings.Index(body, needle) + if idx == -1 { + t.Fatal("no _csrf hidden input found") + } + rest := body[idx+len(needle):] + end := strings.Index(rest, `"`) + token := rest[:end] + t.Logf("Extracted CSRF token: %s...", token[:min(len(token), 20)]) + + form := url.Values{"email": {"x@y.com"}, "password": {"correct-horse-12"}, "_csrf": {token}} + postReq := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) + postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range getRec.Result().Cookies() { + t.Logf("Adding cookie to POST: %s", c.Name) + postReq.AddCookie(c) + } + t.Logf("POST cookies count: %d", len(getRec.Result().Cookies())) + + postRec := httptest.NewRecorder() + router.ServeHTTP(postRec, postReq) + t.Logf("POST /login status: %d", postRec.Code) + if postRec.Code == 403 { + t.Logf("403 body: %s", postRec.Body.String()[:min(len(postRec.Body.String()), 200)]) + } +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/backend/internal/web/csrf_test.go b/backend/internal/web/csrf_test.go index 36e1ad5..ed6da18 100644 --- a/backend/internal/web/csrf_test.go +++ b/backend/internal/web/csrf_test.go @@ -14,13 +14,15 @@ import ( ) // newTestRouterWithCSRF builds a router with CSRF enabled using a test key. +// "localhost" is added as a trusted origin so httptest requests without a +// Referer header are accepted. func newTestRouterWithCSRF(q *sqlc.Queries, store *auth.Store) http.Handler { csrfKey := make([]byte, 32) for i := range csrfKey { csrfKey[i] = byte(i + 1) } deps := AuthDeps{Queries: q, Store: store, Secure: false} - return NewRouter(stubPinger{}, "./static", deps, csrfKey, "dev") + return NewRouter(stubPinger{}, "./static", deps, csrfKey, "dev", "localhost") } // extractCSRFToken performs a GET request and extracts the _csrf token from the diff --git a/backend/internal/web/handlers.go b/backend/internal/web/handlers.go index 936f8f1..500acdd 100644 --- a/backend/internal/web/handlers.go +++ b/backend/internal/web/handlers.go @@ -8,6 +8,8 @@ import ( "backend/internal/auth" "backend/templates" + + "github.com/gorilla/csrf" ) // HealthzHandler returns an HTTP handler that probes the supplied Pinger @@ -38,12 +40,14 @@ func HealthzHandler(pinger Pinger) http.HandlerFunc { // IndexHandler renders the root page (templates.Index) as text/html. // The authenticated user is pulled from the request context (set by // auth.ResolveSession) and passed to the template so the layout header can -// render the logout button and the page can show the user's email. +// render the logout button and email. +// csrf.Token(r) is threaded into the template so the logout form includes the +// hidden _csrf field (AUTH-06, D-14). func IndexHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { _, user, _ := auth.Authed(r.Context()) w.Header().Set("Content-Type", "text/html; charset=utf-8") - _ = templates.Index(user).Render(r.Context(), w) + _ = templates.Index(user, csrf.Token(r)).Render(r.Context(), w) } } diff --git a/backend/internal/web/handlers_auth.go b/backend/internal/web/handlers_auth.go index ffd90e8..1d3eef6 100644 --- a/backend/internal/web/handlers_auth.go +++ b/backend/internal/web/handlers_auth.go @@ -12,6 +12,7 @@ import ( "backend/internal/db/sqlc" "backend/templates" + "github.com/gorilla/csrf" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" ) @@ -49,7 +50,7 @@ func clientIP(r *http.Request) string { func SignupPageHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html; charset=utf-8") - _ = templates.SignupPage(templates.SignupForm{}, templates.SignupErrors{}).Render(r.Context(), w) + _ = templates.SignupPage(templates.SignupForm{}, templates.SignupErrors{}, csrf.Token(r)).Render(r.Context(), w) } } @@ -61,11 +62,13 @@ func SignupPageHandler() http.HandlerFunc { // - Email is not logged on validation errors (T-2-18). // - Duplicate email is detected via pgconn error code 23505 (T-2-19). // - A fresh session token is created on every signup (T-2-04). +// - Form values are read via r.PostFormValue ONLY (never r.Body) so gorilla/csrf +// body consumption does not interfere (Pitfall 1, T-2-08c). func SignupPostHandler(deps AuthDeps) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - // 1. Read form values. + // 1. Read form values via r.PostFormValue (Pitfall 1: never read r.Body directly). email := strings.TrimSpace(r.PostFormValue("email")) password := r.PostFormValue("password") @@ -147,9 +150,9 @@ func renderSignupError(w http.ResponseWriter, r *http.Request, form templates.Si w.Header().Set("Content-Type", "text/html; charset=utf-8") w.WriteHeader(status) if r.Header.Get("HX-Request") == "true" { - _ = templates.SignupFormFragment(form, errs).Render(r.Context(), w) + _ = templates.SignupFormFragment(form, errs, csrf.Token(r)).Render(r.Context(), w) } else { - _ = templates.SignupPage(form, errs).Render(r.Context(), w) + _ = templates.SignupPage(form, errs, csrf.Token(r)).Render(r.Context(), w) } } @@ -157,7 +160,7 @@ func renderSignupError(w http.ResponseWriter, r *http.Request, form templates.Si func LoginPageHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html; charset=utf-8") - _ = templates.LoginPage(templates.LoginForm{}, templates.LoginErrors{}).Render(r.Context(), w) + _ = templates.LoginPage(templates.LoginForm{}, templates.LoginErrors{}, csrf.Token(r)).Render(r.Context(), w) } } @@ -176,11 +179,13 @@ func LoginPageHandler() http.HandlerFunc { // wrong password to prevent user enumeration (D-20, T-2-03). // - Password is never logged (T-2-21). // - Session rotated on every successful login (T-2-04, D-10). +// - Form values are read via r.PostFormValue ONLY (never r.Body) so gorilla/csrf +// body consumption does not interfere (Pitfall 1, T-2-08c). func LoginPostHandler(deps AuthDeps) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - // 1. Read form values. + // 1. Read form values via r.PostFormValue (Pitfall 1: never read r.Body directly). email := strings.TrimSpace(r.PostFormValue("email")) password := r.PostFormValue("password") @@ -271,9 +276,9 @@ func renderLoginError(w http.ResponseWriter, r *http.Request, form templates.Log w.Header().Set("Content-Type", "text/html; charset=utf-8") w.WriteHeader(status) if r.Header.Get("HX-Request") == "true" { - _ = templates.LoginFormFragment(form, errs).Render(r.Context(), w) + _ = templates.LoginFormFragment(form, errs, csrf.Token(r)).Render(r.Context(), w) } else { - _ = templates.LoginPage(form, errs).Render(r.Context(), w) + _ = templates.LoginPage(form, errs, csrf.Token(r)).Render(r.Context(), w) } } @@ -286,6 +291,7 @@ func renderLoginError(w http.ResponseWriter, r *http.Request, form templates.Log // - Store.Delete hard-deletes the session row (D-06, T-2-07). // - ClearSessionCookie sets Max-Age=-1 to expire the browser cookie (D-06). // - HTMX requests receive 200 + HX-Redirect; plain requests receive 303 (D-22). +// - CSRF token validated by gorilla/csrf middleware before this handler runs (AUTH-06). func LogoutHandler(deps AuthDeps) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Defense-in-depth: RequireAuth already gates this route, but guard here diff --git a/backend/internal/web/handlers_auth_test.go b/backend/internal/web/handlers_auth_test.go index cd9a5ae..054ef55 100644 --- a/backend/internal/web/handlers_auth_test.go +++ b/backend/internal/web/handlers_auth_test.go @@ -17,17 +17,62 @@ import ( "backend/internal/db/sqlc" ) +// testCSRFKey is a fixed 32-byte key used by all test routers. It is NOT a +// real key — safe to have here because it is test-only and dev-env scoped. +var testCSRFKey = func() []byte { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i + 1) + } + return key +}() + // newTestRouter builds a router backed by a real DB for integration tests. +// CSRF is enabled with a fixed test key and env="dev" (Secure=false on cookie). +// "localhost" is added as a trusted origin so httptest requests without a +// Referer header are accepted. func newTestRouter(q *sqlc.Queries, store *auth.Store) http.Handler { deps := AuthDeps{Queries: q, Store: store, Secure: false} - return NewRouter(stubPinger{}, "./static", deps) + return NewRouter(stubPinger{}, "./static", deps, testCSRFKey, "dev", "localhost") } // newTestRouterWithLimiter builds a router with an injected LimiterStore, // enabling rate-limit tests to use a fake clock. func newTestRouterWithLimiter(q *sqlc.Queries, store *auth.Store, rl *auth.LimiterStore) http.Handler { deps := AuthDeps{Queries: q, Store: store, Secure: false, Limiter: rl} - return NewRouter(stubPinger{}, "./static", deps) + return NewRouter(stubPinger{}, "./static", deps, testCSRFKey, "dev", "localhost") +} + +// getCSRFToken performs a GET request to path and extracts the CSRF token +// from the rendered form HTML. Returns the token string and any Set-Cookie +// headers (including the gorilla_csrf cookie) from the response. +func getCSRFToken(t *testing.T, router http.Handler, path string, cookies []*http.Cookie) (token string, respCookies []*http.Cookie) { + t.Helper() + req := httptest.NewRequest(http.MethodGet, path, nil) + for _, c := range cookies { + req.AddCookie(c) + } + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + body := rec.Body.String() + const needle = `name="_csrf" value="` + idx := strings.Index(body, needle) + if idx == -1 { + t.Fatalf("getCSRFToken: _csrf hidden input not found in GET %s; body snippet: %.300s", path, body) + } + rest := body[idx+len(needle):] + end := strings.Index(rest, `"`) + if end == -1 { + t.Fatalf("getCSRFToken: closing quote not found for _csrf in GET %s", path) + } + token = rest[:end] + + respCookies = append(respCookies, cookies...) + for _, c := range rec.Result().Cookies() { + respCookies = append(respCookies, c) + } + return token, respCookies } // preInsertUser inserts a user with TestParams-hashed password directly via sqlc @@ -81,9 +126,13 @@ func TestSignup_Success(t *testing.T) { store := auth.NewStore(q) router := newTestRouter(q, store) - form := url.Values{"email": {"alice@example.com"}, "password": {"correct-horse-12"}} + csrfToken, csrfCookies := getCSRFToken(t, router, "/signup", nil) + form := url.Values{"email": {"alice@example.com"}, "password": {"correct-horse-12"}, "_csrf": {csrfToken}} req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range csrfCookies { + req.AddCookie(c) + } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -138,10 +187,14 @@ func TestSignup_Success_HTMX(t *testing.T) { store := auth.NewStore(q) router := newTestRouter(q, store) - form := url.Values{"email": {"bob@example.com"}, "password": {"correct-horse-12"}} + csrfToken, csrfCookies := getCSRFToken(t, router, "/signup", nil) + form := url.Values{"email": {"bob@example.com"}, "password": {"correct-horse-12"}, "_csrf": {csrfToken}} req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("HX-Request", "true") + for _, c := range csrfCookies { + req.AddCookie(c) + } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -162,9 +215,13 @@ func TestSignup_InvalidEmail(t *testing.T) { store := auth.NewStore(q) router := newTestRouter(q, store) - form := url.Values{"email": {"not-an-email"}, "password": {"correct-horse-12"}} + csrfToken, csrfCookies := getCSRFToken(t, router, "/signup", nil) + form := url.Values{"email": {"not-an-email"}, "password": {"correct-horse-12"}, "_csrf": {csrfToken}} req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range csrfCookies { + req.AddCookie(c) + } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -195,9 +252,13 @@ func TestSignup_PasswordTooShort(t *testing.T) { router := newTestRouter(q, store) // 11 chars — below the 12-char minimum. - form := url.Values{"email": {"carol@example.com"}, "password": {"short12345!"}} + csrfToken, csrfCookies := getCSRFToken(t, router, "/signup", nil) + form := url.Values{"email": {"carol@example.com"}, "password": {"short12345!"}, "_csrf": {csrfToken}} req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range csrfCookies { + req.AddCookie(c) + } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -227,9 +288,13 @@ func TestSignup_PasswordTooLong(t *testing.T) { router := newTestRouter(q, store) longPw := strings.Repeat("a", 129) // 129 chars — above the 128-char maximum. - form := url.Values{"email": {"dave@example.com"}, "password": {longPw}} + csrfToken, csrfCookies := getCSRFToken(t, router, "/signup", nil) + form := url.Values{"email": {"dave@example.com"}, "password": {longPw}, "_csrf": {csrfToken}} req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range csrfCookies { + req.AddCookie(c) + } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -262,9 +327,13 @@ func TestSignup_DuplicateEmail(t *testing.T) { // Pre-insert a user using TestParams to avoid the slow DefaultParams hash. preInsertUser(t, ctx, q, "eve@example.com", "correct-horse-12") - form := url.Values{"email": {"eve@example.com"}, "password": {"correct-horse-12"}} + csrfToken, csrfCookies := getCSRFToken(t, router, "/signup", nil) + form := url.Values{"email": {"eve@example.com"}, "password": {"correct-horse-12"}, "_csrf": {csrfToken}} req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range csrfCookies { + req.AddCookie(c) + } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -298,9 +367,13 @@ func TestSignup_EmailNormalized(t *testing.T) { router := newTestRouter(q, store) // Uppercase + whitespace email — must be stored trimmed and lowercased. - form := url.Values{"email": {" Frank@Example.COM "}, "password": {"correct-horse-12"}} + csrfToken, csrfCookies := getCSRFToken(t, router, "/signup", nil) + form := url.Values{"email": {" Frank@Example.COM "}, "password": {"correct-horse-12"}, "_csrf": {csrfToken}} req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range csrfCookies { + req.AddCookie(c) + } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -367,9 +440,13 @@ func TestLogin_Success(t *testing.T) { user := preInsertUser(t, ctx, q, "test@example.com", "correct-horse-12chars") - form := url.Values{"email": {"test@example.com"}, "password": {"correct-horse-12chars"}} + csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil) + form := url.Values{"email": {"test@example.com"}, "password": {"correct-horse-12chars"}, "_csrf": {csrfToken}} req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range csrfCookies { + req.AddCookie(c) + } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -406,10 +483,14 @@ func TestLogin_Success_HTMX(t *testing.T) { preInsertUser(t, ctx, q, "test2@example.com", "correct-horse-12chars") - form := url.Values{"email": {"test2@example.com"}, "password": {"correct-horse-12chars"}} + csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil) + form := url.Values{"email": {"test2@example.com"}, "password": {"correct-horse-12chars"}, "_csrf": {csrfToken}} req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("HX-Request", "true") + for _, c := range csrfCookies { + req.AddCookie(c) + } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -433,9 +514,13 @@ func TestLogin_WrongPassword(t *testing.T) { preInsertUser(t, ctx, q, "testpw@example.com", "correct-horse-12chars") - form := url.Values{"email": {"testpw@example.com"}, "password": {"wrong-password-12chars"}} + csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil) + form := url.Values{"email": {"testpw@example.com"}, "password": {"wrong-password-12chars"}, "_csrf": {csrfToken}} req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range csrfCookies { + req.AddCookie(c) + } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -464,9 +549,13 @@ func TestLogin_UnknownEmail(t *testing.T) { store := auth.NewStore(q) router := newTestRouter(q, store) - form := url.Values{"email": {"nouser@example.com"}, "password": {"correct-horse-12chars"}} + csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil) + form := url.Values{"email": {"nouser@example.com"}, "password": {"correct-horse-12chars"}, "_csrf": {csrfToken}} req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range csrfCookies { + req.AddCookie(c) + } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -488,9 +577,13 @@ func TestLogin_ValidationError_BadEmail(t *testing.T) { store := auth.NewStore(q) router := newTestRouter(q, store) - form := url.Values{"email": {"not-an-email"}, "password": {"correct-horse-12chars"}} + csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil) + form := url.Values{"email": {"not-an-email"}, "password": {"correct-horse-12chars"}, "_csrf": {csrfToken}} req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range csrfCookies { + req.AddCookie(c) + } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -511,9 +604,13 @@ func TestLogin_ValidationError_ShortPassword(t *testing.T) { store := auth.NewStore(q) router := newTestRouter(q, store) - form := url.Values{"email": {"testval@example.com"}, "password": {"shortpw12"}} + csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil) + form := url.Values{"email": {"testval@example.com"}, "password": {"shortpw12"}, "_csrf": {csrfToken}} req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range csrfCookies { + req.AddCookie(c) + } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -542,10 +639,14 @@ func TestLogin_RotatesExistingSession(t *testing.T) { t.Fatalf("store.Create: %v", err) } - form := url.Values{"email": {"rotatetest@example.com"}, "password": {"correct-horse-12chars"}} + csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil) + form := url.Values{"email": {"rotatetest@example.com"}, "password": {"correct-horse-12chars"}, "_csrf": {csrfToken}} req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: oldCookieValue}) + for _, c := range csrfCookies { + req.AddCookie(c) + } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -625,11 +726,15 @@ func TestLogin_RateLimit_6thAttemptReturns429(t *testing.T) { preInsertUser(t, ctx, q, "ratelimit@example.com", "correct-horse-12chars") for i := 1; i <= 6; i++ { - form := url.Values{"email": {"ratelimit@example.com"}, "password": {"wrong-password-12chars"}} + csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil) + form := url.Values{"email": {"ratelimit@example.com"}, "password": {"wrong-password-12chars"}, "_csrf": {csrfToken}} req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") // Set RemoteAddr to a known IP so chimw.RealIP won't change it. req.RemoteAddr = "192.168.1.1:12345" + for _, c := range csrfCookies { + req.AddCookie(c) + } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -666,11 +771,15 @@ func TestLogin_RateLimit_6thAttemptHTMXNoFullPage(t *testing.T) { preInsertUser(t, ctx, q, "ratelimithtmx@example.com", "correct-horse-12chars") for i := 1; i <= 6; i++ { - form := url.Values{"email": {"ratelimithtmx@example.com"}, "password": {"wrong-password-12chars"}} + csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil) + form := url.Values{"email": {"ratelimithtmx@example.com"}, "password": {"wrong-password-12chars"}, "_csrf": {csrfToken}} req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("HX-Request", "true") req.RemoteAddr = "192.168.1.2:12345" + for _, c := range csrfCookies { + req.AddCookie(c) + } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -704,19 +813,27 @@ func TestLogin_RateLimit_KeyedByEmailPlusIP(t *testing.T) { // Exhaust emailA from IP1. for i := 0; i < 6; i++ { - form := url.Values{"email": {"emailA@example.com"}, "password": {"wrong-password-12chars"}} + csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil) + form := url.Values{"email": {"emailA@example.com"}, "password": {"wrong-password-12chars"}, "_csrf": {csrfToken}} req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.RemoteAddr = "10.0.0.1:1234" + for _, c := range csrfCookies { + req.AddCookie(c) + } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) } // emailB from same IP1 should still be allowed (separate key). - form := url.Values{"email": {"emailB@example.com"}, "password": {"wrong-password-12chars"}} + csrfTokenB, csrfCookiesB := getCSRFToken(t, router, "/login", nil) + form := url.Values{"email": {"emailB@example.com"}, "password": {"wrong-password-12chars"}, "_csrf": {csrfTokenB}} req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.RemoteAddr = "10.0.0.1:1234" + for _, c := range csrfCookiesB { + req.AddCookie(c) + } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -738,10 +855,14 @@ func TestLogin_RateLimit_AppliesBeforeUserLookup(t *testing.T) { // Use an email that does NOT exist in the DB. for i := 0; i < 6; i++ { - form := url.Values{"email": {"nonexistent@example.com"}, "password": {"wrong-password-12chars"}} + csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil) + form := url.Values{"email": {"nonexistent@example.com"}, "password": {"wrong-password-12chars"}, "_csrf": {csrfToken}} req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.RemoteAddr = "10.0.0.2:1234" + for _, c := range csrfCookies { + req.AddCookie(c) + } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -772,8 +893,13 @@ func TestLogout_Success(t *testing.T) { } sessionID := hashCookieValue(t, cookieValue) - req := httptest.NewRequest(http.MethodPost, "/logout", nil) - req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: cookieValue}) + sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: cookieValue} + csrfToken, csrfCookies := getCSRFToken(t, router, "/", []*http.Cookie{sessionCookie}) + req := httptest.NewRequest(http.MethodPost, "/logout", strings.NewReader(url.Values{"_csrf": {csrfToken}}.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range csrfCookies { + req.AddCookie(c) + } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -819,18 +945,19 @@ func TestLogout_UnauthRedirectsToLogin(t *testing.T) { store := auth.NewStore(q) router := newTestRouter(q, store) - // POST /logout with NO cookie — RequireAuth must block it. + // POST /logout with NO cookie and NO CSRF token. + // gorilla/csrf runs before RequireAuth, so a missing CSRF token yields 403 + // before RequireAuth can redirect to /login. Both 403 and 303 are acceptable + // here — the important invariant is that the request is rejected (not 500) + // and no logout side-effect occurs. req := httptest.NewRequest(http.MethodPost, "/logout", nil) rec := httptest.NewRecorder() router.ServeHTTP(rec, req) - // Must redirect to /login (from RequireAuth), NOT a 500. - if rec.Code != http.StatusSeeOther { - t.Fatalf("status = %d; want 303 (RequireAuth redirect)", rec.Code) - } - if loc := rec.Header().Get("Location"); loc != "/login" { - t.Errorf("Location = %q; want /login", loc) + // 403 (csrf rejected) or 303 (RequireAuth redirected) — both are correct. + if rec.Code != http.StatusForbidden && rec.Code != http.StatusSeeOther { + t.Fatalf("status = %d; want 403 (CSRF) or 303 (RequireAuth)", rec.Code) } } @@ -849,9 +976,14 @@ func TestLogout_HXRedirect(t *testing.T) { t.Fatalf("store.Create: %v", err) } - req := httptest.NewRequest(http.MethodPost, "/logout", nil) - req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: cookieValue}) + sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: cookieValue} + csrfToken, csrfCookies := getCSRFToken(t, router, "/", []*http.Cookie{sessionCookie}) + req := httptest.NewRequest(http.MethodPost, "/logout", strings.NewReader(url.Values{"_csrf": {csrfToken}}.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("HX-Request", "true") + for _, c := range csrfCookies { + req.AddCookie(c) + } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -879,9 +1011,14 @@ func TestLogout_AfterLogoutSubsequentRequestUnauth(t *testing.T) { t.Fatalf("store.Create: %v", err) } - // Logout first. - logoutReq := httptest.NewRequest(http.MethodPost, "/logout", nil) - logoutReq.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: cookieValue}) + // Logout first — need to get a CSRF token from the protected page. + sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: cookieValue} + csrfToken, csrfCookies := getCSRFToken(t, router, "/", []*http.Cookie{sessionCookie}) + logoutReq := httptest.NewRequest(http.MethodPost, "/logout", strings.NewReader(url.Values{"_csrf": {csrfToken}}.Encode())) + logoutReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range csrfCookies { + logoutReq.AddCookie(c) + } logoutRec := httptest.NewRecorder() router.ServeHTTP(logoutRec, logoutReq) diff --git a/backend/internal/web/handlers_test.go b/backend/internal/web/handlers_test.go index 6f409ad..cf22580 100644 --- a/backend/internal/web/handlers_test.go +++ b/backend/internal/web/handlers_test.go @@ -66,7 +66,7 @@ func TestHealthz_Down(t *testing.T) { // was public. The HTMX demo content is tested by // TestProtected_HomeAuthRendersUserEmail in handlers_auth_test.go. func TestIndex_UnauthRedirects(t *testing.T) { - router := NewRouter(stubPinger{}, "./static", AuthDeps{}) + router := NewRouter(stubPinger{}, "./static", AuthDeps{}, testCSRFKey, "dev") rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -81,7 +81,7 @@ func TestIndex_UnauthRedirects(t *testing.T) { } func TestDemoTime_Fragment(t *testing.T) { - router := NewRouter(stubPinger{}, "./static", AuthDeps{}) + router := NewRouter(stubPinger{}, "./static", AuthDeps{}, testCSRFKey, "dev") rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/demo/time", nil) @@ -104,7 +104,7 @@ func TestDemoTime_Fragment(t *testing.T) { } func TestRequestID_HeaderSet(t *testing.T) { - router := NewRouter(stubPinger{}, "./static", AuthDeps{}) + router := NewRouter(stubPinger{}, "./static", AuthDeps{}, testCSRFKey, "dev") rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/healthz", nil) diff --git a/backend/internal/web/router.go b/backend/internal/web/router.go index cc677cb..c259402 100644 --- a/backend/internal/web/router.go +++ b/backend/internal/web/router.go @@ -26,8 +26,8 @@ type Pinger interface { // 2. chi RealIP // 3. SlogLoggerMiddleware (REPLACES chi's middleware.Logger — Pitfall 6) // 4. chi Recoverer (after Logger so panics carry request_id) -// 5. auth.ResolveSession (reads session cookie, attaches user to context) -// NOTE: csrf.Protect is added in Plan 07. +// 5. auth.ResolveSession (reads session cookie, attaches user to context) — D-24 +// 6. auth.Mount (gorilla/csrf — MUST come after ResolveSession, before routes) — D-24, Pitfall 7 // // Routes: GET / · GET /healthz · GET /demo/time · GET /static/* // GET /signup (auth pages, behind RedirectIfAuthed) · POST /signup. @@ -36,13 +36,24 @@ type Pinger interface { // // deps.Store may be nil during unit tests for Phase 1 routes (those routes // never exercise session resolution). ResolveSession guards against nil Store. -func NewRouter(pinger Pinger, staticDir string, deps AuthDeps) http.Handler { +// +// csrfKey is the 32-byte CSRF authentication key loaded from SESSION_SECRET. +// env is the runtime environment string (e.g. "dev", "development", "production"). +// When env == "dev", the CSRF cookie Secure flag is disabled for plain-HTTP +// local development (D-15, D-24). +// trustedOrigins is an optional list of additional origins for the CSRF +// referer check (used in integration tests to allow localhost requests without +// a Referer header). In production, pass no extra args — leave empty. +func NewRouter(pinger Pinger, staticDir string, deps AuthDeps, csrfKey []byte, env string, trustedOrigins ...string) http.Handler { r := chi.NewRouter() r.Use(RequestIDMiddleware) r.Use(chimw.RealIP) r.Use(SlogLoggerMiddleware(slog.Default())) r.Use(chimw.Recoverer) + // D-24 locked order: ResolveSession BEFORE csrf.Protect (auth.Mount). r.Use(auth.ResolveSession(deps.Store)) + // D-24: gorilla/csrf runs after ResolveSession and before all route groups (Pitfall 7). + r.Use(auth.Mount(env, csrfKey, trustedOrigins...)) // Auth pages — redirect to / if already authenticated. r.Group(func(r chi.Router) { diff --git a/backend/internal/web/ui/csrf_field.templ b/backend/internal/web/ui/csrf_field.templ new file mode 100644 index 0000000..d18e62d --- /dev/null +++ b/backend/internal/web/ui/csrf_field.templ @@ -0,0 +1,9 @@ +package ui + +// CSRFField renders the hidden CSRF token input required by gorilla/csrf. +// Every
must include @ui.CSRFField(csrfToken) as its +// first child so that the middleware can validate the double-submit cookie +// (D-15, AUTH-06). +templ CSRFField(token string) { + +} diff --git a/backend/templates/auth_login.templ b/backend/templates/auth_login.templ index fb9ad40..139170c 100644 --- a/backend/templates/auth_login.templ +++ b/backend/templates/auth_login.templ @@ -5,13 +5,13 @@ import "backend/internal/web/ui" // LoginPage renders the full /login page wrapped in the base Layout. // It delegates the form section to LoginFormFragment so HTMX can swap just the // form on validation errors without re-rendering the surrounding shell. -templ LoginPage(form LoginForm, errs LoginErrors) { - @Layout("Sign in", nil) { +templ LoginPage(form LoginForm, errs LoginErrors, csrfToken string) { + @Layout("Sign in", nil, csrfToken) {
@ui.Card(nil) {

Sign in to your account

- @LoginFormFragment(form, errs) + @LoginFormFragment(form, errs, csrfToken)
}
@@ -22,7 +22,7 @@ templ LoginPage(form LoginForm, errs LoginErrors) { // hx-post targets this component itself so the form can be replaced inline // on validation failure (D-19, D-20). // The outer id="login-form" must match the hx-target on this element. -templ LoginFormFragment(form LoginForm, errs LoginErrors) { +templ LoginFormFragment(form LoginForm, errs LoginErrors, csrfToken string) { - + @ui.CSRFField(csrfToken) @GeneralError(errs.General)
diff --git a/backend/templates/auth_signup.templ b/backend/templates/auth_signup.templ index b7140f0..6acd258 100644 --- a/backend/templates/auth_signup.templ +++ b/backend/templates/auth_signup.templ @@ -5,13 +5,13 @@ import "backend/internal/web/ui" // SignupPage renders the full /signup page wrapped in the base Layout. // It delegates the form section to SignupFormFragment so HTMX can swap just the // form on validation errors without re-rendering the surrounding shell. -templ SignupPage(form SignupForm, errs SignupErrors) { - @Layout("Sign up", nil) { +templ SignupPage(form SignupForm, errs SignupErrors, csrfToken string) { + @Layout("Sign up", nil, csrfToken) {
@ui.Card(nil) {

Create your account

- @SignupFormFragment(form, errs) + @SignupFormFragment(form, errs, csrfToken)
}
@@ -22,7 +22,7 @@ templ SignupPage(form SignupForm, errs SignupErrors) { // hx-post targets this component itself so the form can be replaced inline // on validation failure (D-19, D-25). // The outer id="signup-form" must match the hx-target on this element. -templ SignupFormFragment(form SignupForm, errs SignupErrors) { +templ SignupFormFragment(form SignupForm, errs SignupErrors, csrfToken string) { - + @ui.CSRFField(csrfToken) @GeneralError(errs.General)
diff --git a/backend/templates/auth_signup_test.go b/backend/templates/auth_signup_test.go index 3afb332..05a240f 100644 --- a/backend/templates/auth_signup_test.go +++ b/backend/templates/auth_signup_test.go @@ -11,7 +11,7 @@ import ( // expected form attributes and that email value round-trips correctly. func TestSignupPage_RendersForm(t *testing.T) { var buf bytes.Buffer - err := SignupPage(SignupForm{Email: "x@y.z"}, SignupErrors{}).Render(context.Background(), &buf) + err := SignupPage(SignupForm{Email: "x@y.z"}, SignupErrors{}, "testtoken").Render(context.Background(), &buf) if err != nil { t.Fatalf("SignupPage.Render: %v", err) } @@ -23,6 +23,7 @@ func TestSignupPage_RendersForm(t *testing.T) { `action="/signup"`, `hx-post="/signup"`, `value="x@y.z"`, + `name="_csrf"`, } { if !strings.Contains(body, want) { t.Errorf("SignupPage body missing %q", want) @@ -36,7 +37,7 @@ func TestSignupPage_RendersForm(t *testing.T) { func TestSignupFormFragment_RendersErrors(t *testing.T) { var buf bytes.Buffer errs := SignupErrors{Password: "Password must be 12-128 characters"} - err := SignupFormFragment(SignupForm{}, errs).Render(context.Background(), &buf) + err := SignupFormFragment(SignupForm{}, errs, "testtoken").Render(context.Background(), &buf) if err != nil { t.Fatalf("SignupFormFragment.Render: %v", err) } @@ -55,7 +56,7 @@ func TestSignupFormFragment_RendersErrors(t *testing.T) { // (security requirement T-2-01, D-25). func TestSignupPage_DoesNotEchoPassword(t *testing.T) { var buf bytes.Buffer - err := SignupPage(SignupForm{Email: "a@b.com", Password: "hunter2hunter2"}, SignupErrors{}).Render(context.Background(), &buf) + err := SignupPage(SignupForm{Email: "a@b.com", Password: "hunter2hunter2"}, SignupErrors{}, "testtoken").Render(context.Background(), &buf) if err != nil { t.Fatalf("SignupPage.Render: %v", err) } diff --git a/backend/templates/index.templ b/backend/templates/index.templ index ba7ab14..535ead9 100644 --- a/backend/templates/index.templ +++ b/backend/templates/index.templ @@ -8,8 +8,10 @@ import ( // Index renders the root page (protected, requires auth). // The user parameter is the authenticated user from request context, passed // through to Layout so the header can render the logout button and email. -templ Index(user *auth.User) { - @Layout("Xtablo", user) { +// csrfToken is passed to Layout so the logout form can include the hidden +// _csrf field (AUTH-06, D-14). +templ Index(user *auth.User, csrfToken string) { + @Layout("Xtablo", user, csrfToken) {

Signed in as { user.Email }

Xtablo

diff --git a/backend/templates/layout.templ b/backend/templates/layout.templ index b4f5c28..1238314 100644 --- a/backend/templates/layout.templ +++ b/backend/templates/layout.templ @@ -3,7 +3,10 @@ // generate`; generated files are gitignored. package templates -import "backend/internal/auth" +import ( + "backend/internal/auth" + "backend/internal/web/ui" +) // Layout is the base HTML shell every page renders inside. The structural // classes, container width (max-w-5xl), horizontal padding, header strip, @@ -15,7 +18,10 @@ import "backend/internal/auth" // When non-nil, the header renders a Log out POST form (D-22). Auth pages // pass nil since they're gated behind RedirectIfAuthed and never shown to // authed users. -templ Layout(title string, user *auth.User) { +// +// csrfToken is threaded from the handler via csrf.Token(r) so the logout +// form can embed @ui.CSRFField(csrfToken) (AUTH-06, D-14). +templ Layout(title string, user *auth.User, csrfToken string) { @@ -32,7 +38,7 @@ templ Layout(title string, user *auth.User) {

{ user.Email } - + @ui.CSRFField(csrfToken)
diff --git a/backend/templates/layout_test.go b/backend/templates/layout_test.go index 2389a02..ebdf266 100644 --- a/backend/templates/layout_test.go +++ b/backend/templates/layout_test.go @@ -11,10 +11,11 @@ import ( // TestLayout_LogoutFormVisibleWhenAuthed verifies that the logout form is // rendered in the header when Layout receives a non-nil user (D-22). +// The _csrf hidden field must also be present (AUTH-06). func TestLayout_LogoutFormVisibleWhenAuthed(t *testing.T) { var buf bytes.Buffer user := &auth.User{Email: "a@b.c"} - err := Layout("Test", user).Render(context.Background(), &buf) + err := Layout("Test", user, "mytesttoken").Render(context.Background(), &buf) if err != nil { t.Fatalf("Layout.Render: %v", err) } @@ -26,13 +27,19 @@ func TestLayout_LogoutFormVisibleWhenAuthed(t *testing.T) { if !strings.Contains(body, `method="POST"`) { t.Errorf("Layout body missing method=\"POST\"; logout must be a POST form (D-22)") } + if !strings.Contains(body, `name="_csrf"`) { + t.Errorf("Layout body missing name=\"_csrf\"; logout form must embed CSRF field (AUTH-06)") + } + if !strings.Contains(body, `value="mytesttoken"`) { + t.Errorf("Layout body missing value=\"mytesttoken\"; CSRF token not threaded into form") + } } // TestLayout_LogoutFormHiddenWhenUnauthed verifies that no logout form is // rendered when Layout receives a nil user (unauthenticated request). func TestLayout_LogoutFormHiddenWhenUnauthed(t *testing.T) { var buf bytes.Buffer - err := Layout("Test", nil).Render(context.Background(), &buf) + err := Layout("Test", nil, "").Render(context.Background(), &buf) if err != nil { t.Fatalf("Layout.Render: %v", err) }