xtablo-source/backend/internal/web/handlers_auth_test.go
2026-05-15 21:09:14 +02:00

1306 lines
41 KiB
Go

package web
import (
"bytes"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
"time"
"backend/internal/auth"
"backend/internal/db/sqlc"
"github.com/jackc/pgx/v5/pgtype"
)
// testCSRFKey is a fixed 32-byte key used by all test routers. It is NOT a
// real key — safe to have here because it is test-only and dev-env scoped.
var testCSRFKey = func() []byte {
key := make([]byte, 32)
for i := range key {
key[i] = byte(i + 1)
}
return key
}()
// newTestRouter builds a router backed by a real DB for integration tests.
// CSRF is enabled with a fixed test key and env="dev" (Secure=false on cookie).
// "localhost" is added as a trusted origin so httptest requests without a
// Referer header are accepted.
func newTestRouter(q *sqlc.Queries, store *auth.Store) http.Handler {
deps := AuthDeps{Queries: q, Store: store, Secure: false}
router, err := NewRouter(stubPinger{}, os.DirFS("./static"), deps, TablosDeps{Queries: q}, TasksDeps{Queries: q}, FilesDeps{Queries: q}, testCSRFKey, "dev", "localhost")
if err != nil {
panic("newTestRouter: " + err.Error())
}
return router
}
// newTestRouterWithLimiter builds a router with an injected LimiterStore,
// enabling rate-limit tests to use a fake clock.
func newTestRouterWithLimiter(q *sqlc.Queries, store *auth.Store, rl *auth.LimiterStore) http.Handler {
deps := AuthDeps{Queries: q, Store: store, Secure: false, Limiter: rl}
router, err := NewRouter(stubPinger{}, os.DirFS("./static"), deps, TablosDeps{Queries: q}, TasksDeps{Queries: q}, FilesDeps{Queries: q}, testCSRFKey, "dev", "localhost")
if err != nil {
panic("newTestRouterWithLimiter: " + err.Error())
}
return router
}
func newAuthPageRouter(t *testing.T, deps AuthDeps) http.Handler {
t.Helper()
router, err := NewRouter(stubPinger{}, os.DirFS("./static"), deps, TablosDeps{}, TasksDeps{}, FilesDeps{}, testCSRFKey, "dev", "localhost")
if err != nil {
t.Fatalf("NewRouter: %v", err)
}
return router
}
func configuredProviderDeps() AuthDeps {
return AuthDeps{
OAuth: auth.OAuthConfig{
Google: auth.GoogleProviderConfig{
ClientID: "google-client",
ClientSecret: "google-secret",
RedirectURL: "https://xtablo.test/auth/google/callback",
},
Apple: auth.AppleProviderConfig{
ClientID: "com.xtablo.web",
TeamID: "TEAMID1234",
KeyID: "KEYID1234",
PrivateKey: "apple-private-key",
RedirectURL: "https://xtablo.test/auth/apple/callback",
},
},
}
}
// getCSRFToken performs a GET request to path and extracts the CSRF token
// from the rendered form HTML. Returns the token string and any Set-Cookie
// headers (including the gorilla_csrf cookie) from the response.
func getCSRFToken(t *testing.T, router http.Handler, path string, cookies []*http.Cookie) (token string, respCookies []*http.Cookie) {
t.Helper()
req := httptest.NewRequest(http.MethodGet, path, nil)
for _, c := range cookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
body := rec.Body.String()
const needle = `name="_csrf" value="`
idx := strings.Index(body, needle)
if idx == -1 {
t.Fatalf("getCSRFToken: _csrf hidden input not found in GET %s; body snippet: %.300s", path, body)
}
rest := body[idx+len(needle):]
end := strings.Index(rest, `"`)
if end == -1 {
t.Fatalf("getCSRFToken: closing quote not found for _csrf in GET %s", path)
}
token = rest[:end]
respCookies = append(respCookies, cookies...)
for _, c := range rec.Result().Cookies() {
respCookies = append(respCookies, c)
}
return token, respCookies
}
// preInsertUser inserts a user with TestParams-hashed password directly via sqlc
// (avoids slow DefaultParams hash in test setup — W4 / Pitfall 4).
func preInsertUser(t *testing.T, ctx context.Context, q *sqlc.Queries, email, password string) sqlc.User {
t.Helper()
hash, err := auth.Hash(password, auth.TestParams)
if err != nil {
t.Fatalf("preInsertUser: hash: %v", err)
}
user, err := q.InsertUser(ctx, sqlc.InsertUserParams{
Email: strings.ToLower(email),
PasswordHash: pgtype.Text{String: hash, Valid: true},
})
if err != nil {
t.Fatalf("preInsertUser: InsertUser: %v", err)
}
return user
}
func preInsertSocialOnlyUser(t *testing.T, ctx context.Context, q *sqlc.Queries, email string) sqlc.User {
t.Helper()
user, err := q.InsertSocialUser(ctx, strings.ToLower(email))
if err != nil {
t.Fatalf("preInsertSocialOnlyUser: InsertSocialUser: %v", err)
}
return user
}
// hashCookieValue decodes a base64url cookie value and returns the hex-encoded
// SHA-256 hash — this is the session ID stored in the DB (D-05).
func hashCookieValue(t *testing.T, cookieValue string) string {
t.Helper()
raw, err := base64.RawURLEncoding.DecodeString(cookieValue)
if err != nil {
t.Fatalf("hashCookieValue: decode: %v", err)
}
sum := sha256.Sum256(raw)
return hex.EncodeToString(sum[:])
}
// getSessionCookie extracts the xtablo_session cookie from a response.
func getSessionCookie(rec *httptest.ResponseRecorder) *http.Cookie {
for _, c := range rec.Result().Cookies() {
if c.Name == auth.SessionCookieName {
return c
}
}
return nil
}
// ---- Signup Tests ----
func TestSignupProviderButtonsConfigured(t *testing.T) {
router := newAuthPageRouter(t, configuredProviderDeps())
req := httptest.NewRequest(http.MethodGet, "/signup", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
body := rec.Body.String()
for _, want := range []string{
"Continue with Google",
"Continue with Apple",
`href="/auth/google/start"`,
`href="/auth/apple/start"`,
">or<",
`name="email"`,
`name="password"`,
} {
if !strings.Contains(body, want) {
t.Fatalf("signup page missing %q; body: %s", want, body)
}
}
}
func TestSignupProviderButtonsDisabledWhenConfigMissing(t *testing.T) {
router := newAuthPageRouter(t, AuthDeps{})
req := httptest.NewRequest(http.MethodGet, "/signup", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
body := rec.Body.String()
for _, want := range []string{"Google sign-in not configured", "Apple sign-in not configured"} {
if !strings.Contains(body, want) {
t.Fatalf("signup page missing disabled copy %q; body: %s", want, body)
}
}
if strings.Contains(body, `href="/auth/google/start"`) || strings.Contains(body, `href="/auth/apple/start"`) {
t.Fatalf("disabled provider buttons must not include actionable start hrefs; body: %s", body)
}
}
func TestSignup_Success(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
csrfToken, csrfCookies := getCSRFToken(t, router, "/signup", nil)
form := url.Values{"email": {"alice@example.com"}, "password": {"correct-horse-12"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusSeeOther {
t.Fatalf("status = %d; want 303", rec.Code)
}
if loc := rec.Header().Get("Location"); loc != "/" {
t.Errorf("Location = %q; want /", loc)
}
// Cookie must be present and HttpOnly.
var sessionCookie *http.Cookie
for _, c := range rec.Result().Cookies() {
if c.Name == auth.SessionCookieName {
sessionCookie = c
break
}
}
if sessionCookie == nil {
t.Fatal("session cookie not set")
}
if !sessionCookie.HttpOnly {
t.Error("session cookie must be HttpOnly")
}
// User row must exist with an argon2id hash.
user, err := q.GetUserByEmail(ctx, "alice@example.com")
if err != nil {
t.Fatalf("GetUserByEmail: %v", err)
}
if !user.PasswordHash.Valid || !strings.HasPrefix(user.PasswordHash.String, "$argon2id$") {
t.Errorf("password_hash = %#v; want valid $argon2id$ prefix", user.PasswordHash)
}
// Session row must exist for the user.
var count int
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE user_id = $1", user.ID)
if err := row.Scan(&count); err != nil {
t.Fatalf("session count query: %v", err)
}
if count != 1 {
t.Errorf("session count = %d; want 1", count)
}
}
func TestSignup_SocialOnlyExistingUserShowsProviderMessage(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
preInsertSocialOnlyUser(t, ctx, q, "social-only@example.com")
csrfToken, csrfCookies := getCSRFToken(t, router, "/signup", nil)
form := url.Values{"email": {"social-only@example.com"}, "password": {"correct-horse-12"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusUnprocessableEntity {
t.Fatalf("status = %d; want 422", rec.Code)
}
if !strings.Contains(rec.Body.String(), "An account already exists for this email. Sign in with your provider.") {
t.Fatalf("body missing social-only signup conflict copy; got: %s", rec.Body.String())
}
}
func TestSignup_Success_HTMX(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
csrfToken, csrfCookies := getCSRFToken(t, router, "/signup", nil)
form := url.Values{"email": {"bob@example.com"}, "password": {"correct-horse-12"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("HX-Request", "true")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("HTMX status = %d; want 200", rec.Code)
}
if hxRedir := rec.Header().Get("HX-Redirect"); hxRedir != "/" {
t.Errorf("HX-Redirect = %q; want /", hxRedir)
}
}
func TestSignup_InvalidEmail(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
csrfToken, csrfCookies := getCSRFToken(t, router, "/signup", nil)
form := url.Values{"email": {"not-an-email"}, "password": {"correct-horse-12"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusUnprocessableEntity {
t.Fatalf("status = %d; want 422", rec.Code)
}
if !strings.Contains(rec.Body.String(), "valid email") {
t.Errorf("body missing 'valid email' error; got: %s", rec.Body.String())
}
// No user row must have been inserted.
ctx := context.Background()
var count int
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM users WHERE email = $1", "not-an-email")
_ = row.Scan(&count)
if count != 0 {
t.Errorf("unexpected user row inserted for invalid email")
}
}
func TestSignup_PasswordTooShort(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
// 11 chars — below the 12-char minimum.
csrfToken, csrfCookies := getCSRFToken(t, router, "/signup", nil)
form := url.Values{"email": {"carol@example.com"}, "password": {"short12345!"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusUnprocessableEntity {
t.Fatalf("status = %d; want 422", rec.Code)
}
if !strings.Contains(rec.Body.String(), "12") {
t.Errorf("body missing '12' boundary; got: %s", rec.Body.String())
}
ctx := context.Background()
var count int
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM users WHERE email = $1", "carol@example.com")
_ = row.Scan(&count)
if count != 0 {
t.Errorf("unexpected user row inserted for short password")
}
}
func TestSignup_PasswordTooLong(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
longPw := strings.Repeat("a", 129) // 129 chars — above the 128-char maximum.
csrfToken, csrfCookies := getCSRFToken(t, router, "/signup", nil)
form := url.Values{"email": {"dave@example.com"}, "password": {longPw}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusUnprocessableEntity {
t.Fatalf("status = %d; want 422", rec.Code)
}
if !strings.Contains(rec.Body.String(), "128") {
t.Errorf("body missing '128' boundary; got: %s", rec.Body.String())
}
ctx := context.Background()
var count int
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM users WHERE email = $1", "dave@example.com")
_ = row.Scan(&count)
if count != 0 {
t.Errorf("unexpected user row inserted for long password")
}
}
func TestSignup_DuplicateEmail(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
// Pre-insert a user using TestParams to avoid the slow DefaultParams hash.
preInsertUser(t, ctx, q, "eve@example.com", "correct-horse-12")
csrfToken, csrfCookies := getCSRFToken(t, router, "/signup", nil)
form := url.Values{"email": {"eve@example.com"}, "password": {"correct-horse-12"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
// Must be a client error (422) with an "already in use" message.
if rec.Code != http.StatusUnprocessableEntity {
t.Fatalf("status = %d; want 422", rec.Code)
}
if !strings.Contains(rec.Body.String(), "already in use") {
t.Errorf("body missing 'already in use'; got: %s", rec.Body.String())
}
// No second user row must have been inserted.
var count int
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM users WHERE email = $1", "eve@example.com")
if err := row.Scan(&count); err != nil {
t.Fatalf("count query: %v", err)
}
if count != 1 {
t.Errorf("user count = %d; want exactly 1 (no duplicate inserted)", count)
}
}
func TestSignup_EmailNormalized(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
// Uppercase + whitespace email — must be stored trimmed and lowercased.
csrfToken, csrfCookies := getCSRFToken(t, router, "/signup", nil)
form := url.Values{"email": {" Frank@Example.COM "}, "password": {"correct-horse-12"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusSeeOther {
t.Fatalf("status = %d; want 303", rec.Code)
}
// Stored email must not have leading/trailing whitespace.
user, err := q.GetUserByEmail(ctx, "frank@example.com")
if err != nil {
t.Fatalf("GetUserByEmail: %v", err)
}
if user.Email != "frank@example.com" {
t.Errorf("stored email = %q; want frank@example.com (trimmed + lowercased)", user.Email)
}
}
func TestSignup_AlreadyAuthedBouncesHome(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
// Pre-insert a user and create a real session.
user := preInsertUser(t, ctx, q, "grace@example.com", "correct-horse-12")
cookieValue, expiresAt, err := store.Create(ctx, user.ID)
if err != nil {
t.Fatalf("store.Create: %v", err)
}
// GET /signup with a valid session cookie must redirect to /.
req := httptest.NewRequest(http.MethodGet, "/signup", nil)
req.AddCookie(&http.Cookie{
Name: auth.SessionCookieName,
Value: cookieValue,
})
_ = expiresAt
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusSeeOther {
t.Fatalf("status = %d; want 303 (RedirectIfAuthed)", rec.Code)
}
if loc := rec.Header().Get("Location"); loc != "/" {
t.Errorf("Location = %q; want /", loc)
}
}
// ---- Login Tests ----
func TestLoginProviderButtonsConfigured(t *testing.T) {
router := newAuthPageRouter(t, configuredProviderDeps())
req := httptest.NewRequest(http.MethodGet, "/login", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
body := rec.Body.String()
for _, want := range []string{
"Continue with Google",
"Continue with Apple",
`href="/auth/google/start"`,
`href="/auth/apple/start"`,
">or<",
`name="email"`,
`name="password"`,
} {
if !strings.Contains(body, want) {
t.Fatalf("login page missing %q; body: %s", want, body)
}
}
}
func TestLoginProviderButtonsDisabledWhenConfigMissing(t *testing.T) {
router := newAuthPageRouter(t, AuthDeps{})
req := httptest.NewRequest(http.MethodGet, "/login", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
body := rec.Body.String()
for _, want := range []string{"Google sign-in not configured", "Apple sign-in not configured"} {
if !strings.Contains(body, want) {
t.Fatalf("login page missing disabled copy %q; body: %s", want, body)
}
}
if strings.Contains(body, `href="/auth/google/start"`) || strings.Contains(body, `href="/auth/apple/start"`) {
t.Fatalf("disabled provider buttons must not include actionable start hrefs; body: %s", body)
}
}
func TestLogin_Success(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
user := preInsertUser(t, ctx, q, "test@example.com", "correct-horse-12chars")
csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"test@example.com"}, "password": {"correct-horse-12chars"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusSeeOther {
t.Fatalf("status = %d; want 303", rec.Code)
}
if loc := rec.Header().Get("Location"); loc != "/" {
t.Errorf("Location = %q; want /", loc)
}
if c := getSessionCookie(rec); c == nil {
t.Fatal("session cookie not set after login")
}
// Session row must exist.
var count int
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE user_id = $1", user.ID)
if err := row.Scan(&count); err != nil {
t.Fatalf("session count query: %v", err)
}
if count != 1 {
t.Errorf("session count = %d; want 1", count)
}
}
func TestLogin_Success_HTMX(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
preInsertUser(t, ctx, q, "test2@example.com", "correct-horse-12chars")
csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"test2@example.com"}, "password": {"correct-horse-12chars"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("HX-Request", "true")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("HTMX status = %d; want 200", rec.Code)
}
if hxRedir := rec.Header().Get("HX-Redirect"); hxRedir != "/" {
t.Errorf("HX-Redirect = %q; want /", hxRedir)
}
}
func TestLogin_WrongPassword(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
preInsertUser(t, ctx, q, "testpw@example.com", "correct-horse-12chars")
csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"testpw@example.com"}, "password": {"wrong-password-12chars"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if !bytes.Contains(rec.Body.Bytes(), []byte("Invalid email or password")) {
t.Errorf("body must contain 'Invalid email or password'; got: %s", rec.Body.String())
}
if c := getSessionCookie(rec); c != nil {
t.Fatal("session cookie must NOT be set on wrong password")
}
// No session row for this user.
var count int
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE user_id IN (SELECT id FROM users WHERE email = $1)", "testpw@example.com")
_ = row.Scan(&count)
if count != 0 {
t.Errorf("session count = %d; want 0 on wrong password", count)
}
}
func TestLogin_UnknownEmail(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"nouser@example.com"}, "password": {"correct-horse-12chars"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
// D-20: exact same error string as wrong-password case.
if !bytes.Contains(rec.Body.Bytes(), []byte("Invalid email or password")) {
t.Errorf("body must contain 'Invalid email or password' for unknown email; got: %s", rec.Body.String())
}
if c := getSessionCookie(rec); c != nil {
t.Fatal("session cookie must NOT be set on unknown email")
}
}
func TestLogin_SocialOnlyUserGetsGenericInvalidCredentials(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
preInsertSocialOnlyUser(t, ctx, q, "social-login@example.com")
csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"social-login@example.com"}, "password": {"correct-horse-12chars"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("status = %d; want 401", rec.Code)
}
if !bytes.Contains(rec.Body.Bytes(), []byte("Invalid email or password")) {
t.Fatalf("body must contain generic invalid credentials; got: %s", rec.Body.String())
}
if c := getSessionCookie(rec); c != nil {
t.Fatal("session cookie must NOT be set for social-only password login")
}
}
func TestLogin_ValidationError_BadEmail(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"not-an-email"}, "password": {"correct-horse-12chars"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusUnprocessableEntity {
t.Fatalf("status = %d; want 422", rec.Code)
}
if !strings.Contains(rec.Body.String(), "valid email") {
t.Errorf("body missing 'valid email'; got: %s", rec.Body.String())
}
}
func TestLogin_ValidationError_ShortPassword(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"testval@example.com"}, "password": {"shortpw12"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusUnprocessableEntity {
t.Fatalf("status = %d; want 422", rec.Code)
}
if !strings.Contains(rec.Body.String(), "12") {
t.Errorf("body missing '12' boundary; got: %s", rec.Body.String())
}
}
func TestLogin_RotatesExistingSession(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
user := preInsertUser(t, ctx, q, "rotatetest@example.com", "correct-horse-12chars")
// Pre-create a session for this user.
oldCookieValue, _, err := store.Create(ctx, user.ID)
if err != nil {
t.Fatalf("store.Create: %v", err)
}
csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"rotatetest@example.com"}, "password": {"correct-horse-12chars"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: oldCookieValue})
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusSeeOther {
t.Fatalf("status = %d; want 303", rec.Code)
}
// New cookie value must differ from old.
newCookie := getSessionCookie(rec)
if newCookie == nil {
t.Fatal("no session cookie after login")
}
if newCookie.Value == oldCookieValue {
t.Error("session cookie value must change on login (rotation)")
}
// Old session row must be gone (rotation deletes it).
oldSessionID := hashCookieValue(t, oldCookieValue)
var count int
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE id = $1", oldSessionID)
_ = row.Scan(&count)
if count != 0 {
t.Errorf("old session row still exists after rotation; want 0")
}
// New session row must exist for the user.
row2 := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE user_id = $1", user.ID)
_ = row2.Scan(&count)
if count != 1 {
t.Errorf("new session count = %d; want 1", count)
}
}
func TestLogin_AlreadyAuthedBouncesHome(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
user := preInsertUser(t, ctx, q, "authed@example.com", "correct-horse-12chars")
cookieValue, _, err := store.Create(ctx, user.ID)
if err != nil {
t.Fatalf("store.Create: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/login", nil)
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: cookieValue})
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusSeeOther {
t.Fatalf("status = %d; want 303 (RedirectIfAuthed)", rec.Code)
}
if loc := rec.Header().Get("Location"); loc != "/" {
t.Errorf("Location = %q; want /", loc)
}
}
func TestLogin_RateLimit_6thAttemptReturns429(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
// Frozen clock so all 6 attempts happen at the same instant.
t0 := time.Now()
rl := auth.NewLimiterStoreWithClock(func() time.Time { return t0 })
router := newTestRouterWithLimiter(q, store, rl)
preInsertUser(t, ctx, q, "ratelimit@example.com", "correct-horse-12chars")
for i := 1; i <= 6; i++ {
csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"ratelimit@example.com"}, "password": {"wrong-password-12chars"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// Set RemoteAddr to a known IP so chimw.RealIP won't change it.
req.RemoteAddr = "192.168.1.1:12345"
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if i < 6 {
// Should not be 429 for the first 5 attempts.
if rec.Code == http.StatusTooManyRequests {
t.Fatalf("attempt %d: got 429 early (before 6th attempt)", i)
}
} else {
// 6th attempt must be 429.
if rec.Code != http.StatusTooManyRequests {
t.Fatalf("attempt %d: status = %d; want 429", i, rec.Code)
}
if !bytes.Contains(rec.Body.Bytes(), []byte("Too many")) {
t.Errorf("attempt %d: body missing 'Too many'; got: %s", i, rec.Body.String())
}
}
}
}
func TestLogin_RateLimit_6thAttemptHTMXNoFullPage(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
t0 := time.Now()
rl := auth.NewLimiterStoreWithClock(func() time.Time { return t0 })
router := newTestRouterWithLimiter(q, store, rl)
preInsertUser(t, ctx, q, "ratelimithtmx@example.com", "correct-horse-12chars")
for i := 1; i <= 6; i++ {
csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"ratelimithtmx@example.com"}, "password": {"wrong-password-12chars"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("HX-Request", "true")
req.RemoteAddr = "192.168.1.2:12345"
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if i == 6 {
if rec.Code != http.StatusTooManyRequests {
t.Fatalf("HTMX attempt 6: status = %d; want 429", rec.Code)
}
// For HTMX the response should be a fragment (no <html> tag).
if bytes.Contains(rec.Body.Bytes(), []byte("<html")) {
t.Error("HTMX rate-limit response must not contain full <html> page")
}
}
}
}
func TestLogin_RateLimit_KeyedByEmailPlusIP(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
t0 := time.Now()
rl := auth.NewLimiterStoreWithClock(func() time.Time { return t0 })
router := newTestRouterWithLimiter(q, store, rl)
preInsertUser(t, ctx, q, "emailA@example.com", "correct-horse-12chars")
preInsertUser(t, ctx, q, "emailB@example.com", "correct-horse-12chars")
// Exhaust emailA from IP1.
for i := 0; i < 6; i++ {
csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"emailA@example.com"}, "password": {"wrong-password-12chars"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.RemoteAddr = "10.0.0.1:1234"
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
}
// emailB from same IP1 should still be allowed (separate key).
csrfTokenB, csrfCookiesB := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"emailB@example.com"}, "password": {"wrong-password-12chars"}, "_csrf": {csrfTokenB}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.RemoteAddr = "10.0.0.1:1234"
for _, c := range csrfCookiesB {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code == http.StatusTooManyRequests {
t.Error("emailB must not be rate-limited when only emailA was exhausted (key isolation)")
}
}
func TestLogin_RateLimit_AppliesBeforeUserLookup(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := auth.NewStore(q)
t0 := time.Now()
rl := auth.NewLimiterStoreWithClock(func() time.Time { return t0 })
router := newTestRouterWithLimiter(q, store, rl)
// Use an email that does NOT exist in the DB.
for i := 0; i < 6; i++ {
csrfToken, csrfCookies := getCSRFToken(t, router, "/login", nil)
form := url.Values{"email": {"nonexistent@example.com"}, "password": {"wrong-password-12chars"}, "_csrf": {csrfToken}}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.RemoteAddr = "10.0.0.2:1234"
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if i == 5 {
// 6th attempt: must be 429 even though email doesn't exist.
if rec.Code != http.StatusTooManyRequests {
t.Fatalf("6th attempt for unknown email: status = %d; want 429 (rate gate before user lookup)", rec.Code)
}
}
}
}
// ---- Logout Tests ----
func TestLogout_Success(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
user := preInsertUser(t, ctx, q, "logout@example.com", "correct-horse-12chars")
cookieValue, _, err := store.Create(ctx, user.ID)
if err != nil {
t.Fatalf("store.Create: %v", err)
}
sessionID := hashCookieValue(t, cookieValue)
sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: cookieValue}
csrfToken, csrfCookies := getCSRFToken(t, router, "/", []*http.Cookie{sessionCookie})
req := httptest.NewRequest(http.MethodPost, "/logout", strings.NewReader(url.Values{"_csrf": {csrfToken}}.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusSeeOther {
t.Fatalf("status = %d; want 303", rec.Code)
}
if loc := rec.Header().Get("Location"); loc != "/login" {
t.Errorf("Location = %q; want /login", loc)
}
// Session cookie must be cleared (Max-Age=0 or Expires in the past).
var found *http.Cookie
for _, c := range rec.Result().Cookies() {
if c.Name == auth.SessionCookieName {
found = c
break
}
}
if found == nil {
t.Fatal("expected Set-Cookie header to clear the session cookie; none found")
}
if found.MaxAge > 0 {
t.Errorf("session cookie Max-Age = %d; want <= 0 (expired/cleared)", found.MaxAge)
}
// Session row must be hard-deleted from DB (D-06).
var count int
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE id = $1", sessionID)
if err := row.Scan(&count); err != nil {
t.Fatalf("session count query: %v", err)
}
if count != 0 {
t.Errorf("session row still exists after logout; want 0 (D-06 hard delete)")
}
}
func TestLogout_UnauthRedirectsToLogin(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
// POST /logout with NO cookie and NO CSRF token.
// gorilla/csrf runs before RequireAuth, so a missing CSRF token yields 403
// before RequireAuth can redirect to /login. Both 403 and 303 are acceptable
// here — the important invariant is that the request is rejected (not 500)
// and no logout side-effect occurs.
req := httptest.NewRequest(http.MethodPost, "/logout", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
// 403 (csrf rejected) or 303 (RequireAuth redirected) — both are correct.
if rec.Code != http.StatusForbidden && rec.Code != http.StatusSeeOther {
t.Fatalf("status = %d; want 403 (CSRF) or 303 (RequireAuth)", rec.Code)
}
}
func TestLogout_HXRedirect(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
user := preInsertUser(t, ctx, q, "logouthtmx@example.com", "correct-horse-12chars")
cookieValue, _, err := store.Create(ctx, user.ID)
if err != nil {
t.Fatalf("store.Create: %v", err)
}
sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: cookieValue}
csrfToken, csrfCookies := getCSRFToken(t, router, "/", []*http.Cookie{sessionCookie})
req := httptest.NewRequest(http.MethodPost, "/logout", strings.NewReader(url.Values{"_csrf": {csrfToken}}.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("HX-Request", "true")
for _, c := range csrfCookies {
req.AddCookie(c)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("HTMX status = %d; want 200", rec.Code)
}
if hxRedir := rec.Header().Get("HX-Redirect"); hxRedir != "/login" {
t.Errorf("HX-Redirect = %q; want /login", hxRedir)
}
}
func TestLogout_AfterLogoutSubsequentRequestUnauth(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
user := preInsertUser(t, ctx, q, "stale@example.com", "correct-horse-12chars")
cookieValue, _, err := store.Create(ctx, user.ID)
if err != nil {
t.Fatalf("store.Create: %v", err)
}
// Logout first — need to get a CSRF token from the protected page.
sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: cookieValue}
csrfToken, csrfCookies := getCSRFToken(t, router, "/", []*http.Cookie{sessionCookie})
logoutReq := httptest.NewRequest(http.MethodPost, "/logout", strings.NewReader(url.Values{"_csrf": {csrfToken}}.Encode()))
logoutReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for _, c := range csrfCookies {
logoutReq.AddCookie(c)
}
logoutRec := httptest.NewRecorder()
router.ServeHTTP(logoutRec, logoutReq)
if logoutRec.Code != http.StatusSeeOther {
t.Fatalf("logout status = %d; want 303", logoutRec.Code)
}
// Simulate attacker still holding the old cookie — GET / must redirect to /login.
followReq := httptest.NewRequest(http.MethodGet, "/", nil)
followReq.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: cookieValue})
followRec := httptest.NewRecorder()
router.ServeHTTP(followRec, followReq)
if followRec.Code != http.StatusSeeOther {
t.Fatalf("post-logout GET / status = %d; want 303 (session row deleted)", followRec.Code)
}
if loc := followRec.Header().Get("Location"); loc != "/login" {
t.Errorf("Location = %q; want /login", loc)
}
}
// ---- Protected Route Tests ----
func TestProtected_HomeUnauthRedirects(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusSeeOther {
t.Fatalf("status = %d; want 303 (unauth GET / redirects to /login)", rec.Code)
}
if loc := rec.Header().Get("Location"); loc != "/login" {
t.Errorf("Location = %q; want /login", loc)
}
}
func TestProtected_HomeUnauthHXRedirect(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("HX-Request", "true")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("HTMX status = %d; want 200", rec.Code)
}
if hxRedir := rec.Header().Get("HX-Redirect"); hxRedir != "/login" {
t.Errorf("HX-Redirect = %q; want /login", hxRedir)
}
}
func TestProtected_HomeAuthRendersUserEmail(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
user := preInsertUser(t, ctx, q, "alice@example.com", "correct-horse-12chars")
cookieValue, _, err := store.Create(ctx, user.ID)
if err != nil {
t.Fatalf("store.Create: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: cookieValue})
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d; want 200", rec.Code)
}
if !strings.Contains(rec.Body.String(), "alice@example.com") {
t.Errorf("body must contain user email 'alice@example.com'; got: %s", rec.Body.String())
}
}