xtablo-source/backend/internal/auth/session_test.go
Arthur Belleville fd2301decf
feat(02-03): session store + cookie helpers (real-DB TDD)
- Store.Create: 32-byte crypto/rand token, SHA-256 hex as DB id (D-05)
- Store.Lookup: hashes cookie, maps pgx.ErrNoRows to ErrSessionNotFound (D-07)
- Store.Delete: hard-deletes session row (D-06)
- Store.Rotate: deletes old row before creating new one (D-10, T-2-04)
- Store.MaybeExtend: extends only when remaining < 7 days (D-09)
- SetSessionCookie: HttpOnly + Secure (env-gated) + SameSite=Lax (D-12)
- ClearSessionCookie: MaxAge=-1 not 0 (RESEARCH Pattern 3 / D-06)
- 10 tests: 7 real-DB (skip without TEST_DATABASE_URL) + 3 cookie unit tests
2026-05-14 22:08:04 +02:00

391 lines
10 KiB
Go

package auth
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"backend/internal/db/sqlc"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/pgtype"
)
// ---------- DB tests (skip when TEST_DATABASE_URL unset) ----------
func TestSession_StoresHashedID(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := NewStore(q)
ctx := context.Background()
userID := mustInsertUser(t, pool)
cookieValue, _, err := store.Create(ctx, userID)
if err != nil {
t.Fatalf("Create: %v", err)
}
// The cookie value is base64url. Decode to raw bytes.
raw, err := base64.RawURLEncoding.DecodeString(cookieValue)
if err != nil {
t.Fatalf("base64url decode cookie value: %v", err)
}
if len(raw) != 32 {
t.Fatalf("expected 32 raw bytes, got %d", len(raw))
}
// Compute expected DB id = hex(sha256(raw)).
sum := sha256.Sum256(raw)
expectedID := hex.EncodeToString(sum[:])
// Verify the row in the DB has that exact id (Pitfall 6 guard).
var gotID string
row := pool.QueryRow(ctx, "SELECT id FROM sessions WHERE id = $1", expectedID)
if err := row.Scan(&gotID); err != nil {
t.Fatalf("SELECT session id: %v (expected hashed id %q)", err, expectedID)
}
if gotID != expectedID {
t.Errorf("DB id %q != expected %q", gotID, expectedID)
}
}
func TestSession_LookupRoundtrip(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := NewStore(q)
ctx := context.Background()
userID := mustInsertUser(t, pool)
cookieValue, _, err := store.Create(ctx, userID)
if err != nil {
t.Fatalf("Create: %v", err)
}
sess, user, err := store.Lookup(ctx, cookieValue)
if err != nil {
t.Fatalf("Lookup: %v", err)
}
if sess == nil || user == nil {
t.Fatal("Lookup returned nil session or user")
}
if user.ID != userID {
t.Errorf("user.ID %v != expected %v", user.ID, userID)
}
if sess.ID == "" {
t.Error("session.ID is empty")
}
}
func TestSession_LookupRejectsExpired(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := NewStore(q)
ctx := context.Background()
userID := mustInsertUser(t, pool)
// Insert an expired session directly.
expiredID := "deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef"
expiredAt := time.Now().Add(-1 * time.Hour)
_, err := pool.Exec(ctx,
"INSERT INTO sessions (id, user_id, expires_at) VALUES ($1, $2, $3)",
expiredID, userID, expiredAt,
)
if err != nil {
t.Fatalf("INSERT expired session: %v", err)
}
// Verify that GetSessionWithUser rejects the expired row (D-07 guard).
_, sqlcErr := q.GetSessionWithUser(ctx, expiredID)
if sqlcErr == nil {
t.Fatal("expected error for expired session, got nil")
}
if !strings.Contains(sqlcErr.Error(), "no rows") {
t.Errorf("unexpected error type for expired session: %v", sqlcErr)
}
// Also verify our Store.Lookup returns ErrSessionNotFound for an invalid cookie
// (any garbage value that doesn't correspond to a live session).
_, _, err = store.Lookup(ctx, "garbage-not-base64url-valid-32bytes")
if err != ErrSessionNotFound {
t.Errorf("Lookup with garbage cookie: expected ErrSessionNotFound, got %v", err)
}
}
func TestSession_RotateDeletesOld(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := NewStore(q)
ctx := context.Background()
userID := mustInsertUser(t, pool)
// Create session A.
cookieA, _, err := store.Create(ctx, userID)
if err != nil {
t.Fatalf("Create A: %v", err)
}
rawA, err := base64.RawURLEncoding.DecodeString(cookieA)
if err != nil {
t.Fatalf("decode cookieA: %v", err)
}
sumA := sha256.Sum256(rawA)
idA := hex.EncodeToString(sumA[:])
// Rotate using idA.
cookieB, _, err := store.Rotate(ctx, idA, userID)
if err != nil {
t.Fatalf("Rotate: %v", err)
}
if cookieB == cookieA {
t.Error("Rotate returned same cookie value as original (Pitfall 5)")
}
// Old row must be gone (Pitfall 5 guard).
var count int
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE id = $1", idA)
if err := row.Scan(&count); err != nil {
t.Fatalf("COUNT old session: %v", err)
}
if count != 0 {
t.Errorf("old session row still exists after Rotate (count=%d), session fixation risk", count)
}
// New row must exist.
rawB, err := base64.RawURLEncoding.DecodeString(cookieB)
if err != nil {
t.Fatalf("decode cookieB: %v", err)
}
sumB := sha256.Sum256(rawB)
idB := hex.EncodeToString(sumB[:])
row = pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE id = $1", idB)
if err := row.Scan(&count); err != nil {
t.Fatalf("COUNT new session: %v", err)
}
if count != 1 {
t.Errorf("new session row not found after Rotate (count=%d)", count)
}
}
func TestSession_DeleteRemovesRow(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := NewStore(q)
ctx := context.Background()
userID := mustInsertUser(t, pool)
cookieValue, _, err := store.Create(ctx, userID)
if err != nil {
t.Fatalf("Create: %v", err)
}
rawBytes, _ := base64.RawURLEncoding.DecodeString(cookieValue)
sum := sha256.Sum256(rawBytes)
id := hex.EncodeToString(sum[:])
if err := store.Delete(ctx, id); err != nil {
t.Fatalf("Delete: %v", err)
}
_, _, err = store.Lookup(ctx, cookieValue)
if err != ErrSessionNotFound {
t.Errorf("expected ErrSessionNotFound after Delete, got: %v", err)
}
}
func TestSession_MaybeExtend_NoOp(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := NewStore(q)
fixedNow := time.Now().UTC()
store.now = func() time.Time { return fixedNow }
ctx := context.Background()
userID := mustInsertUser(t, pool)
// Insert a session with expires_at = now + 29 days (above 7-day threshold).
id := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
expiresAt := fixedNow.Add(29 * 24 * time.Hour)
_, err := pool.Exec(ctx,
"INSERT INTO sessions (id, user_id, expires_at) VALUES ($1, $2, $3)",
id, userID, expiresAt,
)
if err != nil {
t.Fatalf("INSERT session: %v", err)
}
if err := store.MaybeExtend(ctx, id, expiresAt); err != nil {
t.Fatalf("MaybeExtend: %v", err)
}
// expires_at must NOT change (within 1 second tolerance).
var gotExp pgtype.Timestamptz
row := pool.QueryRow(ctx, "SELECT expires_at FROM sessions WHERE id = $1", id)
if err := row.Scan(&gotExp); err != nil {
t.Fatalf("SELECT expires_at: %v", err)
}
diff := gotExp.Time.Sub(expiresAt)
if diff < 0 {
diff = -diff
}
if diff > time.Second {
t.Errorf("MaybeExtend changed expires_at when remaining > threshold (diff=%v)", diff)
}
}
func TestSession_MaybeExtend_Extends(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := NewStore(q)
fixedNow := time.Now().UTC()
store.now = func() time.Time { return fixedNow }
ctx := context.Background()
userID := mustInsertUser(t, pool)
// Insert a session with expires_at = now + 1 day (below 7-day threshold).
id := "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"
expiresAt := fixedNow.Add(1 * 24 * time.Hour)
_, err := pool.Exec(ctx,
"INSERT INTO sessions (id, user_id, expires_at) VALUES ($1, $2, $3)",
id, userID, expiresAt,
)
if err != nil {
t.Fatalf("INSERT session: %v", err)
}
if err := store.MaybeExtend(ctx, id, expiresAt); err != nil {
t.Fatalf("MaybeExtend: %v", err)
}
// expires_at must now be ~now + 30 days.
var gotExp pgtype.Timestamptz
row := pool.QueryRow(ctx, "SELECT expires_at FROM sessions WHERE id = $1", id)
if err := row.Scan(&gotExp); err != nil {
t.Fatalf("SELECT expires_at: %v", err)
}
if !gotExp.Time.After(expiresAt) {
t.Errorf("MaybeExtend did not extend expires_at (still %v)", gotExp.Time)
}
expected := fixedNow.Add(SessionTTL)
diff := gotExp.Time.Sub(expected)
if diff < 0 {
diff = -diff
}
if diff > 5*time.Second {
t.Errorf("extended expires_at %v not within 5s of expected %v (diff=%v)", gotExp.Time, expected, diff)
}
}
// ---------- Cookie tests (no DB) ----------
func TestCookie_SetAttributes(t *testing.T) {
w := httptest.NewRecorder()
expiresAt := time.Now().Add(30 * 24 * time.Hour)
SetSessionCookie(w, "my-token-value", expiresAt, true)
resp := w.Result()
cookies := resp.Cookies()
if len(cookies) != 1 {
t.Fatalf("expected 1 cookie, got %d", len(cookies))
}
c := cookies[0]
if c.Name != SessionCookieName {
t.Errorf("cookie name %q, want %q", c.Name, SessionCookieName)
}
if c.Value != "my-token-value" {
t.Errorf("cookie value %q, want %q", c.Value, "my-token-value")
}
if !c.HttpOnly {
t.Error("cookie must be HttpOnly")
}
if !c.Secure {
t.Error("cookie must be Secure when secure=true")
}
if c.SameSite != http.SameSiteLaxMode {
t.Errorf("cookie SameSite %v, want %v", c.SameSite, http.SameSiteLaxMode)
}
if c.Path != "/" {
t.Errorf("cookie Path %q, want %q", c.Path, "/")
}
expected := int(30 * 24 * time.Hour / time.Second)
diff := c.MaxAge - expected
if diff < 0 {
diff = -diff
}
if diff > 5 {
t.Errorf("MaxAge %d not within 5s of expected %d", c.MaxAge, expected)
}
}
func TestCookie_ClearUsesMaxAgeMinus1(t *testing.T) {
w := httptest.NewRecorder()
ClearSessionCookie(w, false)
resp := w.Result()
cookies := resp.Cookies()
if len(cookies) != 1 {
t.Fatalf("expected 1 cookie, got %d", len(cookies))
}
c := cookies[0]
if c.MaxAge != -1 {
t.Errorf("ClearSessionCookie MaxAge %d, want -1 (NOT 0 per RESEARCH Pattern 3)", c.MaxAge)
}
}
func TestCookie_SecureGatedByEnv(t *testing.T) {
w := httptest.NewRecorder()
SetSessionCookie(w, "tok", time.Now().Add(time.Hour), false) // secure=false
resp := w.Result()
cookies := resp.Cookies()
if len(cookies) != 1 {
t.Fatalf("expected 1 cookie, got %d", len(cookies))
}
if cookies[0].Secure {
t.Error("cookie must NOT have Secure attribute when secure=false")
}
}
// ---------- helpers ----------
// mustInsertUser inserts a test user and returns its ID.
func mustInsertUser(t *testing.T, pool *pgxpool.Pool) uuid.UUID {
t.Helper()
ctx := context.Background()
var id uuid.UUID
row := pool.QueryRow(ctx,
"INSERT INTO users (email, password_hash) VALUES ($1, $2) RETURNING id",
"test-"+uuid.NewString()+"@example.com",
"$argon2id$v=19$m=8192,t=1,p=2$AAAAAAAAAAAAAAAAAAAAAA$AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA",
)
if err := row.Scan(&id); err != nil {
t.Fatalf("mustInsertUser: %v", err)
}
return id
}