test(02-07): add failing CSRF tests (RED gate)
- TestLoadCSRFKey_* in internal/auth for env key loading - TestCSRF_*MissingToken / TestCSRF_*ValidToken for all three POST routes - TestForms_ContainCSRFField for hidden _csrf input in rendered HTML - TestRouter_CSRFMountedAfterResolveSession for middleware order (D-24) - TestCSRF_HeaderFallback for X-CSRF-Token header support - Add gorilla/csrf v1.7.3 dependency
This commit is contained in:
parent
00a9388c32
commit
ae2d356f87
4 changed files with 418 additions and 0 deletions
|
|
@ -12,6 +12,8 @@ require (
|
|||
)
|
||||
|
||||
require (
|
||||
github.com/gorilla/csrf v1.7.3 // indirect
|
||||
github.com/gorilla/securecookie v1.1.2 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
|
|
|
|||
|
|
@ -11,6 +11,10 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
|||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/csrf v1.7.3 h1:BHWt6FTLZAb2HtWT5KDBf6qgpZzvtbp9QWDRKZMXJC0=
|
||||
github.com/gorilla/csrf v1.7.3/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk=
|
||||
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
|
||||
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
|
|
|
|||
43
backend/internal/auth/csrf_test.go
Normal file
43
backend/internal/auth/csrf_test.go
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadCSRFKey_Missing(t *testing.T) {
|
||||
os.Unsetenv("SESSION_SECRET")
|
||||
_, err := LoadKeyFromEnv()
|
||||
if err == nil {
|
||||
t.Fatal("expected error when SESSION_SECRET is unset; got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCSRFKey_WrongLength(t *testing.T) {
|
||||
// 31 bytes hex-encoded = 62 hex chars — one byte short.
|
||||
t.Setenv("SESSION_SECRET", "aabbccddeeff00112233445566778899aabbccddeeff001122334455667788")
|
||||
_, err := LoadKeyFromEnv()
|
||||
if err == nil {
|
||||
t.Fatal("expected error when SESSION_SECRET decodes to != 32 bytes; got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCSRFKey_Valid(t *testing.T) {
|
||||
// 32 bytes = 64 hex chars.
|
||||
t.Setenv("SESSION_SECRET", "aabbccddeeff00112233445566778899aabbccddeeff00112233445566778899")
|
||||
key, err := LoadKeyFromEnv()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error with valid 32-byte key: %v", err)
|
||||
}
|
||||
if len(key) != 32 {
|
||||
t.Errorf("key length = %d; want 32", len(key))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCSRFKey_InvalidHex(t *testing.T) {
|
||||
t.Setenv("SESSION_SECRET", "ZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ")
|
||||
_, err := LoadKeyFromEnv()
|
||||
if err == nil {
|
||||
t.Fatal("expected error with invalid hex string; got nil")
|
||||
}
|
||||
}
|
||||
369
backend/internal/web/csrf_test.go
Normal file
369
backend/internal/web/csrf_test.go
Normal file
|
|
@ -0,0 +1,369 @@
|
|||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"backend/internal/auth"
|
||||
"backend/internal/db/sqlc"
|
||||
)
|
||||
|
||||
// newTestRouterWithCSRF builds a router with CSRF enabled using a test key.
|
||||
func newTestRouterWithCSRF(q *sqlc.Queries, store *auth.Store) http.Handler {
|
||||
csrfKey := make([]byte, 32)
|
||||
for i := range csrfKey {
|
||||
csrfKey[i] = byte(i + 1)
|
||||
}
|
||||
deps := AuthDeps{Queries: q, Store: store, Secure: false}
|
||||
return NewRouter(stubPinger{}, "./static", deps, csrfKey, "dev")
|
||||
}
|
||||
|
||||
// extractCSRFToken performs a GET request and extracts the _csrf token from the
|
||||
// rendered HTML form. It parses the hidden input value from the response body.
|
||||
func extractCSRFToken(t *testing.T, router http.Handler, path string, cookies []*http.Cookie) (string, []*http.Cookie) {
|
||||
t.Helper()
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
for _, c := range cookies {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
body := rec.Body.String()
|
||||
// Look for: name="_csrf" value="TOKEN"
|
||||
const needle = `name="_csrf" value="`
|
||||
idx := strings.Index(body, needle)
|
||||
if idx == -1 {
|
||||
t.Fatalf("extractCSRFToken: _csrf hidden input not found in GET %s response\nbody snippet: %s",
|
||||
path, truncate(body, 500))
|
||||
}
|
||||
rest := body[idx+len(needle):]
|
||||
end := strings.Index(rest, `"`)
|
||||
if end == -1 {
|
||||
t.Fatalf("extractCSRFToken: could not find closing quote for _csrf value in GET %s response", path)
|
||||
}
|
||||
token := rest[:end]
|
||||
|
||||
// Collect set cookies (includes the gorilla_csrf cookie).
|
||||
var respCookies []*http.Cookie
|
||||
respCookies = append(respCookies, cookies...)
|
||||
for _, c := range rec.Result().Cookies() {
|
||||
respCookies = append(respCookies, c)
|
||||
}
|
||||
return token, respCookies
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
if len(s) <= n {
|
||||
return s
|
||||
}
|
||||
return s[:n] + "..."
|
||||
}
|
||||
|
||||
// TestCSRF_LoginMissingToken: POST /login without _csrf → 403.
|
||||
func TestCSRF_LoginMissingToken(t *testing.T) {
|
||||
pool, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
q := sqlc.New(pool)
|
||||
store := auth.NewStore(q)
|
||||
router := newTestRouterWithCSRF(q, store)
|
||||
|
||||
form := url.Values{"email": {"csrf@example.com"}, "password": {"correct-horse-12"}}
|
||||
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusForbidden {
|
||||
t.Fatalf("status = %d; want 403 (missing CSRF token)", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCSRF_LoginValidToken: GET /login first → extract token → POST with token → 200/303.
|
||||
func TestCSRF_LoginValidToken(t *testing.T) {
|
||||
pool, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
q := sqlc.New(pool)
|
||||
store := auth.NewStore(q)
|
||||
router := newTestRouterWithCSRF(q, store)
|
||||
|
||||
// Pre-seed user.
|
||||
preInsertUser(t, ctx, q, "csrflogin@example.com", "correct-horse-12")
|
||||
|
||||
token, cookies := extractCSRFToken(t, router, "/login", nil)
|
||||
|
||||
form := url.Values{
|
||||
"email": {"csrflogin@example.com"},
|
||||
"password": {"correct-horse-12"},
|
||||
"_csrf": {token},
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
for _, c := range cookies {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusSeeOther && rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d; want 303 or 200 (valid CSRF token, successful login)", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCSRF_SignupMissingToken: POST /signup without _csrf → 403.
|
||||
func TestCSRF_SignupMissingToken(t *testing.T) {
|
||||
pool, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
q := sqlc.New(pool)
|
||||
store := auth.NewStore(q)
|
||||
router := newTestRouterWithCSRF(q, store)
|
||||
|
||||
form := url.Values{"email": {"nosignup@example.com"}, "password": {"correct-horse-12"}}
|
||||
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusForbidden {
|
||||
t.Fatalf("status = %d; want 403 (missing CSRF token)", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCSRF_SignupValidToken: GET /signup → POST with valid token → 303.
|
||||
func TestCSRF_SignupValidToken(t *testing.T) {
|
||||
pool, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
q := sqlc.New(pool)
|
||||
store := auth.NewStore(q)
|
||||
router := newTestRouterWithCSRF(q, store)
|
||||
|
||||
token, cookies := extractCSRFToken(t, router, "/signup", nil)
|
||||
|
||||
form := url.Values{
|
||||
"email": {"csrfsignup@example.com"},
|
||||
"password": {"correct-horse-12"},
|
||||
"_csrf": {token},
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
for _, c := range cookies {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusSeeOther && rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d; want 303 or 200 (valid CSRF token, successful signup)", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCSRF_LogoutMissingToken: pre-seed session, POST /logout without _csrf → 403, session NOT deleted.
|
||||
func TestCSRF_LogoutMissingToken(t *testing.T) {
|
||||
pool, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
q := sqlc.New(pool)
|
||||
store := auth.NewStore(q)
|
||||
router := newTestRouterWithCSRF(q, store)
|
||||
|
||||
user := preInsertUser(t, ctx, q, "csrflogout@example.com", "correct-horse-12")
|
||||
cookieValue, _, err := store.Create(ctx, user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("store.Create: %v", err)
|
||||
}
|
||||
sessionID := hashCookieValue(t, cookieValue)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/logout", nil)
|
||||
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: cookieValue})
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusForbidden {
|
||||
t.Fatalf("status = %d; want 403 (missing CSRF on logout)", rec.Code)
|
||||
}
|
||||
|
||||
// Session must NOT be deleted.
|
||||
var count int
|
||||
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE id = $1", sessionID)
|
||||
if err := row.Scan(&count); err != nil {
|
||||
t.Fatalf("session count query: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("session count = %d; want 1 (session must survive when CSRF missing on logout)", count)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCSRF_LogoutValidToken: GET / → extract token → POST /logout → 303, session deleted.
|
||||
func TestCSRF_LogoutValidToken(t *testing.T) {
|
||||
pool, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
q := sqlc.New(pool)
|
||||
store := auth.NewStore(q)
|
||||
router := newTestRouterWithCSRF(q, store)
|
||||
|
||||
user := preInsertUser(t, ctx, q, "csrflogout2@example.com", "correct-horse-12")
|
||||
cookieValue, _, err := store.Create(ctx, user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("store.Create: %v", err)
|
||||
}
|
||||
sessionID := hashCookieValue(t, cookieValue)
|
||||
|
||||
sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: cookieValue}
|
||||
token, cookies := extractCSRFToken(t, router, "/", []*http.Cookie{sessionCookie})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/logout", strings.NewReader(url.Values{"_csrf": {token}}.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
for _, c := range cookies {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusSeeOther && rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d; want 303 or 200 (valid CSRF, logout succeeded)", rec.Code)
|
||||
}
|
||||
|
||||
// Session must be deleted.
|
||||
var count int
|
||||
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE id = $1", sessionID)
|
||||
if err := row.Scan(&count); err != nil {
|
||||
t.Fatalf("session count query: %v", err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Errorf("session count = %d; want 0 (session deleted after logout with valid CSRF)", count)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCSRF_HeaderFallback: POST /login with X-CSRF-Token header → accepted.
|
||||
func TestCSRF_HeaderFallback(t *testing.T) {
|
||||
pool, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
q := sqlc.New(pool)
|
||||
store := auth.NewStore(q)
|
||||
router := newTestRouterWithCSRF(q, store)
|
||||
|
||||
// Get a token via GET.
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/login", nil)
|
||||
getRec := httptest.NewRecorder()
|
||||
router.ServeHTTP(getRec, getReq)
|
||||
|
||||
// Extract the gorilla_csrf cookie for the double-submit.
|
||||
var csrfCookie *http.Cookie
|
||||
for _, c := range getRec.Result().Cookies() {
|
||||
if strings.Contains(c.Name, "csrf") || strings.Contains(c.Name, "gorilla") {
|
||||
csrfCookie = c
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Extract token from body.
|
||||
body := getRec.Body.String()
|
||||
const needle = `name="_csrf" value="`
|
||||
idx := strings.Index(body, needle)
|
||||
if idx == -1 {
|
||||
t.Skip("could not find CSRF token in GET /login body — skipping header fallback test")
|
||||
}
|
||||
rest := body[idx+len(needle):]
|
||||
end := strings.Index(rest, `"`)
|
||||
token := rest[:end]
|
||||
|
||||
// POST with token in header, not form body.
|
||||
form := url.Values{"email": {"headercsrf@example.com"}, "password": {"correct-horse-12"}}
|
||||
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("X-CSRF-Token", token)
|
||||
if csrfCookie != nil {
|
||||
req.AddCookie(csrfCookie)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// Should NOT be 403 (token accepted via header).
|
||||
if rec.Code == http.StatusForbidden {
|
||||
t.Fatal("X-CSRF-Token header not accepted; got 403")
|
||||
}
|
||||
}
|
||||
|
||||
// TestForms_ContainCSRFField checks that all form-rendering templ components
|
||||
// include the hidden _csrf field when rendered.
|
||||
func TestForms_ContainCSRFField(t *testing.T) {
|
||||
pool, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
q := sqlc.New(pool)
|
||||
store := auth.NewStore(q)
|
||||
router := newTestRouterWithCSRF(q, store)
|
||||
|
||||
pages := []string{"/login", "/signup"}
|
||||
for _, path := range pages {
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, `name="_csrf"`) {
|
||||
t.Errorf("GET %s: rendered HTML missing name=\"_csrf\" hidden input\nbody snippet: %s",
|
||||
path, truncate(body, 500))
|
||||
}
|
||||
}
|
||||
|
||||
// Also check the index page (has the logout form).
|
||||
user := preInsertUser(t, ctx, q, "csrfform@example.com", "correct-horse-12")
|
||||
cookieValue, _, err := store.Create(ctx, user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("store.Create: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: cookieValue})
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if !strings.Contains(rec.Body.String(), `name="_csrf"`) {
|
||||
t.Errorf("GET /: rendered HTML missing name=\"_csrf\" in logout form\nbody snippet: %s",
|
||||
truncate(rec.Body.String(), 500))
|
||||
}
|
||||
}
|
||||
|
||||
// TestRouter_CSRFMountedAfterResolveSession checks middleware order in source.
|
||||
func TestRouter_CSRFMountedAfterResolveSession(t *testing.T) {
|
||||
data, err := os.ReadFile("router.go")
|
||||
if err != nil {
|
||||
t.Fatalf("could not read router.go: %v", err)
|
||||
}
|
||||
src := string(data)
|
||||
|
||||
resolveIdx := strings.Index(src, "auth.ResolveSession")
|
||||
mountIdx := strings.Index(src, "auth.Mount")
|
||||
if resolveIdx == -1 {
|
||||
t.Fatal("router.go: auth.ResolveSession not found")
|
||||
}
|
||||
if mountIdx == -1 {
|
||||
t.Fatal("router.go: auth.Mount not found")
|
||||
}
|
||||
if resolveIdx >= mountIdx {
|
||||
t.Errorf("middleware order violation (D-24): auth.ResolveSession must appear before auth.Mount in router.go")
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue