feat(02-04): signup handler, router wiring, and integration tests

- Add handlers_auth.go: SignupPageHandler + SignupPostHandler (validate -> hash -> insert -> session -> redirect)
- Add AuthDeps struct; wire argon2id hash, InsertUser, Store.Create, SetSessionCookie
- Update router.go: NewRouter accepts AuthDeps; mount ResolveSession (D-24); wire /signup routes behind RedirectIfAuthed
- Update cmd/web/main.go: build AuthDeps (sqlc.Queries + auth.Store + secure flag) and pass to NewRouter
- Add nil-Store guard to auth.ResolveSession for Phase 1 unit-test compatibility
- Update handlers_test.go: pass AuthDeps{} zero value to NewRouter (Phase 1 routes unaffected)
- Add testdb_test.go: isolated-schema test helper for web package integration tests
- Add handlers_auth_test.go: 8 TestSignup_* integration tests (all pass against real Postgres)
This commit is contained in:
Arthur Belleville 2026-05-14 22:17:50 +02:00
parent 73935ed11c
commit efdc16babe
No known key found for this signature in database
7 changed files with 640 additions and 11 deletions

View file

@ -17,7 +17,9 @@ import (
"syscall"
"time"
"backend/internal/auth"
"backend/internal/db"
"backend/internal/db/sqlc"
"backend/internal/web"
)
@ -54,7 +56,12 @@ func main() {
os.Exit(1)
}
router := web.NewRouter(pool, "./static")
q := sqlc.New(pool)
store := auth.NewStore(q)
secure := env != "development" && env != "dev"
deps := web.AuthDeps{Queries: q, Store: store, Secure: secure}
router := web.NewRouter(pool, "./static", deps)
srv := &http.Server{
Addr: ":" + port,

View file

@ -35,11 +35,21 @@ func Authed(ctx context.Context) (*Session, *User, bool) {
// attaches them to the request context. It NEVER blocks the request — missing
// or invalid sessions are silently ignored; RequireAuth enforces access.
//
// When store is nil (e.g. in Phase 1 unit tests that pass a zero AuthDeps),
// the middleware is a no-op pass-through. Cookie resolution only happens when
// store is non-nil and a cookie is present.
//
// On a valid session hit, MaybeExtend is called best-effort (logged but not
// fatal) to implement the sliding 30-day TTL (D-09).
func ResolveSession(store *Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Nil-store guard: Phase 1 route tests pass AuthDeps{} (zero value).
if store == nil {
next.ServeHTTP(w, r)
return
}
cookie, err := r.Cookie(SessionCookieName)
if err != nil || cookie.Value == "" {
// No cookie — pass through unauthenticated.

View file

@ -0,0 +1,131 @@
package web
import (
"errors"
"net/http"
"net/mail"
"strings"
"backend/internal/auth"
"backend/internal/db/sqlc"
"backend/templates"
"github.com/jackc/pgx/v5/pgconn"
)
// AuthDeps holds the dependencies shared by all auth handlers.
// Secure should be true in all environments except "dev"/"development";
// it gates the cookie Secure attribute (D-12).
type AuthDeps struct {
Queries *sqlc.Queries
Store *auth.Store
Secure bool
}
// SignupPageHandler renders the GET /signup page with an empty form.
func SignupPageHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
_ = templates.SignupPage(templates.SignupForm{}, templates.SignupErrors{}).Render(r.Context(), w)
}
}
// SignupPostHandler handles POST /signup: validate → hash → insert → create session → redirect.
//
// Security invariants (threat model):
// - Password length is validated BEFORE calling auth.Hash to prevent long-password DoS (T-2-14).
// - The raw password is never passed to any template (T-2-01).
// - Email is not logged on validation errors (T-2-18).
// - Duplicate email is detected via pgconn error code 23505 (T-2-19).
// - A fresh session token is created on every signup (T-2-04).
func SignupPostHandler(deps AuthDeps) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// 1. Read form values.
email := strings.TrimSpace(r.PostFormValue("email"))
password := r.PostFormValue("password")
var errs templates.SignupErrors
// 2. Validate email.
if _, err := mail.ParseAddress(email); err != nil {
errs.Email = "Enter a valid email address"
}
// 3. Validate password length BEFORE calling Hash (DoS guard T-2-14, D-25).
if len(password) < 12 {
errs.Password = "Password must be at least 12 characters"
} else if len(password) > 128 {
errs.Password = "Password must be at most 128 characters"
}
if errs.Email != "" || errs.Password != "" {
// Re-populate the email field but NOT the password (T-2-01).
renderSignupError(w, r, templates.SignupForm{Email: email}, errs, http.StatusUnprocessableEntity)
return
}
// 4. Normalize email (lowercase) before insert; citext handles case-insensitive
// uniqueness at DB level but we store canonical lowercase for consistency (D-01).
normalized := strings.ToLower(email)
// 5. Hash password with production cost parameters.
hash, err := auth.Hash(password, auth.DefaultParams)
if err != nil {
http.Error(w, "internal server error", http.StatusInternalServerError)
return
}
// 6. Insert user row.
user, err := deps.Queries.InsertUser(ctx, sqlc.InsertUserParams{
Email: normalized,
PasswordHash: hash,
})
if err != nil {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
// Unique-constraint violation on email (T-2-19).
// Specific error message is acceptable on signup per CONTEXT.md specifics.
errs.Email = "That email is already in use."
renderSignupError(w, r, templates.SignupForm{Email: email}, errs, http.StatusUnprocessableEntity)
return
}
http.Error(w, "internal server error", http.StatusInternalServerError)
return
}
// 7. Create session (D-10: fresh token on every signup auto-login).
cookieValue, expiresAt, err := deps.Store.Create(ctx, user.ID)
if err != nil {
http.Error(w, "internal server error", http.StatusInternalServerError)
return
}
// 8. Set session cookie (D-12).
auth.SetSessionCookie(w, cookieValue, expiresAt, deps.Secure)
// 9. Redirect to home (D-11, D-21).
// HTMX form submissions receive HX-Redirect so HTMX handles navigation client-side.
// Plain (no-JS) form submissions receive 303 See Other (NOT 302 — Pitfall 9).
if r.Header.Get("HX-Request") == "true" {
w.Header().Set("HX-Redirect", "/")
w.WriteHeader(http.StatusOK)
return
}
http.Redirect(w, r, "/", http.StatusSeeOther)
}
}
// renderSignupError writes a validation-error response.
// For HTMX requests it renders only the form fragment; for plain requests it
// renders the full page (D-19, D-25).
func renderSignupError(w http.ResponseWriter, r *http.Request, form templates.SignupForm, errs templates.SignupErrors, status int) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(status)
if r.Header.Get("HX-Request") == "true" {
_ = templates.SignupFormFragment(form, errs).Render(r.Context(), w)
} else {
_ = templates.SignupPage(form, errs).Render(r.Context(), w)
}
}

