diff --git a/backend/internal/auth/middleware.go b/backend/internal/auth/middleware.go new file mode 100644 index 0000000..6d582ce --- /dev/null +++ b/backend/internal/auth/middleware.go @@ -0,0 +1,116 @@ +package auth + +import ( + "context" + "log/slog" + "net/http" +) + +// sessionCtxKey is the unexported context key type for session data owned by +// this package. Using an unexported named struct prevents collisions with +// other packages' context keys. +type sessionCtxKey struct{} + +// sessionKey is the singleton key value used in context.WithValue. +var sessionKey = sessionCtxKey{} + +// authed holds the resolved session and user attached to a request context. +type authed struct { + Session *Session + User *User +} + +// Authed extracts the session and user from the request context. +// Returns (session, user, true) when a valid session is present, and +// (nil, nil, false) when the request is unauthenticated. +func Authed(ctx context.Context) (*Session, *User, bool) { + a, ok := ctx.Value(sessionKey).(*authed) + if !ok || a == nil { + return nil, nil, false + } + return a.Session, a.User, true +} + +// ResolveSession reads the session cookie, looks up the session + user, and +// attaches them to the request context. It NEVER blocks the request — missing +// or invalid sessions are silently ignored; RequireAuth enforces access. +// +// On a valid session hit, MaybeExtend is called best-effort (logged but not +// fatal) to implement the sliding 30-day TTL (D-09). +func ResolveSession(store *Store) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie(SessionCookieName) + if err != nil || cookie.Value == "" { + // No cookie — pass through unauthenticated. + next.ServeHTTP(w, r) + return + } + + sess, user, err := store.Lookup(r.Context(), cookie.Value) + if err != nil { + // Invalid / expired / tampered cookie — do NOT clear the cookie + // here; the handler or RequireAuth will decide what to do. + next.ServeHTTP(w, r) + return + } + + // Session found — attempt lazy extension (D-09). Best-effort: log on + // error but do not fail the request. + if extErr := store.MaybeExtend(r.Context(), sess.ID, sess.ExpiresAt); extErr != nil { + slog.Default().Warn("session extend failed", "session_id", sess.ID, "err", extErr) + } + + // Attach session + user to context for downstream handlers. + ctx := context.WithValue(r.Context(), sessionKey, &authed{Session: sess, User: user}) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// RequireAuth is middleware that enforces an authenticated session. +// If no session is present in the context (set by ResolveSession), it +// redirects unauth requests to /login: +// - HTMX requests (HX-Request: true) → 200 with HX-Redirect: /login header +// - Plain requests → 303 See Other with Location: /login +// +// 303 is mandated (not 302) per D-23 and Pitfall 9: POST/Redirect/GET pattern +// requires 303 to guarantee the redirect uses GET. +func RequireAuth(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, _, ok := Authed(r.Context()); !ok { + redirectTo(w, r, "/login") + return + } + next.ServeHTTP(w, r) + }) +} + +// RedirectIfAuthed bounces already-authenticated users away from auth pages +// (e.g. /login, /signup) to the home route. This prevents authed users from +// accidentally re-logging-in and rotating their session unnecessarily. +// - HTMX requests → 200 with HX-Redirect: / +// - Plain requests → 303 See Other with Location: / +func RedirectIfAuthed(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, _, ok := Authed(r.Context()); ok { + redirectTo(w, r, "/") + return + } + next.ServeHTTP(w, r) + }) +} + +// redirectTo performs an HTMX-aware redirect: +// - When the request carries HX-Request: true, it returns 200 with an +// HX-Redirect header so HTMX can handle the navigation client-side +// (Pattern 5 — avoids confusing HTMX with a 303 response). +// - For plain browser requests it uses 303 See Other (NOT 302 — Pitfall 9). +func redirectTo(w http.ResponseWriter, r *http.Request, target string) { + if r.Header.Get("HX-Request") == "true" { + w.Header().Set("HX-Redirect", target) + w.WriteHeader(http.StatusOK) + return + } + http.Redirect(w, r, target, http.StatusSeeOther) +} diff --git a/backend/internal/auth/middleware_test.go b/backend/internal/auth/middleware_test.go new file mode 100644 index 0000000..bcdc9b7 --- /dev/null +++ b/backend/internal/auth/middleware_test.go @@ -0,0 +1,254 @@ +package auth + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "backend/internal/db/sqlc" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" +) + +// ---------- ResolveSession tests ---------- + +func TestResolveSession_NoCookie(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + q := sqlc.New(pool) + store := NewStore(q) + + var gotOk bool + handler := ResolveSession(store)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _, gotOk = Authed(r.Context()) + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rr.Code) + } + if gotOk { + t.Error("Authed(ctx) should return false when no cookie present") + } +} + +func TestResolveSession_InvalidCookie(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + q := sqlc.New(pool) + store := NewStore(q) + + var gotOk bool + handler := ResolveSession(store)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _, gotOk = Authed(r.Context()) + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: "garbage-invalid-cookie-value-xyz"}) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rr.Code) + } + if gotOk { + t.Error("Authed(ctx) should return false when cookie is invalid/not found") + } +} + +func TestResolveSession_ValidCookie(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + q := sqlc.New(pool) + store := NewStore(q) + ctx := context.Background() + + userID := mustInsertUser(t, pool) + cookieValue, _, err := store.Create(ctx, userID) + if err != nil { + t.Fatalf("Create session: %v", err) + } + + var gotSess *Session + var gotUser *User + var gotOk bool + handler := ResolveSession(store)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotSess, gotUser, gotOk = Authed(r.Context()) + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: cookieValue}) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rr.Code) + } + if !gotOk { + t.Fatal("Authed(ctx) should return true for valid session cookie") + } + if gotUser == nil || gotUser.ID != userID { + t.Errorf("user.ID %v != expected %v", gotUser, userID) + } + if gotSess == nil || gotSess.ID == "" { + t.Error("session should be non-nil with non-empty ID") + } +} + +// ---------- RequireAuth tests ---------- + +func TestRequireAuth_303WhenUnauth(t *testing.T) { + r := chi.NewRouter() + r.Use(func(next http.Handler) http.Handler { + // no session in ctx (no ResolveSession) + return next + }) + r.Handle("/protected", RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }))) + + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusSeeOther { + t.Errorf("expected 303 SeeOther, got %d", rr.Code) + } + if loc := rr.Header().Get("Location"); loc != "/login" { + t.Errorf("expected Location: /login, got %q", loc) + } +} + +func TestRequireAuth_HXRedirectWhenUnauth(t *testing.T) { + r := chi.NewRouter() + r.Handle("/protected", RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }))) + + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("HX-Request", "true") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // HTMX: must return 200 with HX-Redirect header, NOT 303 (Pattern 5). + if rr.Code != http.StatusOK { + t.Errorf("expected 200 for HTMX unauth, got %d", rr.Code) + } + if hxRedir := rr.Header().Get("HX-Redirect"); hxRedir != "/login" { + t.Errorf("expected HX-Redirect: /login, got %q", hxRedir) + } + // Must NOT be a 303 redirect. + if rr.Code == http.StatusSeeOther { + t.Error("HTMX request must NOT receive 303 (Pattern 5)") + } +} + +func TestRequireAuth_PassesWhenAuth(t *testing.T) { + r := chi.NewRouter() + // Inject a valid session into context directly. + r.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), sessionKey, &authed{ + Session: &Session{ID: "test-sess"}, + User: &User{ID: uuid.New()}, + }) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + }) + r.Handle("/protected", RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }))) + + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("expected 200 for authenticated request, got %d", rr.Code) + } +} + +// ---------- RedirectIfAuthed tests ---------- + +func TestRedirectIfAuthed_BouncesWhenAuth(t *testing.T) { + r := chi.NewRouter() + r.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), sessionKey, &authed{ + Session: &Session{ID: "test-sess"}, + User: &User{ID: uuid.New()}, + }) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + }) + r.Handle("/login", RedirectIfAuthed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }))) + + // Non-HTMX: expect 303 Location: / + req := httptest.NewRequest(http.MethodGet, "/login", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusSeeOther { + t.Errorf("expected 303 for authed user on /login, got %d", rr.Code) + } + if loc := rr.Header().Get("Location"); loc != "/" { + t.Errorf("expected Location: /, got %q", loc) + } +} + +func TestRedirectIfAuthed_HXBounceWhenAuth(t *testing.T) { + r := chi.NewRouter() + r.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), sessionKey, &authed{ + Session: &Session{ID: "test-sess"}, + User: &User{ID: uuid.New()}, + }) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + }) + r.Handle("/login", RedirectIfAuthed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }))) + + // HTMX: expect 200 + HX-Redirect: / + req := httptest.NewRequest(http.MethodGet, "/login", nil) + req.Header.Set("HX-Request", "true") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("expected 200 for HTMX bounce, got %d", rr.Code) + } + if hxRedir := rr.Header().Get("HX-Redirect"); hxRedir != "/" { + t.Errorf("expected HX-Redirect: /, got %q", hxRedir) + } +} + +func TestRedirectIfAuthed_PassesWhenUnauth(t *testing.T) { + r := chi.NewRouter() + r.Handle("/login", RedirectIfAuthed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }))) + + req := httptest.NewRequest(http.MethodGet, "/login", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("expected 200 for unauthenticated user on /login, got %d", rr.Code) + } +}