package web import ( "context" "net/http" "net/http/httptest" "net/url" "os" "strings" "testing" "backend/internal/auth" "backend/internal/db/sqlc" ) // newTestRouterWithCSRF builds a router with CSRF enabled using a test key. // "localhost" is added as a trusted origin so httptest requests without a // Referer header are accepted. func newTestRouterWithCSRF(q *sqlc.Queries, store *auth.Store) http.Handler { csrfKey := make([]byte, 32) for i := range csrfKey { csrfKey[i] = byte(i + 1) } deps := AuthDeps{Queries: q, Store: store, Secure: false} router, err := NewRouter(stubPinger{}, os.DirFS("./static"), deps, TablosDeps{Queries: q}, TasksDeps{Queries: q}, EtapesDeps{Queries: q}, EventsDeps{Queries: q}, DiscussionDeps{Queries: q}, PlanningDeps{Queries: q}, FilesDeps{Queries: q}, csrfKey, "dev", "localhost") if err != nil { panic("newTestRouterWithCSRF: " + err.Error()) } return router } // extractCSRFToken performs a GET request and extracts the _csrf token from the // rendered HTML form. It parses the hidden input value from the response body. func extractCSRFToken(t *testing.T, router http.Handler, path string, cookies []*http.Cookie) (string, []*http.Cookie) { t.Helper() req := httptest.NewRequest(http.MethodGet, path, nil) for _, c := range cookies { req.AddCookie(c) } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) body := rec.Body.String() // Look for: name="_csrf" value="TOKEN" const needle = `name="_csrf" value="` idx := strings.Index(body, needle) if idx == -1 { t.Fatalf("extractCSRFToken: _csrf hidden input not found in GET %s response\nbody snippet: %s", path, truncate(body, 500)) } rest := body[idx+len(needle):] end := strings.Index(rest, `"`) if end == -1 { t.Fatalf("extractCSRFToken: could not find closing quote for _csrf value in GET %s response", path) } token := rest[:end] // Collect set cookies (includes the gorilla_csrf cookie). var respCookies []*http.Cookie respCookies = append(respCookies, cookies...) for _, c := range rec.Result().Cookies() { respCookies = append(respCookies, c) } return token, respCookies } func truncate(s string, n int) string { if len(s) <= n { return s } return s[:n] + "..." } // TestCSRF_LoginMissingToken: POST /login without _csrf → 403. func TestCSRF_LoginMissingToken(t *testing.T) { pool, cleanup := setupTestDB(t) defer cleanup() q := sqlc.New(pool) store := auth.NewStore(q) router := newTestRouterWithCSRF(q, store) form := url.Values{"email": {"csrf@example.com"}, "password": {"correct-horse-12"}} req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rec := httptest.NewRecorder() router.ServeHTTP(rec, req) if rec.Code != http.StatusForbidden { t.Fatalf("status = %d; want 403 (missing CSRF token)", rec.Code) } } // TestCSRF_LoginValidToken: GET /login first → extract token → POST with token → 200/303. func TestCSRF_LoginValidToken(t *testing.T) { pool, cleanup := setupTestDB(t) defer cleanup() ctx := context.Background() q := sqlc.New(pool) store := auth.NewStore(q) router := newTestRouterWithCSRF(q, store) // Pre-seed user. preInsertUser(t, ctx, q, "csrflogin@example.com", "correct-horse-12") token, cookies := extractCSRFToken(t, router, "/login", nil) form := url.Values{ "email": {"csrflogin@example.com"}, "password": {"correct-horse-12"}, "_csrf": {token}, } req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") for _, c := range cookies { req.AddCookie(c) } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) if rec.Code != http.StatusSeeOther && rec.Code != http.StatusOK { t.Fatalf("status = %d; want 303 or 200 (valid CSRF token, successful login)", rec.Code) } } // TestCSRF_SignupMissingToken: POST /signup without _csrf → 403. func TestCSRF_SignupMissingToken(t *testing.T) { pool, cleanup := setupTestDB(t) defer cleanup() q := sqlc.New(pool) store := auth.NewStore(q) router := newTestRouterWithCSRF(q, store) form := url.Values{"email": {"nosignup@example.com"}, "password": {"correct-horse-12"}} req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rec := httptest.NewRecorder() router.ServeHTTP(rec, req) if rec.Code != http.StatusForbidden { t.Fatalf("status = %d; want 403 (missing CSRF token)", rec.Code) } } // TestCSRF_SignupValidToken: GET /signup → POST with valid token → 303. func TestCSRF_SignupValidToken(t *testing.T) { pool, cleanup := setupTestDB(t) defer cleanup() q := sqlc.New(pool) store := auth.NewStore(q) router := newTestRouterWithCSRF(q, store) token, cookies := extractCSRFToken(t, router, "/signup", nil) form := url.Values{ "email": {"csrfsignup@example.com"}, "password": {"correct-horse-12"}, "_csrf": {token}, } req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") for _, c := range cookies { req.AddCookie(c) } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) if rec.Code != http.StatusSeeOther && rec.Code != http.StatusOK { t.Fatalf("status = %d; want 303 or 200 (valid CSRF token, successful signup)", rec.Code) } } // TestCSRF_LogoutMissingToken: pre-seed session, POST /logout without _csrf → 403, session NOT deleted. func TestCSRF_LogoutMissingToken(t *testing.T) { pool, cleanup := setupTestDB(t) defer cleanup() ctx := context.Background() q := sqlc.New(pool) store := auth.NewStore(q) router := newTestRouterWithCSRF(q, store) user := preInsertUser(t, ctx, q, "csrflogout@example.com", "correct-horse-12") cookieValue, _, err := store.Create(ctx, user.ID) if err != nil { t.Fatalf("store.Create: %v", err) } sessionID := hashCookieValue(t, cookieValue) req := httptest.NewRequest(http.MethodPost, "/logout", nil) req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: cookieValue}) rec := httptest.NewRecorder() router.ServeHTTP(rec, req) if rec.Code != http.StatusForbidden { t.Fatalf("status = %d; want 403 (missing CSRF on logout)", rec.Code) } // Session must NOT be deleted. var count int row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE id = $1", sessionID) if err := row.Scan(&count); err != nil { t.Fatalf("session count query: %v", err) } if count != 1 { t.Errorf("session count = %d; want 1 (session must survive when CSRF missing on logout)", count) } } // TestCSRF_LogoutValidToken: GET / → extract token → POST /logout → 303, session deleted. func TestCSRF_LogoutValidToken(t *testing.T) { pool, cleanup := setupTestDB(t) defer cleanup() ctx := context.Background() q := sqlc.New(pool) store := auth.NewStore(q) router := newTestRouterWithCSRF(q, store) user := preInsertUser(t, ctx, q, "csrflogout2@example.com", "correct-horse-12") cookieValue, _, err := store.Create(ctx, user.ID) if err != nil { t.Fatalf("store.Create: %v", err) } sessionID := hashCookieValue(t, cookieValue) sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: cookieValue} token, cookies := extractCSRFToken(t, router, "/", []*http.Cookie{sessionCookie}) req := httptest.NewRequest(http.MethodPost, "/logout", strings.NewReader(url.Values{"_csrf": {token}}.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") for _, c := range cookies { req.AddCookie(c) } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) if rec.Code != http.StatusSeeOther && rec.Code != http.StatusOK { t.Fatalf("status = %d; want 303 or 200 (valid CSRF, logout succeeded)", rec.Code) } // Session must be deleted. var count int row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE id = $1", sessionID) if err := row.Scan(&count); err != nil { t.Fatalf("session count query: %v", err) } if count != 0 { t.Errorf("session count = %d; want 0 (session deleted after logout with valid CSRF)", count) } } // TestCSRF_HeaderFallback: POST /login with X-CSRF-Token header → accepted. func TestCSRF_HeaderFallback(t *testing.T) { pool, cleanup := setupTestDB(t) defer cleanup() q := sqlc.New(pool) store := auth.NewStore(q) router := newTestRouterWithCSRF(q, store) // Get a token via GET. getReq := httptest.NewRequest(http.MethodGet, "/login", nil) getRec := httptest.NewRecorder() router.ServeHTTP(getRec, getReq) // Extract the gorilla_csrf cookie for the double-submit. var csrfCookie *http.Cookie for _, c := range getRec.Result().Cookies() { if strings.Contains(c.Name, "csrf") || strings.Contains(c.Name, "gorilla") { csrfCookie = c break } } // Extract token from body. body := getRec.Body.String() const needle = `name="_csrf" value="` idx := strings.Index(body, needle) if idx == -1 { t.Skip("could not find CSRF token in GET /login body — skipping header fallback test") } rest := body[idx+len(needle):] end := strings.Index(rest, `"`) token := rest[:end] // POST with token in header, not form body. form := url.Values{"email": {"headercsrf@example.com"}, "password": {"correct-horse-12"}} req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("X-CSRF-Token", token) if csrfCookie != nil { req.AddCookie(csrfCookie) } rec := httptest.NewRecorder() router.ServeHTTP(rec, req) // Should NOT be 403 (token accepted via header). if rec.Code == http.StatusForbidden { t.Fatal("X-CSRF-Token header not accepted; got 403") } } // TestForms_ContainCSRFField checks that all form-rendering templ components // include the hidden _csrf field when rendered. func TestForms_ContainCSRFField(t *testing.T) { pool, cleanup := setupTestDB(t) defer cleanup() ctx := context.Background() q := sqlc.New(pool) store := auth.NewStore(q) router := newTestRouterWithCSRF(q, store) pages := []string{"/login", "/signup"} for _, path := range pages { req := httptest.NewRequest(http.MethodGet, path, nil) rec := httptest.NewRecorder() router.ServeHTTP(rec, req) body := rec.Body.String() if !strings.Contains(body, `name="_csrf"`) { t.Errorf("GET %s: rendered HTML missing name=\"_csrf\" hidden input\nbody snippet: %s", path, truncate(body, 500)) } } // Also check the index page (has the logout form). user := preInsertUser(t, ctx, q, "csrfform@example.com", "correct-horse-12") cookieValue, _, err := store.Create(ctx, user.ID) if err != nil { t.Fatalf("store.Create: %v", err) } req := httptest.NewRequest(http.MethodGet, "/", nil) req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: cookieValue}) rec := httptest.NewRecorder() router.ServeHTTP(rec, req) if !strings.Contains(rec.Body.String(), `name="_csrf"`) { t.Errorf("GET /: rendered HTML missing name=\"_csrf\" in logout form\nbody snippet: %s", truncate(rec.Body.String(), 500)) } } // TestRouter_CSRFMountedAfterResolveSession checks middleware order in source. func TestRouter_CSRFMountedAfterResolveSession(t *testing.T) { data, err := os.ReadFile("router.go") if err != nil { t.Fatalf("could not read router.go: %v", err) } src := string(data) resolveIdx := strings.Index(src, "auth.ResolveSession") mountIdx := strings.Index(src, "auth.Mount") if resolveIdx == -1 { t.Fatal("router.go: auth.ResolveSession not found") } if mountIdx == -1 { t.Fatal("router.go: auth.Mount not found") } if resolveIdx >= mountIdx { t.Errorf("middleware order violation (D-24): auth.ResolveSession must appear before auth.Mount in router.go") } }