xtablo-source/backend/internal/web/csrf_test.go
2026-05-16 07:26:49 +02:00

375 lines
12 KiB
Go

package web
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
"backend/internal/auth"
"backend/internal/db/sqlc"
)
// newTestRouterWithCSRF builds a router with CSRF enabled using a test key.
// "localhost" is added as a trusted origin so httptest requests without a
// Referer header are accepted.
func newTestRouterWithCSRF(q *sqlc.Queries, store *auth.Store) http.Handler {
csrfKey := make([]byte, 32)
for i := range csrfKey {
csrfKey[i] = byte(i + 1)
}
deps := AuthDeps{Queries: q, Store: store, Secure: false}
router, err := NewRouter(stubPinger{}, os.DirFS("./static"), deps, TablosDeps{Queries: q}, TasksDeps{Queries: q}, EtapesDeps{Queries: q}, EventsDeps{Queries: q}, PlanningDeps{Queries: q}, FilesDeps{Queries: q}, csrfKey, "dev", "localhost")
if err != nil {
panic("newTestRouterWithCSRF: " + err.Error())
}
return router
}
// extractCSRFToken performs a GET request and extracts the _csrf token from the
// rendered HTML form. It parses the hidden input value from the response body.
func extractCSRFToken(t *testing.T, router http.Handler, path string, cookies []*http.Cookie) (string, []*http.Cookie) {
t.Helper()
req := httptest.NewRequest(http.MethodGet, path, nil)
for _, c := range cookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
body := rec.Body.String()
// Look for: name="_csrf" value="TOKEN"
const needle = `name="_csrf" value="`
idx := strings.Index(body, needle)
if idx == -1 {
t.Fatalf("extractCSRFToken: _csrf hidden input not found in GET %s response\nbody snippet: %s",
path, truncate(body, 500))
}
rest := body[idx+len(needle):]
end := strings.Index(rest, `"`)
if end == -1 {
t.Fatalf("extractCSRFToken: could not find closing quote for _csrf value in GET %s response", path)
}
token := rest[:end]
// Collect set cookies (includes the gorilla_csrf cookie).
var respCookies []*http.Cookie
respCookies = append(respCookies, cookies...)
for _, c := range rec.Result().Cookies() {
respCookies = append(respCookies, c)
}
return token, respCookies
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}
// TestCSRF_LoginMissingToken: POST /login without _csrf → 403.
func TestCSRF_LoginMissingToken(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouterWithCSRF(q, store)
form := url.Values{"email": {"csrf@example.com"}, "password": {"correct-horse-12"}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Fatalf("status = %d; want 403 (missing CSRF token)", rec.Code)
}
}
// TestCSRF_LoginValidToken: GET /login first → extract token → POST with token → 200/303.
func TestCSRF_LoginValidToken(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouterWithCSRF(q, store)
// Pre-seed user.
preInsertUser(t, ctx, q, "csrflogin@example.com", "correct-horse-12")
token, cookies := extractCSRFToken(t, router, "/login", nil)
form := url.Values{
"email": {"csrflogin@example.com"},
"password": {"correct-horse-12"},
"_csrf": {token},
}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range cookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusSeeOther && rec.Code != http.StatusOK {
t.Fatalf("status = %d; want 303 or 200 (valid CSRF token, successful login)", rec.Code)
}
}
// TestCSRF_SignupMissingToken: POST /signup without _csrf → 403.
func TestCSRF_SignupMissingToken(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouterWithCSRF(q, store)
form := url.Values{"email": {"nosignup@example.com"}, "password": {"correct-horse-12"}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Fatalf("status = %d; want 403 (missing CSRF token)", rec.Code)
}
}
// TestCSRF_SignupValidToken: GET /signup → POST with valid token → 303.
func TestCSRF_SignupValidToken(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouterWithCSRF(q, store)
token, cookies := extractCSRFToken(t, router, "/signup", nil)
form := url.Values{
"email": {"csrfsignup@example.com"},
"password": {"correct-horse-12"},
"_csrf": {token},
}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range cookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusSeeOther && rec.Code != http.StatusOK {
t.Fatalf("status = %d; want 303 or 200 (valid CSRF token, successful signup)", rec.Code)
}
}
// TestCSRF_LogoutMissingToken: pre-seed session, POST /logout without _csrf → 403, session NOT deleted.
func TestCSRF_LogoutMissingToken(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouterWithCSRF(q, store)
user := preInsertUser(t, ctx, q, "csrflogout@example.com", "correct-horse-12")
cookieValue, _, err := store.Create(ctx, user.ID)
if err != nil {
t.Fatalf("store.Create: %v", err)
}
sessionID := hashCookieValue(t, cookieValue)
req := httptest.NewRequest(http.MethodPost, "/logout", nil)
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: cookieValue})
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Fatalf("status = %d; want 403 (missing CSRF on logout)", rec.Code)
}
// Session must NOT be deleted.
var count int
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE id = $1", sessionID)
if err := row.Scan(&count); err != nil {
t.Fatalf("session count query: %v", err)
}
if count != 1 {
t.Errorf("session count = %d; want 1 (session must survive when CSRF missing on logout)", count)
}
}
// TestCSRF_LogoutValidToken: GET / → extract token → POST /logout → 303, session deleted.
func TestCSRF_LogoutValidToken(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouterWithCSRF(q, store)
user := preInsertUser(t, ctx, q, "csrflogout2@example.com", "correct-horse-12")
cookieValue, _, err := store.Create(ctx, user.ID)
if err != nil {
t.Fatalf("store.Create: %v", err)
}
sessionID := hashCookieValue(t, cookieValue)
sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: cookieValue}
token, cookies := extractCSRFToken(t, router, "/", []*http.Cookie{sessionCookie})
req := httptest.NewRequest(http.MethodPost, "/logout", strings.NewReader(url.Values{"_csrf": {token}}.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range cookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusSeeOther && rec.Code != http.StatusOK {
t.Fatalf("status = %d; want 303 or 200 (valid CSRF, logout succeeded)", rec.Code)
}
// Session must be deleted.
var count int
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE id = $1", sessionID)
if err := row.Scan(&count); err != nil {
t.Fatalf("session count query: %v", err)
}
if count != 0 {
t.Errorf("session count = %d; want 0 (session deleted after logout with valid CSRF)", count)
}
}
// TestCSRF_HeaderFallback: POST /login with X-CSRF-Token header → accepted.
func TestCSRF_HeaderFallback(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouterWithCSRF(q, store)
// Get a token via GET.
getReq := httptest.NewRequest(http.MethodGet, "/login", nil)
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
// Extract the gorilla_csrf cookie for the double-submit.
var csrfCookie *http.Cookie
for _, c := range getRec.Result().Cookies() {
if strings.Contains(c.Name, "csrf") || strings.Contains(c.Name, "gorilla") {
csrfCookie = c
break
}
}
// Extract token from body.
body := getRec.Body.String()
const needle = `name="_csrf" value="`
idx := strings.Index(body, needle)
if idx == -1 {
t.Skip("could not find CSRF token in GET /login body — skipping header fallback test")
}
rest := body[idx+len(needle):]
end := strings.Index(rest, `"`)
token := rest[:end]
// POST with token in header, not form body.
form := url.Values{"email": {"headercsrf@example.com"}, "password": {"correct-horse-12"}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("X-CSRF-Token", token)
if csrfCookie != nil {
req.AddCookie(csrfCookie)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
// Should NOT be 403 (token accepted via header).
if rec.Code == http.StatusForbidden {
t.Fatal("X-CSRF-Token header not accepted; got 403")
}
}
// TestForms_ContainCSRFField checks that all form-rendering templ components
// include the hidden _csrf field when rendered.
func TestForms_ContainCSRFField(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouterWithCSRF(q, store)
pages := []string{"/login", "/signup"}
for _, path := range pages {
req := httptest.NewRequest(http.MethodGet, path, nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
body := rec.Body.String()
if !strings.Contains(body, `name="_csrf"`) {
t.Errorf("GET %s: rendered HTML missing name=\"_csrf\" hidden input\nbody snippet: %s",
path, truncate(body, 500))
}
}
// Also check the index page (has the logout form).
user := preInsertUser(t, ctx, q, "csrfform@example.com", "correct-horse-12")
cookieValue, _, err := store.Create(ctx, user.ID)
if err != nil {
t.Fatalf("store.Create: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: cookieValue})
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if !strings.Contains(rec.Body.String(), `name="_csrf"`) {
t.Errorf("GET /: rendered HTML missing name=\"_csrf\" in logout form\nbody snippet: %s",
truncate(rec.Body.String(), 500))
}
}
// TestRouter_CSRFMountedAfterResolveSession checks middleware order in source.
func TestRouter_CSRFMountedAfterResolveSession(t *testing.T) {
data, err := os.ReadFile("router.go")
if err != nil {
t.Fatalf("could not read router.go: %v", err)
}
src := string(data)
resolveIdx := strings.Index(src, "auth.ResolveSession")
mountIdx := strings.Index(src, "auth.Mount")
if resolveIdx == -1 {
t.Fatal("router.go: auth.ResolveSession not found")
}
if mountIdx == -1 {
t.Fatal("router.go: auth.Mount not found")
}
if resolveIdx >= mountIdx {
t.Errorf("middleware order violation (D-24): auth.ResolveSession must appear before auth.Mount in router.go")
}
}