From ae2d356f8789d31381bc2d847c1e8905c2b34ba0 Mon Sep 17 00:00:00 2001 From: Arthur Belleville Date: Thu, 14 May 2026 22:45:36 +0200 Subject: [PATCH] test(02-07): add failing CSRF tests (RED gate) - TestLoadCSRFKey_* in internal/auth for env key loading - TestCSRF_*MissingToken / TestCSRF_*ValidToken for all three POST routes - TestForms_ContainCSRFField for hidden _csrf input in rendered HTML - TestRouter_CSRFMountedAfterResolveSession for middleware order (D-24) - TestCSRF_HeaderFallback for X-CSRF-Token header support - Add gorilla/csrf v1.7.3 dependency --- backend/go.mod | 2 + backend/go.sum | 4 + backend/internal/auth/csrf_test.go | 43 ++++ backend/internal/web/csrf_test.go | 369 +++++++++++++++++++++++++++++ 4 files changed, 418 insertions(+) create mode 100644 backend/internal/auth/csrf_test.go create mode 100644 backend/internal/web/csrf_test.go diff --git a/backend/go.mod b/backend/go.mod index bbf1d23..a9dee34 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -12,6 +12,8 @@ require ( ) require ( + github.com/gorilla/csrf v1.7.3 // indirect + github.com/gorilla/securecookie v1.1.2 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect diff --git a/backend/go.sum b/backend/go.sum index 1a8a457..a89f98d 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -11,6 +11,10 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/csrf v1.7.3 h1:BHWt6FTLZAb2HtWT5KDBf6qgpZzvtbp9QWDRKZMXJC0= +github.com/gorilla/csrf v1.7.3/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= +github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= +github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= diff --git a/backend/internal/auth/csrf_test.go b/backend/internal/auth/csrf_test.go new file mode 100644 index 0000000..2217b5c --- /dev/null +++ b/backend/internal/auth/csrf_test.go @@ -0,0 +1,43 @@ +package auth + +import ( + "os" + "testing" +) + +func TestLoadCSRFKey_Missing(t *testing.T) { + os.Unsetenv("SESSION_SECRET") + _, err := LoadKeyFromEnv() + if err == nil { + t.Fatal("expected error when SESSION_SECRET is unset; got nil") + } +} + +func TestLoadCSRFKey_WrongLength(t *testing.T) { + // 31 bytes hex-encoded = 62 hex chars — one byte short. + t.Setenv("SESSION_SECRET", "aabbccddeeff00112233445566778899aabbccddeeff001122334455667788") + _, err := LoadKeyFromEnv() + if err == nil { + t.Fatal("expected error when SESSION_SECRET decodes to != 32 bytes; got nil") + } +} + +func TestLoadCSRFKey_Valid(t *testing.T) { + // 32 bytes = 64 hex chars. + t.Setenv("SESSION_SECRET", "aabbccddeeff00112233445566778899aabbccddeeff00112233445566778899") + key, err := LoadKeyFromEnv() + if err != nil { + t.Fatalf("unexpected error with valid 32-byte key: %v", err) + } + if len(key) != 32 { + t.Errorf("key length = %d; want 32", len(key)) + } +} + +func TestLoadCSRFKey_InvalidHex(t *testing.T) { + t.Setenv("SESSION_SECRET", "ZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ") + _, err := LoadKeyFromEnv() + if err == nil { + t.Fatal("expected error with invalid hex string; got nil") + } +} diff --git a/backend/internal/web/csrf_test.go b/backend/internal/web/csrf_test.go new file mode 100644 index 0000000..36e1ad5 --- /dev/null +++ b/backend/internal/web/csrf_test.go @@ -0,0 +1,369 @@ +package web + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strings" + "testing" + + "backend/internal/auth" + "backend/internal/db/sqlc" +) + +// newTestRouterWithCSRF builds a router with CSRF enabled using a test key. +func newTestRouterWithCSRF(q *sqlc.Queries, store *auth.Store) http.Handler { + csrfKey := make([]byte, 32) + for i := range csrfKey { + csrfKey[i] = byte(i + 1) + } + deps := AuthDeps{Queries: q, Store: store, Secure: false} + return NewRouter(stubPinger{}, "./static", deps, csrfKey, "dev") +} + +// extractCSRFToken performs a GET request and extracts the _csrf token from the +// rendered HTML form. It parses the hidden input value from the response body. +func extractCSRFToken(t *testing.T, router http.Handler, path string, cookies []*http.Cookie) (string, []*http.Cookie) { + t.Helper() + req := httptest.NewRequest(http.MethodGet, path, nil) + for _, c := range cookies { + req.AddCookie(c) + } + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + body := rec.Body.String() + // Look for: name="_csrf" value="TOKEN" + const needle = `name="_csrf" value="` + idx := strings.Index(body, needle) + if idx == -1 { + t.Fatalf("extractCSRFToken: _csrf hidden input not found in GET %s response\nbody snippet: %s", + path, truncate(body, 500)) + } + rest := body[idx+len(needle):] + end := strings.Index(rest, `"`) + if end == -1 { + t.Fatalf("extractCSRFToken: could not find closing quote for _csrf value in GET %s response", path) + } + token := rest[:end] + + // Collect set cookies (includes the gorilla_csrf cookie). + var respCookies []*http.Cookie + respCookies = append(respCookies, cookies...) + for _, c := range rec.Result().Cookies() { + respCookies = append(respCookies, c) + } + return token, respCookies +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "..." +} + +// TestCSRF_LoginMissingToken: POST /login without _csrf → 403. +func TestCSRF_LoginMissingToken(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + q := sqlc.New(pool) + store := auth.NewStore(q) + router := newTestRouterWithCSRF(q, store) + + form := url.Values{"email": {"csrf@example.com"}, "password": {"correct-horse-12"}} + req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusForbidden { + t.Fatalf("status = %d; want 403 (missing CSRF token)", rec.Code) + } +} + +// TestCSRF_LoginValidToken: GET /login first → extract token → POST with token → 200/303. +func TestCSRF_LoginValidToken(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + ctx := context.Background() + q := sqlc.New(pool) + store := auth.NewStore(q) + router := newTestRouterWithCSRF(q, store) + + // Pre-seed user. + preInsertUser(t, ctx, q, "csrflogin@example.com", "correct-horse-12") + + token, cookies := extractCSRFToken(t, router, "/login", nil) + + form := url.Values{ + "email": {"csrflogin@example.com"}, + "password": {"correct-horse-12"}, + "_csrf": {token}, + } + req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range cookies { + req.AddCookie(c) + } + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusSeeOther && rec.Code != http.StatusOK { + t.Fatalf("status = %d; want 303 or 200 (valid CSRF token, successful login)", rec.Code) + } +} + +// TestCSRF_SignupMissingToken: POST /signup without _csrf → 403. +func TestCSRF_SignupMissingToken(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + q := sqlc.New(pool) + store := auth.NewStore(q) + router := newTestRouterWithCSRF(q, store) + + form := url.Values{"email": {"nosignup@example.com"}, "password": {"correct-horse-12"}} + req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusForbidden { + t.Fatalf("status = %d; want 403 (missing CSRF token)", rec.Code) + } +} + +// TestCSRF_SignupValidToken: GET /signup → POST with valid token → 303. +func TestCSRF_SignupValidToken(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + q := sqlc.New(pool) + store := auth.NewStore(q) + router := newTestRouterWithCSRF(q, store) + + token, cookies := extractCSRFToken(t, router, "/signup", nil) + + form := url.Values{ + "email": {"csrfsignup@example.com"}, + "password": {"correct-horse-12"}, + "_csrf": {token}, + } + req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range cookies { + req.AddCookie(c) + } + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusSeeOther && rec.Code != http.StatusOK { + t.Fatalf("status = %d; want 303 or 200 (valid CSRF token, successful signup)", rec.Code) + } +} + +// TestCSRF_LogoutMissingToken: pre-seed session, POST /logout without _csrf → 403, session NOT deleted. +func TestCSRF_LogoutMissingToken(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + ctx := context.Background() + q := sqlc.New(pool) + store := auth.NewStore(q) + router := newTestRouterWithCSRF(q, store) + + user := preInsertUser(t, ctx, q, "csrflogout@example.com", "correct-horse-12") + cookieValue, _, err := store.Create(ctx, user.ID) + if err != nil { + t.Fatalf("store.Create: %v", err) + } + sessionID := hashCookieValue(t, cookieValue) + + req := httptest.NewRequest(http.MethodPost, "/logout", nil) + req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: cookieValue}) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusForbidden { + t.Fatalf("status = %d; want 403 (missing CSRF on logout)", rec.Code) + } + + // Session must NOT be deleted. + var count int + row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE id = $1", sessionID) + if err := row.Scan(&count); err != nil { + t.Fatalf("session count query: %v", err) + } + if count != 1 { + t.Errorf("session count = %d; want 1 (session must survive when CSRF missing on logout)", count) + } +} + +// TestCSRF_LogoutValidToken: GET / → extract token → POST /logout → 303, session deleted. +func TestCSRF_LogoutValidToken(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + ctx := context.Background() + q := sqlc.New(pool) + store := auth.NewStore(q) + router := newTestRouterWithCSRF(q, store) + + user := preInsertUser(t, ctx, q, "csrflogout2@example.com", "correct-horse-12") + cookieValue, _, err := store.Create(ctx, user.ID) + if err != nil { + t.Fatalf("store.Create: %v", err) + } + sessionID := hashCookieValue(t, cookieValue) + + sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: cookieValue} + token, cookies := extractCSRFToken(t, router, "/", []*http.Cookie{sessionCookie}) + + req := httptest.NewRequest(http.MethodPost, "/logout", strings.NewReader(url.Values{"_csrf": {token}}.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range cookies { + req.AddCookie(c) + } + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusSeeOther && rec.Code != http.StatusOK { + t.Fatalf("status = %d; want 303 or 200 (valid CSRF, logout succeeded)", rec.Code) + } + + // Session must be deleted. + var count int + row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE id = $1", sessionID) + if err := row.Scan(&count); err != nil { + t.Fatalf("session count query: %v", err) + } + if count != 0 { + t.Errorf("session count = %d; want 0 (session deleted after logout with valid CSRF)", count) + } +} + +// TestCSRF_HeaderFallback: POST /login with X-CSRF-Token header → accepted. +func TestCSRF_HeaderFallback(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + q := sqlc.New(pool) + store := auth.NewStore(q) + router := newTestRouterWithCSRF(q, store) + + // Get a token via GET. + getReq := httptest.NewRequest(http.MethodGet, "/login", nil) + getRec := httptest.NewRecorder() + router.ServeHTTP(getRec, getReq) + + // Extract the gorilla_csrf cookie for the double-submit. + var csrfCookie *http.Cookie + for _, c := range getRec.Result().Cookies() { + if strings.Contains(c.Name, "csrf") || strings.Contains(c.Name, "gorilla") { + csrfCookie = c + break + } + } + + // Extract token from body. + body := getRec.Body.String() + const needle = `name="_csrf" value="` + idx := strings.Index(body, needle) + if idx == -1 { + t.Skip("could not find CSRF token in GET /login body — skipping header fallback test") + } + rest := body[idx+len(needle):] + end := strings.Index(rest, `"`) + token := rest[:end] + + // POST with token in header, not form body. + form := url.Values{"email": {"headercsrf@example.com"}, "password": {"correct-horse-12"}} + req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("X-CSRF-Token", token) + if csrfCookie != nil { + req.AddCookie(csrfCookie) + } + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + // Should NOT be 403 (token accepted via header). + if rec.Code == http.StatusForbidden { + t.Fatal("X-CSRF-Token header not accepted; got 403") + } +} + +// TestForms_ContainCSRFField checks that all form-rendering templ components +// include the hidden _csrf field when rendered. +func TestForms_ContainCSRFField(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + ctx := context.Background() + q := sqlc.New(pool) + store := auth.NewStore(q) + router := newTestRouterWithCSRF(q, store) + + pages := []string{"/login", "/signup"} + for _, path := range pages { + req := httptest.NewRequest(http.MethodGet, path, nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + body := rec.Body.String() + if !strings.Contains(body, `name="_csrf"`) { + t.Errorf("GET %s: rendered HTML missing name=\"_csrf\" hidden input\nbody snippet: %s", + path, truncate(body, 500)) + } + } + + // Also check the index page (has the logout form). + user := preInsertUser(t, ctx, q, "csrfform@example.com", "correct-horse-12") + cookieValue, _, err := store.Create(ctx, user.ID) + if err != nil { + t.Fatalf("store.Create: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: cookieValue}) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if !strings.Contains(rec.Body.String(), `name="_csrf"`) { + t.Errorf("GET /: rendered HTML missing name=\"_csrf\" in logout form\nbody snippet: %s", + truncate(rec.Body.String(), 500)) + } +} + +// TestRouter_CSRFMountedAfterResolveSession checks middleware order in source. +func TestRouter_CSRFMountedAfterResolveSession(t *testing.T) { + data, err := os.ReadFile("router.go") + if err != nil { + t.Fatalf("could not read router.go: %v", err) + } + src := string(data) + + resolveIdx := strings.Index(src, "auth.ResolveSession") + mountIdx := strings.Index(src, "auth.Mount") + if resolveIdx == -1 { + t.Fatal("router.go: auth.ResolveSession not found") + } + if mountIdx == -1 { + t.Fatal("router.go: auth.Mount not found") + } + if resolveIdx >= mountIdx { + t.Errorf("middleware order violation (D-24): auth.ResolveSession must appear before auth.Mount in router.go") + } +}