From efdc16babe79fc186dfac627b1fd6b27123f34a6 Mon Sep 17 00:00:00 2001 From: Arthur Belleville Date: Thu, 14 May 2026 22:17:50 +0200 Subject: [PATCH] feat(02-04): signup handler, router wiring, and integration tests - Add handlers_auth.go: SignupPageHandler + SignupPostHandler (validate -> hash -> insert -> session -> redirect) - Add AuthDeps struct; wire argon2id hash, InsertUser, Store.Create, SetSessionCookie - Update router.go: NewRouter accepts AuthDeps; mount ResolveSession (D-24); wire /signup routes behind RedirectIfAuthed - Update cmd/web/main.go: build AuthDeps (sqlc.Queries + auth.Store + secure flag) and pass to NewRouter - Add nil-Store guard to auth.ResolveSession for Phase 1 unit-test compatibility - Update handlers_test.go: pass AuthDeps{} zero value to NewRouter (Phase 1 routes unaffected) - Add testdb_test.go: isolated-schema test helper for web package integration tests - Add handlers_auth_test.go: 8 TestSignup_* integration tests (all pass against real Postgres) --- backend/cmd/web/main.go | 9 +- backend/internal/auth/middleware.go | 10 + backend/internal/web/handlers_auth.go | 131 +++++++++ backend/internal/web/handlers_auth_test.go | 319 +++++++++++++++++++++ backend/internal/web/handlers_test.go | 6 +- backend/internal/web/router.go | 34 ++- backend/internal/web/testdb_test.go | 142 +++++++++ 7 files changed, 640 insertions(+), 11 deletions(-) create mode 100644 backend/internal/web/handlers_auth.go create mode 100644 backend/internal/web/handlers_auth_test.go create mode 100644 backend/internal/web/testdb_test.go diff --git a/backend/cmd/web/main.go b/backend/cmd/web/main.go index a008e0f..986a6ae 100644 --- a/backend/cmd/web/main.go +++ b/backend/cmd/web/main.go @@ -17,7 +17,9 @@ import ( "syscall" "time" + "backend/internal/auth" "backend/internal/db" + "backend/internal/db/sqlc" "backend/internal/web" ) @@ -54,7 +56,12 @@ func main() { os.Exit(1) } - router := web.NewRouter(pool, "./static") + q := sqlc.New(pool) + store := auth.NewStore(q) + secure := env != "development" && env != "dev" + deps := web.AuthDeps{Queries: q, Store: store, Secure: secure} + + router := web.NewRouter(pool, "./static", deps) srv := &http.Server{ Addr: ":" + port, diff --git a/backend/internal/auth/middleware.go b/backend/internal/auth/middleware.go index 6d582ce..27b8185 100644 --- a/backend/internal/auth/middleware.go +++ b/backend/internal/auth/middleware.go @@ -35,11 +35,21 @@ func Authed(ctx context.Context) (*Session, *User, bool) { // attaches them to the request context. It NEVER blocks the request — missing // or invalid sessions are silently ignored; RequireAuth enforces access. // +// When store is nil (e.g. in Phase 1 unit tests that pass a zero AuthDeps), +// the middleware is a no-op pass-through. Cookie resolution only happens when +// store is non-nil and a cookie is present. +// // On a valid session hit, MaybeExtend is called best-effort (logged but not // fatal) to implement the sliding 30-day TTL (D-09). func ResolveSession(store *Store) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Nil-store guard: Phase 1 route tests pass AuthDeps{} (zero value). + if store == nil { + next.ServeHTTP(w, r) + return + } + cookie, err := r.Cookie(SessionCookieName) if err != nil || cookie.Value == "" { // No cookie — pass through unauthenticated. diff --git a/backend/internal/web/handlers_auth.go b/backend/internal/web/handlers_auth.go new file mode 100644 index 0000000..867d2ff --- /dev/null +++ b/backend/internal/web/handlers_auth.go @@ -0,0 +1,131 @@ +package web + +import ( + "errors" + "net/http" + "net/mail" + "strings" + + "backend/internal/auth" + "backend/internal/db/sqlc" + "backend/templates" + + "github.com/jackc/pgx/v5/pgconn" +) + +// AuthDeps holds the dependencies shared by all auth handlers. +// Secure should be true in all environments except "dev"/"development"; +// it gates the cookie Secure attribute (D-12). +type AuthDeps struct { + Queries *sqlc.Queries + Store *auth.Store + Secure bool +} + +// SignupPageHandler renders the GET /signup page with an empty form. +func SignupPageHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + _ = templates.SignupPage(templates.SignupForm{}, templates.SignupErrors{}).Render(r.Context(), w) + } +} + +// SignupPostHandler handles POST /signup: validate → hash → insert → create session → redirect. +// +// Security invariants (threat model): +// - Password length is validated BEFORE calling auth.Hash to prevent long-password DoS (T-2-14). +// - The raw password is never passed to any template (T-2-01). +// - Email is not logged on validation errors (T-2-18). +// - Duplicate email is detected via pgconn error code 23505 (T-2-19). +// - A fresh session token is created on every signup (T-2-04). +func SignupPostHandler(deps AuthDeps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // 1. Read form values. + email := strings.TrimSpace(r.PostFormValue("email")) + password := r.PostFormValue("password") + + var errs templates.SignupErrors + + // 2. Validate email. + if _, err := mail.ParseAddress(email); err != nil { + errs.Email = "Enter a valid email address" + } + + // 3. Validate password length BEFORE calling Hash (DoS guard T-2-14, D-25). + if len(password) < 12 { + errs.Password = "Password must be at least 12 characters" + } else if len(password) > 128 { + errs.Password = "Password must be at most 128 characters" + } + + if errs.Email != "" || errs.Password != "" { + // Re-populate the email field but NOT the password (T-2-01). + renderSignupError(w, r, templates.SignupForm{Email: email}, errs, http.StatusUnprocessableEntity) + return + } + + // 4. Normalize email (lowercase) before insert; citext handles case-insensitive + // uniqueness at DB level but we store canonical lowercase for consistency (D-01). + normalized := strings.ToLower(email) + + // 5. Hash password with production cost parameters. + hash, err := auth.Hash(password, auth.DefaultParams) + if err != nil { + http.Error(w, "internal server error", http.StatusInternalServerError) + return + } + + // 6. Insert user row. + user, err := deps.Queries.InsertUser(ctx, sqlc.InsertUserParams{ + Email: normalized, + PasswordHash: hash, + }) + if err != nil { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && pgErr.Code == "23505" { + // Unique-constraint violation on email (T-2-19). + // Specific error message is acceptable on signup per CONTEXT.md specifics. + errs.Email = "That email is already in use." + renderSignupError(w, r, templates.SignupForm{Email: email}, errs, http.StatusUnprocessableEntity) + return + } + http.Error(w, "internal server error", http.StatusInternalServerError) + return + } + + // 7. Create session (D-10: fresh token on every signup auto-login). + cookieValue, expiresAt, err := deps.Store.Create(ctx, user.ID) + if err != nil { + http.Error(w, "internal server error", http.StatusInternalServerError) + return + } + + // 8. Set session cookie (D-12). + auth.SetSessionCookie(w, cookieValue, expiresAt, deps.Secure) + + // 9. Redirect to home (D-11, D-21). + // HTMX form submissions receive HX-Redirect so HTMX handles navigation client-side. + // Plain (no-JS) form submissions receive 303 See Other (NOT 302 — Pitfall 9). + if r.Header.Get("HX-Request") == "true" { + w.Header().Set("HX-Redirect", "/") + w.WriteHeader(http.StatusOK) + return + } + http.Redirect(w, r, "/", http.StatusSeeOther) + } +} + +// renderSignupError writes a validation-error response. +// For HTMX requests it renders only the form fragment; for plain requests it +// renders the full page (D-19, D-25). +func renderSignupError(w http.ResponseWriter, r *http.Request, form templates.SignupForm, errs templates.SignupErrors, status int) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(status) + if r.Header.Get("HX-Request") == "true" { + _ = templates.SignupFormFragment(form, errs).Render(r.Context(), w) + } else { + _ = templates.SignupPage(form, errs).Render(r.Context(), w) + } +} diff --git a/backend/internal/web/handlers_auth_test.go b/backend/internal/web/handlers_auth_test.go new file mode 100644 index 0000000..90a88e1 --- /dev/null +++ b/backend/internal/web/handlers_auth_test.go @@ -0,0 +1,319 @@ +package web + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "backend/internal/auth" + "backend/internal/db/sqlc" +) + +// newTestRouter builds a router backed by a real DB for integration tests. +func newTestRouter(q *sqlc.Queries, store *auth.Store) http.Handler { + deps := AuthDeps{Queries: q, Store: store, Secure: false} + return NewRouter(stubPinger{}, "./static", deps) +} + +// preInsertUser inserts a user with TestParams-hashed password directly via sqlc +// (avoids slow DefaultParams hash in test setup — W4 / Pitfall 4). +func preInsertUser(t *testing.T, ctx context.Context, q *sqlc.Queries, email, password string) sqlc.User { + t.Helper() + hash, err := auth.Hash(password, auth.TestParams) + if err != nil { + t.Fatalf("preInsertUser: hash: %v", err) + } + user, err := q.InsertUser(ctx, sqlc.InsertUserParams{ + Email: strings.ToLower(email), + PasswordHash: hash, + }) + if err != nil { + t.Fatalf("preInsertUser: InsertUser: %v", err) + } + return user +} + +func TestSignup_Success(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + ctx := context.Background() + q := sqlc.New(pool) + store := auth.NewStore(q) + router := newTestRouter(q, store) + + form := url.Values{"email": {"alice@example.com"}, "password": {"correct-horse-12"}} + req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusSeeOther { + t.Fatalf("status = %d; want 303", rec.Code) + } + if loc := rec.Header().Get("Location"); loc != "/" { + t.Errorf("Location = %q; want /", loc) + } + + // Cookie must be present and HttpOnly. + var sessionCookie *http.Cookie + for _, c := range rec.Result().Cookies() { + if c.Name == auth.SessionCookieName { + sessionCookie = c + break + } + } + if sessionCookie == nil { + t.Fatal("session cookie not set") + } + if !sessionCookie.HttpOnly { + t.Error("session cookie must be HttpOnly") + } + + // User row must exist with an argon2id hash. + user, err := q.GetUserByEmail(ctx, "alice@example.com") + if err != nil { + t.Fatalf("GetUserByEmail: %v", err) + } + if !strings.HasPrefix(user.PasswordHash, "$argon2id$") { + t.Errorf("password_hash = %q; want $argon2id$ prefix", user.PasswordHash) + } + + // Session row must exist for the user. + var count int + row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE user_id = $1", user.ID) + if err := row.Scan(&count); err != nil { + t.Fatalf("session count query: %v", err) + } + if count != 1 { + t.Errorf("session count = %d; want 1", count) + } +} + +func TestSignup_Success_HTMX(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + q := sqlc.New(pool) + store := auth.NewStore(q) + router := newTestRouter(q, store) + + form := url.Values{"email": {"bob@example.com"}, "password": {"correct-horse-12"}} + req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("HX-Request", "true") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("HTMX status = %d; want 200", rec.Code) + } + if hxRedir := rec.Header().Get("HX-Redirect"); hxRedir != "/" { + t.Errorf("HX-Redirect = %q; want /", hxRedir) + } +} + +func TestSignup_InvalidEmail(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + q := sqlc.New(pool) + store := auth.NewStore(q) + router := newTestRouter(q, store) + + form := url.Values{"email": {"not-an-email"}, "password": {"correct-horse-12"}} + req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnprocessableEntity { + t.Fatalf("status = %d; want 422", rec.Code) + } + if !strings.Contains(rec.Body.String(), "valid email") { + t.Errorf("body missing 'valid email' error; got: %s", rec.Body.String()) + } + + // No user row must have been inserted. + ctx := context.Background() + var count int + row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM users WHERE email = $1", "not-an-email") + _ = row.Scan(&count) + if count != 0 { + t.Errorf("unexpected user row inserted for invalid email") + } +} + +func TestSignup_PasswordTooShort(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + q := sqlc.New(pool) + store := auth.NewStore(q) + router := newTestRouter(q, store) + + // 11 chars — below the 12-char minimum. + form := url.Values{"email": {"carol@example.com"}, "password": {"short12345!"}} + req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnprocessableEntity { + t.Fatalf("status = %d; want 422", rec.Code) + } + if !strings.Contains(rec.Body.String(), "12") { + t.Errorf("body missing '12' boundary; got: %s", rec.Body.String()) + } + + ctx := context.Background() + var count int + row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM users WHERE email = $1", "carol@example.com") + _ = row.Scan(&count) + if count != 0 { + t.Errorf("unexpected user row inserted for short password") + } +} + +func TestSignup_PasswordTooLong(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + q := sqlc.New(pool) + store := auth.NewStore(q) + router := newTestRouter(q, store) + + longPw := strings.Repeat("a", 129) // 129 chars — above the 128-char maximum. + form := url.Values{"email": {"dave@example.com"}, "password": {longPw}} + req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnprocessableEntity { + t.Fatalf("status = %d; want 422", rec.Code) + } + if !strings.Contains(rec.Body.String(), "128") { + t.Errorf("body missing '128' boundary; got: %s", rec.Body.String()) + } + + ctx := context.Background() + var count int + row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM users WHERE email = $1", "dave@example.com") + _ = row.Scan(&count) + if count != 0 { + t.Errorf("unexpected user row inserted for long password") + } +} + +func TestSignup_DuplicateEmail(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + ctx := context.Background() + q := sqlc.New(pool) + store := auth.NewStore(q) + router := newTestRouter(q, store) + + // Pre-insert a user using TestParams to avoid the slow DefaultParams hash. + preInsertUser(t, ctx, q, "eve@example.com", "correct-horse-12") + + form := url.Values{"email": {"eve@example.com"}, "password": {"correct-horse-12"}} + req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + // Must be a client error (422) with an "already in use" message. + if rec.Code != http.StatusUnprocessableEntity { + t.Fatalf("status = %d; want 422", rec.Code) + } + if !strings.Contains(rec.Body.String(), "already in use") { + t.Errorf("body missing 'already in use'; got: %s", rec.Body.String()) + } + + // No second user row must have been inserted. + var count int + row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM users WHERE email = $1", "eve@example.com") + if err := row.Scan(&count); err != nil { + t.Fatalf("count query: %v", err) + } + if count != 1 { + t.Errorf("user count = %d; want exactly 1 (no duplicate inserted)", count) + } +} + +func TestSignup_EmailNormalized(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + ctx := context.Background() + q := sqlc.New(pool) + store := auth.NewStore(q) + router := newTestRouter(q, store) + + // Uppercase + whitespace email — must be stored trimmed and lowercased. + form := url.Values{"email": {" Frank@Example.COM "}, "password": {"correct-horse-12"}} + req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusSeeOther { + t.Fatalf("status = %d; want 303", rec.Code) + } + + // Stored email must not have leading/trailing whitespace. + user, err := q.GetUserByEmail(ctx, "frank@example.com") + if err != nil { + t.Fatalf("GetUserByEmail: %v", err) + } + if user.Email != "frank@example.com" { + t.Errorf("stored email = %q; want frank@example.com (trimmed + lowercased)", user.Email) + } +} + +func TestSignup_AlreadyAuthedBouncesHome(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + ctx := context.Background() + q := sqlc.New(pool) + store := auth.NewStore(q) + router := newTestRouter(q, store) + + // Pre-insert a user and create a real session. + user := preInsertUser(t, ctx, q, "grace@example.com", "correct-horse-12") + cookieValue, expiresAt, err := store.Create(ctx, user.ID) + if err != nil { + t.Fatalf("store.Create: %v", err) + } + + // GET /signup with a valid session cookie must redirect to /. + req := httptest.NewRequest(http.MethodGet, "/signup", nil) + req.AddCookie(&http.Cookie{ + Name: auth.SessionCookieName, + Value: cookieValue, + }) + _ = expiresAt + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusSeeOther { + t.Fatalf("status = %d; want 303 (RedirectIfAuthed)", rec.Code) + } + if loc := rec.Header().Get("Location"); loc != "/" { + t.Errorf("Location = %q; want /", loc) + } +} diff --git a/backend/internal/web/handlers_test.go b/backend/internal/web/handlers_test.go index f8b3df6..e234e4a 100644 --- a/backend/internal/web/handlers_test.go +++ b/backend/internal/web/handlers_test.go @@ -61,7 +61,7 @@ func TestHealthz_Down(t *testing.T) { } func TestIndex_RendersHxGet(t *testing.T) { - router := NewRouter(stubPinger{}, "./static") + router := NewRouter(stubPinger{}, "./static", AuthDeps{}) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -87,7 +87,7 @@ func TestIndex_RendersHxGet(t *testing.T) { } func TestDemoTime_Fragment(t *testing.T) { - router := NewRouter(stubPinger{}, "./static") + router := NewRouter(stubPinger{}, "./static", AuthDeps{}) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/demo/time", nil) @@ -110,7 +110,7 @@ func TestDemoTime_Fragment(t *testing.T) { } func TestRequestID_HeaderSet(t *testing.T) { - router := NewRouter(stubPinger{}, "./static") + router := NewRouter(stubPinger{}, "./static", AuthDeps{}) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/healthz", nil) diff --git a/backend/internal/web/router.go b/backend/internal/web/router.go index 4f76371..f271d63 100644 --- a/backend/internal/web/router.go +++ b/backend/internal/web/router.go @@ -6,6 +6,8 @@ import ( "net/http" "time" + "backend/internal/auth" + "github.com/go-chi/chi/v5" chimw "github.com/go-chi/chi/v5/middleware" ) @@ -18,22 +20,40 @@ type Pinger interface { } // NewRouter constructs the chi router with the middleware stack locked by -// CONTEXT D-08 + RESEARCH Pattern 2: +// CONTEXT D-24: // -// 1. RequestIDMiddleware (UUIDv4 — NOT chi's base32 RequestID) -// 2. chi RealIP -// 3. SlogLoggerMiddleware (REPLACES chi's middleware.Logger — Pitfall 6) -// 4. chi Recoverer (after Logger so panics carry request_id) +// 1. RequestIDMiddleware (UUIDv4 — NOT chi's base32 RequestID) +// 2. chi RealIP +// 3. SlogLoggerMiddleware (REPLACES chi's middleware.Logger — Pitfall 6) +// 4. chi Recoverer (after Logger so panics carry request_id) +// 5. auth.ResolveSession (reads session cookie, attaches user to context) +// NOTE: csrf.Protect is added in Plan 07. // -// Routes (Phase 1 only): GET / · GET /healthz · GET /demo/time · GET /static/*. +// Routes: GET / · GET /healthz · GET /demo/time · GET /static/* +// GET /signup (auth pages, behind RedirectIfAuthed) · POST /signup. // staticDir is the on-disk path served at /static/*; path traversal is // blocked by http.Dir's default behavior (T-01-08). -func NewRouter(pinger Pinger, staticDir string) http.Handler { +// +// deps.Store may be nil during unit tests for Phase 1 routes (those routes +// never exercise session resolution). ResolveSession guards against nil Store. +func NewRouter(pinger Pinger, staticDir string, deps AuthDeps) http.Handler { r := chi.NewRouter() r.Use(RequestIDMiddleware) r.Use(chimw.RealIP) r.Use(SlogLoggerMiddleware(slog.Default())) r.Use(chimw.Recoverer) + r.Use(auth.ResolveSession(deps.Store)) + + // Auth pages — redirect to / if already authenticated. + r.Group(func(r chi.Router) { + r.Use(auth.RedirectIfAuthed) + r.Get("/signup", SignupPageHandler()) + }) + + // Signup POST is intentionally outside the RedirectIfAuthed group: + // an authed user submitting the form directly should still get a useful + // response; the GET guard handles the common case. + r.Post("/signup", SignupPostHandler(deps)) r.Get("/", IndexHandler()) r.Get("/healthz", HealthzHandler(pinger)) diff --git a/backend/internal/web/testdb_test.go b/backend/internal/web/testdb_test.go new file mode 100644 index 0000000..2ffb60b --- /dev/null +++ b/backend/internal/web/testdb_test.go @@ -0,0 +1,142 @@ +package web + +// testdb_test.go exposes a setupTestDB helper for the web package integration +// tests. The implementation is a verbatim copy of the auth package's +// setupTestDB (~20 LOC kernel) so the web package does not import a test-only +// function from another package (Go does not allow importing _test.go files). +// +// Decision: duplication chosen over a shared non-test helper to keep the +// coupling surface minimal and avoid moving test infra into production code. +// Recorded in 02-04-SUMMARY.md. + +import ( + "context" + "database/sql" + "fmt" + "os" + "path/filepath" + "runtime" + "sync" + "testing" + "time" + + "github.com/google/uuid" + _ "github.com/jackc/pgx/v5/stdlib" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/pressly/goose/v3" +) + +func migrationsDir() string { + _, filename, _, _ := runtime.Caller(0) + // filename: .../backend/internal/web/testdb_test.go + // migrations: .../backend/migrations + return filepath.Join(filepath.Dir(filename), "..", "..", "migrations") +} + +var webGooseMu sync.Mutex + +func setupTestDB(t *testing.T) (*pgxpool.Pool, func()) { + t.Helper() + + dsn := os.Getenv("TEST_DATABASE_URL") + if dsn == "" { + dsn = os.Getenv("DATABASE_URL") + } + if dsn == "" { + t.Skip("TEST_DATABASE_URL (or DATABASE_URL) not set — integration test skipped") + return nil, nil + } + + rawID := uuid.New().String()[:12] + cleanID := make([]byte, len(rawID)) + for i := 0; i < len(rawID); i++ { + if rawID[i] == '-' { + cleanID[i] = '_' + } else { + cleanID[i] = rawID[i] + } + } + schemaName := "test_" + string(cleanID) + + bootstrapDB, err := sql.Open("pgx", dsn) + if err != nil { + t.Fatalf("setupTestDB: sql.Open bootstrap: %v", err) + } + if err := bootstrapDB.Ping(); err != nil { + bootstrapDB.Close() + t.Fatalf("setupTestDB: ping: %v", err) + } + if _, err := bootstrapDB.Exec(fmt.Sprintf("CREATE SCHEMA %q", schemaName)); err != nil { + bootstrapDB.Close() + t.Fatalf("setupTestDB: CREATE SCHEMA: %v", err) + } + bootstrapDB.Close() + + sep := "?" + for i := 0; i < len(dsn); i++ { + if dsn[i] == '?' { + sep = "&" + break + } + } + schemaDSN := fmt.Sprintf("%s%ssearch_path=%s,public", dsn, sep, schemaName) + + versionTable := schemaName + "_goose_version" + { + webGooseMu.Lock() + defer webGooseMu.Unlock() + + prevTable := goose.TableName() + goose.SetTableName(versionTable) + defer goose.SetTableName(prevTable) + + goose.SetBaseFS(nil) + if err := goose.SetDialect("postgres"); err != nil { + webDropSchema(dsn, schemaName) + t.Fatalf("setupTestDB: goose.SetDialect: %v", err) + } + + schemaDB, err := sql.Open("pgx", schemaDSN) + if err != nil { + webDropSchema(dsn, schemaName) + t.Fatalf("setupTestDB: sql.Open schema-scoped: %v", err) + } + if err := goose.Up(schemaDB, migrationsDir()); err != nil { + schemaDB.Close() + webDropSchema(dsn, schemaName) + t.Fatalf("setupTestDB: goose.Up: %v", err) + } + schemaDB.Close() + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + cfg, err := pgxpool.ParseConfig(schemaDSN) + if err != nil { + webDropSchema(dsn, schemaName) + t.Fatalf("setupTestDB: pgxpool.ParseConfig: %v", err) + } + cfg.MaxConns = 5 + + pool, err := pgxpool.NewWithConfig(ctx, cfg) + if err != nil { + webDropSchema(dsn, schemaName) + t.Fatalf("setupTestDB: pgxpool.NewWithConfig: %v", err) + } + + cleanup := func() { + pool.Close() + webDropSchema(dsn, schemaName) + } + return pool, cleanup +} + +func webDropSchema(dsn, schemaName string) { + db, err := sql.Open("pgx", dsn) + if err != nil { + return + } + defer db.Close() + db.Exec(fmt.Sprintf("DROP SCHEMA IF EXISTS %q CASCADE", schemaName)) //nolint:errcheck +}