package web import ( "context" "errors" "fmt" "log/slog" "net/http" "strings" "backend/internal/auth" "backend/internal/db/sqlc" "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "golang.org/x/oauth2" ) type TxBeginner interface { Begin(ctx context.Context) (pgx.Tx, error) } const providerEmailUnverified = "This provider did not return a verified email. Try another sign-in method." const providerGenericError = "Could not sign you in with this provider. Try another sign-in method." func GoogleStartHandler(deps AuthDeps) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { cfg := deps.OAuth.Google if !cfg.Configured() { http.Error(w, "Google sign-in not configured", http.StatusServiceUnavailable) return } state, err := auth.GenerateOAuthValue() if err != nil { http.Error(w, "internal server error", http.StatusInternalServerError) return } nonce, err := auth.GenerateOAuthValue() if err != nil { http.Error(w, "internal server error", http.StatusInternalServerError) return } auth.SetOAuthCookie(w, "google", auth.OAuthCookieState, state, deps.Secure) auth.SetOAuthCookie(w, "google", auth.OAuthCookieNonce, nonce, deps.Secure) oauthCfg := cfg.OAuth2Config() url := oauthCfg.AuthCodeURL(state, oauth2.SetAuthURLParam("nonce", nonce)) http.Redirect(w, r, url, http.StatusSeeOther) } } func GoogleCallbackHandler(deps AuthDeps) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if !auth.ValidateOAuthCookie(r, "google", auth.OAuthCookieState, r.URL.Query().Get("state")) { http.Error(w, providerGenericError, http.StatusBadRequest) return } auth.ClearOAuthCookie(w, "google", auth.OAuthCookieState, deps.Secure) auth.ClearOAuthCookie(w, "google", auth.OAuthCookieNonce, deps.Secure) code := r.URL.Query().Get("code") if code == "" { http.Error(w, providerGenericError, http.StatusBadRequest) return } exchanger := deps.GoogleTokenExchanger if exchanger == nil { exchanger = auth.OAuth2CodeExchanger{Config: deps.OAuth.Google.OAuth2Config()} } token, err := exchanger.Exchange(r.Context(), code) if err != nil { slog.Default().Warn("google oauth exchange failed", "err", err) http.Error(w, providerGenericError, http.StatusUnauthorized) return } rawIDToken, _ := token.Extra("id_token").(string) if rawIDToken == "" { http.Error(w, providerGenericError, http.StatusUnauthorized) return } verifier := deps.GoogleVerifier if verifier == nil { cfg := deps.OAuth.Google issuer := cfg.Issuer if issuer == "" { issuer = "https://accounts.google.com" } verifier = auth.OIDCVerifier{Provider: "google", Issuer: issuer, ClientID: cfg.ClientID} } claims, err := verifier.Verify(r.Context(), rawIDToken) if err != nil { slog.Default().Warn("google id token verification failed", "err", err) http.Error(w, providerGenericError, http.StatusUnauthorized) return } if claims.Provider == "" { claims.Provider = "google" } if !auth.ValidateOAuthCookie(r, "google", auth.OAuthCookieNonce, claims.Nonce) { http.Error(w, providerGenericError, http.StatusBadRequest) return } if strings.TrimSpace(claims.Email) == "" || !claims.EmailVerified { http.Error(w, providerEmailUnverified, http.StatusUnauthorized) return } userID, err := linkProviderUser(r.Context(), deps, claims) if err != nil { slog.Default().Error("google account linking failed", "err", err) http.Error(w, providerGenericError, http.StatusInternalServerError) return } cookieValue, expiresAt, err := deps.Store.Create(r.Context(), userID) if err != nil { http.Error(w, providerGenericError, http.StatusInternalServerError) return } auth.SetSessionCookie(w, cookieValue, expiresAt, deps.Secure) http.Redirect(w, r, "/", http.StatusSeeOther) } } func linkProviderUser(ctx context.Context, deps AuthDeps, claims auth.ProviderClaims) (uuid.UUID, error) { if deps.DB == nil { return uuid.Nil, errors.New("missing transaction DB") } tx, err := deps.DB.Begin(ctx) if err != nil { return uuid.Nil, fmt.Errorf("begin tx: %w", err) } defer tx.Rollback(ctx) //nolint:errcheck q := deps.Queries.WithTx(tx) email := strings.ToLower(strings.TrimSpace(claims.Email)) provider := claims.Provider if provider == "" { provider = "google" } if provider == "google" { claims.Provider = provider } identity, err := q.GetUserIdentityByProviderSubject(ctx, sqlc.GetUserIdentityByProviderSubjectParams{ Provider: provider, ProviderSubject: claims.Subject, }) if err == nil { updated, updateErr := q.UpdateUserIdentityLogin(ctx, sqlc.UpdateUserIdentityLoginParams{ Provider: provider, ProviderSubject: claims.Subject, Email: email, EmailVerified: claims.EmailVerified, DisplayName: textOrNull(claims.DisplayName), AvatarUrl: textOrNull(claims.AvatarURL), }) if updateErr != nil { return uuid.Nil, fmt.Errorf("update identity login: %w", updateErr) } if _, emailErr := q.UpdateUserEmailIfAvailable(ctx, sqlc.UpdateUserEmailIfAvailableParams{ ID: updated.UserID, Email: email, }); emailErr != nil && !errors.Is(emailErr, pgx.ErrNoRows) { return uuid.Nil, fmt.Errorf("update linked user email: %w", emailErr) } if err := tx.Commit(ctx); err != nil { return uuid.Nil, fmt.Errorf("commit existing identity: %w", err) } return identity.UserID, nil } if !errors.Is(err, pgx.ErrNoRows) { return uuid.Nil, fmt.Errorf("get identity: %w", err) } user, err := q.GetUserByEmail(ctx, email) if err != nil { if errors.Is(err, pgx.ErrNoRows) { user, err = q.InsertSocialUser(ctx, email) } if err != nil { return uuid.Nil, fmt.Errorf("lookup or create user: %w", err) } } if _, err := q.InsertUserIdentity(ctx, sqlc.InsertUserIdentityParams{ UserID: user.ID, Provider: provider, ProviderSubject: claims.Subject, Email: email, EmailVerified: claims.EmailVerified, DisplayName: textOrNull(claims.DisplayName), AvatarUrl: textOrNull(claims.AvatarURL), }); err != nil { return uuid.Nil, fmt.Errorf("insert identity: %w", err) } if err := tx.Commit(ctx); err != nil { return uuid.Nil, fmt.Errorf("commit new identity: %w", err) } return user.ID, nil } func textOrNull(value string) pgtype.Text { if value == "" { return pgtype.Text{} } return pgtype.Text{String: value, Valid: true} }