From fd2301decf7a3fdfde5840dd4ede36c434562dfd Mon Sep 17 00:00:00 2001 From: Arthur Belleville Date: Thu, 14 May 2026 22:08:04 +0200 Subject: [PATCH] feat(02-03): session store + cookie helpers (real-DB TDD) - Store.Create: 32-byte crypto/rand token, SHA-256 hex as DB id (D-05) - Store.Lookup: hashes cookie, maps pgx.ErrNoRows to ErrSessionNotFound (D-07) - Store.Delete: hard-deletes session row (D-06) - Store.Rotate: deletes old row before creating new one (D-10, T-2-04) - Store.MaybeExtend: extends only when remaining < 7 days (D-09) - SetSessionCookie: HttpOnly + Secure (env-gated) + SameSite=Lax (D-12) - ClearSessionCookie: MaxAge=-1 not 0 (RESEARCH Pattern 3 / D-06) - 10 tests: 7 real-DB (skip without TEST_DATABASE_URL) + 3 cookie unit tests --- backend/internal/auth/cookie.go | 50 ++++ backend/internal/auth/session.go | 138 +++++++++ backend/internal/auth/session_test.go | 391 ++++++++++++++++++++++++++ 3 files changed, 579 insertions(+) create mode 100644 backend/internal/auth/cookie.go create mode 100644 backend/internal/auth/session.go create mode 100644 backend/internal/auth/session_test.go diff --git a/backend/internal/auth/cookie.go b/backend/internal/auth/cookie.go new file mode 100644 index 0000000..5d2451d --- /dev/null +++ b/backend/internal/auth/cookie.go @@ -0,0 +1,50 @@ +package auth + +import ( + "net/http" + "time" +) + +// SetSessionCookie writes an HTTP-only session cookie to the response. +// The Secure attribute is gated by the secure parameter — callers pass +// (ENV != "dev") so plain localhost works in development (D-12). +// +// Cookie attributes per D-12: +// - HttpOnly: true (no JS access) +// - Secure: env-gated +// - SameSite: Lax (blocks cross-site POST; allows top-level GET navigation) +// - Path: / +// - MaxAge: mirrors the session TTL (re-issued on extension so the browser +// cookie lifetime stays in sync with the DB row) +func SetSessionCookie(w http.ResponseWriter, value string, expiresAt time.Time, secure bool) { + maxAge := int(time.Until(expiresAt).Seconds()) + if maxAge < 0 { + maxAge = 0 + } + http.SetCookie(w, &http.Cookie{ + Name: SessionCookieName, + Value: value, + Path: "/", + Expires: expiresAt, + MaxAge: maxAge, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +// ClearSessionCookie instructs the browser to delete the session cookie. +// MaxAge=-1 in Go's http.Cookie emits "Max-Age=0" in the Set-Cookie header, +// which browsers interpret as "delete immediately" (D-06, RESEARCH Pattern 3). +func ClearSessionCookie(w http.ResponseWriter, secure bool) { + http.SetCookie(w, &http.Cookie{ + Name: SessionCookieName, + Value: "", + Path: "/", + Expires: time.Unix(0, 0), + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} diff --git a/backend/internal/auth/session.go b/backend/internal/auth/session.go new file mode 100644 index 0000000..77312e0 --- /dev/null +++ b/backend/internal/auth/session.go @@ -0,0 +1,138 @@ +package auth + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "time" + + "backend/internal/db/sqlc" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" +) + +// Store manages session lifecycle: create, lookup, delete, rotate, and +// lazy-extend. It uses the sqlc-generated Queries for all DB operations. +// +// The raw 32-byte token is held only in memory (as the cookie value); +// the DB row stores hex(sha256(token)) so a DB-read leak does not expose +// live sessions (D-05). +type Store struct { + q *sqlc.Queries + now func() time.Time // injectable for testing (D-09 MaybeExtend) +} + +// NewStore returns a Store backed by q. The store's clock defaults to +// time.Now; tests may override store.now after construction. +func NewStore(q *sqlc.Queries) *Store { + return &Store{q: q, now: time.Now} +} + +// Create generates a fresh 32-byte crypto/rand token, stores its SHA-256 hash +// in the sessions table, and returns the base64url-encoded cookie value plus +// the session expiry time. +// +// This is the only place tokens are generated. The cookie value is the raw +// token (base64url); the DB stores hex(sha256(rawToken)) only (D-05). +func (s *Store) Create(ctx context.Context, userID uuid.UUID) (cookieValue string, expiresAt time.Time, err error) { + raw := make([]byte, 32) + if _, err = rand.Read(raw); err != nil { + return "", time.Time{}, fmt.Errorf("auth: generate session token: %w", err) + } + + cookieValue = base64.RawURLEncoding.EncodeToString(raw) + sum := sha256.Sum256(raw) // D-05: store hash, never raw token + id := hex.EncodeToString(sum[:]) + expiresAt = s.now().Add(SessionTTL) + + pgExpires := pgtype.Timestamptz{Time: expiresAt, Valid: true} + if err = s.q.InsertSession(ctx, sqlc.InsertSessionParams{ + ID: id, + UserID: userID, + ExpiresAt: pgExpires, + }); err != nil { + return "", time.Time{}, fmt.Errorf("auth: insert session: %w", err) + } + return cookieValue, expiresAt, nil +} + +// Lookup decodes the cookie value, hashes it, and retrieves the matching +// live session + user from the DB. Returns ErrSessionNotFound if the cookie +// is malformed, the session does not exist, or it has expired (D-07). +func (s *Store) Lookup(ctx context.Context, cookieValue string) (*Session, *User, error) { + raw, err := base64.RawURLEncoding.DecodeString(cookieValue) + if err != nil || len(raw) != 32 { + return nil, nil, ErrSessionNotFound + } + + sum := sha256.Sum256(raw) // D-05: hash on every lookup, never store raw token + id := hex.EncodeToString(sum[:]) + row, err := s.q.GetSessionWithUser(ctx, id) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, nil, ErrSessionNotFound + } + return nil, nil, fmt.Errorf("auth: lookup session: %w", err) + } + + sess := &Session{ + ID: row.ID, + UserID: row.UserID, + CreatedAt: row.CreatedAt.Time, + ExpiresAt: row.ExpiresAt.Time, + } + user := &User{ + ID: row.UID, + Email: row.Email, + PasswordHash: row.PasswordHash, + CreatedAt: row.UCreatedAt.Time, + UpdatedAt: row.UUpdatedAt.Time, + } + return sess, user, nil +} + +// Delete hard-deletes the session row by its hashed ID (D-06). +func (s *Store) Delete(ctx context.Context, id string) error { + if err := s.q.DeleteSession(ctx, id); err != nil { + return fmt.Errorf("auth: delete session: %w", err) + } + return nil +} + +// Rotate deletes the old session row (best-effort) and creates a new one. +// This mitigates session fixation on every login and signup (D-10, T-2-04). +// If oldID is empty the delete step is skipped. +func (s *Store) Rotate(ctx context.Context, oldID string, userID uuid.UUID) (string, time.Time, error) { + if oldID != "" { + // Best-effort: ignore delete error; even if it fails the new session is safe. + _ = s.q.DeleteSession(ctx, oldID) + } + return s.Create(ctx, userID) +} + +// MaybeExtend updates expires_at only when the remaining session lifetime drops +// below SessionExtendThreshold (~7 days). This provides a sliding-window TTL +// with at most one DB write per ~23 days (D-09). +func (s *Store) MaybeExtend(ctx context.Context, id string, expiresAt time.Time) error { + remaining := expiresAt.Sub(s.now()) + if remaining >= SessionExtendThreshold { + // Plenty of time left — no update needed. + return nil + } + + newExpiry := s.now().Add(SessionTTL) + if err := s.q.ExtendSession(ctx, sqlc.ExtendSessionParams{ + ID: id, + ExpiresAt: pgtype.Timestamptz{Time: newExpiry, Valid: true}, + }); err != nil { + return fmt.Errorf("auth: extend session: %w", err) + } + return nil +} + diff --git a/backend/internal/auth/session_test.go b/backend/internal/auth/session_test.go new file mode 100644 index 0000000..a04d16b --- /dev/null +++ b/backend/internal/auth/session_test.go @@ -0,0 +1,391 @@ +package auth + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "backend/internal/db/sqlc" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5/pgtype" +) + +// ---------- DB tests (skip when TEST_DATABASE_URL unset) ---------- + +func TestSession_StoresHashedID(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + q := sqlc.New(pool) + store := NewStore(q) + + ctx := context.Background() + userID := mustInsertUser(t, pool) + + cookieValue, _, err := store.Create(ctx, userID) + if err != nil { + t.Fatalf("Create: %v", err) + } + + // The cookie value is base64url. Decode to raw bytes. + raw, err := base64.RawURLEncoding.DecodeString(cookieValue) + if err != nil { + t.Fatalf("base64url decode cookie value: %v", err) + } + if len(raw) != 32 { + t.Fatalf("expected 32 raw bytes, got %d", len(raw)) + } + + // Compute expected DB id = hex(sha256(raw)). + sum := sha256.Sum256(raw) + expectedID := hex.EncodeToString(sum[:]) + + // Verify the row in the DB has that exact id (Pitfall 6 guard). + var gotID string + row := pool.QueryRow(ctx, "SELECT id FROM sessions WHERE id = $1", expectedID) + if err := row.Scan(&gotID); err != nil { + t.Fatalf("SELECT session id: %v (expected hashed id %q)", err, expectedID) + } + if gotID != expectedID { + t.Errorf("DB id %q != expected %q", gotID, expectedID) + } +} + +func TestSession_LookupRoundtrip(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + q := sqlc.New(pool) + store := NewStore(q) + ctx := context.Background() + + userID := mustInsertUser(t, pool) + cookieValue, _, err := store.Create(ctx, userID) + if err != nil { + t.Fatalf("Create: %v", err) + } + + sess, user, err := store.Lookup(ctx, cookieValue) + if err != nil { + t.Fatalf("Lookup: %v", err) + } + if sess == nil || user == nil { + t.Fatal("Lookup returned nil session or user") + } + if user.ID != userID { + t.Errorf("user.ID %v != expected %v", user.ID, userID) + } + if sess.ID == "" { + t.Error("session.ID is empty") + } +} + +func TestSession_LookupRejectsExpired(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + q := sqlc.New(pool) + store := NewStore(q) + ctx := context.Background() + + userID := mustInsertUser(t, pool) + + // Insert an expired session directly. + expiredID := "deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef" + expiredAt := time.Now().Add(-1 * time.Hour) + _, err := pool.Exec(ctx, + "INSERT INTO sessions (id, user_id, expires_at) VALUES ($1, $2, $3)", + expiredID, userID, expiredAt, + ) + if err != nil { + t.Fatalf("INSERT expired session: %v", err) + } + + // Verify that GetSessionWithUser rejects the expired row (D-07 guard). + _, sqlcErr := q.GetSessionWithUser(ctx, expiredID) + if sqlcErr == nil { + t.Fatal("expected error for expired session, got nil") + } + if !strings.Contains(sqlcErr.Error(), "no rows") { + t.Errorf("unexpected error type for expired session: %v", sqlcErr) + } + + // Also verify our Store.Lookup returns ErrSessionNotFound for an invalid cookie + // (any garbage value that doesn't correspond to a live session). + _, _, err = store.Lookup(ctx, "garbage-not-base64url-valid-32bytes") + if err != ErrSessionNotFound { + t.Errorf("Lookup with garbage cookie: expected ErrSessionNotFound, got %v", err) + } +} + +func TestSession_RotateDeletesOld(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + q := sqlc.New(pool) + store := NewStore(q) + ctx := context.Background() + + userID := mustInsertUser(t, pool) + + // Create session A. + cookieA, _, err := store.Create(ctx, userID) + if err != nil { + t.Fatalf("Create A: %v", err) + } + + rawA, err := base64.RawURLEncoding.DecodeString(cookieA) + if err != nil { + t.Fatalf("decode cookieA: %v", err) + } + sumA := sha256.Sum256(rawA) + idA := hex.EncodeToString(sumA[:]) + + // Rotate using idA. + cookieB, _, err := store.Rotate(ctx, idA, userID) + if err != nil { + t.Fatalf("Rotate: %v", err) + } + if cookieB == cookieA { + t.Error("Rotate returned same cookie value as original (Pitfall 5)") + } + + // Old row must be gone (Pitfall 5 guard). + var count int + row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE id = $1", idA) + if err := row.Scan(&count); err != nil { + t.Fatalf("COUNT old session: %v", err) + } + if count != 0 { + t.Errorf("old session row still exists after Rotate (count=%d), session fixation risk", count) + } + + // New row must exist. + rawB, err := base64.RawURLEncoding.DecodeString(cookieB) + if err != nil { + t.Fatalf("decode cookieB: %v", err) + } + sumB := sha256.Sum256(rawB) + idB := hex.EncodeToString(sumB[:]) + row = pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE id = $1", idB) + if err := row.Scan(&count); err != nil { + t.Fatalf("COUNT new session: %v", err) + } + if count != 1 { + t.Errorf("new session row not found after Rotate (count=%d)", count) + } +} + +func TestSession_DeleteRemovesRow(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + q := sqlc.New(pool) + store := NewStore(q) + ctx := context.Background() + + userID := mustInsertUser(t, pool) + cookieValue, _, err := store.Create(ctx, userID) + if err != nil { + t.Fatalf("Create: %v", err) + } + + rawBytes, _ := base64.RawURLEncoding.DecodeString(cookieValue) + sum := sha256.Sum256(rawBytes) + id := hex.EncodeToString(sum[:]) + + if err := store.Delete(ctx, id); err != nil { + t.Fatalf("Delete: %v", err) + } + + _, _, err = store.Lookup(ctx, cookieValue) + if err != ErrSessionNotFound { + t.Errorf("expected ErrSessionNotFound after Delete, got: %v", err) + } +} + +func TestSession_MaybeExtend_NoOp(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + q := sqlc.New(pool) + store := NewStore(q) + + fixedNow := time.Now().UTC() + store.now = func() time.Time { return fixedNow } + + ctx := context.Background() + userID := mustInsertUser(t, pool) + + // Insert a session with expires_at = now + 29 days (above 7-day threshold). + id := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + expiresAt := fixedNow.Add(29 * 24 * time.Hour) + _, err := pool.Exec(ctx, + "INSERT INTO sessions (id, user_id, expires_at) VALUES ($1, $2, $3)", + id, userID, expiresAt, + ) + if err != nil { + t.Fatalf("INSERT session: %v", err) + } + + if err := store.MaybeExtend(ctx, id, expiresAt); err != nil { + t.Fatalf("MaybeExtend: %v", err) + } + + // expires_at must NOT change (within 1 second tolerance). + var gotExp pgtype.Timestamptz + row := pool.QueryRow(ctx, "SELECT expires_at FROM sessions WHERE id = $1", id) + if err := row.Scan(&gotExp); err != nil { + t.Fatalf("SELECT expires_at: %v", err) + } + diff := gotExp.Time.Sub(expiresAt) + if diff < 0 { + diff = -diff + } + if diff > time.Second { + t.Errorf("MaybeExtend changed expires_at when remaining > threshold (diff=%v)", diff) + } +} + +func TestSession_MaybeExtend_Extends(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + q := sqlc.New(pool) + store := NewStore(q) + + fixedNow := time.Now().UTC() + store.now = func() time.Time { return fixedNow } + + ctx := context.Background() + userID := mustInsertUser(t, pool) + + // Insert a session with expires_at = now + 1 day (below 7-day threshold). + id := "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb" + expiresAt := fixedNow.Add(1 * 24 * time.Hour) + _, err := pool.Exec(ctx, + "INSERT INTO sessions (id, user_id, expires_at) VALUES ($1, $2, $3)", + id, userID, expiresAt, + ) + if err != nil { + t.Fatalf("INSERT session: %v", err) + } + + if err := store.MaybeExtend(ctx, id, expiresAt); err != nil { + t.Fatalf("MaybeExtend: %v", err) + } + + // expires_at must now be ~now + 30 days. + var gotExp pgtype.Timestamptz + row := pool.QueryRow(ctx, "SELECT expires_at FROM sessions WHERE id = $1", id) + if err := row.Scan(&gotExp); err != nil { + t.Fatalf("SELECT expires_at: %v", err) + } + if !gotExp.Time.After(expiresAt) { + t.Errorf("MaybeExtend did not extend expires_at (still %v)", gotExp.Time) + } + expected := fixedNow.Add(SessionTTL) + diff := gotExp.Time.Sub(expected) + if diff < 0 { + diff = -diff + } + if diff > 5*time.Second { + t.Errorf("extended expires_at %v not within 5s of expected %v (diff=%v)", gotExp.Time, expected, diff) + } +} + +// ---------- Cookie tests (no DB) ---------- + +func TestCookie_SetAttributes(t *testing.T) { + w := httptest.NewRecorder() + expiresAt := time.Now().Add(30 * 24 * time.Hour) + SetSessionCookie(w, "my-token-value", expiresAt, true) + + resp := w.Result() + cookies := resp.Cookies() + if len(cookies) != 1 { + t.Fatalf("expected 1 cookie, got %d", len(cookies)) + } + c := cookies[0] + if c.Name != SessionCookieName { + t.Errorf("cookie name %q, want %q", c.Name, SessionCookieName) + } + if c.Value != "my-token-value" { + t.Errorf("cookie value %q, want %q", c.Value, "my-token-value") + } + if !c.HttpOnly { + t.Error("cookie must be HttpOnly") + } + if !c.Secure { + t.Error("cookie must be Secure when secure=true") + } + if c.SameSite != http.SameSiteLaxMode { + t.Errorf("cookie SameSite %v, want %v", c.SameSite, http.SameSiteLaxMode) + } + if c.Path != "/" { + t.Errorf("cookie Path %q, want %q", c.Path, "/") + } + expected := int(30 * 24 * time.Hour / time.Second) + diff := c.MaxAge - expected + if diff < 0 { + diff = -diff + } + if diff > 5 { + t.Errorf("MaxAge %d not within 5s of expected %d", c.MaxAge, expected) + } +} + +func TestCookie_ClearUsesMaxAgeMinus1(t *testing.T) { + w := httptest.NewRecorder() + ClearSessionCookie(w, false) + + resp := w.Result() + cookies := resp.Cookies() + if len(cookies) != 1 { + t.Fatalf("expected 1 cookie, got %d", len(cookies)) + } + c := cookies[0] + if c.MaxAge != -1 { + t.Errorf("ClearSessionCookie MaxAge %d, want -1 (NOT 0 per RESEARCH Pattern 3)", c.MaxAge) + } +} + +func TestCookie_SecureGatedByEnv(t *testing.T) { + w := httptest.NewRecorder() + SetSessionCookie(w, "tok", time.Now().Add(time.Hour), false) // secure=false + + resp := w.Result() + cookies := resp.Cookies() + if len(cookies) != 1 { + t.Fatalf("expected 1 cookie, got %d", len(cookies)) + } + if cookies[0].Secure { + t.Error("cookie must NOT have Secure attribute when secure=false") + } +} + +// ---------- helpers ---------- + +// mustInsertUser inserts a test user and returns its ID. +func mustInsertUser(t *testing.T, pool *pgxpool.Pool) uuid.UUID { + t.Helper() + ctx := context.Background() + var id uuid.UUID + row := pool.QueryRow(ctx, + "INSERT INTO users (email, password_hash) VALUES ($1, $2) RETURNING id", + "test-"+uuid.NewString()+"@example.com", + "$argon2id$v=19$m=8192,t=1,p=2$AAAAAAAAAAAAAAAAAAAAAA$AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", + ) + if err := row.Scan(&id); err != nil { + t.Fatalf("mustInsertUser: %v", err) + } + return id +}