xtablo-source/backend/internal/auth/middleware_test.go

255 lines
7.1 KiB
Go
Raw Normal View History

package auth
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"backend/internal/db/sqlc"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
)
// ---------- ResolveSession tests ----------
func TestResolveSession_NoCookie(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := NewStore(q)
var gotOk bool
handler := ResolveSession(store)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _, gotOk = Authed(r.Context())
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200, got %d", rr.Code)
}
if gotOk {
t.Error("Authed(ctx) should return false when no cookie present")
}
}
func TestResolveSession_InvalidCookie(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := NewStore(q)
var gotOk bool
handler := ResolveSession(store)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _, gotOk = Authed(r.Context())
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: "garbage-invalid-cookie-value-xyz"})
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200, got %d", rr.Code)
}
if gotOk {
t.Error("Authed(ctx) should return false when cookie is invalid/not found")
}
}
func TestResolveSession_ValidCookie(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := NewStore(q)
ctx := context.Background()
userID := mustInsertUser(t, pool)
cookieValue, _, err := store.Create(ctx, userID)
if err != nil {
t.Fatalf("Create session: %v", err)
}
var gotSess *Session
var gotUser *User
var gotOk bool
handler := ResolveSession(store)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotSess, gotUser, gotOk = Authed(r.Context())
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: cookieValue})
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200, got %d", rr.Code)
}
if !gotOk {
t.Fatal("Authed(ctx) should return true for valid session cookie")
}
if gotUser == nil || gotUser.ID != userID {
t.Errorf("user.ID %v != expected %v", gotUser, userID)
}
if gotSess == nil || gotSess.ID == "" {
t.Error("session should be non-nil with non-empty ID")
}
}
// ---------- RequireAuth tests ----------
func TestRequireAuth_303WhenUnauth(t *testing.T) {
r := chi.NewRouter()
r.Use(func(next http.Handler) http.Handler {
// no session in ctx (no ResolveSession)
return next
})
r.Handle("/protected", RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})))
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
if rr.Code != http.StatusSeeOther {
t.Errorf("expected 303 SeeOther, got %d", rr.Code)
}
if loc := rr.Header().Get("Location"); loc != "/login" {
t.Errorf("expected Location: /login, got %q", loc)
}
}
func TestRequireAuth_HXRedirectWhenUnauth(t *testing.T) {
r := chi.NewRouter()
r.Handle("/protected", RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})))
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("HX-Request", "true")
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
// HTMX: must return 200 with HX-Redirect header, NOT 303 (Pattern 5).
if rr.Code != http.StatusOK {
t.Errorf("expected 200 for HTMX unauth, got %d", rr.Code)
}
if hxRedir := rr.Header().Get("HX-Redirect"); hxRedir != "/login" {
t.Errorf("expected HX-Redirect: /login, got %q", hxRedir)
}
// Must NOT be a 303 redirect.
if rr.Code == http.StatusSeeOther {
t.Error("HTMX request must NOT receive 303 (Pattern 5)")
}
}
func TestRequireAuth_PassesWhenAuth(t *testing.T) {
r := chi.NewRouter()
// Inject a valid session into context directly.
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), sessionKey, &authed{
Session: &Session{ID: "test-sess"},
User: &User{ID: uuid.New()},
})
next.ServeHTTP(w, r.WithContext(ctx))
})
})
r.Handle("/protected", RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})))
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200 for authenticated request, got %d", rr.Code)
}
}
// ---------- RedirectIfAuthed tests ----------
func TestRedirectIfAuthed_BouncesWhenAuth(t *testing.T) {
r := chi.NewRouter()
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), sessionKey, &authed{
Session: &Session{ID: "test-sess"},
User: &User{ID: uuid.New()},
})
next.ServeHTTP(w, r.WithContext(ctx))
})
})
r.Handle("/login", RedirectIfAuthed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})))
// Non-HTMX: expect 303 Location: /
req := httptest.NewRequest(http.MethodGet, "/login", nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
if rr.Code != http.StatusSeeOther {
t.Errorf("expected 303 for authed user on /login, got %d", rr.Code)
}
if loc := rr.Header().Get("Location"); loc != "/" {
t.Errorf("expected Location: /, got %q", loc)
}
}
func TestRedirectIfAuthed_HXBounceWhenAuth(t *testing.T) {
r := chi.NewRouter()
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), sessionKey, &authed{
Session: &Session{ID: "test-sess"},
User: &User{ID: uuid.New()},
})
next.ServeHTTP(w, r.WithContext(ctx))
})
})
r.Handle("/login", RedirectIfAuthed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})))
// HTMX: expect 200 + HX-Redirect: /
req := httptest.NewRequest(http.MethodGet, "/login", nil)
req.Header.Set("HX-Request", "true")
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200 for HTMX bounce, got %d", rr.Code)
}
if hxRedir := rr.Header().Get("HX-Redirect"); hxRedir != "/" {
t.Errorf("expected HX-Redirect: /, got %q", hxRedir)
}
}
func TestRedirectIfAuthed_PassesWhenUnauth(t *testing.T) {
r := chi.NewRouter()
r.Handle("/login", RedirectIfAuthed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})))
req := httptest.NewRequest(http.MethodGet, "/login", nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200 for unauthenticated user on /login, got %d", rr.Code)
}
}