package auth import ( "context" "crypto/sha256" "encoding/base64" "encoding/hex" "net/http" "net/http/httptest" "strings" "testing" "time" "backend/internal/db/sqlc" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgtype" ) // ---------- DB tests (skip when TEST_DATABASE_URL unset) ---------- func TestSession_StoresHashedID(t *testing.T) { pool, cleanup := setupTestDB(t) defer cleanup() q := sqlc.New(pool) store := NewStore(q) ctx := context.Background() userID := mustInsertUser(t, pool) cookieValue, _, err := store.Create(ctx, userID) if err != nil { t.Fatalf("Create: %v", err) } // The cookie value is base64url. Decode to raw bytes. raw, err := base64.RawURLEncoding.DecodeString(cookieValue) if err != nil { t.Fatalf("base64url decode cookie value: %v", err) } if len(raw) != 32 { t.Fatalf("expected 32 raw bytes, got %d", len(raw)) } // Compute expected DB id = hex(sha256(raw)). sum := sha256.Sum256(raw) expectedID := hex.EncodeToString(sum[:]) // Verify the row in the DB has that exact id (Pitfall 6 guard). var gotID string row := pool.QueryRow(ctx, "SELECT id FROM sessions WHERE id = $1", expectedID) if err := row.Scan(&gotID); err != nil { t.Fatalf("SELECT session id: %v (expected hashed id %q)", err, expectedID) } if gotID != expectedID { t.Errorf("DB id %q != expected %q", gotID, expectedID) } } func TestSession_LookupRoundtrip(t *testing.T) { pool, cleanup := setupTestDB(t) defer cleanup() q := sqlc.New(pool) store := NewStore(q) ctx := context.Background() userID := mustInsertUser(t, pool) cookieValue, _, err := store.Create(ctx, userID) if err != nil { t.Fatalf("Create: %v", err) } sess, user, err := store.Lookup(ctx, cookieValue) if err != nil { t.Fatalf("Lookup: %v", err) } if sess == nil || user == nil { t.Fatal("Lookup returned nil session or user") } if user.ID != userID { t.Errorf("user.ID %v != expected %v", user.ID, userID) } if sess.ID == "" { t.Error("session.ID is empty") } } func TestSession_LookupRejectsExpired(t *testing.T) { pool, cleanup := setupTestDB(t) defer cleanup() q := sqlc.New(pool) store := NewStore(q) ctx := context.Background() userID := mustInsertUser(t, pool) // Insert an expired session directly. expiredID := "deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef" expiredAt := time.Now().Add(-1 * time.Hour) _, err := pool.Exec(ctx, "INSERT INTO sessions (id, user_id, expires_at) VALUES ($1, $2, $3)", expiredID, userID, expiredAt, ) if err != nil { t.Fatalf("INSERT expired session: %v", err) } // Verify that GetSessionWithUser rejects the expired row (D-07 guard). _, sqlcErr := q.GetSessionWithUser(ctx, expiredID) if sqlcErr == nil { t.Fatal("expected error for expired session, got nil") } if !strings.Contains(sqlcErr.Error(), "no rows") { t.Errorf("unexpected error type for expired session: %v", sqlcErr) } // Also verify our Store.Lookup returns ErrSessionNotFound for an invalid cookie // (any garbage value that doesn't correspond to a live session). _, _, err = store.Lookup(ctx, "garbage-not-base64url-valid-32bytes") if err != ErrSessionNotFound { t.Errorf("Lookup with garbage cookie: expected ErrSessionNotFound, got %v", err) } } func TestSession_RotateDeletesOld(t *testing.T) { pool, cleanup := setupTestDB(t) defer cleanup() q := sqlc.New(pool) store := NewStore(q) ctx := context.Background() userID := mustInsertUser(t, pool) // Create session A. cookieA, _, err := store.Create(ctx, userID) if err != nil { t.Fatalf("Create A: %v", err) } rawA, err := base64.RawURLEncoding.DecodeString(cookieA) if err != nil { t.Fatalf("decode cookieA: %v", err) } sumA := sha256.Sum256(rawA) idA := hex.EncodeToString(sumA[:]) // Rotate using idA. cookieB, _, err := store.Rotate(ctx, idA, userID) if err != nil { t.Fatalf("Rotate: %v", err) } if cookieB == cookieA { t.Error("Rotate returned same cookie value as original (Pitfall 5)") } // Old row must be gone (Pitfall 5 guard). var count int row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE id = $1", idA) if err := row.Scan(&count); err != nil { t.Fatalf("COUNT old session: %v", err) } if count != 0 { t.Errorf("old session row still exists after Rotate (count=%d), session fixation risk", count) } // New row must exist. rawB, err := base64.RawURLEncoding.DecodeString(cookieB) if err != nil { t.Fatalf("decode cookieB: %v", err) } sumB := sha256.Sum256(rawB) idB := hex.EncodeToString(sumB[:]) row = pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE id = $1", idB) if err := row.Scan(&count); err != nil { t.Fatalf("COUNT new session: %v", err) } if count != 1 { t.Errorf("new session row not found after Rotate (count=%d)", count) } } func TestSession_DeleteRemovesRow(t *testing.T) { pool, cleanup := setupTestDB(t) defer cleanup() q := sqlc.New(pool) store := NewStore(q) ctx := context.Background() userID := mustInsertUser(t, pool) cookieValue, _, err := store.Create(ctx, userID) if err != nil { t.Fatalf("Create: %v", err) } rawBytes, _ := base64.RawURLEncoding.DecodeString(cookieValue) sum := sha256.Sum256(rawBytes) id := hex.EncodeToString(sum[:]) if err := store.Delete(ctx, id); err != nil { t.Fatalf("Delete: %v", err) } _, _, err = store.Lookup(ctx, cookieValue) if err != ErrSessionNotFound { t.Errorf("expected ErrSessionNotFound after Delete, got: %v", err) } } func TestSession_MaybeExtend_NoOp(t *testing.T) { pool, cleanup := setupTestDB(t) defer cleanup() q := sqlc.New(pool) store := NewStore(q) fixedNow := time.Now().UTC() store.now = func() time.Time { return fixedNow } ctx := context.Background() userID := mustInsertUser(t, pool) // Insert a session with expires_at = now + 29 days (above 7-day threshold). id := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" expiresAt := fixedNow.Add(29 * 24 * time.Hour) _, err := pool.Exec(ctx, "INSERT INTO sessions (id, user_id, expires_at) VALUES ($1, $2, $3)", id, userID, expiresAt, ) if err != nil { t.Fatalf("INSERT session: %v", err) } if err := store.MaybeExtend(ctx, id, expiresAt); err != nil { t.Fatalf("MaybeExtend: %v", err) } // expires_at must NOT change (within 1 second tolerance). var gotExp pgtype.Timestamptz row := pool.QueryRow(ctx, "SELECT expires_at FROM sessions WHERE id = $1", id) if err := row.Scan(&gotExp); err != nil { t.Fatalf("SELECT expires_at: %v", err) } diff := gotExp.Time.Sub(expiresAt) if diff < 0 { diff = -diff } if diff > time.Second { t.Errorf("MaybeExtend changed expires_at when remaining > threshold (diff=%v)", diff) } } func TestSession_MaybeExtend_Extends(t *testing.T) { pool, cleanup := setupTestDB(t) defer cleanup() q := sqlc.New(pool) store := NewStore(q) fixedNow := time.Now().UTC() store.now = func() time.Time { return fixedNow } ctx := context.Background() userID := mustInsertUser(t, pool) // Insert a session with expires_at = now + 1 day (below 7-day threshold). id := "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb" expiresAt := fixedNow.Add(1 * 24 * time.Hour) _, err := pool.Exec(ctx, "INSERT INTO sessions (id, user_id, expires_at) VALUES ($1, $2, $3)", id, userID, expiresAt, ) if err != nil { t.Fatalf("INSERT session: %v", err) } if err := store.MaybeExtend(ctx, id, expiresAt); err != nil { t.Fatalf("MaybeExtend: %v", err) } // expires_at must now be ~now + 30 days. var gotExp pgtype.Timestamptz row := pool.QueryRow(ctx, "SELECT expires_at FROM sessions WHERE id = $1", id) if err := row.Scan(&gotExp); err != nil { t.Fatalf("SELECT expires_at: %v", err) } if !gotExp.Time.After(expiresAt) { t.Errorf("MaybeExtend did not extend expires_at (still %v)", gotExp.Time) } expected := fixedNow.Add(SessionTTL) diff := gotExp.Time.Sub(expected) if diff < 0 { diff = -diff } if diff > 5*time.Second { t.Errorf("extended expires_at %v not within 5s of expected %v (diff=%v)", gotExp.Time, expected, diff) } } // ---------- Cookie tests (no DB) ---------- func TestCookie_SetAttributes(t *testing.T) { w := httptest.NewRecorder() expiresAt := time.Now().Add(30 * 24 * time.Hour) SetSessionCookie(w, "my-token-value", expiresAt, true) resp := w.Result() cookies := resp.Cookies() if len(cookies) != 1 { t.Fatalf("expected 1 cookie, got %d", len(cookies)) } c := cookies[0] if c.Name != SessionCookieName { t.Errorf("cookie name %q, want %q", c.Name, SessionCookieName) } if c.Value != "my-token-value" { t.Errorf("cookie value %q, want %q", c.Value, "my-token-value") } if !c.HttpOnly { t.Error("cookie must be HttpOnly") } if !c.Secure { t.Error("cookie must be Secure when secure=true") } if c.SameSite != http.SameSiteLaxMode { t.Errorf("cookie SameSite %v, want %v", c.SameSite, http.SameSiteLaxMode) } if c.Path != "/" { t.Errorf("cookie Path %q, want %q", c.Path, "/") } expected := int(30 * 24 * time.Hour / time.Second) diff := c.MaxAge - expected if diff < 0 { diff = -diff } if diff > 5 { t.Errorf("MaxAge %d not within 5s of expected %d", c.MaxAge, expected) } } func TestCookie_ClearUsesMaxAgeMinus1(t *testing.T) { w := httptest.NewRecorder() ClearSessionCookie(w, false) resp := w.Result() cookies := resp.Cookies() if len(cookies) != 1 { t.Fatalf("expected 1 cookie, got %d", len(cookies)) } c := cookies[0] if c.MaxAge != -1 { t.Errorf("ClearSessionCookie MaxAge %d, want -1 (NOT 0 per RESEARCH Pattern 3)", c.MaxAge) } } func TestCookie_SecureGatedByEnv(t *testing.T) { w := httptest.NewRecorder() SetSessionCookie(w, "tok", time.Now().Add(time.Hour), false) // secure=false resp := w.Result() cookies := resp.Cookies() if len(cookies) != 1 { t.Fatalf("expected 1 cookie, got %d", len(cookies)) } if cookies[0].Secure { t.Error("cookie must NOT have Secure attribute when secure=false") } } // ---------- helpers ---------- // mustInsertUser inserts a test user and returns its ID. func mustInsertUser(t *testing.T, pool *pgxpool.Pool) uuid.UUID { t.Helper() ctx := context.Background() var id uuid.UUID row := pool.QueryRow(ctx, "INSERT INTO users (email, password_hash) VALUES ($1, $2) RETURNING id", "test-"+uuid.NewString()+"@example.com", "$argon2id$v=19$m=8192,t=1,p=2$AAAAAAAAAAAAAAAAAAAAAA$AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", ) if err := row.Scan(&id); err != nil { t.Fatalf("mustInsertUser: %v", err) } return id }