xtablo-source/backend/internal/web/handlers_social.go
2026-05-15 21:03:30 +02:00

210 lines
6.4 KiB
Go

package web
import (
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"backend/internal/auth"
"backend/internal/db/sqlc"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"golang.org/x/oauth2"
)
type TxBeginner interface {
Begin(ctx context.Context) (pgx.Tx, error)
}
const providerEmailUnverified = "This provider did not return a verified email. Try another sign-in method."
const providerGenericError = "Could not sign you in with this provider. Try another sign-in method."
func GoogleStartHandler(deps AuthDeps) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
cfg := deps.OAuth.Google
if !cfg.Configured() {
http.Error(w, "Google sign-in not configured", http.StatusServiceUnavailable)
return
}
state, err := auth.GenerateOAuthValue()
if err != nil {
http.Error(w, "internal server error", http.StatusInternalServerError)
return
}
nonce, err := auth.GenerateOAuthValue()
if err != nil {
http.Error(w, "internal server error", http.StatusInternalServerError)
return
}
auth.SetOAuthCookie(w, "google", auth.OAuthCookieState, state, deps.Secure)
auth.SetOAuthCookie(w, "google", auth.OAuthCookieNonce, nonce, deps.Secure)
oauthCfg := cfg.OAuth2Config()
url := oauthCfg.AuthCodeURL(state, oauth2.SetAuthURLParam("nonce", nonce))
http.Redirect(w, r, url, http.StatusSeeOther)
}
}
func GoogleCallbackHandler(deps AuthDeps) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !auth.ValidateOAuthCookie(r, "google", auth.OAuthCookieState, r.URL.Query().Get("state")) {
http.Error(w, providerGenericError, http.StatusBadRequest)
return
}
auth.ClearOAuthCookie(w, "google", auth.OAuthCookieState, deps.Secure)
auth.ClearOAuthCookie(w, "google", auth.OAuthCookieNonce, deps.Secure)
code := r.URL.Query().Get("code")
if code == "" {
http.Error(w, providerGenericError, http.StatusBadRequest)
return
}
exchanger := deps.GoogleTokenExchanger
if exchanger == nil {
exchanger = auth.OAuth2CodeExchanger{Config: deps.OAuth.Google.OAuth2Config()}
}
token, err := exchanger.Exchange(r.Context(), code)
if err != nil {
slog.Default().Warn("google oauth exchange failed", "err", err)
http.Error(w, providerGenericError, http.StatusUnauthorized)
return
}
rawIDToken, _ := token.Extra("id_token").(string)
if rawIDToken == "" {
http.Error(w, providerGenericError, http.StatusUnauthorized)
return
}
verifier := deps.GoogleVerifier
if verifier == nil {
cfg := deps.OAuth.Google
issuer := cfg.Issuer
if issuer == "" {
issuer = "https://accounts.google.com"
}
verifier = auth.OIDCVerifier{Provider: "google", Issuer: issuer, ClientID: cfg.ClientID}
}
claims, err := verifier.Verify(r.Context(), rawIDToken)
if err != nil {
slog.Default().Warn("google id token verification failed", "err", err)
http.Error(w, providerGenericError, http.StatusUnauthorized)
return
}
if claims.Provider == "" {
claims.Provider = "google"
}
if !auth.ValidateOAuthCookie(r, "google", auth.OAuthCookieNonce, claims.Nonce) {
http.Error(w, providerGenericError, http.StatusBadRequest)
return
}
if strings.TrimSpace(claims.Email) == "" || !claims.EmailVerified {
http.Error(w, providerEmailUnverified, http.StatusUnauthorized)
return
}
userID, err := linkProviderUser(r.Context(), deps, claims)
if err != nil {
slog.Default().Error("google account linking failed", "err", err)
http.Error(w, providerGenericError, http.StatusInternalServerError)
return
}
cookieValue, expiresAt, err := deps.Store.Create(r.Context(), userID)
if err != nil {
http.Error(w, providerGenericError, http.StatusInternalServerError)
return
}
auth.SetSessionCookie(w, cookieValue, expiresAt, deps.Secure)
http.Redirect(w, r, "/", http.StatusSeeOther)
}
}
func linkProviderUser(ctx context.Context, deps AuthDeps, claims auth.ProviderClaims) (uuid.UUID, error) {
if deps.DB == nil {
return uuid.Nil, errors.New("missing transaction DB")
}
tx, err := deps.DB.Begin(ctx)
if err != nil {
return uuid.Nil, fmt.Errorf("begin tx: %w", err)
}
defer tx.Rollback(ctx) //nolint:errcheck
q := deps.Queries.WithTx(tx)
email := strings.ToLower(strings.TrimSpace(claims.Email))
provider := claims.Provider
if provider == "" {
provider = "google"
}
if provider == "google" {
claims.Provider = provider
}
identity, err := q.GetUserIdentityByProviderSubject(ctx, sqlc.GetUserIdentityByProviderSubjectParams{
Provider: provider,
ProviderSubject: claims.Subject,
})
if err == nil {
updated, updateErr := q.UpdateUserIdentityLogin(ctx, sqlc.UpdateUserIdentityLoginParams{
Provider: provider,
ProviderSubject: claims.Subject,
Email: email,
EmailVerified: claims.EmailVerified,
DisplayName: textOrNull(claims.DisplayName),
AvatarUrl: textOrNull(claims.AvatarURL),
})
if updateErr != nil {
return uuid.Nil, fmt.Errorf("update identity login: %w", updateErr)
}
if _, emailErr := q.UpdateUserEmailIfAvailable(ctx, sqlc.UpdateUserEmailIfAvailableParams{
ID: updated.UserID,
Email: email,
}); emailErr != nil && !errors.Is(emailErr, pgx.ErrNoRows) {
return uuid.Nil, fmt.Errorf("update linked user email: %w", emailErr)
}
if err := tx.Commit(ctx); err != nil {
return uuid.Nil, fmt.Errorf("commit existing identity: %w", err)
}
return identity.UserID, nil
}
if !errors.Is(err, pgx.ErrNoRows) {
return uuid.Nil, fmt.Errorf("get identity: %w", err)
}
user, err := q.GetUserByEmail(ctx, email)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
user, err = q.InsertSocialUser(ctx, email)
}
if err != nil {
return uuid.Nil, fmt.Errorf("lookup or create user: %w", err)
}
}
if _, err := q.InsertUserIdentity(ctx, sqlc.InsertUserIdentityParams{
UserID: user.ID,
Provider: provider,
ProviderSubject: claims.Subject,
Email: email,
EmailVerified: claims.EmailVerified,
DisplayName: textOrNull(claims.DisplayName),
AvatarUrl: textOrNull(claims.AvatarURL),
}); err != nil {
return uuid.Nil, fmt.Errorf("insert identity: %w", err)
}
if err := tx.Commit(ctx); err != nil {
return uuid.Nil, fmt.Errorf("commit new identity: %w", err)
}
return user.ID, nil
}
func textOrNull(value string) pgtype.Text {
if value == "" {
return pgtype.Text{}
}
return pgtype.Text{String: value, Valid: true}
}