package auth import ( "context" "net/http" "net/http/httptest" "testing" "backend/internal/db/sqlc" "github.com/go-chi/chi/v5" "github.com/google/uuid" ) // ---------- ResolveSession tests ---------- func TestResolveSession_NoCookie(t *testing.T) { pool, cleanup := setupTestDB(t) defer cleanup() q := sqlc.New(pool) store := NewStore(q) var gotOk bool handler := ResolveSession(store)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _, gotOk = Authed(r.Context()) w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected 200, got %d", rr.Code) } if gotOk { t.Error("Authed(ctx) should return false when no cookie present") } } func TestResolveSession_InvalidCookie(t *testing.T) { pool, cleanup := setupTestDB(t) defer cleanup() q := sqlc.New(pool) store := NewStore(q) var gotOk bool handler := ResolveSession(store)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _, gotOk = Authed(r.Context()) w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: "garbage-invalid-cookie-value-xyz"}) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected 200, got %d", rr.Code) } if gotOk { t.Error("Authed(ctx) should return false when cookie is invalid/not found") } } func TestResolveSession_ValidCookie(t *testing.T) { pool, cleanup := setupTestDB(t) defer cleanup() q := sqlc.New(pool) store := NewStore(q) ctx := context.Background() userID := mustInsertUser(t, pool) cookieValue, _, err := store.Create(ctx, userID) if err != nil { t.Fatalf("Create session: %v", err) } var gotSess *Session var gotUser *User var gotOk bool handler := ResolveSession(store)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotSess, gotUser, gotOk = Authed(r.Context()) w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: cookieValue}) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected 200, got %d", rr.Code) } if !gotOk { t.Fatal("Authed(ctx) should return true for valid session cookie") } if gotUser == nil || gotUser.ID != userID { t.Errorf("user.ID %v != expected %v", gotUser, userID) } if gotSess == nil || gotSess.ID == "" { t.Error("session should be non-nil with non-empty ID") } } // ---------- RequireAuth tests ---------- func TestRequireAuth_303WhenUnauth(t *testing.T) { r := chi.NewRouter() r.Use(func(next http.Handler) http.Handler { // no session in ctx (no ResolveSession) return next }) r.Handle("/protected", RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }))) req := httptest.NewRequest(http.MethodGet, "/protected", nil) rr := httptest.NewRecorder() r.ServeHTTP(rr, req) if rr.Code != http.StatusSeeOther { t.Errorf("expected 303 SeeOther, got %d", rr.Code) } if loc := rr.Header().Get("Location"); loc != "/login" { t.Errorf("expected Location: /login, got %q", loc) } } func TestRequireAuth_HXRedirectWhenUnauth(t *testing.T) { r := chi.NewRouter() r.Handle("/protected", RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }))) req := httptest.NewRequest(http.MethodGet, "/protected", nil) req.Header.Set("HX-Request", "true") rr := httptest.NewRecorder() r.ServeHTTP(rr, req) // HTMX: must return 200 with HX-Redirect header, NOT 303 (Pattern 5). if rr.Code != http.StatusOK { t.Errorf("expected 200 for HTMX unauth, got %d", rr.Code) } if hxRedir := rr.Header().Get("HX-Redirect"); hxRedir != "/login" { t.Errorf("expected HX-Redirect: /login, got %q", hxRedir) } // Must NOT be a 303 redirect. if rr.Code == http.StatusSeeOther { t.Error("HTMX request must NOT receive 303 (Pattern 5)") } } func TestRequireAuth_PassesWhenAuth(t *testing.T) { r := chi.NewRouter() // Inject a valid session into context directly. r.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), sessionKey, &authed{ Session: &Session{ID: "test-sess"}, User: &User{ID: uuid.New()}, }) next.ServeHTTP(w, r.WithContext(ctx)) }) }) r.Handle("/protected", RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }))) req := httptest.NewRequest(http.MethodGet, "/protected", nil) rr := httptest.NewRecorder() r.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected 200 for authenticated request, got %d", rr.Code) } } // ---------- RedirectIfAuthed tests ---------- func TestRedirectIfAuthed_BouncesWhenAuth(t *testing.T) { r := chi.NewRouter() r.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), sessionKey, &authed{ Session: &Session{ID: "test-sess"}, User: &User{ID: uuid.New()}, }) next.ServeHTTP(w, r.WithContext(ctx)) }) }) r.Handle("/login", RedirectIfAuthed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }))) // Non-HTMX: expect 303 Location: / req := httptest.NewRequest(http.MethodGet, "/login", nil) rr := httptest.NewRecorder() r.ServeHTTP(rr, req) if rr.Code != http.StatusSeeOther { t.Errorf("expected 303 for authed user on /login, got %d", rr.Code) } if loc := rr.Header().Get("Location"); loc != "/" { t.Errorf("expected Location: /, got %q", loc) } } func TestRedirectIfAuthed_HXBounceWhenAuth(t *testing.T) { r := chi.NewRouter() r.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), sessionKey, &authed{ Session: &Session{ID: "test-sess"}, User: &User{ID: uuid.New()}, }) next.ServeHTTP(w, r.WithContext(ctx)) }) }) r.Handle("/login", RedirectIfAuthed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }))) // HTMX: expect 200 + HX-Redirect: / req := httptest.NewRequest(http.MethodGet, "/login", nil) req.Header.Set("HX-Request", "true") rr := httptest.NewRecorder() r.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected 200 for HTMX bounce, got %d", rr.Code) } if hxRedir := rr.Header().Get("HX-Redirect"); hxRedir != "/" { t.Errorf("expected HX-Redirect: /, got %q", hxRedir) } } func TestRedirectIfAuthed_PassesWhenUnauth(t *testing.T) { r := chi.NewRouter() r.Handle("/login", RedirectIfAuthed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }))) req := httptest.NewRequest(http.MethodGet, "/login", nil) rr := httptest.NewRecorder() r.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected 200 for unauthenticated user on /login, got %d", rr.Code) } }