feat(02-03): session store + cookie helpers (real-DB TDD)

- Store.Create: 32-byte crypto/rand token, SHA-256 hex as DB id (D-05)
- Store.Lookup: hashes cookie, maps pgx.ErrNoRows to ErrSessionNotFound (D-07)
- Store.Delete: hard-deletes session row (D-06)
- Store.Rotate: deletes old row before creating new one (D-10, T-2-04)
- Store.MaybeExtend: extends only when remaining < 7 days (D-09)
- SetSessionCookie: HttpOnly + Secure (env-gated) + SameSite=Lax (D-12)
- ClearSessionCookie: MaxAge=-1 not 0 (RESEARCH Pattern 3 / D-06)
- 10 tests: 7 real-DB (skip without TEST_DATABASE_URL) + 3 cookie unit tests
This commit is contained in:
Arthur Belleville 2026-05-14 22:08:04 +02:00
parent 648ce143a2
commit fd2301decf
No known key found for this signature in database
3 changed files with 579 additions and 0 deletions

View file

@ -0,0 +1,50 @@
package auth
import (
"net/http"
"time"
)
// SetSessionCookie writes an HTTP-only session cookie to the response.
// The Secure attribute is gated by the secure parameter — callers pass
// (ENV != "dev") so plain localhost works in development (D-12).
//
// Cookie attributes per D-12:
// - HttpOnly: true (no JS access)
// - Secure: env-gated
// - SameSite: Lax (blocks cross-site POST; allows top-level GET navigation)
// - Path: /
// - MaxAge: mirrors the session TTL (re-issued on extension so the browser
// cookie lifetime stays in sync with the DB row)
func SetSessionCookie(w http.ResponseWriter, value string, expiresAt time.Time, secure bool) {
maxAge := int(time.Until(expiresAt).Seconds())
if maxAge < 0 {
maxAge = 0
}
http.SetCookie(w, &http.Cookie{
Name: SessionCookieName,
Value: value,
Path: "/",
Expires: expiresAt,
MaxAge: maxAge,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
// ClearSessionCookie instructs the browser to delete the session cookie.
// MaxAge=-1 in Go's http.Cookie emits "Max-Age=0" in the Set-Cookie header,
// which browsers interpret as "delete immediately" (D-06, RESEARCH Pattern 3).
func ClearSessionCookie(w http.ResponseWriter, secure bool) {
http.SetCookie(w, &http.Cookie{
Name: SessionCookieName,
Value: "",
Path: "/",
Expires: time.Unix(0, 0),
MaxAge: -1,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}

View file

@ -0,0 +1,138 @@
package auth
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"time"
"backend/internal/db/sqlc"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
)
// Store manages session lifecycle: create, lookup, delete, rotate, and
// lazy-extend. It uses the sqlc-generated Queries for all DB operations.
//
// The raw 32-byte token is held only in memory (as the cookie value);
// the DB row stores hex(sha256(token)) so a DB-read leak does not expose
// live sessions (D-05).
type Store struct {
q *sqlc.Queries
now func() time.Time // injectable for testing (D-09 MaybeExtend)
}
// NewStore returns a Store backed by q. The store's clock defaults to
// time.Now; tests may override store.now after construction.
func NewStore(q *sqlc.Queries) *Store {
return &Store{q: q, now: time.Now}
}
// Create generates a fresh 32-byte crypto/rand token, stores its SHA-256 hash
// in the sessions table, and returns the base64url-encoded cookie value plus
// the session expiry time.
//
// This is the only place tokens are generated. The cookie value is the raw
// token (base64url); the DB stores hex(sha256(rawToken)) only (D-05).
func (s *Store) Create(ctx context.Context, userID uuid.UUID) (cookieValue string, expiresAt time.Time, err error) {
raw := make([]byte, 32)
if _, err = rand.Read(raw); err != nil {
return "", time.Time{}, fmt.Errorf("auth: generate session token: %w", err)
}
cookieValue = base64.RawURLEncoding.EncodeToString(raw)
sum := sha256.Sum256(raw) // D-05: store hash, never raw token
id := hex.EncodeToString(sum[:])
expiresAt = s.now().Add(SessionTTL)
pgExpires := pgtype.Timestamptz{Time: expiresAt, Valid: true}
if err = s.q.InsertSession(ctx, sqlc.InsertSessionParams{
ID: id,
UserID: userID,
ExpiresAt: pgExpires,
}); err != nil {
return "", time.Time{}, fmt.Errorf("auth: insert session: %w", err)
}
return cookieValue, expiresAt, nil
}
// Lookup decodes the cookie value, hashes it, and retrieves the matching
// live session + user from the DB. Returns ErrSessionNotFound if the cookie
// is malformed, the session does not exist, or it has expired (D-07).
func (s *Store) Lookup(ctx context.Context, cookieValue string) (*Session, *User, error) {
raw, err := base64.RawURLEncoding.DecodeString(cookieValue)
if err != nil || len(raw) != 32 {
return nil, nil, ErrSessionNotFound
}
sum := sha256.Sum256(raw) // D-05: hash on every lookup, never store raw token
id := hex.EncodeToString(sum[:])
row, err := s.q.GetSessionWithUser(ctx, id)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil, ErrSessionNotFound
}
return nil, nil, fmt.Errorf("auth: lookup session: %w", err)
}
sess := &Session{
ID: row.ID,
UserID: row.UserID,
CreatedAt: row.CreatedAt.Time,
ExpiresAt: row.ExpiresAt.Time,
}
user := &User{
ID: row.UID,
Email: row.Email,
PasswordHash: row.PasswordHash,
CreatedAt: row.UCreatedAt.Time,
UpdatedAt: row.UUpdatedAt.Time,
}
return sess, user, nil
}
// Delete hard-deletes the session row by its hashed ID (D-06).
func (s *Store) Delete(ctx context.Context, id string) error {
if err := s.q.DeleteSession(ctx, id); err != nil {
return fmt.Errorf("auth: delete session: %w", err)
}
return nil
}
// Rotate deletes the old session row (best-effort) and creates a new one.
// This mitigates session fixation on every login and signup (D-10, T-2-04).
// If oldID is empty the delete step is skipped.
func (s *Store) Rotate(ctx context.Context, oldID string, userID uuid.UUID) (string, time.Time, error) {
if oldID != "" {
// Best-effort: ignore delete error; even if it fails the new session is safe.
_ = s.q.DeleteSession(ctx, oldID)
}
return s.Create(ctx, userID)
}
// MaybeExtend updates expires_at only when the remaining session lifetime drops
// below SessionExtendThreshold (~7 days). This provides a sliding-window TTL
// with at most one DB write per ~23 days (D-09).
func (s *Store) MaybeExtend(ctx context.Context, id string, expiresAt time.Time) error {
remaining := expiresAt.Sub(s.now())
if remaining >= SessionExtendThreshold {
// Plenty of time left — no update needed.
return nil
}
newExpiry := s.now().Add(SessionTTL)
if err := s.q.ExtendSession(ctx, sqlc.ExtendSessionParams{
ID: id,
ExpiresAt: pgtype.Timestamptz{Time: newExpiry, Valid: true},
}); err != nil {
return fmt.Errorf("auth: extend session: %w", err)
}
return nil
}

View file

@ -0,0 +1,391 @@
package auth
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"backend/internal/db/sqlc"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/pgtype"
)
// ---------- DB tests (skip when TEST_DATABASE_URL unset) ----------
func TestSession_StoresHashedID(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := NewStore(q)
ctx := context.Background()
userID := mustInsertUser(t, pool)
cookieValue, _, err := store.Create(ctx, userID)
if err != nil {
t.Fatalf("Create: %v", err)
}
// The cookie value is base64url. Decode to raw bytes.
raw, err := base64.RawURLEncoding.DecodeString(cookieValue)
if err != nil {
t.Fatalf("base64url decode cookie value: %v", err)
}
if len(raw) != 32 {
t.Fatalf("expected 32 raw bytes, got %d", len(raw))
}
// Compute expected DB id = hex(sha256(raw)).
sum := sha256.Sum256(raw)
expectedID := hex.EncodeToString(sum[:])
// Verify the row in the DB has that exact id (Pitfall 6 guard).
var gotID string
row := pool.QueryRow(ctx, "SELECT id FROM sessions WHERE id = $1", expectedID)
if err := row.Scan(&gotID); err != nil {
t.Fatalf("SELECT session id: %v (expected hashed id %q)", err, expectedID)
}
if gotID != expectedID {
t.Errorf("DB id %q != expected %q", gotID, expectedID)
}
}
func TestSession_LookupRoundtrip(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := NewStore(q)
ctx := context.Background()
userID := mustInsertUser(t, pool)
cookieValue, _, err := store.Create(ctx, userID)
if err != nil {
t.Fatalf("Create: %v", err)
}
sess, user, err := store.Lookup(ctx, cookieValue)
if err != nil {
t.Fatalf("Lookup: %v", err)
}
if sess == nil || user == nil {
t.Fatal("Lookup returned nil session or user")
}
if user.ID != userID {
t.Errorf("user.ID %v != expected %v", user.ID, userID)
}
if sess.ID == "" {
t.Error("session.ID is empty")
}
}
func TestSession_LookupRejectsExpired(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := NewStore(q)
ctx := context.Background()
userID := mustInsertUser(t, pool)
// Insert an expired session directly.
expiredID := "deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef"
expiredAt := time.Now().Add(-1 * time.Hour)
_, err := pool.Exec(ctx,
"INSERT INTO sessions (id, user_id, expires_at) VALUES ($1, $2, $3)",
expiredID, userID, expiredAt,
)
if err != nil {
t.Fatalf("INSERT expired session: %v", err)
}
// Verify that GetSessionWithUser rejects the expired row (D-07 guard).
_, sqlcErr := q.GetSessionWithUser(ctx, expiredID)
if sqlcErr == nil {
t.Fatal("expected error for expired session, got nil")
}
if !strings.Contains(sqlcErr.Error(), "no rows") {
t.Errorf("unexpected error type for expired session: %v", sqlcErr)
}
// Also verify our Store.Lookup returns ErrSessionNotFound for an invalid cookie
// (any garbage value that doesn't correspond to a live session).
_, _, err = store.Lookup(ctx, "garbage-not-base64url-valid-32bytes")
if err != ErrSessionNotFound {
t.Errorf("Lookup with garbage cookie: expected ErrSessionNotFound, got %v", err)
}
}
func TestSession_RotateDeletesOld(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := NewStore(q)
ctx := context.Background()
userID := mustInsertUser(t, pool)
// Create session A.
cookieA, _, err := store.Create(ctx, userID)
if err != nil {
t.Fatalf("Create A: %v", err)
}
rawA, err := base64.RawURLEncoding.DecodeString(cookieA)
if err != nil {
t.Fatalf("decode cookieA: %v", err)
}
sumA := sha256.Sum256(rawA)
idA := hex.EncodeToString(sumA[:])
// Rotate using idA.
cookieB, _, err := store.Rotate(ctx, idA, userID)
if err != nil {
t.Fatalf("Rotate: %v", err)
}
if cookieB == cookieA {
t.Error("Rotate returned same cookie value as original (Pitfall 5)")
}
// Old row must be gone (Pitfall 5 guard).
var count int
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE id = $1", idA)
if err := row.Scan(&count); err != nil {
t.Fatalf("COUNT old session: %v", err)
}
if count != 0 {
t.Errorf("old session row still exists after Rotate (count=%d), session fixation risk", count)
}
// New row must exist.
rawB, err := base64.RawURLEncoding.DecodeString(cookieB)
if err != nil {
t.Fatalf("decode cookieB: %v", err)
}
sumB := sha256.Sum256(rawB)
idB := hex.EncodeToString(sumB[:])
row = pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE id = $1", idB)
if err := row.Scan(&count); err != nil {
t.Fatalf("COUNT new session: %v", err)
}
if count != 1 {
t.Errorf("new session row not found after Rotate (count=%d)", count)
}
}
func TestSession_DeleteRemovesRow(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := NewStore(q)
ctx := context.Background()
userID := mustInsertUser(t, pool)
cookieValue, _, err := store.Create(ctx, userID)
if err != nil {
t.Fatalf("Create: %v", err)
}
rawBytes, _ := base64.RawURLEncoding.DecodeString(cookieValue)
sum := sha256.Sum256(rawBytes)
id := hex.EncodeToString(sum[:])
if err := store.Delete(ctx, id); err != nil {
t.Fatalf("Delete: %v", err)
}
_, _, err = store.Lookup(ctx, cookieValue)
if err != ErrSessionNotFound {
t.Errorf("expected ErrSessionNotFound after Delete, got: %v", err)
}
}
func TestSession_MaybeExtend_NoOp(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := NewStore(q)
fixedNow := time.Now().UTC()
store.now = func() time.Time { return fixedNow }
ctx := context.Background()
userID := mustInsertUser(t, pool)
// Insert a session with expires_at = now + 29 days (above 7-day threshold).
id := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
expiresAt := fixedNow.Add(29 * 24 * time.Hour)
_, err := pool.Exec(ctx,
"INSERT INTO sessions (id, user_id, expires_at) VALUES ($1, $2, $3)",
id, userID, expiresAt,
)
if err != nil {
t.Fatalf("INSERT session: %v", err)
}
if err := store.MaybeExtend(ctx, id, expiresAt); err != nil {
t.Fatalf("MaybeExtend: %v", err)
}
// expires_at must NOT change (within 1 second tolerance).
var gotExp pgtype.Timestamptz
row := pool.QueryRow(ctx, "SELECT expires_at FROM sessions WHERE id = $1", id)
if err := row.Scan(&gotExp); err != nil {
t.Fatalf("SELECT expires_at: %v", err)
}
diff := gotExp.Time.Sub(expiresAt)
if diff < 0 {
diff = -diff
}
if diff > time.Second {
t.Errorf("MaybeExtend changed expires_at when remaining > threshold (diff=%v)", diff)
}
}
func TestSession_MaybeExtend_Extends(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := NewStore(q)
fixedNow := time.Now().UTC()
store.now = func() time.Time { return fixedNow }
ctx := context.Background()
userID := mustInsertUser(t, pool)
// Insert a session with expires_at = now + 1 day (below 7-day threshold).
id := "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"
expiresAt := fixedNow.Add(1 * 24 * time.Hour)
_, err := pool.Exec(ctx,
"INSERT INTO sessions (id, user_id, expires_at) VALUES ($1, $2, $3)",
id, userID, expiresAt,
)
if err != nil {
t.Fatalf("INSERT session: %v", err)
}
if err := store.MaybeExtend(ctx, id, expiresAt); err != nil {
t.Fatalf("MaybeExtend: %v", err)
}
// expires_at must now be ~now + 30 days.
var gotExp pgtype.Timestamptz
row := pool.QueryRow(ctx, "SELECT expires_at FROM sessions WHERE id = $1", id)
if err := row.Scan(&gotExp); err != nil {
t.Fatalf("SELECT expires_at: %v", err)
}
if !gotExp.Time.After(expiresAt) {
t.Errorf("MaybeExtend did not extend expires_at (still %v)", gotExp.Time)
}
expected := fixedNow.Add(SessionTTL)
diff := gotExp.Time.Sub(expected)
if diff < 0 {
diff = -diff
}
if diff > 5*time.Second {
t.Errorf("extended expires_at %v not within 5s of expected %v (diff=%v)", gotExp.Time, expected, diff)
}
}
// ---------- Cookie tests (no DB) ----------
func TestCookie_SetAttributes(t *testing.T) {
w := httptest.NewRecorder()
expiresAt := time.Now().Add(30 * 24 * time.Hour)
SetSessionCookie(w, "my-token-value", expiresAt, true)
resp := w.Result()
cookies := resp.Cookies()
if len(cookies) != 1 {
t.Fatalf("expected 1 cookie, got %d", len(cookies))
}
c := cookies[0]
if c.Name != SessionCookieName {
t.Errorf("cookie name %q, want %q", c.Name, SessionCookieName)
}
if c.Value != "my-token-value" {
t.Errorf("cookie value %q, want %q", c.Value, "my-token-value")
}
if !c.HttpOnly {
t.Error("cookie must be HttpOnly")
}
if !c.Secure {
t.Error("cookie must be Secure when secure=true")
}
if c.SameSite != http.SameSiteLaxMode {
t.Errorf("cookie SameSite %v, want %v", c.SameSite, http.SameSiteLaxMode)
}
if c.Path != "/" {
t.Errorf("cookie Path %q, want %q", c.Path, "/")
}
expected := int(30 * 24 * time.Hour / time.Second)
diff := c.MaxAge - expected
if diff < 0 {
diff = -diff
}
if diff > 5 {
t.Errorf("MaxAge %d not within 5s of expected %d", c.MaxAge, expected)
}
}
func TestCookie_ClearUsesMaxAgeMinus1(t *testing.T) {
w := httptest.NewRecorder()
ClearSessionCookie(w, false)
resp := w.Result()
cookies := resp.Cookies()
if len(cookies) != 1 {
t.Fatalf("expected 1 cookie, got %d", len(cookies))
}
c := cookies[0]
if c.MaxAge != -1 {
t.Errorf("ClearSessionCookie MaxAge %d, want -1 (NOT 0 per RESEARCH Pattern 3)", c.MaxAge)
}
}
func TestCookie_SecureGatedByEnv(t *testing.T) {
w := httptest.NewRecorder()
SetSessionCookie(w, "tok", time.Now().Add(time.Hour), false) // secure=false
resp := w.Result()
cookies := resp.Cookies()
if len(cookies) != 1 {
t.Fatalf("expected 1 cookie, got %d", len(cookies))
}
if cookies[0].Secure {
t.Error("cookie must NOT have Secure attribute when secure=false")
}
}
// ---------- helpers ----------
// mustInsertUser inserts a test user and returns its ID.
func mustInsertUser(t *testing.T, pool *pgxpool.Pool) uuid.UUID {
t.Helper()
ctx := context.Background()
var id uuid.UUID
row := pool.QueryRow(ctx,
"INSERT INTO users (email, password_hash) VALUES ($1, $2) RETURNING id",
"test-"+uuid.NewString()+"@example.com",
"$argon2id$v=19$m=8192,t=1,p=2$AAAAAAAAAAAAAAAAAAAAAA$AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA",
)
if err := row.Scan(&id); err != nil {
t.Fatalf("mustInsertUser: %v", err)
}
return id
}