feat(02-03): ResolveSession + RequireAuth + RedirectIfAuthed middleware
- ResolveSession: reads cookie, SHA-256 lookup, MaybeExtend best-effort, attaches Session+User to ctx - RequireAuth: 303 /login for plain requests; HX-Redirect: /login for HTMX (D-23, Pattern 5) - RedirectIfAuthed: bounces authed users to / from login/signup pages - Authed(ctx): typed context accessor for session + user - redirect helper centralizes 303 vs HX-Redirect logic (Pitfall 9: no 302) - 9 tests: 3 real-DB (ResolveSession) + 6 pure ctx/routing (RequireAuth, RedirectIfAuthed)
This commit is contained in:
parent
fd2301decf
commit
1d07830954
2 changed files with 370 additions and 0 deletions
116
backend/internal/auth/middleware.go
Normal file
116
backend/internal/auth/middleware.go
Normal file
|
|
@ -0,0 +1,116 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// sessionCtxKey is the unexported context key type for session data owned by
|
||||||
|
// this package. Using an unexported named struct prevents collisions with
|
||||||
|
// other packages' context keys.
|
||||||
|
type sessionCtxKey struct{}
|
||||||
|
|
||||||
|
// sessionKey is the singleton key value used in context.WithValue.
|
||||||
|
var sessionKey = sessionCtxKey{}
|
||||||
|
|
||||||
|
// authed holds the resolved session and user attached to a request context.
|
||||||
|
type authed struct {
|
||||||
|
Session *Session
|
||||||
|
User *User
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authed extracts the session and user from the request context.
|
||||||
|
// Returns (session, user, true) when a valid session is present, and
|
||||||
|
// (nil, nil, false) when the request is unauthenticated.
|
||||||
|
func Authed(ctx context.Context) (*Session, *User, bool) {
|
||||||
|
a, ok := ctx.Value(sessionKey).(*authed)
|
||||||
|
if !ok || a == nil {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
return a.Session, a.User, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveSession reads the session cookie, looks up the session + user, and
|
||||||
|
// attaches them to the request context. It NEVER blocks the request — missing
|
||||||
|
// or invalid sessions are silently ignored; RequireAuth enforces access.
|
||||||
|
//
|
||||||
|
// On a valid session hit, MaybeExtend is called best-effort (logged but not
|
||||||
|
// fatal) to implement the sliding 30-day TTL (D-09).
|
||||||
|
func ResolveSession(store *Store) func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
cookie, err := r.Cookie(SessionCookieName)
|
||||||
|
if err != nil || cookie.Value == "" {
|
||||||
|
// No cookie — pass through unauthenticated.
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sess, user, err := store.Lookup(r.Context(), cookie.Value)
|
||||||
|
if err != nil {
|
||||||
|
// Invalid / expired / tampered cookie — do NOT clear the cookie
|
||||||
|
// here; the handler or RequireAuth will decide what to do.
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Session found — attempt lazy extension (D-09). Best-effort: log on
|
||||||
|
// error but do not fail the request.
|
||||||
|
if extErr := store.MaybeExtend(r.Context(), sess.ID, sess.ExpiresAt); extErr != nil {
|
||||||
|
slog.Default().Warn("session extend failed", "session_id", sess.ID, "err", extErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attach session + user to context for downstream handlers.
|
||||||
|
ctx := context.WithValue(r.Context(), sessionKey, &authed{Session: sess, User: user})
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequireAuth is middleware that enforces an authenticated session.
|
||||||
|
// If no session is present in the context (set by ResolveSession), it
|
||||||
|
// redirects unauth requests to /login:
|
||||||
|
// - HTMX requests (HX-Request: true) → 200 with HX-Redirect: /login header
|
||||||
|
// - Plain requests → 303 See Other with Location: /login
|
||||||
|
//
|
||||||
|
// 303 is mandated (not 302) per D-23 and Pitfall 9: POST/Redirect/GET pattern
|
||||||
|
// requires 303 to guarantee the redirect uses GET.
|
||||||
|
func RequireAuth(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if _, _, ok := Authed(r.Context()); !ok {
|
||||||
|
redirectTo(w, r, "/login")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RedirectIfAuthed bounces already-authenticated users away from auth pages
|
||||||
|
// (e.g. /login, /signup) to the home route. This prevents authed users from
|
||||||
|
// accidentally re-logging-in and rotating their session unnecessarily.
|
||||||
|
// - HTMX requests → 200 with HX-Redirect: /
|
||||||
|
// - Plain requests → 303 See Other with Location: /
|
||||||
|
func RedirectIfAuthed(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if _, _, ok := Authed(r.Context()); ok {
|
||||||
|
redirectTo(w, r, "/")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// redirectTo performs an HTMX-aware redirect:
|
||||||
|
// - When the request carries HX-Request: true, it returns 200 with an
|
||||||
|
// HX-Redirect header so HTMX can handle the navigation client-side
|
||||||
|
// (Pattern 5 — avoids confusing HTMX with a 303 response).
|
||||||
|
// - For plain browser requests it uses 303 See Other (NOT 302 — Pitfall 9).
|
||||||
|
func redirectTo(w http.ResponseWriter, r *http.Request, target string) {
|
||||||
|
if r.Header.Get("HX-Request") == "true" {
|
||||||
|
w.Header().Set("HX-Redirect", target)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.Redirect(w, r, target, http.StatusSeeOther)
|
||||||
|
}
|
||||||
254
backend/internal/auth/middleware_test.go
Normal file
254
backend/internal/auth/middleware_test.go
Normal file
|
|
@ -0,0 +1,254 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"backend/internal/db/sqlc"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------- ResolveSession tests ----------
|
||||||
|
|
||||||
|
func TestResolveSession_NoCookie(t *testing.T) {
|
||||||
|
pool, cleanup := setupTestDB(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
q := sqlc.New(pool)
|
||||||
|
store := NewStore(q)
|
||||||
|
|
||||||
|
var gotOk bool
|
||||||
|
handler := ResolveSession(store)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _, gotOk = Authed(r.Context())
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
if gotOk {
|
||||||
|
t.Error("Authed(ctx) should return false when no cookie present")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveSession_InvalidCookie(t *testing.T) {
|
||||||
|
pool, cleanup := setupTestDB(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
q := sqlc.New(pool)
|
||||||
|
store := NewStore(q)
|
||||||
|
|
||||||
|
var gotOk bool
|
||||||
|
handler := ResolveSession(store)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _, gotOk = Authed(r.Context())
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: "garbage-invalid-cookie-value-xyz"})
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
if gotOk {
|
||||||
|
t.Error("Authed(ctx) should return false when cookie is invalid/not found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveSession_ValidCookie(t *testing.T) {
|
||||||
|
pool, cleanup := setupTestDB(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
q := sqlc.New(pool)
|
||||||
|
store := NewStore(q)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
userID := mustInsertUser(t, pool)
|
||||||
|
cookieValue, _, err := store.Create(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Create session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var gotSess *Session
|
||||||
|
var gotUser *User
|
||||||
|
var gotOk bool
|
||||||
|
handler := ResolveSession(store)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotSess, gotUser, gotOk = Authed(r.Context())
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: cookieValue})
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
if !gotOk {
|
||||||
|
t.Fatal("Authed(ctx) should return true for valid session cookie")
|
||||||
|
}
|
||||||
|
if gotUser == nil || gotUser.ID != userID {
|
||||||
|
t.Errorf("user.ID %v != expected %v", gotUser, userID)
|
||||||
|
}
|
||||||
|
if gotSess == nil || gotSess.ID == "" {
|
||||||
|
t.Error("session should be non-nil with non-empty ID")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- RequireAuth tests ----------
|
||||||
|
|
||||||
|
func TestRequireAuth_303WhenUnauth(t *testing.T) {
|
||||||
|
r := chi.NewRouter()
|
||||||
|
r.Use(func(next http.Handler) http.Handler {
|
||||||
|
// no session in ctx (no ResolveSession)
|
||||||
|
return next
|
||||||
|
})
|
||||||
|
r.Handle("/protected", RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusSeeOther {
|
||||||
|
t.Errorf("expected 303 SeeOther, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
if loc := rr.Header().Get("Location"); loc != "/login" {
|
||||||
|
t.Errorf("expected Location: /login, got %q", loc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireAuth_HXRedirectWhenUnauth(t *testing.T) {
|
||||||
|
r := chi.NewRouter()
|
||||||
|
r.Handle("/protected", RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||||
|
req.Header.Set("HX-Request", "true")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
// HTMX: must return 200 with HX-Redirect header, NOT 303 (Pattern 5).
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200 for HTMX unauth, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
if hxRedir := rr.Header().Get("HX-Redirect"); hxRedir != "/login" {
|
||||||
|
t.Errorf("expected HX-Redirect: /login, got %q", hxRedir)
|
||||||
|
}
|
||||||
|
// Must NOT be a 303 redirect.
|
||||||
|
if rr.Code == http.StatusSeeOther {
|
||||||
|
t.Error("HTMX request must NOT receive 303 (Pattern 5)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireAuth_PassesWhenAuth(t *testing.T) {
|
||||||
|
r := chi.NewRouter()
|
||||||
|
// Inject a valid session into context directly.
|
||||||
|
r.Use(func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := context.WithValue(r.Context(), sessionKey, &authed{
|
||||||
|
Session: &Session{ID: "test-sess"},
|
||||||
|
User: &User{ID: uuid.New()},
|
||||||
|
})
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
r.Handle("/protected", RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200 for authenticated request, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- RedirectIfAuthed tests ----------
|
||||||
|
|
||||||
|
func TestRedirectIfAuthed_BouncesWhenAuth(t *testing.T) {
|
||||||
|
r := chi.NewRouter()
|
||||||
|
r.Use(func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := context.WithValue(r.Context(), sessionKey, &authed{
|
||||||
|
Session: &Session{ID: "test-sess"},
|
||||||
|
User: &User{ID: uuid.New()},
|
||||||
|
})
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
r.Handle("/login", RedirectIfAuthed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})))
|
||||||
|
|
||||||
|
// Non-HTMX: expect 303 Location: /
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/login", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusSeeOther {
|
||||||
|
t.Errorf("expected 303 for authed user on /login, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
if loc := rr.Header().Get("Location"); loc != "/" {
|
||||||
|
t.Errorf("expected Location: /, got %q", loc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRedirectIfAuthed_HXBounceWhenAuth(t *testing.T) {
|
||||||
|
r := chi.NewRouter()
|
||||||
|
r.Use(func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := context.WithValue(r.Context(), sessionKey, &authed{
|
||||||
|
Session: &Session{ID: "test-sess"},
|
||||||
|
User: &User{ID: uuid.New()},
|
||||||
|
})
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
r.Handle("/login", RedirectIfAuthed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})))
|
||||||
|
|
||||||
|
// HTMX: expect 200 + HX-Redirect: /
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/login", nil)
|
||||||
|
req.Header.Set("HX-Request", "true")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200 for HTMX bounce, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
if hxRedir := rr.Header().Get("HX-Redirect"); hxRedir != "/" {
|
||||||
|
t.Errorf("expected HX-Redirect: /, got %q", hxRedir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRedirectIfAuthed_PassesWhenUnauth(t *testing.T) {
|
||||||
|
r := chi.NewRouter()
|
||||||
|
r.Handle("/login", RedirectIfAuthed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/login", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200 for unauthenticated user on /login, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue