diff --git a/backend/cmd/web/main.go b/backend/cmd/web/main.go index 6f5fc28..18dc40a 100644 --- a/backend/cmd/web/main.go +++ b/backend/cmd/web/main.go @@ -85,7 +85,34 @@ func main() { stopJanitor := make(chan struct{}) rl.StartJanitor(time.Minute, stopJanitor) - deps := web.AuthDeps{Queries: q, Store: store, Secure: secure, Limiter: rl} + oauthCfg := auth.OAuthConfig{ + Google: auth.GoogleProviderConfig{ + ClientID: os.Getenv("GOOGLE_CLIENT_ID"), + ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"), + RedirectURL: os.Getenv("GOOGLE_REDIRECT_URL"), + }, + } + var googleExchanger auth.CodeExchanger + var googleVerifier auth.IDTokenVerifier + if oauthCfg.Google.Configured() { + googleExchanger = auth.OAuth2CodeExchanger{Config: oauthCfg.Google.OAuth2Config()} + googleVerifier = auth.OIDCVerifier{ + Provider: "google", + Issuer: "https://accounts.google.com", + ClientID: oauthCfg.Google.ClientID, + } + } + + deps := web.AuthDeps{ + Queries: q, + Store: store, + Secure: secure, + Limiter: rl, + DB: pool, + OAuth: oauthCfg, + GoogleTokenExchanger: googleExchanger, + GoogleVerifier: googleVerifier, + } tabloDeps := web.TablosDeps{Queries: q} taskDeps := web.TasksDeps{Queries: q} diff --git a/backend/go.mod b/backend/go.mod index 752a07a..e9cddd8 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -32,7 +32,9 @@ require ( github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 // indirect github.com/aws/smithy-go v1.25.1 // indirect + github.com/coreos/go-oidc/v3 v3.18.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-jose/go-jose/v4 v4.1.4 // indirect github.com/gorilla/securecookie v1.1.2 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect @@ -52,6 +54,7 @@ require ( github.com/tidwall/sjson v1.2.5 // indirect go.uber.org/goleak v1.3.0 // indirect go.uber.org/multierr v1.11.0 // indirect + golang.org/x/oauth2 v0.36.0 // indirect golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.44.0 // indirect golang.org/x/text v0.37.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index cc00de1..ceab2a3 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -36,6 +36,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 h1:F/M5Y9I3nwr2IEpshZgh1GeHpOIt github.com/aws/aws-sdk-go-v2/service/sts v1.42.1/go.mod h1:mTNxImtovCOEEuD65mKW7DCsL+2gjEH+RPEAexAzAio= github.com/aws/smithy-go v1.25.1 h1:J8ERsGSU7d+aCmdQur5Txg6bVoYelvQJgtZehD12GkI= github.com/aws/smithy-go v1.25.1/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= +github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A= +github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -43,6 +45,8 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= +github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA= +github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= @@ -107,6 +111,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= diff --git a/backend/internal/auth/oauth.go b/backend/internal/auth/oauth.go new file mode 100644 index 0000000..f6e350b --- /dev/null +++ b/backend/internal/auth/oauth.go @@ -0,0 +1,181 @@ +package auth + +import ( + "context" + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "errors" + "fmt" + "net/http" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" +) + +const ( + OAuthCookieState = "state" + OAuthCookieNonce = "nonce" +) + +const oauthCookieTTL = 5 * time.Minute + +type GoogleProviderConfig struct { + ClientID string + ClientSecret string + RedirectURL string + AuthURL string + TokenURL string + Issuer string +} + +func (c GoogleProviderConfig) Configured() bool { + return c.ClientID != "" && c.ClientSecret != "" && c.RedirectURL != "" +} + +func (c GoogleProviderConfig) withDefaults() GoogleProviderConfig { + if c.AuthURL == "" { + c.AuthURL = "https://accounts.google.com/o/oauth2/v2/auth" + } + if c.TokenURL == "" { + c.TokenURL = "https://oauth2.googleapis.com/token" + } + if c.Issuer == "" { + c.Issuer = "https://accounts.google.com" + } + return c +} + +func (c GoogleProviderConfig) OAuth2Config() oauth2.Config { + c = c.withDefaults() + return oauth2.Config{ + ClientID: c.ClientID, + ClientSecret: c.ClientSecret, + RedirectURL: c.RedirectURL, + Scopes: []string{oidc.ScopeOpenID, "email", "profile"}, + Endpoint: oauth2.Endpoint{ + AuthURL: c.AuthURL, + TokenURL: c.TokenURL, + }, + } +} + +type OAuthConfig struct { + Google GoogleProviderConfig +} + +type ProviderClaims struct { + Provider string + Subject string + Email string + EmailVerified bool + DisplayName string + AvatarURL string + Nonce string +} + +type CodeExchanger interface { + Exchange(ctx context.Context, code string) (*oauth2.Token, error) +} + +type OAuth2CodeExchanger struct { + Config oauth2.Config +} + +func (e OAuth2CodeExchanger) Exchange(ctx context.Context, code string) (*oauth2.Token, error) { + return e.Config.Exchange(ctx, code) +} + +type IDTokenVerifier interface { + Verify(ctx context.Context, rawIDToken string) (ProviderClaims, error) +} + +type OIDCVerifier struct { + Provider string + Issuer string + ClientID string +} + +func (v OIDCVerifier) Verify(ctx context.Context, rawIDToken string) (ProviderClaims, error) { + provider, err := oidc.NewProvider(ctx, v.Issuer) + if err != nil { + return ProviderClaims{}, fmt.Errorf("auth: oidc provider discovery: %w", err) + } + verified, err := provider.Verifier(&oidc.Config{ClientID: v.ClientID}).Verify(ctx, rawIDToken) + if err != nil { + return ProviderClaims{}, fmt.Errorf("auth: verify id token: %w", err) + } + var claims struct { + Subject string `json:"sub"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Name string `json:"name"` + Picture string `json:"picture"` + Nonce string `json:"nonce"` + } + if err := verified.Claims(&claims); err != nil { + return ProviderClaims{}, fmt.Errorf("auth: decode id token claims: %w", err) + } + if claims.Subject == "" { + return ProviderClaims{}, errors.New("auth: id token missing subject") + } + return ProviderClaims{ + Provider: v.Provider, + Subject: claims.Subject, + Email: claims.Email, + EmailVerified: claims.EmailVerified, + DisplayName: claims.Name, + AvatarURL: claims.Picture, + Nonce: claims.Nonce, + }, nil +} + +func GenerateOAuthValue() (string, error) { + raw := make([]byte, 32) + if _, err := rand.Read(raw); err != nil { + return "", fmt.Errorf("auth: generate oauth value: %w", err) + } + return base64.RawURLEncoding.EncodeToString(raw), nil +} + +func OAuthCookieName(provider, kind string) string { + return "xtablo_oauth_" + provider + "_" + kind +} + +func SetOAuthCookie(w http.ResponseWriter, provider, kind, value string, secure bool) { + http.SetCookie(w, &http.Cookie{ + Name: OAuthCookieName(provider, kind), + Value: value, + Path: "/", + MaxAge: int(oauthCookieTTL.Seconds()), + Expires: time.Now().Add(oauthCookieTTL), + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func ClearOAuthCookie(w http.ResponseWriter, provider, kind string, secure bool) { + http.SetCookie(w, &http.Cookie{ + Name: OAuthCookieName(provider, kind), + Value: "", + Path: "/", + MaxAge: -1, + Expires: time.Unix(0, 0), + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func ValidateOAuthCookie(r *http.Request, provider, kind, value string) bool { + c, err := r.Cookie(OAuthCookieName(provider, kind)) + if err != nil || c.Value == "" || value == "" { + return false + } + if len(c.Value) != len(value) { + return false + } + return subtle.ConstantTimeCompare([]byte(c.Value), []byte(value)) == 1 +} diff --git a/backend/internal/auth/oauth_test.go b/backend/internal/auth/oauth_test.go new file mode 100644 index 0000000..f0f6eb8 --- /dev/null +++ b/backend/internal/auth/oauth_test.go @@ -0,0 +1,56 @@ +package auth + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestGoogleProviderConfigConfigured(t *testing.T) { + empty := GoogleProviderConfig{} + if empty.Configured() { + t.Fatal("empty Google config must not be configured") + } + + cfg := GoogleProviderConfig{ + ClientID: "google-client", + ClientSecret: "google-secret", + RedirectURL: "https://xtablo.test/auth/google/callback", + } + if !cfg.Configured() { + t.Fatal("complete Google config must be configured") + } +} + +func TestOAuthStateAndNonceCookiesValidateExactValue(t *testing.T) { + rec := httptest.NewRecorder() + SetOAuthCookie(rec, "google", OAuthCookieState, "state-value", false) + SetOAuthCookie(rec, "google", OAuthCookieNonce, "nonce-value", false) + + req := httptest.NewRequest(http.MethodGet, "/auth/google/callback", nil) + for _, c := range rec.Result().Cookies() { + req.AddCookie(c) + } + + if !ValidateOAuthCookie(req, "google", OAuthCookieState, "state-value") { + t.Fatal("state cookie should validate matching value") + } + if ValidateOAuthCookie(req, "google", OAuthCookieState, "wrong-state") { + t.Fatal("state cookie should reject mismatched value") + } + if !ValidateOAuthCookie(req, "google", OAuthCookieNonce, "nonce-value") { + t.Fatal("nonce cookie should validate matching value") + } + if ValidateOAuthCookie(req, "google", OAuthCookieNonce, "wrong-nonce") { + t.Fatal("nonce cookie should reject mismatched value") + } +} + +func TestOAuthCookieNameIncludesProviderAndKind(t *testing.T) { + if got := OAuthCookieName("google", OAuthCookieState); got != "xtablo_oauth_google_state" { + t.Fatalf("state cookie name = %q", got) + } + if got := OAuthCookieName("google", OAuthCookieNonce); got != "xtablo_oauth_google_nonce" { + t.Fatalf("nonce cookie name = %q", got) + } +} diff --git a/backend/internal/web/handlers_auth.go b/backend/internal/web/handlers_auth.go index 269893d..9d1a1c6 100644 --- a/backend/internal/web/handlers_auth.go +++ b/backend/internal/web/handlers_auth.go @@ -24,10 +24,14 @@ import ( // Limiter is the in-memory rate limiter for POST /login (D-16, AUTH-07). // When nil, rate limiting is skipped (unit tests that don't exercise that path). type AuthDeps struct { - Queries *sqlc.Queries - Store *auth.Store - Secure bool - Limiter *auth.LimiterStore + Queries *sqlc.Queries + Store *auth.Store + Secure bool + Limiter *auth.LimiterStore + DB TxBeginner + OAuth auth.OAuthConfig + GoogleTokenExchanger auth.CodeExchanger + GoogleVerifier auth.IDTokenVerifier } // errInvalidCreds is the intentionally generic error message for login failures diff --git a/backend/internal/web/handlers_social.go b/backend/internal/web/handlers_social.go new file mode 100644 index 0000000..48b2bf0 --- /dev/null +++ b/backend/internal/web/handlers_social.go @@ -0,0 +1,210 @@ +package web + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net/http" + "strings" + + "backend/internal/auth" + "backend/internal/db/sqlc" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "golang.org/x/oauth2" +) + +type TxBeginner interface { + Begin(ctx context.Context) (pgx.Tx, error) +} + +const providerEmailUnverified = "This provider did not return a verified email. Try another sign-in method." +const providerGenericError = "Could not sign you in with this provider. Try another sign-in method." + +func GoogleStartHandler(deps AuthDeps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + cfg := deps.OAuth.Google + if !cfg.Configured() { + http.Error(w, "Google sign-in not configured", http.StatusServiceUnavailable) + return + } + state, err := auth.GenerateOAuthValue() + if err != nil { + http.Error(w, "internal server error", http.StatusInternalServerError) + return + } + nonce, err := auth.GenerateOAuthValue() + if err != nil { + http.Error(w, "internal server error", http.StatusInternalServerError) + return + } + auth.SetOAuthCookie(w, "google", auth.OAuthCookieState, state, deps.Secure) + auth.SetOAuthCookie(w, "google", auth.OAuthCookieNonce, nonce, deps.Secure) + + oauthCfg := cfg.OAuth2Config() + url := oauthCfg.AuthCodeURL(state, oauth2.SetAuthURLParam("nonce", nonce)) + http.Redirect(w, r, url, http.StatusSeeOther) + } +} + +func GoogleCallbackHandler(deps AuthDeps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if !auth.ValidateOAuthCookie(r, "google", auth.OAuthCookieState, r.URL.Query().Get("state")) { + http.Error(w, providerGenericError, http.StatusBadRequest) + return + } + auth.ClearOAuthCookie(w, "google", auth.OAuthCookieState, deps.Secure) + auth.ClearOAuthCookie(w, "google", auth.OAuthCookieNonce, deps.Secure) + + code := r.URL.Query().Get("code") + if code == "" { + http.Error(w, providerGenericError, http.StatusBadRequest) + return + } + + exchanger := deps.GoogleTokenExchanger + if exchanger == nil { + exchanger = auth.OAuth2CodeExchanger{Config: deps.OAuth.Google.OAuth2Config()} + } + token, err := exchanger.Exchange(r.Context(), code) + if err != nil { + slog.Default().Warn("google oauth exchange failed", "err", err) + http.Error(w, providerGenericError, http.StatusUnauthorized) + return + } + rawIDToken, _ := token.Extra("id_token").(string) + if rawIDToken == "" { + http.Error(w, providerGenericError, http.StatusUnauthorized) + return + } + + verifier := deps.GoogleVerifier + if verifier == nil { + cfg := deps.OAuth.Google + issuer := cfg.Issuer + if issuer == "" { + issuer = "https://accounts.google.com" + } + verifier = auth.OIDCVerifier{Provider: "google", Issuer: issuer, ClientID: cfg.ClientID} + } + claims, err := verifier.Verify(r.Context(), rawIDToken) + if err != nil { + slog.Default().Warn("google id token verification failed", "err", err) + http.Error(w, providerGenericError, http.StatusUnauthorized) + return + } + if claims.Provider == "" { + claims.Provider = "google" + } + if !auth.ValidateOAuthCookie(r, "google", auth.OAuthCookieNonce, claims.Nonce) { + http.Error(w, providerGenericError, http.StatusBadRequest) + return + } + if strings.TrimSpace(claims.Email) == "" || !claims.EmailVerified { + http.Error(w, providerEmailUnverified, http.StatusUnauthorized) + return + } + + userID, err := linkProviderUser(r.Context(), deps, claims) + if err != nil { + slog.Default().Error("google account linking failed", "err", err) + http.Error(w, providerGenericError, http.StatusInternalServerError) + return + } + cookieValue, expiresAt, err := deps.Store.Create(r.Context(), userID) + if err != nil { + http.Error(w, providerGenericError, http.StatusInternalServerError) + return + } + auth.SetSessionCookie(w, cookieValue, expiresAt, deps.Secure) + http.Redirect(w, r, "/", http.StatusSeeOther) + } +} + +func linkProviderUser(ctx context.Context, deps AuthDeps, claims auth.ProviderClaims) (uuid.UUID, error) { + if deps.DB == nil { + return uuid.Nil, errors.New("missing transaction DB") + } + tx, err := deps.DB.Begin(ctx) + if err != nil { + return uuid.Nil, fmt.Errorf("begin tx: %w", err) + } + defer tx.Rollback(ctx) //nolint:errcheck + + q := deps.Queries.WithTx(tx) + email := strings.ToLower(strings.TrimSpace(claims.Email)) + provider := claims.Provider + if provider == "" { + provider = "google" + } + if provider == "google" { + claims.Provider = provider + } + + identity, err := q.GetUserIdentityByProviderSubject(ctx, sqlc.GetUserIdentityByProviderSubjectParams{ + Provider: provider, + ProviderSubject: claims.Subject, + }) + if err == nil { + updated, updateErr := q.UpdateUserIdentityLogin(ctx, sqlc.UpdateUserIdentityLoginParams{ + Provider: provider, + ProviderSubject: claims.Subject, + Email: email, + EmailVerified: claims.EmailVerified, + DisplayName: textOrNull(claims.DisplayName), + AvatarUrl: textOrNull(claims.AvatarURL), + }) + if updateErr != nil { + return uuid.Nil, fmt.Errorf("update identity login: %w", updateErr) + } + if _, emailErr := q.UpdateUserEmailIfAvailable(ctx, sqlc.UpdateUserEmailIfAvailableParams{ + ID: updated.UserID, + Email: email, + }); emailErr != nil && !errors.Is(emailErr, pgx.ErrNoRows) { + return uuid.Nil, fmt.Errorf("update linked user email: %w", emailErr) + } + if err := tx.Commit(ctx); err != nil { + return uuid.Nil, fmt.Errorf("commit existing identity: %w", err) + } + return identity.UserID, nil + } + if !errors.Is(err, pgx.ErrNoRows) { + return uuid.Nil, fmt.Errorf("get identity: %w", err) + } + + user, err := q.GetUserByEmail(ctx, email) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + user, err = q.InsertSocialUser(ctx, email) + } + if err != nil { + return uuid.Nil, fmt.Errorf("lookup or create user: %w", err) + } + } + + if _, err := q.InsertUserIdentity(ctx, sqlc.InsertUserIdentityParams{ + UserID: user.ID, + Provider: provider, + ProviderSubject: claims.Subject, + Email: email, + EmailVerified: claims.EmailVerified, + DisplayName: textOrNull(claims.DisplayName), + AvatarUrl: textOrNull(claims.AvatarURL), + }); err != nil { + return uuid.Nil, fmt.Errorf("insert identity: %w", err) + } + if err := tx.Commit(ctx); err != nil { + return uuid.Nil, fmt.Errorf("commit new identity: %w", err) + } + return user.ID, nil +} + +func textOrNull(value string) pgtype.Text { + if value == "" { + return pgtype.Text{} + } + return pgtype.Text{String: value, Valid: true} +} diff --git a/backend/internal/web/handlers_social_test.go b/backend/internal/web/handlers_social_test.go new file mode 100644 index 0000000..4999ecd --- /dev/null +++ b/backend/internal/web/handlers_social_test.go @@ -0,0 +1,375 @@ +package web + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strings" + "testing" + + "backend/internal/auth" + "backend/internal/db/sqlc" + + "golang.org/x/oauth2" +) + +type fakeCodeExchanger struct { + token *oauth2.Token + err error +} + +func (f fakeCodeExchanger) Exchange(ctx context.Context, code string) (*oauth2.Token, error) { + return f.token, f.err +} + +type fakeIDTokenVerifier struct { + claims auth.ProviderClaims + err error +} + +func (f fakeIDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (auth.ProviderClaims, error) { + return f.claims, f.err +} + +func newGoogleAuthDeps(q *sqlc.Queries, store *auth.Store) AuthDeps { + return AuthDeps{ + Queries: q, + Store: store, + Secure: false, + OAuth: auth.OAuthConfig{ + Google: auth.GoogleProviderConfig{ + ClientID: "google-client", + ClientSecret: "google-secret", + RedirectURL: "https://xtablo.test/auth/google/callback", + AuthURL: "https://accounts.google.test/o/oauth2/v2/auth", + TokenURL: "https://oauth2.google.test/token", + Issuer: "https://accounts.google.test", + }, + }, + GoogleTokenExchanger: fakeCodeExchanger{ + token: (&oauth2.Token{AccessToken: "access"}).WithExtra(map[string]any{"id_token": "raw-id-token"}), + }, + GoogleVerifier: fakeIDTokenVerifier{ + claims: auth.ProviderClaims{ + Provider: "google", + Subject: "google-subject-1", + Email: "google@example.com", + EmailVerified: true, + DisplayName: "Google User", + AvatarURL: "https://example.com/avatar.png", + Nonce: "nonce-value", + }, + }, + } +} + +func withGoogleClaims(deps AuthDeps, claims auth.ProviderClaims) AuthDeps { + deps.GoogleVerifier = fakeIDTokenVerifier{claims: claims} + return deps +} + +func newSocialRouter(t *testing.T, deps AuthDeps) http.Handler { + t.Helper() + router, err := NewRouter(stubPinger{}, os.DirFS("./static"), deps, TablosDeps{}, TasksDeps{}, FilesDeps{}, testCSRFKey, "dev", "localhost") + if err != nil { + t.Fatalf("NewRouter: %v", err) + } + return router +} + +func TestGoogleStartRedirectsAndSetsStateNonceCookies(t *testing.T) { + router := newSocialRouter(t, newGoogleAuthDeps(nil, nil)) + + req := httptest.NewRequest(http.MethodGet, "/auth/google/start", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusSeeOther { + t.Fatalf("status = %d; want 303", rec.Code) + } + loc := rec.Header().Get("Location") + if !strings.Contains(loc, "https://accounts.google.test/o/oauth2/v2/auth") { + t.Fatalf("Location = %q; want Google auth URL", loc) + } + if !strings.Contains(loc, "client_id=google-client") { + t.Fatalf("Location = %q; missing client_id", loc) + } + if !strings.Contains(loc, "scope=openid+email+profile") { + t.Fatalf("Location = %q; missing openid email profile scope", loc) + } + + cookies := rec.Result().Cookies() + if findCookie(cookies, auth.OAuthCookieName("google", auth.OAuthCookieState)) == nil { + t.Fatal("missing Google state cookie") + } + if findCookie(cookies, auth.OAuthCookieName("google", auth.OAuthCookieNonce)) == nil { + t.Fatal("missing Google nonce cookie") + } +} + +func TestGoogleCallbackInvalidStateRejectedBeforeExchange(t *testing.T) { + router := newSocialRouter(t, newGoogleAuthDeps(nil, nil)) + + req := httptest.NewRequest(http.MethodGet, "/auth/google/callback?state=wrong&code=code", nil) + req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("google", auth.OAuthCookieState), Value: "expected"}) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d; want 400", rec.Code) + } +} + +func TestGoogleCallbackUnverifiedEmailRejected(t *testing.T) { + deps := newGoogleAuthDeps(nil, nil) + deps.GoogleVerifier = fakeIDTokenVerifier{claims: auth.ProviderClaims{ + Provider: "google", + Subject: "google-subject-1", + Email: "google@example.com", + EmailVerified: false, + Nonce: "nonce-value", + }} + router := newSocialRouter(t, deps) + + callback := "/auth/google/callback?" + url.Values{"state": {"state-value"}, "code": {"code"}}.Encode() + req := httptest.NewRequest(http.MethodGet, callback, nil) + req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("google", auth.OAuthCookieState), Value: "state-value"}) + req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("google", auth.OAuthCookieNonce), Value: "nonce-value"}) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status = %d; want 401", rec.Code) + } + if !strings.Contains(rec.Body.String(), "This provider did not return a verified email. Try another sign-in method.") { + t.Fatalf("body missing unverified email copy; got: %s", rec.Body.String()) + } +} + +func TestGoogleCallbackVerifiedEmailLinksExistingUserAndSetsSession(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + ctx := context.Background() + q := sqlc.New(pool) + store := auth.NewStore(q) + user := preInsertUser(t, ctx, q, "google@example.com", "correct-horse-12chars") + deps := newGoogleAuthDeps(q, store) + deps.DB = pool + router := newSocialRouter(t, deps) + + callback := "/auth/google/callback?" + url.Values{"state": {"state-value"}, "code": {"code"}}.Encode() + req := httptest.NewRequest(http.MethodGet, callback, nil) + req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("google", auth.OAuthCookieState), Value: "state-value"}) + req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("google", auth.OAuthCookieNonce), Value: "nonce-value"}) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusSeeOther { + t.Fatalf("status = %d; want 303; body: %s", rec.Code, rec.Body.String()) + } + if loc := rec.Header().Get("Location"); loc != "/" { + t.Fatalf("Location = %q; want /", loc) + } + if c := getSessionCookie(rec); c == nil { + t.Fatal("session cookie not set") + } + + var count int + if err := pool.QueryRow(ctx, ` + SELECT COUNT(*) + FROM user_identities + WHERE user_id = $1 AND provider = 'google' AND provider_subject = 'google-subject-1' + `, user.ID).Scan(&count); err != nil { + t.Fatalf("count identity: %v", err) + } + if count != 1 { + t.Fatalf("identity count = %d; want 1", count) + } +} + +func TestGoogleCallbackNewVerifiedEmailCreatesSocialOnlyUserAndSession(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + ctx := context.Background() + q := sqlc.New(pool) + store := auth.NewStore(q) + deps := newGoogleAuthDeps(q, store) + deps.DB = pool + deps = withGoogleClaims(deps, auth.ProviderClaims{ + Provider: "google", + Subject: "google-new-subject", + Email: "new-google@example.com", + EmailVerified: true, + DisplayName: "New Google", + AvatarURL: "https://example.com/new.png", + Nonce: "nonce-value", + }) + router := newSocialRouter(t, deps) + + rec := serveGoogleCallback(router) + + if rec.Code != http.StatusSeeOther { + t.Fatalf("status = %d; want 303; body: %s", rec.Code, rec.Body.String()) + } + if c := getSessionCookie(rec); c == nil { + t.Fatal("session cookie not set") + } + user, err := q.GetUserByEmail(ctx, "new-google@example.com") + if err != nil { + t.Fatalf("GetUserByEmail: %v", err) + } + if user.PasswordHash.Valid { + t.Fatalf("new social user password hash Valid = true; want false") + } + identity, err := q.GetUserIdentityByProviderSubject(ctx, sqlc.GetUserIdentityByProviderSubjectParams{ + Provider: "google", + ProviderSubject: "google-new-subject", + }) + if err != nil { + t.Fatalf("GetUserIdentityByProviderSubject: %v", err) + } + if identity.UserID != user.ID { + t.Fatalf("identity user id = %s; want %s", identity.UserID, user.ID) + } +} + +func TestGoogleCallbackExistingSubjectWinsWhenEmailChanges(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + ctx := context.Background() + q := sqlc.New(pool) + store := auth.NewStore(q) + user := preInsertSocialOnlyUser(t, ctx, q, "old-google@example.com") + if _, err := q.InsertUserIdentity(ctx, sqlc.InsertUserIdentityParams{ + UserID: user.ID, + Provider: "google", + ProviderSubject: "stable-google-subject", + Email: "old-google@example.com", + EmailVerified: true, + }); err != nil { + t.Fatalf("InsertUserIdentity: %v", err) + } + deps := newGoogleAuthDeps(q, store) + deps.DB = pool + deps = withGoogleClaims(deps, auth.ProviderClaims{ + Provider: "google", + Subject: "stable-google-subject", + Email: "changed-google@example.com", + EmailVerified: true, + Nonce: "nonce-value", + }) + router := newSocialRouter(t, deps) + + rec := serveGoogleCallback(router) + + if rec.Code != http.StatusSeeOther { + t.Fatalf("status = %d; want 303; body: %s", rec.Code, rec.Body.String()) + } + identity, err := q.GetUserIdentityByProviderSubject(ctx, sqlc.GetUserIdentityByProviderSubjectParams{ + Provider: "google", + ProviderSubject: "stable-google-subject", + }) + if err != nil { + t.Fatalf("GetUserIdentityByProviderSubject: %v", err) + } + if identity.UserID != user.ID { + t.Fatalf("identity relinked to %s; want original %s", identity.UserID, user.ID) + } + if identity.Email != "changed-google@example.com" { + t.Fatalf("identity email = %q; want changed-google@example.com", identity.Email) + } + updatedUser, err := q.GetUserByID(ctx, user.ID) + if err != nil { + t.Fatalf("GetUserByID: %v", err) + } + if updatedUser.Email != "changed-google@example.com" { + t.Fatalf("user email = %q; want changed-google@example.com", updatedUser.Email) + } +} + +func TestGoogleCallbackEmailUpdateConflictDoesNotRelinkSubject(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + ctx := context.Background() + q := sqlc.New(pool) + store := auth.NewStore(q) + linkedUser := preInsertSocialOnlyUser(t, ctx, q, "linked-google@example.com") + conflictUser := preInsertUser(t, ctx, q, "conflict-google@example.com", "correct-horse-12chars") + if _, err := q.InsertUserIdentity(ctx, sqlc.InsertUserIdentityParams{ + UserID: linkedUser.ID, + Provider: "google", + ProviderSubject: "conflict-google-subject", + Email: "linked-google@example.com", + EmailVerified: true, + }); err != nil { + t.Fatalf("InsertUserIdentity: %v", err) + } + deps := newGoogleAuthDeps(q, store) + deps.DB = pool + deps = withGoogleClaims(deps, auth.ProviderClaims{ + Provider: "google", + Subject: "conflict-google-subject", + Email: "conflict-google@example.com", + EmailVerified: true, + Nonce: "nonce-value", + }) + router := newSocialRouter(t, deps) + + rec := serveGoogleCallback(router) + + if rec.Code != http.StatusSeeOther { + t.Fatalf("status = %d; want 303; body: %s", rec.Code, rec.Body.String()) + } + identity, err := q.GetUserIdentityByProviderSubject(ctx, sqlc.GetUserIdentityByProviderSubjectParams{ + Provider: "google", + ProviderSubject: "conflict-google-subject", + }) + if err != nil { + t.Fatalf("GetUserIdentityByProviderSubject: %v", err) + } + if identity.UserID != linkedUser.ID { + t.Fatalf("identity user id = %s; want linked user %s", identity.UserID, linkedUser.ID) + } + if identity.Email != "conflict-google@example.com" { + t.Fatalf("identity email = %q; want conflict-google@example.com", identity.Email) + } + stillLinked, err := q.GetUserByID(ctx, linkedUser.ID) + if err != nil { + t.Fatalf("GetUserByID linked: %v", err) + } + if stillLinked.Email != "linked-google@example.com" { + t.Fatalf("linked user email = %q; want linked-google@example.com", stillLinked.Email) + } + stillConflict, err := q.GetUserByID(ctx, conflictUser.ID) + if err != nil { + t.Fatalf("GetUserByID conflict: %v", err) + } + if stillConflict.Email != "conflict-google@example.com" { + t.Fatalf("conflict user email = %q; want conflict-google@example.com", stillConflict.Email) + } +} + +func serveGoogleCallback(router http.Handler) *httptest.ResponseRecorder { + callback := "/auth/google/callback?" + url.Values{"state": {"state-value"}, "code": {"code"}}.Encode() + req := httptest.NewRequest(http.MethodGet, callback, nil) + req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("google", auth.OAuthCookieState), Value: "state-value"}) + req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("google", auth.OAuthCookieNonce), Value: "nonce-value"}) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + return rec +} + +func findCookie(cookies []*http.Cookie, name string) *http.Cookie { + for _, c := range cookies { + if c.Name == name { + return c + } + } + return nil +} diff --git a/backend/internal/web/router.go b/backend/internal/web/router.go index 45ced85..0a984f6 100644 --- a/backend/internal/web/router.go +++ b/backend/internal/web/router.go @@ -32,7 +32,9 @@ type Pinger interface { // 6. auth.Mount (gorilla/csrf — MUST come after ResolveSession, before routes) — D-24, Pitfall 7 // // Routes: GET / · GET /healthz (liveness) · GET /readyz (readiness) · GET /demo/time · GET /static/* -// GET /signup (auth pages, behind RedirectIfAuthed) · POST /signup. +// +// GET /signup (auth pages, behind RedirectIfAuthed) · POST /signup. +// // staticFS is the embedded FS (or os.DirFS in tests) served at /static/*; the // embedded FS pattern blocks path traversal at the http.FS layer (T-01-08). // @@ -69,6 +71,8 @@ func NewRouter(pinger Pinger, staticFS fs.FS, deps AuthDeps, tabloDeps TablosDep // response; the GET guard handles the common case. r.Post("/signup", SignupPostHandler(deps)) r.Post("/login", LoginPostHandler(deps)) + r.Get("/auth/google/start", GoogleStartHandler(deps)) + r.Get("/auth/google/callback", GoogleCallbackHandler(deps)) // Protected routes — require an authenticated session (D-23, AUTH-05). // RequireAuth checks the context set by ResolveSession above and redirects