328 lines
10 KiB
Go
328 lines
10 KiB
Go
package web
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"backend/internal/auth"
|
|
"backend/internal/db/sqlc"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
type TxBeginner interface {
|
|
Begin(ctx context.Context) (pgx.Tx, error)
|
|
}
|
|
|
|
const providerEmailUnverified = "This provider did not return a verified email. Try another sign-in method."
|
|
const providerGenericError = "Could not sign you in with this provider. Try another sign-in method."
|
|
|
|
func GoogleStartHandler(deps AuthDeps) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
cfg := deps.OAuth.Google
|
|
if !cfg.Configured() {
|
|
http.Error(w, "Google sign-in not configured", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
state, err := auth.GenerateOAuthValue()
|
|
if err != nil {
|
|
http.Error(w, "internal server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
nonce, err := auth.GenerateOAuthValue()
|
|
if err != nil {
|
|
http.Error(w, "internal server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
auth.SetOAuthCookie(w, "google", auth.OAuthCookieState, state, deps.Secure)
|
|
auth.SetOAuthCookie(w, "google", auth.OAuthCookieNonce, nonce, deps.Secure)
|
|
|
|
oauthCfg := cfg.OAuth2Config()
|
|
url := oauthCfg.AuthCodeURL(state, oauth2.SetAuthURLParam("nonce", nonce))
|
|
http.Redirect(w, r, url, http.StatusSeeOther)
|
|
}
|
|
}
|
|
|
|
func GoogleCallbackHandler(deps AuthDeps) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
if !auth.ValidateOAuthCookie(r, "google", auth.OAuthCookieState, r.URL.Query().Get("state")) {
|
|
http.Error(w, providerGenericError, http.StatusBadRequest)
|
|
return
|
|
}
|
|
auth.ClearOAuthCookie(w, "google", auth.OAuthCookieState, deps.Secure)
|
|
auth.ClearOAuthCookie(w, "google", auth.OAuthCookieNonce, deps.Secure)
|
|
|
|
code := r.URL.Query().Get("code")
|
|
if code == "" {
|
|
http.Error(w, providerGenericError, http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
exchanger := deps.GoogleTokenExchanger
|
|
if exchanger == nil {
|
|
exchanger = auth.OAuth2CodeExchanger{Config: deps.OAuth.Google.OAuth2Config()}
|
|
}
|
|
token, err := exchanger.Exchange(r.Context(), code)
|
|
if err != nil {
|
|
slog.Default().Warn("google oauth exchange failed", "err", err)
|
|
http.Error(w, providerGenericError, http.StatusUnauthorized)
|
|
return
|
|
}
|
|
rawIDToken, _ := token.Extra("id_token").(string)
|
|
if rawIDToken == "" {
|
|
http.Error(w, providerGenericError, http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
verifier := deps.GoogleVerifier
|
|
if verifier == nil {
|
|
cfg := deps.OAuth.Google
|
|
issuer := cfg.Issuer
|
|
if issuer == "" {
|
|
issuer = "https://accounts.google.com"
|
|
}
|
|
verifier = auth.OIDCVerifier{Provider: "google", Issuer: issuer, ClientID: cfg.ClientID}
|
|
}
|
|
claims, err := verifier.Verify(r.Context(), rawIDToken)
|
|
if err != nil {
|
|
slog.Default().Warn("google id token verification failed", "err", err)
|
|
http.Error(w, providerGenericError, http.StatusUnauthorized)
|
|
return
|
|
}
|
|
if claims.Provider == "" {
|
|
claims.Provider = "google"
|
|
}
|
|
if !auth.ValidateOAuthCookie(r, "google", auth.OAuthCookieNonce, claims.Nonce) {
|
|
http.Error(w, providerGenericError, http.StatusBadRequest)
|
|
return
|
|
}
|
|
if strings.TrimSpace(claims.Email) == "" || !claims.EmailVerified {
|
|
http.Error(w, providerEmailUnverified, http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
userID, err := linkProviderUser(r.Context(), deps, claims)
|
|
if err != nil {
|
|
slog.Default().Error("google account linking failed", "err", err)
|
|
http.Error(w, providerGenericError, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
cookieValue, expiresAt, err := deps.Store.Create(r.Context(), userID)
|
|
if err != nil {
|
|
http.Error(w, providerGenericError, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
auth.SetSessionCookie(w, cookieValue, expiresAt, deps.Secure)
|
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
|
}
|
|
}
|
|
|
|
func AppleStartHandler(deps AuthDeps) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
cfg := deps.OAuth.Apple
|
|
if !cfg.Configured() {
|
|
http.Error(w, "Apple sign-in not configured", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
state, err := auth.GenerateOAuthValue()
|
|
if err != nil {
|
|
http.Error(w, "internal server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
nonce, err := auth.GenerateOAuthValue()
|
|
if err != nil {
|
|
http.Error(w, "internal server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
clientSecret, err := cfg.ClientSecret(timeNow())
|
|
if err != nil {
|
|
http.Error(w, "Apple sign-in not configured", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
auth.SetOAuthCookie(w, "apple", auth.OAuthCookieState, state, deps.Secure)
|
|
auth.SetOAuthCookie(w, "apple", auth.OAuthCookieNonce, nonce, deps.Secure)
|
|
|
|
oauthCfg := cfg.OAuth2Config(clientSecret)
|
|
url := oauthCfg.AuthCodeURL(state,
|
|
oauth2.SetAuthURLParam("nonce", nonce),
|
|
oauth2.SetAuthURLParam("response_mode", "query"),
|
|
)
|
|
http.Redirect(w, r, url, http.StatusSeeOther)
|
|
}
|
|
}
|
|
|
|
func AppleCallbackHandler(deps AuthDeps) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
if !auth.ValidateOAuthCookie(r, "apple", auth.OAuthCookieState, r.URL.Query().Get("state")) {
|
|
http.Error(w, providerGenericError, http.StatusBadRequest)
|
|
return
|
|
}
|
|
auth.ClearOAuthCookie(w, "apple", auth.OAuthCookieState, deps.Secure)
|
|
auth.ClearOAuthCookie(w, "apple", auth.OAuthCookieNonce, deps.Secure)
|
|
|
|
code := r.URL.Query().Get("code")
|
|
if code == "" {
|
|
http.Error(w, providerGenericError, http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
exchanger := deps.AppleTokenExchanger
|
|
if exchanger == nil {
|
|
clientSecret, err := deps.OAuth.Apple.ClientSecret(timeNow())
|
|
if err != nil {
|
|
http.Error(w, providerGenericError, http.StatusUnauthorized)
|
|
return
|
|
}
|
|
exchanger = auth.OAuth2CodeExchanger{Config: deps.OAuth.Apple.OAuth2Config(clientSecret)}
|
|
}
|
|
token, err := exchanger.Exchange(r.Context(), code)
|
|
if err != nil {
|
|
slog.Default().Warn("apple oauth exchange failed", "err", err)
|
|
http.Error(w, providerGenericError, http.StatusUnauthorized)
|
|
return
|
|
}
|
|
rawIDToken, _ := token.Extra("id_token").(string)
|
|
if rawIDToken == "" {
|
|
http.Error(w, providerGenericError, http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
verifier := deps.AppleVerifier
|
|
if verifier == nil {
|
|
cfg := deps.OAuth.Apple
|
|
issuer := cfg.Issuer
|
|
if issuer == "" {
|
|
issuer = "https://appleid.apple.com"
|
|
}
|
|
verifier = auth.OIDCVerifier{Provider: "apple", Issuer: issuer, ClientID: cfg.ClientID}
|
|
}
|
|
claims, err := verifier.Verify(r.Context(), rawIDToken)
|
|
if err != nil {
|
|
slog.Default().Warn("apple id token verification failed", "err", err)
|
|
http.Error(w, providerGenericError, http.StatusUnauthorized)
|
|
return
|
|
}
|
|
if claims.Provider == "" {
|
|
claims.Provider = "apple"
|
|
}
|
|
if !auth.ValidateOAuthCookie(r, "apple", auth.OAuthCookieNonce, claims.Nonce) {
|
|
http.Error(w, providerGenericError, http.StatusBadRequest)
|
|
return
|
|
}
|
|
if strings.TrimSpace(claims.Email) == "" || !claims.EmailVerified {
|
|
http.Error(w, providerEmailUnverified, http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
userID, err := linkProviderUser(r.Context(), deps, claims)
|
|
if err != nil {
|
|
slog.Default().Error("apple account linking failed", "err", err)
|
|
http.Error(w, providerGenericError, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
cookieValue, expiresAt, err := deps.Store.Create(r.Context(), userID)
|
|
if err != nil {
|
|
http.Error(w, providerGenericError, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
auth.SetSessionCookie(w, cookieValue, expiresAt, deps.Secure)
|
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
|
}
|
|
}
|
|
|
|
var timeNow = func() time.Time {
|
|
return time.Now()
|
|
}
|
|
|
|
func linkProviderUser(ctx context.Context, deps AuthDeps, claims auth.ProviderClaims) (uuid.UUID, error) {
|
|
if deps.DB == nil {
|
|
return uuid.Nil, errors.New("missing transaction DB")
|
|
}
|
|
tx, err := deps.DB.Begin(ctx)
|
|
if err != nil {
|
|
return uuid.Nil, fmt.Errorf("begin tx: %w", err)
|
|
}
|
|
defer tx.Rollback(ctx) //nolint:errcheck
|
|
|
|
q := deps.Queries.WithTx(tx)
|
|
email := strings.ToLower(strings.TrimSpace(claims.Email))
|
|
provider := claims.Provider
|
|
if provider == "" {
|
|
provider = "google"
|
|
}
|
|
if provider == "google" {
|
|
claims.Provider = provider
|
|
}
|
|
|
|
identity, err := q.GetUserIdentityByProviderSubject(ctx, sqlc.GetUserIdentityByProviderSubjectParams{
|
|
Provider: provider,
|
|
ProviderSubject: claims.Subject,
|
|
})
|
|
if err == nil {
|
|
updated, updateErr := q.UpdateUserIdentityLogin(ctx, sqlc.UpdateUserIdentityLoginParams{
|
|
Provider: provider,
|
|
ProviderSubject: claims.Subject,
|
|
Email: email,
|
|
EmailVerified: claims.EmailVerified,
|
|
DisplayName: textOrNull(claims.DisplayName),
|
|
AvatarUrl: textOrNull(claims.AvatarURL),
|
|
})
|
|
if updateErr != nil {
|
|
return uuid.Nil, fmt.Errorf("update identity login: %w", updateErr)
|
|
}
|
|
if _, emailErr := q.UpdateUserEmailIfAvailable(ctx, sqlc.UpdateUserEmailIfAvailableParams{
|
|
ID: updated.UserID,
|
|
Email: email,
|
|
}); emailErr != nil && !errors.Is(emailErr, pgx.ErrNoRows) {
|
|
return uuid.Nil, fmt.Errorf("update linked user email: %w", emailErr)
|
|
}
|
|
if err := tx.Commit(ctx); err != nil {
|
|
return uuid.Nil, fmt.Errorf("commit existing identity: %w", err)
|
|
}
|
|
return identity.UserID, nil
|
|
}
|
|
if !errors.Is(err, pgx.ErrNoRows) {
|
|
return uuid.Nil, fmt.Errorf("get identity: %w", err)
|
|
}
|
|
|
|
user, err := q.GetUserByEmail(ctx, email)
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
user, err = q.InsertSocialUser(ctx, email)
|
|
}
|
|
if err != nil {
|
|
return uuid.Nil, fmt.Errorf("lookup or create user: %w", err)
|
|
}
|
|
}
|
|
|
|
if _, err := q.InsertUserIdentity(ctx, sqlc.InsertUserIdentityParams{
|
|
UserID: user.ID,
|
|
Provider: provider,
|
|
ProviderSubject: claims.Subject,
|
|
Email: email,
|
|
EmailVerified: claims.EmailVerified,
|
|
DisplayName: textOrNull(claims.DisplayName),
|
|
AvatarUrl: textOrNull(claims.AvatarURL),
|
|
}); err != nil {
|
|
return uuid.Nil, fmt.Errorf("insert identity: %w", err)
|
|
}
|
|
if err := tx.Commit(ctx); err != nil {
|
|
return uuid.Nil, fmt.Errorf("commit new identity: %w", err)
|
|
}
|
|
return user.ID, nil
|
|
}
|
|
|
|
func textOrNull(value string) pgtype.Text {
|
|
if value == "" {
|
|
return pgtype.Text{}
|
|
}
|
|
return pgtype.Text{String: value, Valid: true}
|
|
}
|