375 lines
12 KiB
Go
375 lines
12 KiB
Go
package web
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
|
|
"backend/internal/auth"
|
|
"backend/internal/db/sqlc"
|
|
)
|
|
|
|
// newTestRouterWithCSRF builds a router with CSRF enabled using a test key.
|
|
// "localhost" is added as a trusted origin so httptest requests without a
|
|
// Referer header are accepted.
|
|
func newTestRouterWithCSRF(q *sqlc.Queries, store *auth.Store) http.Handler {
|
|
csrfKey := make([]byte, 32)
|
|
for i := range csrfKey {
|
|
csrfKey[i] = byte(i + 1)
|
|
}
|
|
deps := AuthDeps{Queries: q, Store: store, Secure: false}
|
|
router, err := NewRouter(stubPinger{}, os.DirFS("./static"), deps, TablosDeps{Queries: q}, TasksDeps{Queries: q}, EtapesDeps{Queries: q}, EventsDeps{Queries: q}, DiscussionDeps{Queries: q}, PlanningDeps{Queries: q}, FilesDeps{Queries: q}, csrfKey, "dev", "localhost")
|
|
if err != nil {
|
|
panic("newTestRouterWithCSRF: " + err.Error())
|
|
}
|
|
return router
|
|
}
|
|
|
|
// extractCSRFToken performs a GET request and extracts the _csrf token from the
|
|
// rendered HTML form. It parses the hidden input value from the response body.
|
|
func extractCSRFToken(t *testing.T, router http.Handler, path string, cookies []*http.Cookie) (string, []*http.Cookie) {
|
|
t.Helper()
|
|
req := httptest.NewRequest(http.MethodGet, path, nil)
|
|
for _, c := range cookies {
|
|
req.AddCookie(c)
|
|
}
|
|
rec := httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req)
|
|
|
|
body := rec.Body.String()
|
|
// Look for: name="_csrf" value="TOKEN"
|
|
const needle = `name="_csrf" value="`
|
|
idx := strings.Index(body, needle)
|
|
if idx == -1 {
|
|
t.Fatalf("extractCSRFToken: _csrf hidden input not found in GET %s response\nbody snippet: %s",
|
|
path, truncate(body, 500))
|
|
}
|
|
rest := body[idx+len(needle):]
|
|
end := strings.Index(rest, `"`)
|
|
if end == -1 {
|
|
t.Fatalf("extractCSRFToken: could not find closing quote for _csrf value in GET %s response", path)
|
|
}
|
|
token := rest[:end]
|
|
|
|
// Collect set cookies (includes the gorilla_csrf cookie).
|
|
var respCookies []*http.Cookie
|
|
respCookies = append(respCookies, cookies...)
|
|
for _, c := range rec.Result().Cookies() {
|
|
respCookies = append(respCookies, c)
|
|
}
|
|
return token, respCookies
|
|
}
|
|
|
|
func truncate(s string, n int) string {
|
|
if len(s) <= n {
|
|
return s
|
|
}
|
|
return s[:n] + "..."
|
|
}
|
|
|
|
// TestCSRF_LoginMissingToken: POST /login without _csrf → 403.
|
|
func TestCSRF_LoginMissingToken(t *testing.T) {
|
|
pool, cleanup := setupTestDB(t)
|
|
defer cleanup()
|
|
|
|
q := sqlc.New(pool)
|
|
store := auth.NewStore(q)
|
|
router := newTestRouterWithCSRF(q, store)
|
|
|
|
form := url.Values{"email": {"csrf@example.com"}, "password": {"correct-horse-12"}}
|
|
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusForbidden {
|
|
t.Fatalf("status = %d; want 403 (missing CSRF token)", rec.Code)
|
|
}
|
|
}
|
|
|
|
// TestCSRF_LoginValidToken: GET /login first → extract token → POST with token → 200/303.
|
|
func TestCSRF_LoginValidToken(t *testing.T) {
|
|
pool, cleanup := setupTestDB(t)
|
|
defer cleanup()
|
|
|
|
ctx := context.Background()
|
|
q := sqlc.New(pool)
|
|
store := auth.NewStore(q)
|
|
router := newTestRouterWithCSRF(q, store)
|
|
|
|
// Pre-seed user.
|
|
preInsertUser(t, ctx, q, "csrflogin@example.com", "correct-horse-12")
|
|
|
|
token, cookies := extractCSRFToken(t, router, "/login", nil)
|
|
|
|
form := url.Values{
|
|
"email": {"csrflogin@example.com"},
|
|
"password": {"correct-horse-12"},
|
|
"_csrf": {token},
|
|
}
|
|
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
for _, c := range cookies {
|
|
req.AddCookie(c)
|
|
}
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusSeeOther && rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d; want 303 or 200 (valid CSRF token, successful login)", rec.Code)
|
|
}
|
|
}
|
|
|
|
// TestCSRF_SignupMissingToken: POST /signup without _csrf → 403.
|
|
func TestCSRF_SignupMissingToken(t *testing.T) {
|
|
pool, cleanup := setupTestDB(t)
|
|
defer cleanup()
|
|
|
|
q := sqlc.New(pool)
|
|
store := auth.NewStore(q)
|
|
router := newTestRouterWithCSRF(q, store)
|
|
|
|
form := url.Values{"email": {"nosignup@example.com"}, "password": {"correct-horse-12"}}
|
|
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusForbidden {
|
|
t.Fatalf("status = %d; want 403 (missing CSRF token)", rec.Code)
|
|
}
|
|
}
|
|
|
|
// TestCSRF_SignupValidToken: GET /signup → POST with valid token → 303.
|
|
func TestCSRF_SignupValidToken(t *testing.T) {
|
|
pool, cleanup := setupTestDB(t)
|
|
defer cleanup()
|
|
|
|
q := sqlc.New(pool)
|
|
store := auth.NewStore(q)
|
|
router := newTestRouterWithCSRF(q, store)
|
|
|
|
token, cookies := extractCSRFToken(t, router, "/signup", nil)
|
|
|
|
form := url.Values{
|
|
"email": {"csrfsignup@example.com"},
|
|
"password": {"correct-horse-12"},
|
|
"_csrf": {token},
|
|
}
|
|
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
for _, c := range cookies {
|
|
req.AddCookie(c)
|
|
}
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusSeeOther && rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d; want 303 or 200 (valid CSRF token, successful signup)", rec.Code)
|
|
}
|
|
}
|
|
|
|
// TestCSRF_LogoutMissingToken: pre-seed session, POST /logout without _csrf → 403, session NOT deleted.
|
|
func TestCSRF_LogoutMissingToken(t *testing.T) {
|
|
pool, cleanup := setupTestDB(t)
|
|
defer cleanup()
|
|
|
|
ctx := context.Background()
|
|
q := sqlc.New(pool)
|
|
store := auth.NewStore(q)
|
|
router := newTestRouterWithCSRF(q, store)
|
|
|
|
user := preInsertUser(t, ctx, q, "csrflogout@example.com", "correct-horse-12")
|
|
cookieValue, _, err := store.Create(ctx, user.ID)
|
|
if err != nil {
|
|
t.Fatalf("store.Create: %v", err)
|
|
}
|
|
sessionID := hashCookieValue(t, cookieValue)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/logout", nil)
|
|
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: cookieValue})
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusForbidden {
|
|
t.Fatalf("status = %d; want 403 (missing CSRF on logout)", rec.Code)
|
|
}
|
|
|
|
// Session must NOT be deleted.
|
|
var count int
|
|
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE id = $1", sessionID)
|
|
if err := row.Scan(&count); err != nil {
|
|
t.Fatalf("session count query: %v", err)
|
|
}
|
|
if count != 1 {
|
|
t.Errorf("session count = %d; want 1 (session must survive when CSRF missing on logout)", count)
|
|
}
|
|
}
|
|
|
|
// TestCSRF_LogoutValidToken: GET / → extract token → POST /logout → 303, session deleted.
|
|
func TestCSRF_LogoutValidToken(t *testing.T) {
|
|
pool, cleanup := setupTestDB(t)
|
|
defer cleanup()
|
|
|
|
ctx := context.Background()
|
|
q := sqlc.New(pool)
|
|
store := auth.NewStore(q)
|
|
router := newTestRouterWithCSRF(q, store)
|
|
|
|
user := preInsertUser(t, ctx, q, "csrflogout2@example.com", "correct-horse-12")
|
|
cookieValue, _, err := store.Create(ctx, user.ID)
|
|
if err != nil {
|
|
t.Fatalf("store.Create: %v", err)
|
|
}
|
|
sessionID := hashCookieValue(t, cookieValue)
|
|
|
|
sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: cookieValue}
|
|
token, cookies := extractCSRFToken(t, router, "/", []*http.Cookie{sessionCookie})
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/logout", strings.NewReader(url.Values{"_csrf": {token}}.Encode()))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
for _, c := range cookies {
|
|
req.AddCookie(c)
|
|
}
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusSeeOther && rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d; want 303 or 200 (valid CSRF, logout succeeded)", rec.Code)
|
|
}
|
|
|
|
// Session must be deleted.
|
|
var count int
|
|
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE id = $1", sessionID)
|
|
if err := row.Scan(&count); err != nil {
|
|
t.Fatalf("session count query: %v", err)
|
|
}
|
|
if count != 0 {
|
|
t.Errorf("session count = %d; want 0 (session deleted after logout with valid CSRF)", count)
|
|
}
|
|
}
|
|
|
|
// TestCSRF_HeaderFallback: POST /login with X-CSRF-Token header → accepted.
|
|
func TestCSRF_HeaderFallback(t *testing.T) {
|
|
pool, cleanup := setupTestDB(t)
|
|
defer cleanup()
|
|
|
|
q := sqlc.New(pool)
|
|
store := auth.NewStore(q)
|
|
router := newTestRouterWithCSRF(q, store)
|
|
|
|
// Get a token via GET.
|
|
getReq := httptest.NewRequest(http.MethodGet, "/login", nil)
|
|
getRec := httptest.NewRecorder()
|
|
router.ServeHTTP(getRec, getReq)
|
|
|
|
// Extract the gorilla_csrf cookie for the double-submit.
|
|
var csrfCookie *http.Cookie
|
|
for _, c := range getRec.Result().Cookies() {
|
|
if strings.Contains(c.Name, "csrf") || strings.Contains(c.Name, "gorilla") {
|
|
csrfCookie = c
|
|
break
|
|
}
|
|
}
|
|
|
|
// Extract token from body.
|
|
body := getRec.Body.String()
|
|
const needle = `name="_csrf" value="`
|
|
idx := strings.Index(body, needle)
|
|
if idx == -1 {
|
|
t.Skip("could not find CSRF token in GET /login body — skipping header fallback test")
|
|
}
|
|
rest := body[idx+len(needle):]
|
|
end := strings.Index(rest, `"`)
|
|
token := rest[:end]
|
|
|
|
// POST with token in header, not form body.
|
|
form := url.Values{"email": {"headercsrf@example.com"}, "password": {"correct-horse-12"}}
|
|
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.Header.Set("X-CSRF-Token", token)
|
|
if csrfCookie != nil {
|
|
req.AddCookie(csrfCookie)
|
|
}
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(rec, req)
|
|
|
|
// Should NOT be 403 (token accepted via header).
|
|
if rec.Code == http.StatusForbidden {
|
|
t.Fatal("X-CSRF-Token header not accepted; got 403")
|
|
}
|
|
}
|
|
|
|
// TestForms_ContainCSRFField checks that all form-rendering templ components
|
|
// include the hidden _csrf field when rendered.
|
|
func TestForms_ContainCSRFField(t *testing.T) {
|
|
pool, cleanup := setupTestDB(t)
|
|
defer cleanup()
|
|
|
|
ctx := context.Background()
|
|
q := sqlc.New(pool)
|
|
store := auth.NewStore(q)
|
|
router := newTestRouterWithCSRF(q, store)
|
|
|
|
pages := []string{"/login", "/signup"}
|
|
for _, path := range pages {
|
|
req := httptest.NewRequest(http.MethodGet, path, nil)
|
|
rec := httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req)
|
|
|
|
body := rec.Body.String()
|
|
if !strings.Contains(body, `name="_csrf"`) {
|
|
t.Errorf("GET %s: rendered HTML missing name=\"_csrf\" hidden input\nbody snippet: %s",
|
|
path, truncate(body, 500))
|
|
}
|
|
}
|
|
|
|
// Also check the index page (has the logout form).
|
|
user := preInsertUser(t, ctx, q, "csrfform@example.com", "correct-horse-12")
|
|
cookieValue, _, err := store.Create(ctx, user.ID)
|
|
if err != nil {
|
|
t.Fatalf("store.Create: %v", err)
|
|
}
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: cookieValue})
|
|
rec := httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req)
|
|
|
|
if !strings.Contains(rec.Body.String(), `name="_csrf"`) {
|
|
t.Errorf("GET /: rendered HTML missing name=\"_csrf\" in logout form\nbody snippet: %s",
|
|
truncate(rec.Body.String(), 500))
|
|
}
|
|
}
|
|
|
|
// TestRouter_CSRFMountedAfterResolveSession checks middleware order in source.
|
|
func TestRouter_CSRFMountedAfterResolveSession(t *testing.T) {
|
|
data, err := os.ReadFile("router.go")
|
|
if err != nil {
|
|
t.Fatalf("could not read router.go: %v", err)
|
|
}
|
|
src := string(data)
|
|
|
|
resolveIdx := strings.Index(src, "auth.ResolveSession")
|
|
mountIdx := strings.Index(src, "auth.Mount")
|
|
if resolveIdx == -1 {
|
|
t.Fatal("router.go: auth.ResolveSession not found")
|
|
}
|
|
if mountIdx == -1 {
|
|
t.Fatal("router.go: auth.Mount not found")
|
|
}
|
|
if resolveIdx >= mountIdx {
|
|
t.Errorf("middleware order violation (D-24): auth.ResolveSession must appear before auth.Mount in router.go")
|
|
}
|
|
}
|