From a8b6a03eac6894fd640f8910327a0f5fe09ac6d4 Mon Sep 17 00:00:00 2001 From: Arthur Belleville Date: Fri, 15 May 2026 21:06:08 +0200 Subject: [PATCH] feat(08-03): add apple social sign-in flow --- backend/cmd/web/main.go | 24 +++ backend/internal/auth/oauth.go | 124 ++++++++++++- backend/internal/auth/oauth_test.go | 106 +++++++++++ backend/internal/web/handlers_auth.go | 2 + backend/internal/web/handlers_social.go | 118 ++++++++++++ backend/internal/web/handlers_social_test.go | 184 +++++++++++++++++++ backend/internal/web/router.go | 2 + 7 files changed, 553 insertions(+), 7 deletions(-) diff --git a/backend/cmd/web/main.go b/backend/cmd/web/main.go index 18dc40a..3603dfe 100644 --- a/backend/cmd/web/main.go +++ b/backend/cmd/web/main.go @@ -91,9 +91,18 @@ func main() { ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"), RedirectURL: os.Getenv("GOOGLE_REDIRECT_URL"), }, + Apple: auth.AppleProviderConfig{ + ClientID: os.Getenv("APPLE_CLIENT_ID"), + TeamID: os.Getenv("APPLE_TEAM_ID"), + KeyID: os.Getenv("APPLE_KEY_ID"), + PrivateKey: os.Getenv("APPLE_PRIVATE_KEY"), + RedirectURL: os.Getenv("APPLE_REDIRECT_URL"), + }, } var googleExchanger auth.CodeExchanger var googleVerifier auth.IDTokenVerifier + var appleExchanger auth.CodeExchanger + var appleVerifier auth.IDTokenVerifier if oauthCfg.Google.Configured() { googleExchanger = auth.OAuth2CodeExchanger{Config: oauthCfg.Google.OAuth2Config()} googleVerifier = auth.OIDCVerifier{ @@ -102,6 +111,19 @@ func main() { ClientID: oauthCfg.Google.ClientID, } } + if oauthCfg.Apple.Configured() { + appleSecret, err := oauthCfg.Apple.ClientSecret(time.Now()) + if err != nil { + slog.Error("invalid Apple sign-in config", "err", err) + os.Exit(1) + } + appleExchanger = auth.OAuth2CodeExchanger{Config: oauthCfg.Apple.OAuth2Config(appleSecret)} + appleVerifier = auth.OIDCVerifier{ + Provider: "apple", + Issuer: "https://appleid.apple.com", + ClientID: oauthCfg.Apple.ClientID, + } + } deps := web.AuthDeps{ Queries: q, @@ -112,6 +134,8 @@ func main() { OAuth: oauthCfg, GoogleTokenExchanger: googleExchanger, GoogleVerifier: googleVerifier, + AppleTokenExchanger: appleExchanger, + AppleVerifier: appleVerifier, } tabloDeps := web.TablosDeps{Queries: q} taskDeps := web.TasksDeps{Queries: q} diff --git a/backend/internal/auth/oauth.go b/backend/internal/auth/oauth.go index f6e350b..7e45738 100644 --- a/backend/internal/auth/oauth.go +++ b/backend/internal/auth/oauth.go @@ -2,15 +2,22 @@ package auth import ( "context" + "crypto/ecdsa" "crypto/rand" "crypto/subtle" + "crypto/x509" "encoding/base64" + "encoding/json" + "encoding/pem" "errors" "fmt" "net/http" + "strings" "time" "github.com/coreos/go-oidc/v3/oidc" + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" "golang.org/x/oauth2" ) @@ -30,6 +37,92 @@ type GoogleProviderConfig struct { Issuer string } +type AppleProviderConfig struct { + ClientID string + TeamID string + KeyID string + PrivateKey string + RedirectURL string + AuthURL string + TokenURL string + Issuer string +} + +func (c AppleProviderConfig) Configured() bool { + return c.ClientID != "" && c.TeamID != "" && c.KeyID != "" && c.PrivateKey != "" && c.RedirectURL != "" +} + +func (c AppleProviderConfig) withDefaults() AppleProviderConfig { + if c.AuthURL == "" { + c.AuthURL = "https://appleid.apple.com/auth/authorize" + } + if c.TokenURL == "" { + c.TokenURL = "https://appleid.apple.com/auth/token" + } + if c.Issuer == "" { + c.Issuer = "https://appleid.apple.com" + } + return c +} + +func (c AppleProviderConfig) OAuth2Config(clientSecret string) oauth2.Config { + c = c.withDefaults() + return oauth2.Config{ + ClientID: c.ClientID, + ClientSecret: clientSecret, + RedirectURL: c.RedirectURL, + Scopes: []string{"name", "email"}, + Endpoint: oauth2.Endpoint{ + AuthURL: c.AuthURL, + TokenURL: c.TokenURL, + }, + } +} + +func (c AppleProviderConfig) ClientSecret(now time.Time) (string, error) { + privateKey, err := parseApplePrivateKey(c.PrivateKey) + if err != nil { + return "", err + } + opts := (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", c.KeyID) + signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: privateKey}, opts) + if err != nil { + return "", fmt.Errorf("auth: create apple client secret signer: %w", err) + } + claims := jwt.Claims{ + Issuer: c.TeamID, + Subject: c.ClientID, + Audience: jwt.Audience{"https://appleid.apple.com"}, + IssuedAt: jwt.NewNumericDate(now), + Expiry: jwt.NewNumericDate(now.Add(6 * time.Hour)), + } + secret, err := jwt.Signed(signer).Claims(claims).Serialize() + if err != nil { + return "", fmt.Errorf("auth: sign apple client secret: %w", err) + } + return secret, nil +} + +func parseApplePrivateKey(raw string) (*ecdsa.PrivateKey, error) { + normalized := strings.ReplaceAll(raw, `\n`, "\n") + block, _ := pem.Decode([]byte(normalized)) + if block == nil { + return nil, errors.New("auth: apple private key PEM block missing") + } + if key, err := x509.ParseECPrivateKey(block.Bytes); err == nil { + return key, nil + } + parsed, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("auth: parse apple private key: %w", err) + } + key, ok := parsed.(*ecdsa.PrivateKey) + if !ok { + return nil, errors.New("auth: apple private key is not ECDSA") + } + return key, nil +} + func (c GoogleProviderConfig) Configured() bool { return c.ClientID != "" && c.ClientSecret != "" && c.RedirectURL != "" } @@ -63,6 +156,7 @@ func (c GoogleProviderConfig) OAuth2Config() oauth2.Config { type OAuthConfig struct { Google GoogleProviderConfig + Apple AppleProviderConfig } type ProviderClaims struct { @@ -107,12 +201,12 @@ func (v OIDCVerifier) Verify(ctx context.Context, rawIDToken string) (ProviderCl return ProviderClaims{}, fmt.Errorf("auth: verify id token: %w", err) } var claims struct { - Subject string `json:"sub"` - Email string `json:"email"` - EmailVerified bool `json:"email_verified"` - Name string `json:"name"` - Picture string `json:"picture"` - Nonce string `json:"nonce"` + Subject string `json:"sub"` + Email string `json:"email"` + EmailVerified verifiedBool `json:"email_verified"` + Name string `json:"name"` + Picture string `json:"picture"` + Nonce string `json:"nonce"` } if err := verified.Claims(&claims); err != nil { return ProviderClaims{}, fmt.Errorf("auth: decode id token claims: %w", err) @@ -124,13 +218,29 @@ func (v OIDCVerifier) Verify(ctx context.Context, rawIDToken string) (ProviderCl Provider: v.Provider, Subject: claims.Subject, Email: claims.Email, - EmailVerified: claims.EmailVerified, + EmailVerified: bool(claims.EmailVerified), DisplayName: claims.Name, AvatarURL: claims.Picture, Nonce: claims.Nonce, }, nil } +type verifiedBool bool + +func (v *verifiedBool) UnmarshalJSON(data []byte) error { + var b bool + if err := json.Unmarshal(data, &b); err == nil { + *v = verifiedBool(b) + return nil + } + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + *v = verifiedBool(s == "true") + return nil +} + func GenerateOAuthValue() (string, error) { raw := make([]byte, 32) if _, err := rand.Read(raw); err != nil { diff --git a/backend/internal/auth/oauth_test.go b/backend/internal/auth/oauth_test.go index f0f6eb8..48fcead 100644 --- a/backend/internal/auth/oauth_test.go +++ b/backend/internal/auth/oauth_test.go @@ -1,9 +1,18 @@ package auth import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "encoding/pem" "net/http" "net/http/httptest" "testing" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" ) func TestGoogleProviderConfigConfigured(t *testing.T) { @@ -54,3 +63,100 @@ func TestOAuthCookieNameIncludesProviderAndKind(t *testing.T) { t.Fatalf("nonce cookie name = %q", got) } } + +func TestAppleProviderConfigConfigured(t *testing.T) { + empty := AppleProviderConfig{} + if empty.Configured() { + t.Fatal("empty Apple config must not be configured") + } + + cfg := AppleProviderConfig{ + ClientID: "com.xtablo.web", + TeamID: "TEAMID1234", + KeyID: "KEYID1234", + PrivateKey: testApplePrivateKeyPEM(t), + RedirectURL: "https://xtablo.test/auth/apple/callback", + } + if !cfg.Configured() { + t.Fatal("complete Apple config must be configured") + } +} + +func TestAppleClientSecretClaimsAndKeyID(t *testing.T) { + now := time.Date(2026, 5, 15, 12, 0, 0, 0, time.UTC) + privateKeyPEM := testApplePrivateKeyPEM(t) + cfg := AppleProviderConfig{ + ClientID: "com.xtablo.web", + TeamID: "TEAMID1234", + KeyID: "KEYID1234", + PrivateKey: stringsWithEscapedNewlines(privateKeyPEM), + RedirectURL: "https://xtablo.test/auth/apple/callback", + } + + secret, err := cfg.ClientSecret(now) + if err != nil { + t.Fatalf("ClientSecret: %v", err) + } + key := parseApplePrivateKeyForTest(t, privateKeyPEM) + parsed, err := jwt.ParseSigned(secret, []jose.SignatureAlgorithm{jose.ES256}) + if err != nil { + t.Fatalf("ParseSigned: %v", err) + } + var claims jwt.Claims + if err := parsed.Claims(&key.PublicKey, &claims); err != nil { + t.Fatalf("Claims: %v", err) + } + if claims.Issuer != "TEAMID1234" { + t.Fatalf("iss = %q", claims.Issuer) + } + if claims.Subject != "com.xtablo.web" { + t.Fatalf("sub = %q", claims.Subject) + } + if len(claims.Audience) != 1 || claims.Audience[0] != "https://appleid.apple.com" { + t.Fatalf("aud = %#v", claims.Audience) + } + if !claims.Expiry.Time().After(now) { + t.Fatalf("exp = %s; want after %s", claims.Expiry.Time(), now) + } + if parsed.Headers[0].KeyID != "KEYID1234" { + t.Fatalf("kid = %q", parsed.Headers[0].KeyID) + } +} + +func testApplePrivateKeyPEM(t *testing.T) string { + t.Helper() + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + der, err := x509.MarshalECPrivateKey(key) + if err != nil { + t.Fatalf("MarshalECPrivateKey: %v", err) + } + return string(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: der})) +} + +func parseApplePrivateKeyForTest(t *testing.T, privateKeyPEM string) *ecdsa.PrivateKey { + t.Helper() + block, _ := pem.Decode([]byte(privateKeyPEM)) + if block == nil { + t.Fatal("missing PEM block") + } + key, err := x509.ParseECPrivateKey(block.Bytes) + if err != nil { + t.Fatalf("ParseECPrivateKey: %v", err) + } + return key +} + +func stringsWithEscapedNewlines(value string) string { + out := "" + for _, r := range value { + if r == '\n' { + out += `\n` + } else { + out += string(r) + } + } + return out +} diff --git a/backend/internal/web/handlers_auth.go b/backend/internal/web/handlers_auth.go index 9d1a1c6..8198e5c 100644 --- a/backend/internal/web/handlers_auth.go +++ b/backend/internal/web/handlers_auth.go @@ -32,6 +32,8 @@ type AuthDeps struct { OAuth auth.OAuthConfig GoogleTokenExchanger auth.CodeExchanger GoogleVerifier auth.IDTokenVerifier + AppleTokenExchanger auth.CodeExchanger + AppleVerifier auth.IDTokenVerifier } // errInvalidCreds is the intentionally generic error message for login failures diff --git a/backend/internal/web/handlers_social.go b/backend/internal/web/handlers_social.go index 48b2bf0..55ba5de 100644 --- a/backend/internal/web/handlers_social.go +++ b/backend/internal/web/handlers_social.go @@ -7,6 +7,7 @@ import ( "log/slog" "net/http" "strings" + "time" "backend/internal/auth" "backend/internal/db/sqlc" @@ -124,6 +125,123 @@ func GoogleCallbackHandler(deps AuthDeps) http.HandlerFunc { } } +func AppleStartHandler(deps AuthDeps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + cfg := deps.OAuth.Apple + if !cfg.Configured() { + http.Error(w, "Apple sign-in not configured", http.StatusServiceUnavailable) + return + } + state, err := auth.GenerateOAuthValue() + if err != nil { + http.Error(w, "internal server error", http.StatusInternalServerError) + return + } + nonce, err := auth.GenerateOAuthValue() + if err != nil { + http.Error(w, "internal server error", http.StatusInternalServerError) + return + } + clientSecret, err := cfg.ClientSecret(timeNow()) + if err != nil { + http.Error(w, "Apple sign-in not configured", http.StatusServiceUnavailable) + return + } + auth.SetOAuthCookie(w, "apple", auth.OAuthCookieState, state, deps.Secure) + auth.SetOAuthCookie(w, "apple", auth.OAuthCookieNonce, nonce, deps.Secure) + + oauthCfg := cfg.OAuth2Config(clientSecret) + url := oauthCfg.AuthCodeURL(state, + oauth2.SetAuthURLParam("nonce", nonce), + oauth2.SetAuthURLParam("response_mode", "query"), + ) + http.Redirect(w, r, url, http.StatusSeeOther) + } +} + +func AppleCallbackHandler(deps AuthDeps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if !auth.ValidateOAuthCookie(r, "apple", auth.OAuthCookieState, r.URL.Query().Get("state")) { + http.Error(w, providerGenericError, http.StatusBadRequest) + return + } + auth.ClearOAuthCookie(w, "apple", auth.OAuthCookieState, deps.Secure) + auth.ClearOAuthCookie(w, "apple", auth.OAuthCookieNonce, deps.Secure) + + code := r.URL.Query().Get("code") + if code == "" { + http.Error(w, providerGenericError, http.StatusBadRequest) + return + } + + exchanger := deps.AppleTokenExchanger + if exchanger == nil { + clientSecret, err := deps.OAuth.Apple.ClientSecret(timeNow()) + if err != nil { + http.Error(w, providerGenericError, http.StatusUnauthorized) + return + } + exchanger = auth.OAuth2CodeExchanger{Config: deps.OAuth.Apple.OAuth2Config(clientSecret)} + } + token, err := exchanger.Exchange(r.Context(), code) + if err != nil { + slog.Default().Warn("apple oauth exchange failed", "err", err) + http.Error(w, providerGenericError, http.StatusUnauthorized) + return + } + rawIDToken, _ := token.Extra("id_token").(string) + if rawIDToken == "" { + http.Error(w, providerGenericError, http.StatusUnauthorized) + return + } + + verifier := deps.AppleVerifier + if verifier == nil { + cfg := deps.OAuth.Apple + issuer := cfg.Issuer + if issuer == "" { + issuer = "https://appleid.apple.com" + } + verifier = auth.OIDCVerifier{Provider: "apple", Issuer: issuer, ClientID: cfg.ClientID} + } + claims, err := verifier.Verify(r.Context(), rawIDToken) + if err != nil { + slog.Default().Warn("apple id token verification failed", "err", err) + http.Error(w, providerGenericError, http.StatusUnauthorized) + return + } + if claims.Provider == "" { + claims.Provider = "apple" + } + if !auth.ValidateOAuthCookie(r, "apple", auth.OAuthCookieNonce, claims.Nonce) { + http.Error(w, providerGenericError, http.StatusBadRequest) + return + } + if strings.TrimSpace(claims.Email) == "" || !claims.EmailVerified { + http.Error(w, providerEmailUnverified, http.StatusUnauthorized) + return + } + + userID, err := linkProviderUser(r.Context(), deps, claims) + if err != nil { + slog.Default().Error("apple account linking failed", "err", err) + http.Error(w, providerGenericError, http.StatusInternalServerError) + return + } + cookieValue, expiresAt, err := deps.Store.Create(r.Context(), userID) + if err != nil { + http.Error(w, providerGenericError, http.StatusInternalServerError) + return + } + auth.SetSessionCookie(w, cookieValue, expiresAt, deps.Secure) + http.Redirect(w, r, "/", http.StatusSeeOther) + } +} + +var timeNow = func() time.Time { + return time.Now() +} + func linkProviderUser(ctx context.Context, deps AuthDeps, claims auth.ProviderClaims) (uuid.UUID, error) { if deps.DB == nil { return uuid.Nil, errors.New("missing transaction DB") diff --git a/backend/internal/web/handlers_social_test.go b/backend/internal/web/handlers_social_test.go index 4999ecd..f59a00f 100644 --- a/backend/internal/web/handlers_social_test.go +++ b/backend/internal/web/handlers_social_test.go @@ -2,6 +2,11 @@ package web import ( "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "encoding/pem" "net/http" "net/http/httptest" "net/url" @@ -65,11 +70,66 @@ func newGoogleAuthDeps(q *sqlc.Queries, store *auth.Store) AuthDeps { } } +func newAppleAuthDeps(t *testing.T, q *sqlc.Queries, store *auth.Store) AuthDeps { + t.Helper() + return AuthDeps{ + Queries: q, + Store: store, + Secure: false, + OAuth: auth.OAuthConfig{ + Apple: auth.AppleProviderConfig{ + ClientID: "com.xtablo.web", + TeamID: "TEAMID1234", + KeyID: "KEYID1234", + PrivateKey: testApplePrivateKeyPEMForWeb(t), + RedirectURL: "https://xtablo.test/auth/apple/callback", + AuthURL: "https://appleid.apple.test/auth/authorize", + TokenURL: "https://appleid.apple.test/auth/token", + Issuer: "https://appleid.apple.test", + }, + }, + AppleTokenExchanger: fakeCodeExchanger{ + token: (&oauth2.Token{AccessToken: "access"}).WithExtra(map[string]any{"id_token": "raw-apple-id-token"}), + }, + AppleVerifier: fakeIDTokenVerifier{ + claims: auth.ProviderClaims{ + Provider: "apple", + Subject: "apple-subject-1", + Email: "apple@example.com", + EmailVerified: true, + DisplayName: "Apple User", + Nonce: "nonce-value", + }, + }, + } +} + +func testApplePrivateKeyPEMForWeb(t interface { + Helper() + Fatalf(string, ...any) +}) string { + t.Helper() + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + der, err := x509.MarshalECPrivateKey(key) + if err != nil { + t.Fatalf("MarshalECPrivateKey: %v", err) + } + return string(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: der})) +} + func withGoogleClaims(deps AuthDeps, claims auth.ProviderClaims) AuthDeps { deps.GoogleVerifier = fakeIDTokenVerifier{claims: claims} return deps } +func withAppleClaims(deps AuthDeps, claims auth.ProviderClaims) AuthDeps { + deps.AppleVerifier = fakeIDTokenVerifier{claims: claims} + return deps +} + func newSocialRouter(t *testing.T, deps AuthDeps) http.Handler { t.Helper() router, err := NewRouter(stubPinger{}, os.DirFS("./static"), deps, TablosDeps{}, TasksDeps{}, FilesDeps{}, testCSRFKey, "dev", "localhost") @@ -355,6 +415,120 @@ func TestGoogleCallbackEmailUpdateConflictDoesNotRelinkSubject(t *testing.T) { } } +func TestAppleStartRedirectsAndSetsStateNonceCookies(t *testing.T) { + router := newSocialRouter(t, newAppleAuthDeps(t, nil, nil)) + + req := httptest.NewRequest(http.MethodGet, "/auth/apple/start", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusSeeOther { + t.Fatalf("status = %d; want 303", rec.Code) + } + loc := rec.Header().Get("Location") + if !strings.Contains(loc, "https://appleid.apple.test/auth/authorize") { + t.Fatalf("Location = %q; want Apple auth URL", loc) + } + if !strings.Contains(loc, "client_id=com.xtablo.web") { + t.Fatalf("Location = %q; missing client_id", loc) + } + if !strings.Contains(loc, "scope=name+email") { + t.Fatalf("Location = %q; missing name email scope", loc) + } + if findCookie(rec.Result().Cookies(), auth.OAuthCookieName("apple", auth.OAuthCookieState)) == nil { + t.Fatal("missing Apple state cookie") + } + if findCookie(rec.Result().Cookies(), auth.OAuthCookieName("apple", auth.OAuthCookieNonce)) == nil { + t.Fatal("missing Apple nonce cookie") + } +} + +func TestAppleCallbackInvalidNonceRejectedBeforeLinking(t *testing.T) { + deps := newAppleAuthDeps(t, nil, nil) + deps.AppleVerifier = fakeIDTokenVerifier{claims: auth.ProviderClaims{ + Provider: "apple", + Subject: "apple-subject-1", + Email: "apple@example.com", + EmailVerified: true, + Nonce: "wrong-nonce", + }} + router := newSocialRouter(t, deps) + + callback := "/auth/apple/callback?" + url.Values{"state": {"state-value"}, "code": {"code"}}.Encode() + req := httptest.NewRequest(http.MethodGet, callback, nil) + req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("apple", auth.OAuthCookieState), Value: "state-value"}) + req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("apple", auth.OAuthCookieNonce), Value: "nonce-value"}) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d; want 400", rec.Code) + } +} + +func TestAppleCallbackUnverifiedEmailRejected(t *testing.T) { + deps := newAppleAuthDeps(t, nil, nil) + deps.AppleVerifier = fakeIDTokenVerifier{claims: auth.ProviderClaims{ + Provider: "apple", + Subject: "apple-subject-1", + Email: "apple@example.com", + EmailVerified: false, + Nonce: "nonce-value", + }} + router := newSocialRouter(t, deps) + + rec := serveAppleCallback(router) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status = %d; want 401", rec.Code) + } + if !strings.Contains(rec.Body.String(), "This provider did not return a verified email. Try another sign-in method.") { + t.Fatalf("body missing unverified email copy; got: %s", rec.Body.String()) + } +} + +func TestAppleCallbackVerifiedRelayEmailStoresNameAndSetsSession(t *testing.T) { + pool, cleanup := setupTestDB(t) + defer cleanup() + + ctx := context.Background() + q := sqlc.New(pool) + store := auth.NewStore(q) + deps := newAppleAuthDeps(t, q, store) + deps.DB = pool + deps = withAppleClaims(deps, auth.ProviderClaims{ + Provider: "apple", + Subject: "apple-relay-subject", + Email: "relay@privaterelay.appleid.com", + EmailVerified: true, + DisplayName: "Apple Relay", + Nonce: "nonce-value", + }) + router := newSocialRouter(t, deps) + + rec := serveAppleCallback(router) + + if rec.Code != http.StatusSeeOther { + t.Fatalf("status = %d; want 303; body: %s", rec.Code, rec.Body.String()) + } + if c := getSessionCookie(rec); c == nil { + t.Fatal("session cookie not set") + } + identity, err := q.GetUserIdentityByProviderSubject(ctx, sqlc.GetUserIdentityByProviderSubjectParams{ + Provider: "apple", + ProviderSubject: "apple-relay-subject", + }) + if err != nil { + t.Fatalf("GetUserIdentityByProviderSubject: %v", err) + } + if identity.Email != "relay@privaterelay.appleid.com" { + t.Fatalf("identity email = %q", identity.Email) + } + if !identity.DisplayName.Valid || identity.DisplayName.String != "Apple Relay" { + t.Fatalf("display name = %#v; want Apple Relay", identity.DisplayName) + } +} + func serveGoogleCallback(router http.Handler) *httptest.ResponseRecorder { callback := "/auth/google/callback?" + url.Values{"state": {"state-value"}, "code": {"code"}}.Encode() req := httptest.NewRequest(http.MethodGet, callback, nil) @@ -365,6 +539,16 @@ func serveGoogleCallback(router http.Handler) *httptest.ResponseRecorder { return rec } +func serveAppleCallback(router http.Handler) *httptest.ResponseRecorder { + callback := "/auth/apple/callback?" + url.Values{"state": {"state-value"}, "code": {"code"}}.Encode() + req := httptest.NewRequest(http.MethodGet, callback, nil) + req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("apple", auth.OAuthCookieState), Value: "state-value"}) + req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("apple", auth.OAuthCookieNonce), Value: "nonce-value"}) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + return rec +} + func findCookie(cookies []*http.Cookie, name string) *http.Cookie { for _, c := range cookies { if c.Name == name { diff --git a/backend/internal/web/router.go b/backend/internal/web/router.go index 0a984f6..13c975a 100644 --- a/backend/internal/web/router.go +++ b/backend/internal/web/router.go @@ -73,6 +73,8 @@ func NewRouter(pinger Pinger, staticFS fs.FS, deps AuthDeps, tabloDeps TablosDep r.Post("/login", LoginPostHandler(deps)) r.Get("/auth/google/start", GoogleStartHandler(deps)) r.Get("/auth/google/callback", GoogleCallbackHandler(deps)) + r.Get("/auth/apple/start", AppleStartHandler(deps)) + r.Get("/auth/apple/callback", AppleCallbackHandler(deps)) // Protected routes — require an authenticated session (D-23, AUTH-05). // RequireAuth checks the context set by ResolveSession above and redirects