View file

@ -0,0 +1,319 @@
package web
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"backend/internal/auth"
"backend/internal/db/sqlc"
)
// newTestRouter builds a router backed by a real DB for integration tests.
func newTestRouter(q *sqlc.Queries, store *auth.Store) http.Handler {
deps := AuthDeps{Queries: q, Store: store, Secure: false}
return NewRouter(stubPinger{}, "./static", deps)
}
// preInsertUser inserts a user with TestParams-hashed password directly via sqlc
// (avoids slow DefaultParams hash in test setup — W4 / Pitfall 4).
func preInsertUser(t *testing.T, ctx context.Context, q *sqlc.Queries, email, password string) sqlc.User {
t.Helper()
hash, err := auth.Hash(password, auth.TestParams)
if err != nil {
t.Fatalf("preInsertUser: hash: %v", err)
}
user, err := q.InsertUser(ctx, sqlc.InsertUserParams{
Email: strings.ToLower(email),
PasswordHash: hash,
})
if err != nil {
t.Fatalf("preInsertUser: InsertUser: %v", err)
}
return user
}
func TestSignup_Success(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
form := url.Values{"email": {"alice@example.com"}, "password": {"correct-horse-12"}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusSeeOther {
t.Fatalf("status = %d; want 303", rec.Code)
}
if loc := rec.Header().Get("Location"); loc != "/" {
t.Errorf("Location = %q; want /", loc)
}
// Cookie must be present and HttpOnly.
var sessionCookie *http.Cookie
for _, c := range rec.Result().Cookies() {
if c.Name == auth.SessionCookieName {
sessionCookie = c
break
}
}
if sessionCookie == nil {
t.Fatal("session cookie not set")
}
if !sessionCookie.HttpOnly {
t.Error("session cookie must be HttpOnly")
}
// User row must exist with an argon2id hash.
user, err := q.GetUserByEmail(ctx, "alice@example.com")
if err != nil {
t.Fatalf("GetUserByEmail: %v", err)
}
if !strings.HasPrefix(user.PasswordHash, "$argon2id$") {
t.Errorf("password_hash = %q; want $argon2id$ prefix", user.PasswordHash)
}
// Session row must exist for the user.
var count int
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE user_id = $1", user.ID)
if err := row.Scan(&count); err != nil {
t.Fatalf("session count query: %v", err)
}
if count != 1 {
t.Errorf("session count = %d; want 1", count)
}
}
func TestSignup_Success_HTMX(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
form := url.Values{"email": {"bob@example.com"}, "password": {"correct-horse-12"}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("HX-Request", "true")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("HTMX status = %d; want 200", rec.Code)
}
if hxRedir := rec.Header().Get("HX-Redirect"); hxRedir != "/" {
t.Errorf("HX-Redirect = %q; want /", hxRedir)
}
}
func TestSignup_InvalidEmail(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
form := url.Values{"email": {"not-an-email"}, "password": {"correct-horse-12"}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusUnprocessableEntity {
t.Fatalf("status = %d; want 422", rec.Code)
}
if !strings.Contains(rec.Body.String(), "valid email") {
t.Errorf("body missing 'valid email' error; got: %s", rec.Body.String())
}
// No user row must have been inserted.
ctx := context.Background()
var count int
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM users WHERE email = $1", "not-an-email")
_ = row.Scan(&count)
if count != 0 {
t.Errorf("unexpected user row inserted for invalid email")
}
}
func TestSignup_PasswordTooShort(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
// 11 chars — below the 12-char minimum.
form := url.Values{"email": {"carol@example.com"}, "password": {"short12345!"}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusUnprocessableEntity {
t.Fatalf("status = %d; want 422", rec.Code)
}
if !strings.Contains(rec.Body.String(), "12") {
t.Errorf("body missing '12' boundary; got: %s", rec.Body.String())
}
ctx := context.Background()
var count int
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM users WHERE email = $1", "carol@example.com")
_ = row.Scan(&count)
if count != 0 {
t.Errorf("unexpected user row inserted for short password")
}
}
func TestSignup_PasswordTooLong(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
longPw := strings.Repeat("a", 129) // 129 chars — above the 128-char maximum.
form := url.Values{"email": {"dave@example.com"}, "password": {longPw}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusUnprocessableEntity {
t.Fatalf("status = %d; want 422", rec.Code)
}
if !strings.Contains(rec.Body.String(), "128") {
t.Errorf("body missing '128' boundary; got: %s", rec.Body.String())
}
ctx := context.Background()
var count int
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM users WHERE email = $1", "dave@example.com")
_ = row.Scan(&count)
if count != 0 {
t.Errorf("unexpected user row inserted for long password")
}
}
func TestSignup_DuplicateEmail(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
// Pre-insert a user using TestParams to avoid the slow DefaultParams hash.
preInsertUser(t, ctx, q, "eve@example.com", "correct-horse-12")
form := url.Values{"email": {"eve@example.com"}, "password": {"correct-horse-12"}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
// Must be a client error (422) with an "already in use" message.
if rec.Code != http.StatusUnprocessableEntity {
t.Fatalf("status = %d; want 422", rec.Code)
}
if !strings.Contains(rec.Body.String(), "already in use") {
t.Errorf("body missing 'already in use'; got: %s", rec.Body.String())
}
// No second user row must have been inserted.
var count int
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM users WHERE email = $1", "eve@example.com")
if err := row.Scan(&count); err != nil {
t.Fatalf("count query: %v", err)
}
if count != 1 {
t.Errorf("user count = %d; want exactly 1 (no duplicate inserted)", count)
}
}
func TestSignup_EmailNormalized(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
// Uppercase + whitespace email — must be stored trimmed and lowercased.
form := url.Values{"email": {" Frank@Example.COM "}, "password": {"correct-horse-12"}}
req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusSeeOther {
t.Fatalf("status = %d; want 303", rec.Code)
}
// Stored email must not have leading/trailing whitespace.
user, err := q.GetUserByEmail(ctx, "frank@example.com")
if err != nil {
t.Fatalf("GetUserByEmail: %v", err)
}
if user.Email != "frank@example.com" {
t.Errorf("stored email = %q; want frank@example.com (trimmed + lowercased)", user.Email)
}
}
func TestSignup_AlreadyAuthedBouncesHome(t *testing.T) {
pool, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
q := sqlc.New(pool)
store := auth.NewStore(q)
router := newTestRouter(q, store)
// Pre-insert a user and create a real session.
user := preInsertUser(t, ctx, q, "grace@example.com", "correct-horse-12")
cookieValue, expiresAt, err := store.Create(ctx, user.ID)
if err != nil {
t.Fatalf("store.Create: %v", err)
}
// GET /signup with a valid session cookie must redirect to /.
req := httptest.NewRequest(http.MethodGet, "/signup", nil)
req.AddCookie(&http.Cookie{
Name: auth.SessionCookieName,
Value: cookieValue,
})
_ = expiresAt
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusSeeOther {
t.Fatalf("status = %d; want 303 (RedirectIfAuthed)", rec.Code)
}
if loc := rec.Header().Get("Location"); loc != "/" {
t.Errorf("Location = %q; want /", loc)
}
}

View file

@ -61,7 +61,7 @@ func TestHealthz_Down(t *testing.T) {
}
func TestIndex_RendersHxGet(t *testing.T) {
router := NewRouter(stubPinger{}, "./static")
router := NewRouter(stubPinger{}, "./static", AuthDeps{})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
@ -87,7 +87,7 @@ func TestIndex_RendersHxGet(t *testing.T) {
}
func TestDemoTime_Fragment(t *testing.T) {
router := NewRouter(stubPinger{}, "./static")
router := NewRouter(stubPinger{}, "./static", AuthDeps{})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/demo/time", nil)
@ -110,7 +110,7 @@ func TestDemoTime_Fragment(t *testing.T) {
}
func TestRequestID_HeaderSet(t *testing.T) {
router := NewRouter(stubPinger{}, "./static")
router := NewRouter(stubPinger{}, "./static", AuthDeps{})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)

View file

@ -6,6 +6,8 @@ import (
"net/http"
"time"
"backend/internal/auth"
"github.com/go-chi/chi/v5"
chimw "github.com/go-chi/chi/v5/middleware"
)
@ -18,22 +20,40 @@ type Pinger interface {
}
// NewRouter constructs the chi router with the middleware stack locked by
// CONTEXT D-08 + RESEARCH Pattern 2:
// CONTEXT D-24:
//
// 1. RequestIDMiddleware (UUIDv4 — NOT chi's base32 RequestID)
// 2. chi RealIP
// 3. SlogLoggerMiddleware (REPLACES chi's middleware.Logger — Pitfall 6)
// 4. chi Recoverer (after Logger so panics carry request_id)
// 1. RequestIDMiddleware (UUIDv4 — NOT chi's base32 RequestID)
// 2. chi RealIP
// 3. SlogLoggerMiddleware (REPLACES chi's middleware.Logger — Pitfall 6)
// 4. chi Recoverer (after Logger so panics carry request_id)
// 5. auth.ResolveSession (reads session cookie, attaches user to context)
// NOTE: csrf.Protect is added in Plan 07.
//
// Routes (Phase 1 only): GET / · GET /healthz · GET /demo/time · GET /static/*.
// Routes: GET / · GET /healthz · GET /demo/time · GET /static/*
// GET /signup (auth pages, behind RedirectIfAuthed) · POST /signup.
// staticDir is the on-disk path served at /static/*; path traversal is
// blocked by http.Dir's default behavior (T-01-08).
func NewRouter(pinger Pinger, staticDir string) http.Handler {
//
// deps.Store may be nil during unit tests for Phase 1 routes (those routes
// never exercise session resolution). ResolveSession guards against nil Store.
func NewRouter(pinger Pinger, staticDir string, deps AuthDeps) http.Handler {
r := chi.NewRouter()
r.Use(RequestIDMiddleware)
r.Use(chimw.RealIP)
r.Use(SlogLoggerMiddleware(slog.Default()))
r.Use(chimw.Recoverer)
r.Use(auth.ResolveSession(deps.Store))
// Auth pages — redirect to / if already authenticated.
r.Group(func(r chi.Router) {
r.Use(auth.RedirectIfAuthed)
r.Get("/signup", SignupPageHandler())
})
// Signup POST is intentionally outside the RedirectIfAuthed group:
// an authed user submitting the form directly should still get a useful
// response; the GET guard handles the common case.
r.Post("/signup", SignupPostHandler(deps))
r.Get("/", IndexHandler())
r.Get("/healthz", HealthzHandler(pinger))

View file

@ -0,0 +1,142 @@
package web
// testdb_test.go exposes a setupTestDB helper for the web package integration
// tests. The implementation is a verbatim copy of the auth package's
// setupTestDB (~20 LOC kernel) so the web package does not import a test-only
// function from another package (Go does not allow importing _test.go files).
//
// Decision: duplication chosen over a shared non-test helper to keep the
// coupling surface minimal and avoid moving test infra into production code.
// Recorded in 02-04-SUMMARY.md.
import (
"context"
"database/sql"
"fmt"
"os"
"path/filepath"
"runtime"
"sync"
"testing"
"time"
"github.com/google/uuid"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/pressly/goose/v3"
)
func migrationsDir() string {
_, filename, _, _ := runtime.Caller(0)
// filename: .../backend/internal/web/testdb_test.go
// migrations: .../backend/migrations
return filepath.Join(filepath.Dir(filename), "..", "..", "migrations")
}
var webGooseMu sync.Mutex
func setupTestDB(t *testing.T) (*pgxpool.Pool, func()) {
t.Helper()
dsn := os.Getenv("TEST_DATABASE_URL")
if dsn == "" {
dsn = os.Getenv("DATABASE_URL")
}
if dsn == "" {
t.Skip("TEST_DATABASE_URL (or DATABASE_URL) not set — integration test skipped")
return nil, nil
}
rawID := uuid.New().String()[:12]
cleanID := make([]byte, len(rawID))
for i := 0; i < len(rawID); i++ {
if rawID[i] == '-' {
cleanID[i] = '_'
} else {
cleanID[i] = rawID[i]
}
}
schemaName := "test_" + string(cleanID)
bootstrapDB, err := sql.Open("pgx", dsn)
if err != nil {
t.Fatalf("setupTestDB: sql.Open bootstrap: %v", err)
}
if err := bootstrapDB.Ping(); err != nil {
bootstrapDB.Close()
t.Fatalf("setupTestDB: ping: %v", err)
}
if _, err := bootstrapDB.Exec(fmt.Sprintf("CREATE SCHEMA %q", schemaName)); err != nil {
bootstrapDB.Close()
t.Fatalf("setupTestDB: CREATE SCHEMA: %v", err)
}
bootstrapDB.Close()
sep := "?"
for i := 0; i < len(dsn); i++ {
if dsn[i] == '?' {
sep = "&"
break
}
}
schemaDSN := fmt.Sprintf("%s%ssearch_path=%s,public", dsn, sep, schemaName)
versionTable := schemaName + "_goose_version"
{
webGooseMu.Lock()
defer webGooseMu.Unlock()
prevTable := goose.TableName()
goose.SetTableName(versionTable)
defer goose.SetTableName(prevTable)
goose.SetBaseFS(nil)
if err := goose.SetDialect("postgres"); err != nil {
webDropSchema(dsn, schemaName)
t.Fatalf("setupTestDB: goose.SetDialect: %v", err)
}
schemaDB, err := sql.Open("pgx", schemaDSN)
if err != nil {
webDropSchema(dsn, schemaName)
t.Fatalf("setupTestDB: sql.Open schema-scoped: %v", err)
}
if err := goose.Up(schemaDB, migrationsDir()); err != nil {
schemaDB.Close()
webDropSchema(dsn, schemaName)
t.Fatalf("setupTestDB: goose.Up: %v", err)
}
schemaDB.Close()
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cfg, err := pgxpool.ParseConfig(schemaDSN)
if err != nil {
webDropSchema(dsn, schemaName)
t.Fatalf("setupTestDB: pgxpool.ParseConfig: %v", err)
}
cfg.MaxConns = 5
pool, err := pgxpool.NewWithConfig(ctx, cfg)
if err != nil {
webDropSchema(dsn, schemaName)
t.Fatalf("setupTestDB: pgxpool.NewWithConfig: %v", err)
}
cleanup := func() {
pool.Close()
webDropSchema(dsn, schemaName)
}
return pool, cleanup
}
func webDropSchema(dsn, schemaName string) {
db, err := sql.Open("pgx", dsn)
if err != nil {
return
}
defer db.Close()
db.Exec(fmt.Sprintf("DROP SCHEMA IF EXISTS %q CASCADE", schemaName)) //nolint:errcheck
}