feat(08-03): add apple social sign-in flow
This commit is contained in:
parent
6779663c8a
commit
a8b6a03eac
7 changed files with 553 additions and 7 deletions
|
|
@ -91,9 +91,18 @@ func main() {
|
||||||
ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"),
|
ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"),
|
||||||
RedirectURL: os.Getenv("GOOGLE_REDIRECT_URL"),
|
RedirectURL: os.Getenv("GOOGLE_REDIRECT_URL"),
|
||||||
},
|
},
|
||||||
|
Apple: auth.AppleProviderConfig{
|
||||||
|
ClientID: os.Getenv("APPLE_CLIENT_ID"),
|
||||||
|
TeamID: os.Getenv("APPLE_TEAM_ID"),
|
||||||
|
KeyID: os.Getenv("APPLE_KEY_ID"),
|
||||||
|
PrivateKey: os.Getenv("APPLE_PRIVATE_KEY"),
|
||||||
|
RedirectURL: os.Getenv("APPLE_REDIRECT_URL"),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
var googleExchanger auth.CodeExchanger
|
var googleExchanger auth.CodeExchanger
|
||||||
var googleVerifier auth.IDTokenVerifier
|
var googleVerifier auth.IDTokenVerifier
|
||||||
|
var appleExchanger auth.CodeExchanger
|
||||||
|
var appleVerifier auth.IDTokenVerifier
|
||||||
if oauthCfg.Google.Configured() {
|
if oauthCfg.Google.Configured() {
|
||||||
googleExchanger = auth.OAuth2CodeExchanger{Config: oauthCfg.Google.OAuth2Config()}
|
googleExchanger = auth.OAuth2CodeExchanger{Config: oauthCfg.Google.OAuth2Config()}
|
||||||
googleVerifier = auth.OIDCVerifier{
|
googleVerifier = auth.OIDCVerifier{
|
||||||
|
|
@ -102,6 +111,19 @@ func main() {
|
||||||
ClientID: oauthCfg.Google.ClientID,
|
ClientID: oauthCfg.Google.ClientID,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if oauthCfg.Apple.Configured() {
|
||||||
|
appleSecret, err := oauthCfg.Apple.ClientSecret(time.Now())
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("invalid Apple sign-in config", "err", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
appleExchanger = auth.OAuth2CodeExchanger{Config: oauthCfg.Apple.OAuth2Config(appleSecret)}
|
||||||
|
appleVerifier = auth.OIDCVerifier{
|
||||||
|
Provider: "apple",
|
||||||
|
Issuer: "https://appleid.apple.com",
|
||||||
|
ClientID: oauthCfg.Apple.ClientID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
deps := web.AuthDeps{
|
deps := web.AuthDeps{
|
||||||
Queries: q,
|
Queries: q,
|
||||||
|
|
@ -112,6 +134,8 @@ func main() {
|
||||||
OAuth: oauthCfg,
|
OAuth: oauthCfg,
|
||||||
GoogleTokenExchanger: googleExchanger,
|
GoogleTokenExchanger: googleExchanger,
|
||||||
GoogleVerifier: googleVerifier,
|
GoogleVerifier: googleVerifier,
|
||||||
|
AppleTokenExchanger: appleExchanger,
|
||||||
|
AppleVerifier: appleVerifier,
|
||||||
}
|
}
|
||||||
tabloDeps := web.TablosDeps{Queries: q}
|
tabloDeps := web.TablosDeps{Queries: q}
|
||||||
taskDeps := web.TasksDeps{Queries: q}
|
taskDeps := web.TasksDeps{Queries: q}
|
||||||
|
|
|
||||||
|
|
@ -2,15 +2,22 @@ package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/ecdsa"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/go-jose/go-jose/v4"
|
||||||
|
"github.com/go-jose/go-jose/v4/jwt"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -30,6 +37,92 @@ type GoogleProviderConfig struct {
|
||||||
Issuer string
|
Issuer string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AppleProviderConfig struct {
|
||||||
|
ClientID string
|
||||||
|
TeamID string
|
||||||
|
KeyID string
|
||||||
|
PrivateKey string
|
||||||
|
RedirectURL string
|
||||||
|
AuthURL string
|
||||||
|
TokenURL string
|
||||||
|
Issuer string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c AppleProviderConfig) Configured() bool {
|
||||||
|
return c.ClientID != "" && c.TeamID != "" && c.KeyID != "" && c.PrivateKey != "" && c.RedirectURL != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c AppleProviderConfig) withDefaults() AppleProviderConfig {
|
||||||
|
if c.AuthURL == "" {
|
||||||
|
c.AuthURL = "https://appleid.apple.com/auth/authorize"
|
||||||
|
}
|
||||||
|
if c.TokenURL == "" {
|
||||||
|
c.TokenURL = "https://appleid.apple.com/auth/token"
|
||||||
|
}
|
||||||
|
if c.Issuer == "" {
|
||||||
|
c.Issuer = "https://appleid.apple.com"
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c AppleProviderConfig) OAuth2Config(clientSecret string) oauth2.Config {
|
||||||
|
c = c.withDefaults()
|
||||||
|
return oauth2.Config{
|
||||||
|
ClientID: c.ClientID,
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
RedirectURL: c.RedirectURL,
|
||||||
|
Scopes: []string{"name", "email"},
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: c.AuthURL,
|
||||||
|
TokenURL: c.TokenURL,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c AppleProviderConfig) ClientSecret(now time.Time) (string, error) {
|
||||||
|
privateKey, err := parseApplePrivateKey(c.PrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
opts := (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", c.KeyID)
|
||||||
|
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: privateKey}, opts)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("auth: create apple client secret signer: %w", err)
|
||||||
|
}
|
||||||
|
claims := jwt.Claims{
|
||||||
|
Issuer: c.TeamID,
|
||||||
|
Subject: c.ClientID,
|
||||||
|
Audience: jwt.Audience{"https://appleid.apple.com"},
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
Expiry: jwt.NewNumericDate(now.Add(6 * time.Hour)),
|
||||||
|
}
|
||||||
|
secret, err := jwt.Signed(signer).Claims(claims).Serialize()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("auth: sign apple client secret: %w", err)
|
||||||
|
}
|
||||||
|
return secret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseApplePrivateKey(raw string) (*ecdsa.PrivateKey, error) {
|
||||||
|
normalized := strings.ReplaceAll(raw, `\n`, "\n")
|
||||||
|
block, _ := pem.Decode([]byte(normalized))
|
||||||
|
if block == nil {
|
||||||
|
return nil, errors.New("auth: apple private key PEM block missing")
|
||||||
|
}
|
||||||
|
if key, err := x509.ParseECPrivateKey(block.Bytes); err == nil {
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
parsed, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("auth: parse apple private key: %w", err)
|
||||||
|
}
|
||||||
|
key, ok := parsed.(*ecdsa.PrivateKey)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("auth: apple private key is not ECDSA")
|
||||||
|
}
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c GoogleProviderConfig) Configured() bool {
|
func (c GoogleProviderConfig) Configured() bool {
|
||||||
return c.ClientID != "" && c.ClientSecret != "" && c.RedirectURL != ""
|
return c.ClientID != "" && c.ClientSecret != "" && c.RedirectURL != ""
|
||||||
}
|
}
|
||||||
|
|
@ -63,6 +156,7 @@ func (c GoogleProviderConfig) OAuth2Config() oauth2.Config {
|
||||||
|
|
||||||
type OAuthConfig struct {
|
type OAuthConfig struct {
|
||||||
Google GoogleProviderConfig
|
Google GoogleProviderConfig
|
||||||
|
Apple AppleProviderConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProviderClaims struct {
|
type ProviderClaims struct {
|
||||||
|
|
@ -107,12 +201,12 @@ func (v OIDCVerifier) Verify(ctx context.Context, rawIDToken string) (ProviderCl
|
||||||
return ProviderClaims{}, fmt.Errorf("auth: verify id token: %w", err)
|
return ProviderClaims{}, fmt.Errorf("auth: verify id token: %w", err)
|
||||||
}
|
}
|
||||||
var claims struct {
|
var claims struct {
|
||||||
Subject string `json:"sub"`
|
Subject string `json:"sub"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
EmailVerified bool `json:"email_verified"`
|
EmailVerified verifiedBool `json:"email_verified"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Picture string `json:"picture"`
|
Picture string `json:"picture"`
|
||||||
Nonce string `json:"nonce"`
|
Nonce string `json:"nonce"`
|
||||||
}
|
}
|
||||||
if err := verified.Claims(&claims); err != nil {
|
if err := verified.Claims(&claims); err != nil {
|
||||||
return ProviderClaims{}, fmt.Errorf("auth: decode id token claims: %w", err)
|
return ProviderClaims{}, fmt.Errorf("auth: decode id token claims: %w", err)
|
||||||
|
|
@ -124,13 +218,29 @@ func (v OIDCVerifier) Verify(ctx context.Context, rawIDToken string) (ProviderCl
|
||||||
Provider: v.Provider,
|
Provider: v.Provider,
|
||||||
Subject: claims.Subject,
|
Subject: claims.Subject,
|
||||||
Email: claims.Email,
|
Email: claims.Email,
|
||||||
EmailVerified: claims.EmailVerified,
|
EmailVerified: bool(claims.EmailVerified),
|
||||||
DisplayName: claims.Name,
|
DisplayName: claims.Name,
|
||||||
AvatarURL: claims.Picture,
|
AvatarURL: claims.Picture,
|
||||||
Nonce: claims.Nonce,
|
Nonce: claims.Nonce,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type verifiedBool bool
|
||||||
|
|
||||||
|
func (v *verifiedBool) UnmarshalJSON(data []byte) error {
|
||||||
|
var b bool
|
||||||
|
if err := json.Unmarshal(data, &b); err == nil {
|
||||||
|
*v = verifiedBool(b)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(data, &s); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*v = verifiedBool(s == "true")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func GenerateOAuthValue() (string, error) {
|
func GenerateOAuthValue() (string, error) {
|
||||||
raw := make([]byte, 32)
|
raw := make([]byte, 32)
|
||||||
if _, err := rand.Read(raw); err != nil {
|
if _, err := rand.Read(raw); err != nil {
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,18 @@
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-jose/go-jose/v4"
|
||||||
|
"github.com/go-jose/go-jose/v4/jwt"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGoogleProviderConfigConfigured(t *testing.T) {
|
func TestGoogleProviderConfigConfigured(t *testing.T) {
|
||||||
|
|
@ -54,3 +63,100 @@ func TestOAuthCookieNameIncludesProviderAndKind(t *testing.T) {
|
||||||
t.Fatalf("nonce cookie name = %q", got)
|
t.Fatalf("nonce cookie name = %q", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAppleProviderConfigConfigured(t *testing.T) {
|
||||||
|
empty := AppleProviderConfig{}
|
||||||
|
if empty.Configured() {
|
||||||
|
t.Fatal("empty Apple config must not be configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := AppleProviderConfig{
|
||||||
|
ClientID: "com.xtablo.web",
|
||||||
|
TeamID: "TEAMID1234",
|
||||||
|
KeyID: "KEYID1234",
|
||||||
|
PrivateKey: testApplePrivateKeyPEM(t),
|
||||||
|
RedirectURL: "https://xtablo.test/auth/apple/callback",
|
||||||
|
}
|
||||||
|
if !cfg.Configured() {
|
||||||
|
t.Fatal("complete Apple config must be configured")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleClientSecretClaimsAndKeyID(t *testing.T) {
|
||||||
|
now := time.Date(2026, 5, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
privateKeyPEM := testApplePrivateKeyPEM(t)
|
||||||
|
cfg := AppleProviderConfig{
|
||||||
|
ClientID: "com.xtablo.web",
|
||||||
|
TeamID: "TEAMID1234",
|
||||||
|
KeyID: "KEYID1234",
|
||||||
|
PrivateKey: stringsWithEscapedNewlines(privateKeyPEM),
|
||||||
|
RedirectURL: "https://xtablo.test/auth/apple/callback",
|
||||||
|
}
|
||||||
|
|
||||||
|
secret, err := cfg.ClientSecret(now)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ClientSecret: %v", err)
|
||||||
|
}
|
||||||
|
key := parseApplePrivateKeyForTest(t, privateKeyPEM)
|
||||||
|
parsed, err := jwt.ParseSigned(secret, []jose.SignatureAlgorithm{jose.ES256})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseSigned: %v", err)
|
||||||
|
}
|
||||||
|
var claims jwt.Claims
|
||||||
|
if err := parsed.Claims(&key.PublicKey, &claims); err != nil {
|
||||||
|
t.Fatalf("Claims: %v", err)
|
||||||
|
}
|
||||||
|
if claims.Issuer != "TEAMID1234" {
|
||||||
|
t.Fatalf("iss = %q", claims.Issuer)
|
||||||
|
}
|
||||||
|
if claims.Subject != "com.xtablo.web" {
|
||||||
|
t.Fatalf("sub = %q", claims.Subject)
|
||||||
|
}
|
||||||
|
if len(claims.Audience) != 1 || claims.Audience[0] != "https://appleid.apple.com" {
|
||||||
|
t.Fatalf("aud = %#v", claims.Audience)
|
||||||
|
}
|
||||||
|
if !claims.Expiry.Time().After(now) {
|
||||||
|
t.Fatalf("exp = %s; want after %s", claims.Expiry.Time(), now)
|
||||||
|
}
|
||||||
|
if parsed.Headers[0].KeyID != "KEYID1234" {
|
||||||
|
t.Fatalf("kid = %q", parsed.Headers[0].KeyID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testApplePrivateKeyPEM(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateKey: %v", err)
|
||||||
|
}
|
||||||
|
der, err := x509.MarshalECPrivateKey(key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MarshalECPrivateKey: %v", err)
|
||||||
|
}
|
||||||
|
return string(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: der}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseApplePrivateKeyForTest(t *testing.T, privateKeyPEM string) *ecdsa.PrivateKey {
|
||||||
|
t.Helper()
|
||||||
|
block, _ := pem.Decode([]byte(privateKeyPEM))
|
||||||
|
if block == nil {
|
||||||
|
t.Fatal("missing PEM block")
|
||||||
|
}
|
||||||
|
key, err := x509.ParseECPrivateKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseECPrivateKey: %v", err)
|
||||||
|
}
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringsWithEscapedNewlines(value string) string {
|
||||||
|
out := ""
|
||||||
|
for _, r := range value {
|
||||||
|
if r == '\n' {
|
||||||
|
out += `\n`
|
||||||
|
} else {
|
||||||
|
out += string(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,8 @@ type AuthDeps struct {
|
||||||
OAuth auth.OAuthConfig
|
OAuth auth.OAuthConfig
|
||||||
GoogleTokenExchanger auth.CodeExchanger
|
GoogleTokenExchanger auth.CodeExchanger
|
||||||
GoogleVerifier auth.IDTokenVerifier
|
GoogleVerifier auth.IDTokenVerifier
|
||||||
|
AppleTokenExchanger auth.CodeExchanger
|
||||||
|
AppleVerifier auth.IDTokenVerifier
|
||||||
}
|
}
|
||||||
|
|
||||||
// errInvalidCreds is the intentionally generic error message for login failures
|
// errInvalidCreds is the intentionally generic error message for login failures
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"backend/internal/auth"
|
"backend/internal/auth"
|
||||||
"backend/internal/db/sqlc"
|
"backend/internal/db/sqlc"
|
||||||
|
|
@ -124,6 +125,123 @@ func GoogleCallbackHandler(deps AuthDeps) http.HandlerFunc {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func AppleStartHandler(deps AuthDeps) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
cfg := deps.OAuth.Apple
|
||||||
|
if !cfg.Configured() {
|
||||||
|
http.Error(w, "Apple sign-in not configured", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
state, err := auth.GenerateOAuthValue()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nonce, err := auth.GenerateOAuthValue()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
clientSecret, err := cfg.ClientSecret(timeNow())
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Apple sign-in not configured", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
auth.SetOAuthCookie(w, "apple", auth.OAuthCookieState, state, deps.Secure)
|
||||||
|
auth.SetOAuthCookie(w, "apple", auth.OAuthCookieNonce, nonce, deps.Secure)
|
||||||
|
|
||||||
|
oauthCfg := cfg.OAuth2Config(clientSecret)
|
||||||
|
url := oauthCfg.AuthCodeURL(state,
|
||||||
|
oauth2.SetAuthURLParam("nonce", nonce),
|
||||||
|
oauth2.SetAuthURLParam("response_mode", "query"),
|
||||||
|
)
|
||||||
|
http.Redirect(w, r, url, http.StatusSeeOther)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AppleCallbackHandler(deps AuthDeps) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !auth.ValidateOAuthCookie(r, "apple", auth.OAuthCookieState, r.URL.Query().Get("state")) {
|
||||||
|
http.Error(w, providerGenericError, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
auth.ClearOAuthCookie(w, "apple", auth.OAuthCookieState, deps.Secure)
|
||||||
|
auth.ClearOAuthCookie(w, "apple", auth.OAuthCookieNonce, deps.Secure)
|
||||||
|
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
if code == "" {
|
||||||
|
http.Error(w, providerGenericError, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
exchanger := deps.AppleTokenExchanger
|
||||||
|
if exchanger == nil {
|
||||||
|
clientSecret, err := deps.OAuth.Apple.ClientSecret(timeNow())
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, providerGenericError, http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
exchanger = auth.OAuth2CodeExchanger{Config: deps.OAuth.Apple.OAuth2Config(clientSecret)}
|
||||||
|
}
|
||||||
|
token, err := exchanger.Exchange(r.Context(), code)
|
||||||
|
if err != nil {
|
||||||
|
slog.Default().Warn("apple oauth exchange failed", "err", err)
|
||||||
|
http.Error(w, providerGenericError, http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rawIDToken, _ := token.Extra("id_token").(string)
|
||||||
|
if rawIDToken == "" {
|
||||||
|
http.Error(w, providerGenericError, http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
verifier := deps.AppleVerifier
|
||||||
|
if verifier == nil {
|
||||||
|
cfg := deps.OAuth.Apple
|
||||||
|
issuer := cfg.Issuer
|
||||||
|
if issuer == "" {
|
||||||
|
issuer = "https://appleid.apple.com"
|
||||||
|
}
|
||||||
|
verifier = auth.OIDCVerifier{Provider: "apple", Issuer: issuer, ClientID: cfg.ClientID}
|
||||||
|
}
|
||||||
|
claims, err := verifier.Verify(r.Context(), rawIDToken)
|
||||||
|
if err != nil {
|
||||||
|
slog.Default().Warn("apple id token verification failed", "err", err)
|
||||||
|
http.Error(w, providerGenericError, http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if claims.Provider == "" {
|
||||||
|
claims.Provider = "apple"
|
||||||
|
}
|
||||||
|
if !auth.ValidateOAuthCookie(r, "apple", auth.OAuthCookieNonce, claims.Nonce) {
|
||||||
|
http.Error(w, providerGenericError, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(claims.Email) == "" || !claims.EmailVerified {
|
||||||
|
http.Error(w, providerEmailUnverified, http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := linkProviderUser(r.Context(), deps, claims)
|
||||||
|
if err != nil {
|
||||||
|
slog.Default().Error("apple account linking failed", "err", err)
|
||||||
|
http.Error(w, providerGenericError, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cookieValue, expiresAt, err := deps.Store.Create(r.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, providerGenericError, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
auth.SetSessionCookie(w, cookieValue, expiresAt, deps.Secure)
|
||||||
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var timeNow = func() time.Time {
|
||||||
|
return time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
func linkProviderUser(ctx context.Context, deps AuthDeps, claims auth.ProviderClaims) (uuid.UUID, error) {
|
func linkProviderUser(ctx context.Context, deps AuthDeps, claims auth.ProviderClaims) (uuid.UUID, error) {
|
||||||
if deps.DB == nil {
|
if deps.DB == nil {
|
||||||
return uuid.Nil, errors.New("missing transaction DB")
|
return uuid.Nil, errors.New("missing transaction DB")
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,11 @@ package web
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
@ -65,11 +70,66 @@ func newGoogleAuthDeps(q *sqlc.Queries, store *auth.Store) AuthDeps {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newAppleAuthDeps(t *testing.T, q *sqlc.Queries, store *auth.Store) AuthDeps {
|
||||||
|
t.Helper()
|
||||||
|
return AuthDeps{
|
||||||
|
Queries: q,
|
||||||
|
Store: store,
|
||||||
|
Secure: false,
|
||||||
|
OAuth: auth.OAuthConfig{
|
||||||
|
Apple: auth.AppleProviderConfig{
|
||||||
|
ClientID: "com.xtablo.web",
|
||||||
|
TeamID: "TEAMID1234",
|
||||||
|
KeyID: "KEYID1234",
|
||||||
|
PrivateKey: testApplePrivateKeyPEMForWeb(t),
|
||||||
|
RedirectURL: "https://xtablo.test/auth/apple/callback",
|
||||||
|
AuthURL: "https://appleid.apple.test/auth/authorize",
|
||||||
|
TokenURL: "https://appleid.apple.test/auth/token",
|
||||||
|
Issuer: "https://appleid.apple.test",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
AppleTokenExchanger: fakeCodeExchanger{
|
||||||
|
token: (&oauth2.Token{AccessToken: "access"}).WithExtra(map[string]any{"id_token": "raw-apple-id-token"}),
|
||||||
|
},
|
||||||
|
AppleVerifier: fakeIDTokenVerifier{
|
||||||
|
claims: auth.ProviderClaims{
|
||||||
|
Provider: "apple",
|
||||||
|
Subject: "apple-subject-1",
|
||||||
|
Email: "apple@example.com",
|
||||||
|
EmailVerified: true,
|
||||||
|
DisplayName: "Apple User",
|
||||||
|
Nonce: "nonce-value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testApplePrivateKeyPEMForWeb(t interface {
|
||||||
|
Helper()
|
||||||
|
Fatalf(string, ...any)
|
||||||
|
}) string {
|
||||||
|
t.Helper()
|
||||||
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateKey: %v", err)
|
||||||
|
}
|
||||||
|
der, err := x509.MarshalECPrivateKey(key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MarshalECPrivateKey: %v", err)
|
||||||
|
}
|
||||||
|
return string(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: der}))
|
||||||
|
}
|
||||||
|
|
||||||
func withGoogleClaims(deps AuthDeps, claims auth.ProviderClaims) AuthDeps {
|
func withGoogleClaims(deps AuthDeps, claims auth.ProviderClaims) AuthDeps {
|
||||||
deps.GoogleVerifier = fakeIDTokenVerifier{claims: claims}
|
deps.GoogleVerifier = fakeIDTokenVerifier{claims: claims}
|
||||||
return deps
|
return deps
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func withAppleClaims(deps AuthDeps, claims auth.ProviderClaims) AuthDeps {
|
||||||
|
deps.AppleVerifier = fakeIDTokenVerifier{claims: claims}
|
||||||
|
return deps
|
||||||
|
}
|
||||||
|
|
||||||
func newSocialRouter(t *testing.T, deps AuthDeps) http.Handler {
|
func newSocialRouter(t *testing.T, deps AuthDeps) http.Handler {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
router, err := NewRouter(stubPinger{}, os.DirFS("./static"), deps, TablosDeps{}, TasksDeps{}, FilesDeps{}, testCSRFKey, "dev", "localhost")
|
router, err := NewRouter(stubPinger{}, os.DirFS("./static"), deps, TablosDeps{}, TasksDeps{}, FilesDeps{}, testCSRFKey, "dev", "localhost")
|
||||||
|
|
@ -355,6 +415,120 @@ func TestGoogleCallbackEmailUpdateConflictDoesNotRelinkSubject(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAppleStartRedirectsAndSetsStateNonceCookies(t *testing.T) {
|
||||||
|
router := newSocialRouter(t, newAppleAuthDeps(t, nil, nil))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/auth/apple/start", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusSeeOther {
|
||||||
|
t.Fatalf("status = %d; want 303", rec.Code)
|
||||||
|
}
|
||||||
|
loc := rec.Header().Get("Location")
|
||||||
|
if !strings.Contains(loc, "https://appleid.apple.test/auth/authorize") {
|
||||||
|
t.Fatalf("Location = %q; want Apple auth URL", loc)
|
||||||
|
}
|
||||||
|
if !strings.Contains(loc, "client_id=com.xtablo.web") {
|
||||||
|
t.Fatalf("Location = %q; missing client_id", loc)
|
||||||
|
}
|
||||||
|
if !strings.Contains(loc, "scope=name+email") {
|
||||||
|
t.Fatalf("Location = %q; missing name email scope", loc)
|
||||||
|
}
|
||||||
|
if findCookie(rec.Result().Cookies(), auth.OAuthCookieName("apple", auth.OAuthCookieState)) == nil {
|
||||||
|
t.Fatal("missing Apple state cookie")
|
||||||
|
}
|
||||||
|
if findCookie(rec.Result().Cookies(), auth.OAuthCookieName("apple", auth.OAuthCookieNonce)) == nil {
|
||||||
|
t.Fatal("missing Apple nonce cookie")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleCallbackInvalidNonceRejectedBeforeLinking(t *testing.T) {
|
||||||
|
deps := newAppleAuthDeps(t, nil, nil)
|
||||||
|
deps.AppleVerifier = fakeIDTokenVerifier{claims: auth.ProviderClaims{
|
||||||
|
Provider: "apple",
|
||||||
|
Subject: "apple-subject-1",
|
||||||
|
Email: "apple@example.com",
|
||||||
|
EmailVerified: true,
|
||||||
|
Nonce: "wrong-nonce",
|
||||||
|
}}
|
||||||
|
router := newSocialRouter(t, deps)
|
||||||
|
|
||||||
|
callback := "/auth/apple/callback?" + url.Values{"state": {"state-value"}, "code": {"code"}}.Encode()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, callback, nil)
|
||||||
|
req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("apple", auth.OAuthCookieState), Value: "state-value"})
|
||||||
|
req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("apple", auth.OAuthCookieNonce), Value: "nonce-value"})
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("status = %d; want 400", rec.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleCallbackUnverifiedEmailRejected(t *testing.T) {
|
||||||
|
deps := newAppleAuthDeps(t, nil, nil)
|
||||||
|
deps.AppleVerifier = fakeIDTokenVerifier{claims: auth.ProviderClaims{
|
||||||
|
Provider: "apple",
|
||||||
|
Subject: "apple-subject-1",
|
||||||
|
Email: "apple@example.com",
|
||||||
|
EmailVerified: false,
|
||||||
|
Nonce: "nonce-value",
|
||||||
|
}}
|
||||||
|
router := newSocialRouter(t, deps)
|
||||||
|
|
||||||
|
rec := serveAppleCallback(router)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("status = %d; want 401", rec.Code)
|
||||||
|
}
|
||||||
|
if !strings.Contains(rec.Body.String(), "This provider did not return a verified email. Try another sign-in method.") {
|
||||||
|
t.Fatalf("body missing unverified email copy; got: %s", rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleCallbackVerifiedRelayEmailStoresNameAndSetsSession(t *testing.T) {
|
||||||
|
pool, cleanup := setupTestDB(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
q := sqlc.New(pool)
|
||||||
|
store := auth.NewStore(q)
|
||||||
|
deps := newAppleAuthDeps(t, q, store)
|
||||||
|
deps.DB = pool
|
||||||
|
deps = withAppleClaims(deps, auth.ProviderClaims{
|
||||||
|
Provider: "apple",
|
||||||
|
Subject: "apple-relay-subject",
|
||||||
|
Email: "relay@privaterelay.appleid.com",
|
||||||
|
EmailVerified: true,
|
||||||
|
DisplayName: "Apple Relay",
|
||||||
|
Nonce: "nonce-value",
|
||||||
|
})
|
||||||
|
router := newSocialRouter(t, deps)
|
||||||
|
|
||||||
|
rec := serveAppleCallback(router)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusSeeOther {
|
||||||
|
t.Fatalf("status = %d; want 303; body: %s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
if c := getSessionCookie(rec); c == nil {
|
||||||
|
t.Fatal("session cookie not set")
|
||||||
|
}
|
||||||
|
identity, err := q.GetUserIdentityByProviderSubject(ctx, sqlc.GetUserIdentityByProviderSubjectParams{
|
||||||
|
Provider: "apple",
|
||||||
|
ProviderSubject: "apple-relay-subject",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetUserIdentityByProviderSubject: %v", err)
|
||||||
|
}
|
||||||
|
if identity.Email != "relay@privaterelay.appleid.com" {
|
||||||
|
t.Fatalf("identity email = %q", identity.Email)
|
||||||
|
}
|
||||||
|
if !identity.DisplayName.Valid || identity.DisplayName.String != "Apple Relay" {
|
||||||
|
t.Fatalf("display name = %#v; want Apple Relay", identity.DisplayName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func serveGoogleCallback(router http.Handler) *httptest.ResponseRecorder {
|
func serveGoogleCallback(router http.Handler) *httptest.ResponseRecorder {
|
||||||
callback := "/auth/google/callback?" + url.Values{"state": {"state-value"}, "code": {"code"}}.Encode()
|
callback := "/auth/google/callback?" + url.Values{"state": {"state-value"}, "code": {"code"}}.Encode()
|
||||||
req := httptest.NewRequest(http.MethodGet, callback, nil)
|
req := httptest.NewRequest(http.MethodGet, callback, nil)
|
||||||
|
|
@ -365,6 +539,16 @@ func serveGoogleCallback(router http.Handler) *httptest.ResponseRecorder {
|
||||||
return rec
|
return rec
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func serveAppleCallback(router http.Handler) *httptest.ResponseRecorder {
|
||||||
|
callback := "/auth/apple/callback?" + url.Values{"state": {"state-value"}, "code": {"code"}}.Encode()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, callback, nil)
|
||||||
|
req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("apple", auth.OAuthCookieState), Value: "state-value"})
|
||||||
|
req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("apple", auth.OAuthCookieNonce), Value: "nonce-value"})
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
return rec
|
||||||
|
}
|
||||||
|
|
||||||
func findCookie(cookies []*http.Cookie, name string) *http.Cookie {
|
func findCookie(cookies []*http.Cookie, name string) *http.Cookie {
|
||||||
for _, c := range cookies {
|
for _, c := range cookies {
|
||||||
if c.Name == name {
|
if c.Name == name {
|
||||||
|
|
|
||||||
|
|
@ -73,6 +73,8 @@ func NewRouter(pinger Pinger, staticFS fs.FS, deps AuthDeps, tabloDeps TablosDep
|
||||||
r.Post("/login", LoginPostHandler(deps))
|
r.Post("/login", LoginPostHandler(deps))
|
||||||
r.Get("/auth/google/start", GoogleStartHandler(deps))
|
r.Get("/auth/google/start", GoogleStartHandler(deps))
|
||||||
r.Get("/auth/google/callback", GoogleCallbackHandler(deps))
|
r.Get("/auth/google/callback", GoogleCallbackHandler(deps))
|
||||||
|
r.Get("/auth/apple/start", AppleStartHandler(deps))
|
||||||
|
r.Get("/auth/apple/callback", AppleCallbackHandler(deps))
|
||||||
|
|
||||||
// Protected routes — require an authenticated session (D-23, AUTH-05).
|
// Protected routes — require an authenticated session (D-23, AUTH-05).
|
||||||
// RequireAuth checks the context set by ResolveSession above and redirects
|
// RequireAuth checks the context set by ResolveSession above and redirects
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue