- Store.Create: 32-byte crypto/rand token, SHA-256 hex as DB id (D-05) - Store.Lookup: hashes cookie, maps pgx.ErrNoRows to ErrSessionNotFound (D-07) - Store.Delete: hard-deletes session row (D-06) - Store.Rotate: deletes old row before creating new one (D-10, T-2-04) - Store.MaybeExtend: extends only when remaining < 7 days (D-09) - SetSessionCookie: HttpOnly + Secure (env-gated) + SameSite=Lax (D-12) - ClearSessionCookie: MaxAge=-1 not 0 (RESEARCH Pattern 3 / D-06) - 10 tests: 7 real-DB (skip without TEST_DATABASE_URL) + 3 cookie unit tests
391 lines
10 KiB
Go
391 lines
10 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"backend/internal/db/sqlc"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
)
|
|
|
|
// ---------- DB tests (skip when TEST_DATABASE_URL unset) ----------
|
|
|
|
func TestSession_StoresHashedID(t *testing.T) {
|
|
pool, cleanup := setupTestDB(t)
|
|
defer cleanup()
|
|
|
|
q := sqlc.New(pool)
|
|
store := NewStore(q)
|
|
|
|
ctx := context.Background()
|
|
userID := mustInsertUser(t, pool)
|
|
|
|
cookieValue, _, err := store.Create(ctx, userID)
|
|
if err != nil {
|
|
t.Fatalf("Create: %v", err)
|
|
}
|
|
|
|
// The cookie value is base64url. Decode to raw bytes.
|
|
raw, err := base64.RawURLEncoding.DecodeString(cookieValue)
|
|
if err != nil {
|
|
t.Fatalf("base64url decode cookie value: %v", err)
|
|
}
|
|
if len(raw) != 32 {
|
|
t.Fatalf("expected 32 raw bytes, got %d", len(raw))
|
|
}
|
|
|
|
// Compute expected DB id = hex(sha256(raw)).
|
|
sum := sha256.Sum256(raw)
|
|
expectedID := hex.EncodeToString(sum[:])
|
|
|
|
// Verify the row in the DB has that exact id (Pitfall 6 guard).
|
|
var gotID string
|
|
row := pool.QueryRow(ctx, "SELECT id FROM sessions WHERE id = $1", expectedID)
|
|
if err := row.Scan(&gotID); err != nil {
|
|
t.Fatalf("SELECT session id: %v (expected hashed id %q)", err, expectedID)
|
|
}
|
|
if gotID != expectedID {
|
|
t.Errorf("DB id %q != expected %q", gotID, expectedID)
|
|
}
|
|
}
|
|
|
|
func TestSession_LookupRoundtrip(t *testing.T) {
|
|
pool, cleanup := setupTestDB(t)
|
|
defer cleanup()
|
|
|
|
q := sqlc.New(pool)
|
|
store := NewStore(q)
|
|
ctx := context.Background()
|
|
|
|
userID := mustInsertUser(t, pool)
|
|
cookieValue, _, err := store.Create(ctx, userID)
|
|
if err != nil {
|
|
t.Fatalf("Create: %v", err)
|
|
}
|
|
|
|
sess, user, err := store.Lookup(ctx, cookieValue)
|
|
if err != nil {
|
|
t.Fatalf("Lookup: %v", err)
|
|
}
|
|
if sess == nil || user == nil {
|
|
t.Fatal("Lookup returned nil session or user")
|
|
}
|
|
if user.ID != userID {
|
|
t.Errorf("user.ID %v != expected %v", user.ID, userID)
|
|
}
|
|
if sess.ID == "" {
|
|
t.Error("session.ID is empty")
|
|
}
|
|
}
|
|
|
|
func TestSession_LookupRejectsExpired(t *testing.T) {
|
|
pool, cleanup := setupTestDB(t)
|
|
defer cleanup()
|
|
|
|
q := sqlc.New(pool)
|
|
store := NewStore(q)
|
|
ctx := context.Background()
|
|
|
|
userID := mustInsertUser(t, pool)
|
|
|
|
// Insert an expired session directly.
|
|
expiredID := "deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef"
|
|
expiredAt := time.Now().Add(-1 * time.Hour)
|
|
_, err := pool.Exec(ctx,
|
|
"INSERT INTO sessions (id, user_id, expires_at) VALUES ($1, $2, $3)",
|
|
expiredID, userID, expiredAt,
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("INSERT expired session: %v", err)
|
|
}
|
|
|
|
// Verify that GetSessionWithUser rejects the expired row (D-07 guard).
|
|
_, sqlcErr := q.GetSessionWithUser(ctx, expiredID)
|
|
if sqlcErr == nil {
|
|
t.Fatal("expected error for expired session, got nil")
|
|
}
|
|
if !strings.Contains(sqlcErr.Error(), "no rows") {
|
|
t.Errorf("unexpected error type for expired session: %v", sqlcErr)
|
|
}
|
|
|
|
// Also verify our Store.Lookup returns ErrSessionNotFound for an invalid cookie
|
|
// (any garbage value that doesn't correspond to a live session).
|
|
_, _, err = store.Lookup(ctx, "garbage-not-base64url-valid-32bytes")
|
|
if err != ErrSessionNotFound {
|
|
t.Errorf("Lookup with garbage cookie: expected ErrSessionNotFound, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestSession_RotateDeletesOld(t *testing.T) {
|
|
pool, cleanup := setupTestDB(t)
|
|
defer cleanup()
|
|
|
|
q := sqlc.New(pool)
|
|
store := NewStore(q)
|
|
ctx := context.Background()
|
|
|
|
userID := mustInsertUser(t, pool)
|
|
|
|
// Create session A.
|
|
cookieA, _, err := store.Create(ctx, userID)
|
|
if err != nil {
|
|
t.Fatalf("Create A: %v", err)
|
|
}
|
|
|
|
rawA, err := base64.RawURLEncoding.DecodeString(cookieA)
|
|
if err != nil {
|
|
t.Fatalf("decode cookieA: %v", err)
|
|
}
|
|
sumA := sha256.Sum256(rawA)
|
|
idA := hex.EncodeToString(sumA[:])
|
|
|
|
// Rotate using idA.
|
|
cookieB, _, err := store.Rotate(ctx, idA, userID)
|
|
if err != nil {
|
|
t.Fatalf("Rotate: %v", err)
|
|
}
|
|
if cookieB == cookieA {
|
|
t.Error("Rotate returned same cookie value as original (Pitfall 5)")
|
|
}
|
|
|
|
// Old row must be gone (Pitfall 5 guard).
|
|
var count int
|
|
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE id = $1", idA)
|
|
if err := row.Scan(&count); err != nil {
|
|
t.Fatalf("COUNT old session: %v", err)
|
|
}
|
|
if count != 0 {
|
|
t.Errorf("old session row still exists after Rotate (count=%d), session fixation risk", count)
|
|
}
|
|
|
|
// New row must exist.
|
|
rawB, err := base64.RawURLEncoding.DecodeString(cookieB)
|
|
if err != nil {
|
|
t.Fatalf("decode cookieB: %v", err)
|
|
}
|
|
sumB := sha256.Sum256(rawB)
|
|
idB := hex.EncodeToString(sumB[:])
|
|
row = pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE id = $1", idB)
|
|
if err := row.Scan(&count); err != nil {
|
|
t.Fatalf("COUNT new session: %v", err)
|
|
}
|
|
if count != 1 {
|
|
t.Errorf("new session row not found after Rotate (count=%d)", count)
|
|
}
|
|
}
|
|
|
|
func TestSession_DeleteRemovesRow(t *testing.T) {
|
|
pool, cleanup := setupTestDB(t)
|
|
defer cleanup()
|
|
|
|
q := sqlc.New(pool)
|
|
store := NewStore(q)
|
|
ctx := context.Background()
|
|
|
|
userID := mustInsertUser(t, pool)
|
|
cookieValue, _, err := store.Create(ctx, userID)
|
|
if err != nil {
|
|
t.Fatalf("Create: %v", err)
|
|
}
|
|
|
|
rawBytes, _ := base64.RawURLEncoding.DecodeString(cookieValue)
|
|
sum := sha256.Sum256(rawBytes)
|
|
id := hex.EncodeToString(sum[:])
|
|
|
|
if err := store.Delete(ctx, id); err != nil {
|
|
t.Fatalf("Delete: %v", err)
|
|
}
|
|
|
|
_, _, err = store.Lookup(ctx, cookieValue)
|
|
if err != ErrSessionNotFound {
|
|
t.Errorf("expected ErrSessionNotFound after Delete, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestSession_MaybeExtend_NoOp(t *testing.T) {
|
|
pool, cleanup := setupTestDB(t)
|
|
defer cleanup()
|
|
|
|
q := sqlc.New(pool)
|
|
store := NewStore(q)
|
|
|
|
fixedNow := time.Now().UTC()
|
|
store.now = func() time.Time { return fixedNow }
|
|
|
|
ctx := context.Background()
|
|
userID := mustInsertUser(t, pool)
|
|
|
|
// Insert a session with expires_at = now + 29 days (above 7-day threshold).
|
|
id := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
|
expiresAt := fixedNow.Add(29 * 24 * time.Hour)
|
|
_, err := pool.Exec(ctx,
|
|
"INSERT INTO sessions (id, user_id, expires_at) VALUES ($1, $2, $3)",
|
|
id, userID, expiresAt,
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("INSERT session: %v", err)
|
|
}
|
|
|
|
if err := store.MaybeExtend(ctx, id, expiresAt); err != nil {
|
|
t.Fatalf("MaybeExtend: %v", err)
|
|
}
|
|
|
|
// expires_at must NOT change (within 1 second tolerance).
|
|
var gotExp pgtype.Timestamptz
|
|
row := pool.QueryRow(ctx, "SELECT expires_at FROM sessions WHERE id = $1", id)
|
|
if err := row.Scan(&gotExp); err != nil {
|
|
t.Fatalf("SELECT expires_at: %v", err)
|
|
}
|
|
diff := gotExp.Time.Sub(expiresAt)
|
|
if diff < 0 {
|
|
diff = -diff
|
|
}
|
|
if diff > time.Second {
|
|
t.Errorf("MaybeExtend changed expires_at when remaining > threshold (diff=%v)", diff)
|
|
}
|
|
}
|
|
|
|
func TestSession_MaybeExtend_Extends(t *testing.T) {
|
|
pool, cleanup := setupTestDB(t)
|
|
defer cleanup()
|
|
|
|
q := sqlc.New(pool)
|
|
store := NewStore(q)
|
|
|
|
fixedNow := time.Now().UTC()
|
|
store.now = func() time.Time { return fixedNow }
|
|
|
|
ctx := context.Background()
|
|
userID := mustInsertUser(t, pool)
|
|
|
|
// Insert a session with expires_at = now + 1 day (below 7-day threshold).
|
|
id := "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"
|
|
expiresAt := fixedNow.Add(1 * 24 * time.Hour)
|
|
_, err := pool.Exec(ctx,
|
|
"INSERT INTO sessions (id, user_id, expires_at) VALUES ($1, $2, $3)",
|
|
id, userID, expiresAt,
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("INSERT session: %v", err)
|
|
}
|
|
|
|
if err := store.MaybeExtend(ctx, id, expiresAt); err != nil {
|
|
t.Fatalf("MaybeExtend: %v", err)
|
|
}
|
|
|
|
// expires_at must now be ~now + 30 days.
|
|
var gotExp pgtype.Timestamptz
|
|
row := pool.QueryRow(ctx, "SELECT expires_at FROM sessions WHERE id = $1", id)
|
|
if err := row.Scan(&gotExp); err != nil {
|
|
t.Fatalf("SELECT expires_at: %v", err)
|
|
}
|
|
if !gotExp.Time.After(expiresAt) {
|
|
t.Errorf("MaybeExtend did not extend expires_at (still %v)", gotExp.Time)
|
|
}
|
|
expected := fixedNow.Add(SessionTTL)
|
|
diff := gotExp.Time.Sub(expected)
|
|
if diff < 0 {
|
|
diff = -diff
|
|
}
|
|
if diff > 5*time.Second {
|
|
t.Errorf("extended expires_at %v not within 5s of expected %v (diff=%v)", gotExp.Time, expected, diff)
|
|
}
|
|
}
|
|
|
|
// ---------- Cookie tests (no DB) ----------
|
|
|
|
func TestCookie_SetAttributes(t *testing.T) {
|
|
w := httptest.NewRecorder()
|
|
expiresAt := time.Now().Add(30 * 24 * time.Hour)
|
|
SetSessionCookie(w, "my-token-value", expiresAt, true)
|
|
|
|
resp := w.Result()
|
|
cookies := resp.Cookies()
|
|
if len(cookies) != 1 {
|
|
t.Fatalf("expected 1 cookie, got %d", len(cookies))
|
|
}
|
|
c := cookies[0]
|
|
if c.Name != SessionCookieName {
|
|
t.Errorf("cookie name %q, want %q", c.Name, SessionCookieName)
|
|
}
|
|
if c.Value != "my-token-value" {
|
|
t.Errorf("cookie value %q, want %q", c.Value, "my-token-value")
|
|
}
|
|
if !c.HttpOnly {
|
|
t.Error("cookie must be HttpOnly")
|
|
}
|
|
if !c.Secure {
|
|
t.Error("cookie must be Secure when secure=true")
|
|
}
|
|
if c.SameSite != http.SameSiteLaxMode {
|
|
t.Errorf("cookie SameSite %v, want %v", c.SameSite, http.SameSiteLaxMode)
|
|
}
|
|
if c.Path != "/" {
|
|
t.Errorf("cookie Path %q, want %q", c.Path, "/")
|
|
}
|
|
expected := int(30 * 24 * time.Hour / time.Second)
|
|
diff := c.MaxAge - expected
|
|
if diff < 0 {
|
|
diff = -diff
|
|
}
|
|
if diff > 5 {
|
|
t.Errorf("MaxAge %d not within 5s of expected %d", c.MaxAge, expected)
|
|
}
|
|
}
|
|
|
|
func TestCookie_ClearUsesMaxAgeMinus1(t *testing.T) {
|
|
w := httptest.NewRecorder()
|
|
ClearSessionCookie(w, false)
|
|
|
|
resp := w.Result()
|
|
cookies := resp.Cookies()
|
|
if len(cookies) != 1 {
|
|
t.Fatalf("expected 1 cookie, got %d", len(cookies))
|
|
}
|
|
c := cookies[0]
|
|
if c.MaxAge != -1 {
|
|
t.Errorf("ClearSessionCookie MaxAge %d, want -1 (NOT 0 per RESEARCH Pattern 3)", c.MaxAge)
|
|
}
|
|
}
|
|
|
|
func TestCookie_SecureGatedByEnv(t *testing.T) {
|
|
w := httptest.NewRecorder()
|
|
SetSessionCookie(w, "tok", time.Now().Add(time.Hour), false) // secure=false
|
|
|
|
resp := w.Result()
|
|
cookies := resp.Cookies()
|
|
if len(cookies) != 1 {
|
|
t.Fatalf("expected 1 cookie, got %d", len(cookies))
|
|
}
|
|
if cookies[0].Secure {
|
|
t.Error("cookie must NOT have Secure attribute when secure=false")
|
|
}
|
|
}
|
|
|
|
// ---------- helpers ----------
|
|
|
|
// mustInsertUser inserts a test user and returns its ID.
|
|
func mustInsertUser(t *testing.T, pool *pgxpool.Pool) uuid.UUID {
|
|
t.Helper()
|
|
ctx := context.Background()
|
|
var id uuid.UUID
|
|
row := pool.QueryRow(ctx,
|
|
"INSERT INTO users (email, password_hash) VALUES ($1, $2) RETURNING id",
|
|
"test-"+uuid.NewString()+"@example.com",
|
|
"$argon2id$v=19$m=8192,t=1,p=2$AAAAAAAAAAAAAAAAAAAAAA$AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA",
|
|
)
|
|
if err := row.Scan(&id); err != nil {
|
|
t.Fatalf("mustInsertUser: %v", err)
|
|
}
|
|
return id
|
|
}
|