feat(08-02): add google social sign-in flow
This commit is contained in:
parent
2d004cd251
commit
6779663c8a
9 changed files with 872 additions and 6 deletions
|
|
@ -85,7 +85,34 @@ func main() {
|
||||||
stopJanitor := make(chan struct{})
|
stopJanitor := make(chan struct{})
|
||||||
rl.StartJanitor(time.Minute, stopJanitor)
|
rl.StartJanitor(time.Minute, stopJanitor)
|
||||||
|
|
||||||
deps := web.AuthDeps{Queries: q, Store: store, Secure: secure, Limiter: rl}
|
oauthCfg := auth.OAuthConfig{
|
||||||
|
Google: auth.GoogleProviderConfig{
|
||||||
|
ClientID: os.Getenv("GOOGLE_CLIENT_ID"),
|
||||||
|
ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"),
|
||||||
|
RedirectURL: os.Getenv("GOOGLE_REDIRECT_URL"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
var googleExchanger auth.CodeExchanger
|
||||||
|
var googleVerifier auth.IDTokenVerifier
|
||||||
|
if oauthCfg.Google.Configured() {
|
||||||
|
googleExchanger = auth.OAuth2CodeExchanger{Config: oauthCfg.Google.OAuth2Config()}
|
||||||
|
googleVerifier = auth.OIDCVerifier{
|
||||||
|
Provider: "google",
|
||||||
|
Issuer: "https://accounts.google.com",
|
||||||
|
ClientID: oauthCfg.Google.ClientID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
deps := web.AuthDeps{
|
||||||
|
Queries: q,
|
||||||
|
Store: store,
|
||||||
|
Secure: secure,
|
||||||
|
Limiter: rl,
|
||||||
|
DB: pool,
|
||||||
|
OAuth: oauthCfg,
|
||||||
|
GoogleTokenExchanger: googleExchanger,
|
||||||
|
GoogleVerifier: googleVerifier,
|
||||||
|
}
|
||||||
tabloDeps := web.TablosDeps{Queries: q}
|
tabloDeps := web.TablosDeps{Queries: q}
|
||||||
taskDeps := web.TasksDeps{Queries: q}
|
taskDeps := web.TasksDeps{Queries: q}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,9 @@ require (
|
||||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 // indirect
|
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 // indirect
|
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 // indirect
|
||||||
github.com/aws/smithy-go v1.25.1 // indirect
|
github.com/aws/smithy-go v1.25.1 // indirect
|
||||||
|
github.com/coreos/go-oidc/v3 v3.18.0 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/go-jose/go-jose/v4 v4.1.4 // indirect
|
||||||
github.com/gorilla/securecookie v1.1.2 // indirect
|
github.com/gorilla/securecookie v1.1.2 // indirect
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||||
|
|
@ -52,6 +54,7 @@ require (
|
||||||
github.com/tidwall/sjson v1.2.5 // indirect
|
github.com/tidwall/sjson v1.2.5 // indirect
|
||||||
go.uber.org/goleak v1.3.0 // indirect
|
go.uber.org/goleak v1.3.0 // indirect
|
||||||
go.uber.org/multierr v1.11.0 // indirect
|
go.uber.org/multierr v1.11.0 // indirect
|
||||||
|
golang.org/x/oauth2 v0.36.0 // indirect
|
||||||
golang.org/x/sync v0.20.0 // indirect
|
golang.org/x/sync v0.20.0 // indirect
|
||||||
golang.org/x/sys v0.44.0 // indirect
|
golang.org/x/sys v0.44.0 // indirect
|
||||||
golang.org/x/text v0.37.0 // indirect
|
golang.org/x/text v0.37.0 // indirect
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 h1:F/M5Y9I3nwr2IEpshZgh1GeHpOIt
|
||||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1/go.mod h1:mTNxImtovCOEEuD65mKW7DCsL+2gjEH+RPEAexAzAio=
|
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1/go.mod h1:mTNxImtovCOEEuD65mKW7DCsL+2gjEH+RPEAexAzAio=
|
||||||
github.com/aws/smithy-go v1.25.1 h1:J8ERsGSU7d+aCmdQur5Txg6bVoYelvQJgtZehD12GkI=
|
github.com/aws/smithy-go v1.25.1 h1:J8ERsGSU7d+aCmdQur5Txg6bVoYelvQJgtZehD12GkI=
|
||||||
github.com/aws/smithy-go v1.25.1/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
github.com/aws/smithy-go v1.25.1/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||||
|
github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A=
|
||||||
|
github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4=
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
|
@ -43,6 +45,8 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp
|
||||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
|
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
|
||||||
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
|
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
|
||||||
|
github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA=
|
||||||
|
github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
|
||||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
||||||
|
|
@ -107,6 +111,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||||
golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI=
|
golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI=
|
||||||
golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8=
|
golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8=
|
||||||
|
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||||
|
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||||
golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ=
|
golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ=
|
||||||
|
|
|
||||||
181
backend/internal/auth/oauth.go
Normal file
181
backend/internal/auth/oauth.go
Normal file
|
|
@ -0,0 +1,181 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
OAuthCookieState = "state"
|
||||||
|
OAuthCookieNonce = "nonce"
|
||||||
|
)
|
||||||
|
|
||||||
|
const oauthCookieTTL = 5 * time.Minute
|
||||||
|
|
||||||
|
type GoogleProviderConfig struct {
|
||||||
|
ClientID string
|
||||||
|
ClientSecret string
|
||||||
|
RedirectURL string
|
||||||
|
AuthURL string
|
||||||
|
TokenURL string
|
||||||
|
Issuer string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c GoogleProviderConfig) Configured() bool {
|
||||||
|
return c.ClientID != "" && c.ClientSecret != "" && c.RedirectURL != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c GoogleProviderConfig) withDefaults() GoogleProviderConfig {
|
||||||
|
if c.AuthURL == "" {
|
||||||
|
c.AuthURL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||||
|
}
|
||||||
|
if c.TokenURL == "" {
|
||||||
|
c.TokenURL = "https://oauth2.googleapis.com/token"
|
||||||
|
}
|
||||||
|
if c.Issuer == "" {
|
||||||
|
c.Issuer = "https://accounts.google.com"
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c GoogleProviderConfig) OAuth2Config() oauth2.Config {
|
||||||
|
c = c.withDefaults()
|
||||||
|
return oauth2.Config{
|
||||||
|
ClientID: c.ClientID,
|
||||||
|
ClientSecret: c.ClientSecret,
|
||||||
|
RedirectURL: c.RedirectURL,
|
||||||
|
Scopes: []string{oidc.ScopeOpenID, "email", "profile"},
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: c.AuthURL,
|
||||||
|
TokenURL: c.TokenURL,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type OAuthConfig struct {
|
||||||
|
Google GoogleProviderConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProviderClaims struct {
|
||||||
|
Provider string
|
||||||
|
Subject string
|
||||||
|
Email string
|
||||||
|
EmailVerified bool
|
||||||
|
DisplayName string
|
||||||
|
AvatarURL string
|
||||||
|
Nonce string
|
||||||
|
}
|
||||||
|
|
||||||
|
type CodeExchanger interface {
|
||||||
|
Exchange(ctx context.Context, code string) (*oauth2.Token, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type OAuth2CodeExchanger struct {
|
||||||
|
Config oauth2.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e OAuth2CodeExchanger) Exchange(ctx context.Context, code string) (*oauth2.Token, error) {
|
||||||
|
return e.Config.Exchange(ctx, code)
|
||||||
|
}
|
||||||
|
|
||||||
|
type IDTokenVerifier interface {
|
||||||
|
Verify(ctx context.Context, rawIDToken string) (ProviderClaims, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type OIDCVerifier struct {
|
||||||
|
Provider string
|
||||||
|
Issuer string
|
||||||
|
ClientID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v OIDCVerifier) Verify(ctx context.Context, rawIDToken string) (ProviderClaims, error) {
|
||||||
|
provider, err := oidc.NewProvider(ctx, v.Issuer)
|
||||||
|
if err != nil {
|
||||||
|
return ProviderClaims{}, fmt.Errorf("auth: oidc provider discovery: %w", err)
|
||||||
|
}
|
||||||
|
verified, err := provider.Verifier(&oidc.Config{ClientID: v.ClientID}).Verify(ctx, rawIDToken)
|
||||||
|
if err != nil {
|
||||||
|
return ProviderClaims{}, fmt.Errorf("auth: verify id token: %w", err)
|
||||||
|
}
|
||||||
|
var claims struct {
|
||||||
|
Subject string `json:"sub"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
EmailVerified bool `json:"email_verified"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Picture string `json:"picture"`
|
||||||
|
Nonce string `json:"nonce"`
|
||||||
|
}
|
||||||
|
if err := verified.Claims(&claims); err != nil {
|
||||||
|
return ProviderClaims{}, fmt.Errorf("auth: decode id token claims: %w", err)
|
||||||
|
}
|
||||||
|
if claims.Subject == "" {
|
||||||
|
return ProviderClaims{}, errors.New("auth: id token missing subject")
|
||||||
|
}
|
||||||
|
return ProviderClaims{
|
||||||
|
Provider: v.Provider,
|
||||||
|
Subject: claims.Subject,
|
||||||
|
Email: claims.Email,
|
||||||
|
EmailVerified: claims.EmailVerified,
|
||||||
|
DisplayName: claims.Name,
|
||||||
|
AvatarURL: claims.Picture,
|
||||||
|
Nonce: claims.Nonce,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateOAuthValue() (string, error) {
|
||||||
|
raw := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(raw); err != nil {
|
||||||
|
return "", fmt.Errorf("auth: generate oauth value: %w", err)
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(raw), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func OAuthCookieName(provider, kind string) string {
|
||||||
|
return "xtablo_oauth_" + provider + "_" + kind
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetOAuthCookie(w http.ResponseWriter, provider, kind, value string, secure bool) {
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: OAuthCookieName(provider, kind),
|
||||||
|
Value: value,
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: int(oauthCookieTTL.Seconds()),
|
||||||
|
Expires: time.Now().Add(oauthCookieTTL),
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: secure,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func ClearOAuthCookie(w http.ResponseWriter, provider, kind string, secure bool) {
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: OAuthCookieName(provider, kind),
|
||||||
|
Value: "",
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: -1,
|
||||||
|
Expires: time.Unix(0, 0),
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: secure,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateOAuthCookie(r *http.Request, provider, kind, value string) bool {
|
||||||
|
c, err := r.Cookie(OAuthCookieName(provider, kind))
|
||||||
|
if err != nil || c.Value == "" || value == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(c.Value) != len(value) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return subtle.ConstantTimeCompare([]byte(c.Value), []byte(value)) == 1
|
||||||
|
}
|
||||||
56
backend/internal/auth/oauth_test.go
Normal file
56
backend/internal/auth/oauth_test.go
Normal file
|
|
@ -0,0 +1,56 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGoogleProviderConfigConfigured(t *testing.T) {
|
||||||
|
empty := GoogleProviderConfig{}
|
||||||
|
if empty.Configured() {
|
||||||
|
t.Fatal("empty Google config must not be configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := GoogleProviderConfig{
|
||||||
|
ClientID: "google-client",
|
||||||
|
ClientSecret: "google-secret",
|
||||||
|
RedirectURL: "https://xtablo.test/auth/google/callback",
|
||||||
|
}
|
||||||
|
if !cfg.Configured() {
|
||||||
|
t.Fatal("complete Google config must be configured")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOAuthStateAndNonceCookiesValidateExactValue(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
SetOAuthCookie(rec, "google", OAuthCookieState, "state-value", false)
|
||||||
|
SetOAuthCookie(rec, "google", OAuthCookieNonce, "nonce-value", false)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/auth/google/callback", nil)
|
||||||
|
for _, c := range rec.Result().Cookies() {
|
||||||
|
req.AddCookie(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ValidateOAuthCookie(req, "google", OAuthCookieState, "state-value") {
|
||||||
|
t.Fatal("state cookie should validate matching value")
|
||||||
|
}
|
||||||
|
if ValidateOAuthCookie(req, "google", OAuthCookieState, "wrong-state") {
|
||||||
|
t.Fatal("state cookie should reject mismatched value")
|
||||||
|
}
|
||||||
|
if !ValidateOAuthCookie(req, "google", OAuthCookieNonce, "nonce-value") {
|
||||||
|
t.Fatal("nonce cookie should validate matching value")
|
||||||
|
}
|
||||||
|
if ValidateOAuthCookie(req, "google", OAuthCookieNonce, "wrong-nonce") {
|
||||||
|
t.Fatal("nonce cookie should reject mismatched value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOAuthCookieNameIncludesProviderAndKind(t *testing.T) {
|
||||||
|
if got := OAuthCookieName("google", OAuthCookieState); got != "xtablo_oauth_google_state" {
|
||||||
|
t.Fatalf("state cookie name = %q", got)
|
||||||
|
}
|
||||||
|
if got := OAuthCookieName("google", OAuthCookieNonce); got != "xtablo_oauth_google_nonce" {
|
||||||
|
t.Fatalf("nonce cookie name = %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -24,10 +24,14 @@ import (
|
||||||
// Limiter is the in-memory rate limiter for POST /login (D-16, AUTH-07).
|
// Limiter is the in-memory rate limiter for POST /login (D-16, AUTH-07).
|
||||||
// When nil, rate limiting is skipped (unit tests that don't exercise that path).
|
// When nil, rate limiting is skipped (unit tests that don't exercise that path).
|
||||||
type AuthDeps struct {
|
type AuthDeps struct {
|
||||||
Queries *sqlc.Queries
|
Queries *sqlc.Queries
|
||||||
Store *auth.Store
|
Store *auth.Store
|
||||||
Secure bool
|
Secure bool
|
||||||
Limiter *auth.LimiterStore
|
Limiter *auth.LimiterStore
|
||||||
|
DB TxBeginner
|
||||||
|
OAuth auth.OAuthConfig
|
||||||
|
GoogleTokenExchanger auth.CodeExchanger
|
||||||
|
GoogleVerifier auth.IDTokenVerifier
|
||||||
}
|
}
|
||||||
|
|
||||||
// errInvalidCreds is the intentionally generic error message for login failures
|
// errInvalidCreds is the intentionally generic error message for login failures
|
||||||
|
|
|
||||||
210
backend/internal/web/handlers_social.go
Normal file
210
backend/internal/web/handlers_social.go
Normal file
|
|
@ -0,0 +1,210 @@
|
||||||
|
package web
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"backend/internal/auth"
|
||||||
|
"backend/internal/db/sqlc"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TxBeginner interface {
|
||||||
|
Begin(ctx context.Context) (pgx.Tx, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
const providerEmailUnverified = "This provider did not return a verified email. Try another sign-in method."
|
||||||
|
const providerGenericError = "Could not sign you in with this provider. Try another sign-in method."
|
||||||
|
|
||||||
|
func GoogleStartHandler(deps AuthDeps) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
cfg := deps.OAuth.Google
|
||||||
|
if !cfg.Configured() {
|
||||||
|
http.Error(w, "Google sign-in not configured", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
state, err := auth.GenerateOAuthValue()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nonce, err := auth.GenerateOAuthValue()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
auth.SetOAuthCookie(w, "google", auth.OAuthCookieState, state, deps.Secure)
|
||||||
|
auth.SetOAuthCookie(w, "google", auth.OAuthCookieNonce, nonce, deps.Secure)
|
||||||
|
|
||||||
|
oauthCfg := cfg.OAuth2Config()
|
||||||
|
url := oauthCfg.AuthCodeURL(state, oauth2.SetAuthURLParam("nonce", nonce))
|
||||||
|
http.Redirect(w, r, url, http.StatusSeeOther)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GoogleCallbackHandler(deps AuthDeps) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !auth.ValidateOAuthCookie(r, "google", auth.OAuthCookieState, r.URL.Query().Get("state")) {
|
||||||
|
http.Error(w, providerGenericError, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
auth.ClearOAuthCookie(w, "google", auth.OAuthCookieState, deps.Secure)
|
||||||
|
auth.ClearOAuthCookie(w, "google", auth.OAuthCookieNonce, deps.Secure)
|
||||||
|
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
if code == "" {
|
||||||
|
http.Error(w, providerGenericError, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
exchanger := deps.GoogleTokenExchanger
|
||||||
|
if exchanger == nil {
|
||||||
|
exchanger = auth.OAuth2CodeExchanger{Config: deps.OAuth.Google.OAuth2Config()}
|
||||||
|
}
|
||||||
|
token, err := exchanger.Exchange(r.Context(), code)
|
||||||
|
if err != nil {
|
||||||
|
slog.Default().Warn("google oauth exchange failed", "err", err)
|
||||||
|
http.Error(w, providerGenericError, http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rawIDToken, _ := token.Extra("id_token").(string)
|
||||||
|
if rawIDToken == "" {
|
||||||
|
http.Error(w, providerGenericError, http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
verifier := deps.GoogleVerifier
|
||||||
|
if verifier == nil {
|
||||||
|
cfg := deps.OAuth.Google
|
||||||
|
issuer := cfg.Issuer
|
||||||
|
if issuer == "" {
|
||||||
|
issuer = "https://accounts.google.com"
|
||||||
|
}
|
||||||
|
verifier = auth.OIDCVerifier{Provider: "google", Issuer: issuer, ClientID: cfg.ClientID}
|
||||||
|
}
|
||||||
|
claims, err := verifier.Verify(r.Context(), rawIDToken)
|
||||||
|
if err != nil {
|
||||||
|
slog.Default().Warn("google id token verification failed", "err", err)
|
||||||
|
http.Error(w, providerGenericError, http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if claims.Provider == "" {
|
||||||
|
claims.Provider = "google"
|
||||||
|
}
|
||||||
|
if !auth.ValidateOAuthCookie(r, "google", auth.OAuthCookieNonce, claims.Nonce) {
|
||||||
|
http.Error(w, providerGenericError, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(claims.Email) == "" || !claims.EmailVerified {
|
||||||
|
http.Error(w, providerEmailUnverified, http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := linkProviderUser(r.Context(), deps, claims)
|
||||||
|
if err != nil {
|
||||||
|
slog.Default().Error("google account linking failed", "err", err)
|
||||||
|
http.Error(w, providerGenericError, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cookieValue, expiresAt, err := deps.Store.Create(r.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, providerGenericError, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
auth.SetSessionCookie(w, cookieValue, expiresAt, deps.Secure)
|
||||||
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func linkProviderUser(ctx context.Context, deps AuthDeps, claims auth.ProviderClaims) (uuid.UUID, error) {
|
||||||
|
if deps.DB == nil {
|
||||||
|
return uuid.Nil, errors.New("missing transaction DB")
|
||||||
|
}
|
||||||
|
tx, err := deps.DB.Begin(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return uuid.Nil, fmt.Errorf("begin tx: %w", err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback(ctx) //nolint:errcheck
|
||||||
|
|
||||||
|
q := deps.Queries.WithTx(tx)
|
||||||
|
email := strings.ToLower(strings.TrimSpace(claims.Email))
|
||||||
|
provider := claims.Provider
|
||||||
|
if provider == "" {
|
||||||
|
provider = "google"
|
||||||
|
}
|
||||||
|
if provider == "google" {
|
||||||
|
claims.Provider = provider
|
||||||
|
}
|
||||||
|
|
||||||
|
identity, err := q.GetUserIdentityByProviderSubject(ctx, sqlc.GetUserIdentityByProviderSubjectParams{
|
||||||
|
Provider: provider,
|
||||||
|
ProviderSubject: claims.Subject,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
updated, updateErr := q.UpdateUserIdentityLogin(ctx, sqlc.UpdateUserIdentityLoginParams{
|
||||||
|
Provider: provider,
|
||||||
|
ProviderSubject: claims.Subject,
|
||||||
|
Email: email,
|
||||||
|
EmailVerified: claims.EmailVerified,
|
||||||
|
DisplayName: textOrNull(claims.DisplayName),
|
||||||
|
AvatarUrl: textOrNull(claims.AvatarURL),
|
||||||
|
})
|
||||||
|
if updateErr != nil {
|
||||||
|
return uuid.Nil, fmt.Errorf("update identity login: %w", updateErr)
|
||||||
|
}
|
||||||
|
if _, emailErr := q.UpdateUserEmailIfAvailable(ctx, sqlc.UpdateUserEmailIfAvailableParams{
|
||||||
|
ID: updated.UserID,
|
||||||
|
Email: email,
|
||||||
|
}); emailErr != nil && !errors.Is(emailErr, pgx.ErrNoRows) {
|
||||||
|
return uuid.Nil, fmt.Errorf("update linked user email: %w", emailErr)
|
||||||
|
}
|
||||||
|
if err := tx.Commit(ctx); err != nil {
|
||||||
|
return uuid.Nil, fmt.Errorf("commit existing identity: %w", err)
|
||||||
|
}
|
||||||
|
return identity.UserID, nil
|
||||||
|
}
|
||||||
|
if !errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
return uuid.Nil, fmt.Errorf("get identity: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := q.GetUserByEmail(ctx, email)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
user, err = q.InsertSocialUser(ctx, email)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return uuid.Nil, fmt.Errorf("lookup or create user: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := q.InsertUserIdentity(ctx, sqlc.InsertUserIdentityParams{
|
||||||
|
UserID: user.ID,
|
||||||
|
Provider: provider,
|
||||||
|
ProviderSubject: claims.Subject,
|
||||||
|
Email: email,
|
||||||
|
EmailVerified: claims.EmailVerified,
|
||||||
|
DisplayName: textOrNull(claims.DisplayName),
|
||||||
|
AvatarUrl: textOrNull(claims.AvatarURL),
|
||||||
|
}); err != nil {
|
||||||
|
return uuid.Nil, fmt.Errorf("insert identity: %w", err)
|
||||||
|
}
|
||||||
|
if err := tx.Commit(ctx); err != nil {
|
||||||
|
return uuid.Nil, fmt.Errorf("commit new identity: %w", err)
|
||||||
|
}
|
||||||
|
return user.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func textOrNull(value string) pgtype.Text {
|
||||||
|
if value == "" {
|
||||||
|
return pgtype.Text{}
|
||||||
|
}
|
||||||
|
return pgtype.Text{String: value, Valid: true}
|
||||||
|
}
|
||||||
375
backend/internal/web/handlers_social_test.go
Normal file
375
backend/internal/web/handlers_social_test.go
Normal file
|
|
@ -0,0 +1,375 @@
|
||||||
|
package web
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"backend/internal/auth"
|
||||||
|
"backend/internal/db/sqlc"
|
||||||
|
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeCodeExchanger struct {
|
||||||
|
token *oauth2.Token
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f fakeCodeExchanger) Exchange(ctx context.Context, code string) (*oauth2.Token, error) {
|
||||||
|
return f.token, f.err
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeIDTokenVerifier struct {
|
||||||
|
claims auth.ProviderClaims
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f fakeIDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (auth.ProviderClaims, error) {
|
||||||
|
return f.claims, f.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func newGoogleAuthDeps(q *sqlc.Queries, store *auth.Store) AuthDeps {
|
||||||
|
return AuthDeps{
|
||||||
|
Queries: q,
|
||||||
|
Store: store,
|
||||||
|
Secure: false,
|
||||||
|
OAuth: auth.OAuthConfig{
|
||||||
|
Google: auth.GoogleProviderConfig{
|
||||||
|
ClientID: "google-client",
|
||||||
|
ClientSecret: "google-secret",
|
||||||
|
RedirectURL: "https://xtablo.test/auth/google/callback",
|
||||||
|
AuthURL: "https://accounts.google.test/o/oauth2/v2/auth",
|
||||||
|
TokenURL: "https://oauth2.google.test/token",
|
||||||
|
Issuer: "https://accounts.google.test",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
GoogleTokenExchanger: fakeCodeExchanger{
|
||||||
|
token: (&oauth2.Token{AccessToken: "access"}).WithExtra(map[string]any{"id_token": "raw-id-token"}),
|
||||||
|
},
|
||||||
|
GoogleVerifier: fakeIDTokenVerifier{
|
||||||
|
claims: auth.ProviderClaims{
|
||||||
|
Provider: "google",
|
||||||
|
Subject: "google-subject-1",
|
||||||
|
Email: "google@example.com",
|
||||||
|
EmailVerified: true,
|
||||||
|
DisplayName: "Google User",
|
||||||
|
AvatarURL: "https://example.com/avatar.png",
|
||||||
|
Nonce: "nonce-value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func withGoogleClaims(deps AuthDeps, claims auth.ProviderClaims) AuthDeps {
|
||||||
|
deps.GoogleVerifier = fakeIDTokenVerifier{claims: claims}
|
||||||
|
return deps
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSocialRouter(t *testing.T, deps AuthDeps) http.Handler {
|
||||||
|
t.Helper()
|
||||||
|
router, err := NewRouter(stubPinger{}, os.DirFS("./static"), deps, TablosDeps{}, TasksDeps{}, FilesDeps{}, testCSRFKey, "dev", "localhost")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRouter: %v", err)
|
||||||
|
}
|
||||||
|
return router
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoogleStartRedirectsAndSetsStateNonceCookies(t *testing.T) {
|
||||||
|
router := newSocialRouter(t, newGoogleAuthDeps(nil, nil))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/auth/google/start", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusSeeOther {
|
||||||
|
t.Fatalf("status = %d; want 303", rec.Code)
|
||||||
|
}
|
||||||
|
loc := rec.Header().Get("Location")
|
||||||
|
if !strings.Contains(loc, "https://accounts.google.test/o/oauth2/v2/auth") {
|
||||||
|
t.Fatalf("Location = %q; want Google auth URL", loc)
|
||||||
|
}
|
||||||
|
if !strings.Contains(loc, "client_id=google-client") {
|
||||||
|
t.Fatalf("Location = %q; missing client_id", loc)
|
||||||
|
}
|
||||||
|
if !strings.Contains(loc, "scope=openid+email+profile") {
|
||||||
|
t.Fatalf("Location = %q; missing openid email profile scope", loc)
|
||||||
|
}
|
||||||
|
|
||||||
|
cookies := rec.Result().Cookies()
|
||||||
|
if findCookie(cookies, auth.OAuthCookieName("google", auth.OAuthCookieState)) == nil {
|
||||||
|
t.Fatal("missing Google state cookie")
|
||||||
|
}
|
||||||
|
if findCookie(cookies, auth.OAuthCookieName("google", auth.OAuthCookieNonce)) == nil {
|
||||||
|
t.Fatal("missing Google nonce cookie")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoogleCallbackInvalidStateRejectedBeforeExchange(t *testing.T) {
|
||||||
|
router := newSocialRouter(t, newGoogleAuthDeps(nil, nil))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/auth/google/callback?state=wrong&code=code", nil)
|
||||||
|
req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("google", auth.OAuthCookieState), Value: "expected"})
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("status = %d; want 400", rec.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoogleCallbackUnverifiedEmailRejected(t *testing.T) {
|
||||||
|
deps := newGoogleAuthDeps(nil, nil)
|
||||||
|
deps.GoogleVerifier = fakeIDTokenVerifier{claims: auth.ProviderClaims{
|
||||||
|
Provider: "google",
|
||||||
|
Subject: "google-subject-1",
|
||||||
|
Email: "google@example.com",
|
||||||
|
EmailVerified: false,
|
||||||
|
Nonce: "nonce-value",
|
||||||
|
}}
|
||||||
|
router := newSocialRouter(t, deps)
|
||||||
|
|
||||||
|
callback := "/auth/google/callback?" + url.Values{"state": {"state-value"}, "code": {"code"}}.Encode()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, callback, nil)
|
||||||
|
req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("google", auth.OAuthCookieState), Value: "state-value"})
|
||||||
|
req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("google", auth.OAuthCookieNonce), Value: "nonce-value"})
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("status = %d; want 401", rec.Code)
|
||||||
|
}
|
||||||
|
if !strings.Contains(rec.Body.String(), "This provider did not return a verified email. Try another sign-in method.") {
|
||||||
|
t.Fatalf("body missing unverified email copy; got: %s", rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoogleCallbackVerifiedEmailLinksExistingUserAndSetsSession(t *testing.T) {
|
||||||
|
pool, cleanup := setupTestDB(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
q := sqlc.New(pool)
|
||||||
|
store := auth.NewStore(q)
|
||||||
|
user := preInsertUser(t, ctx, q, "google@example.com", "correct-horse-12chars")
|
||||||
|
deps := newGoogleAuthDeps(q, store)
|
||||||
|
deps.DB = pool
|
||||||
|
router := newSocialRouter(t, deps)
|
||||||
|
|
||||||
|
callback := "/auth/google/callback?" + url.Values{"state": {"state-value"}, "code": {"code"}}.Encode()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, callback, nil)
|
||||||
|
req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("google", auth.OAuthCookieState), Value: "state-value"})
|
||||||
|
req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("google", auth.OAuthCookieNonce), Value: "nonce-value"})
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusSeeOther {
|
||||||
|
t.Fatalf("status = %d; want 303; body: %s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
if loc := rec.Header().Get("Location"); loc != "/" {
|
||||||
|
t.Fatalf("Location = %q; want /", loc)
|
||||||
|
}
|
||||||
|
if c := getSessionCookie(rec); c == nil {
|
||||||
|
t.Fatal("session cookie not set")
|
||||||
|
}
|
||||||
|
|
||||||
|
var count int
|
||||||
|
if err := pool.QueryRow(ctx, `
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM user_identities
|
||||||
|
WHERE user_id = $1 AND provider = 'google' AND provider_subject = 'google-subject-1'
|
||||||
|
`, user.ID).Scan(&count); err != nil {
|
||||||
|
t.Fatalf("count identity: %v", err)
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
t.Fatalf("identity count = %d; want 1", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoogleCallbackNewVerifiedEmailCreatesSocialOnlyUserAndSession(t *testing.T) {
|
||||||
|
pool, cleanup := setupTestDB(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
q := sqlc.New(pool)
|
||||||
|
store := auth.NewStore(q)
|
||||||
|
deps := newGoogleAuthDeps(q, store)
|
||||||
|
deps.DB = pool
|
||||||
|
deps = withGoogleClaims(deps, auth.ProviderClaims{
|
||||||
|
Provider: "google",
|
||||||
|
Subject: "google-new-subject",
|
||||||
|
Email: "new-google@example.com",
|
||||||
|
EmailVerified: true,
|
||||||
|
DisplayName: "New Google",
|
||||||
|
AvatarURL: "https://example.com/new.png",
|
||||||
|
Nonce: "nonce-value",
|
||||||
|
})
|
||||||
|
router := newSocialRouter(t, deps)
|
||||||
|
|
||||||
|
rec := serveGoogleCallback(router)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusSeeOther {
|
||||||
|
t.Fatalf("status = %d; want 303; body: %s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
if c := getSessionCookie(rec); c == nil {
|
||||||
|
t.Fatal("session cookie not set")
|
||||||
|
}
|
||||||
|
user, err := q.GetUserByEmail(ctx, "new-google@example.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetUserByEmail: %v", err)
|
||||||
|
}
|
||||||
|
if user.PasswordHash.Valid {
|
||||||
|
t.Fatalf("new social user password hash Valid = true; want false")
|
||||||
|
}
|
||||||
|
identity, err := q.GetUserIdentityByProviderSubject(ctx, sqlc.GetUserIdentityByProviderSubjectParams{
|
||||||
|
Provider: "google",
|
||||||
|
ProviderSubject: "google-new-subject",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetUserIdentityByProviderSubject: %v", err)
|
||||||
|
}
|
||||||
|
if identity.UserID != user.ID {
|
||||||
|
t.Fatalf("identity user id = %s; want %s", identity.UserID, user.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoogleCallbackExistingSubjectWinsWhenEmailChanges(t *testing.T) {
|
||||||
|
pool, cleanup := setupTestDB(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
q := sqlc.New(pool)
|
||||||
|
store := auth.NewStore(q)
|
||||||
|
user := preInsertSocialOnlyUser(t, ctx, q, "old-google@example.com")
|
||||||
|
if _, err := q.InsertUserIdentity(ctx, sqlc.InsertUserIdentityParams{
|
||||||
|
UserID: user.ID,
|
||||||
|
Provider: "google",
|
||||||
|
ProviderSubject: "stable-google-subject",
|
||||||
|
Email: "old-google@example.com",
|
||||||
|
EmailVerified: true,
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("InsertUserIdentity: %v", err)
|
||||||
|
}
|
||||||
|
deps := newGoogleAuthDeps(q, store)
|
||||||
|
deps.DB = pool
|
||||||
|
deps = withGoogleClaims(deps, auth.ProviderClaims{
|
||||||
|
Provider: "google",
|
||||||
|
Subject: "stable-google-subject",
|
||||||
|
Email: "changed-google@example.com",
|
||||||
|
EmailVerified: true,
|
||||||
|
Nonce: "nonce-value",
|
||||||
|
})
|
||||||
|
router := newSocialRouter(t, deps)
|
||||||
|
|
||||||
|
rec := serveGoogleCallback(router)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusSeeOther {
|
||||||
|
t.Fatalf("status = %d; want 303; body: %s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
identity, err := q.GetUserIdentityByProviderSubject(ctx, sqlc.GetUserIdentityByProviderSubjectParams{
|
||||||
|
Provider: "google",
|
||||||
|
ProviderSubject: "stable-google-subject",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetUserIdentityByProviderSubject: %v", err)
|
||||||
|
}
|
||||||
|
if identity.UserID != user.ID {
|
||||||
|
t.Fatalf("identity relinked to %s; want original %s", identity.UserID, user.ID)
|
||||||
|
}
|
||||||
|
if identity.Email != "changed-google@example.com" {
|
||||||
|
t.Fatalf("identity email = %q; want changed-google@example.com", identity.Email)
|
||||||
|
}
|
||||||
|
updatedUser, err := q.GetUserByID(ctx, user.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetUserByID: %v", err)
|
||||||
|
}
|
||||||
|
if updatedUser.Email != "changed-google@example.com" {
|
||||||
|
t.Fatalf("user email = %q; want changed-google@example.com", updatedUser.Email)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoogleCallbackEmailUpdateConflictDoesNotRelinkSubject(t *testing.T) {
|
||||||
|
pool, cleanup := setupTestDB(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
q := sqlc.New(pool)
|
||||||
|
store := auth.NewStore(q)
|
||||||
|
linkedUser := preInsertSocialOnlyUser(t, ctx, q, "linked-google@example.com")
|
||||||
|
conflictUser := preInsertUser(t, ctx, q, "conflict-google@example.com", "correct-horse-12chars")
|
||||||
|
if _, err := q.InsertUserIdentity(ctx, sqlc.InsertUserIdentityParams{
|
||||||
|
UserID: linkedUser.ID,
|
||||||
|
Provider: "google",
|
||||||
|
ProviderSubject: "conflict-google-subject",
|
||||||
|
Email: "linked-google@example.com",
|
||||||
|
EmailVerified: true,
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("InsertUserIdentity: %v", err)
|
||||||
|
}
|
||||||
|
deps := newGoogleAuthDeps(q, store)
|
||||||
|
deps.DB = pool
|
||||||
|
deps = withGoogleClaims(deps, auth.ProviderClaims{
|
||||||
|
Provider: "google",
|
||||||
|
Subject: "conflict-google-subject",
|
||||||
|
Email: "conflict-google@example.com",
|
||||||
|
EmailVerified: true,
|
||||||
|
Nonce: "nonce-value",
|
||||||
|
})
|
||||||
|
router := newSocialRouter(t, deps)
|
||||||
|
|
||||||
|
rec := serveGoogleCallback(router)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusSeeOther {
|
||||||
|
t.Fatalf("status = %d; want 303; body: %s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
identity, err := q.GetUserIdentityByProviderSubject(ctx, sqlc.GetUserIdentityByProviderSubjectParams{
|
||||||
|
Provider: "google",
|
||||||
|
ProviderSubject: "conflict-google-subject",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetUserIdentityByProviderSubject: %v", err)
|
||||||
|
}
|
||||||
|
if identity.UserID != linkedUser.ID {
|
||||||
|
t.Fatalf("identity user id = %s; want linked user %s", identity.UserID, linkedUser.ID)
|
||||||
|
}
|
||||||
|
if identity.Email != "conflict-google@example.com" {
|
||||||
|
t.Fatalf("identity email = %q; want conflict-google@example.com", identity.Email)
|
||||||
|
}
|
||||||
|
stillLinked, err := q.GetUserByID(ctx, linkedUser.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetUserByID linked: %v", err)
|
||||||
|
}
|
||||||
|
if stillLinked.Email != "linked-google@example.com" {
|
||||||
|
t.Fatalf("linked user email = %q; want linked-google@example.com", stillLinked.Email)
|
||||||
|
}
|
||||||
|
stillConflict, err := q.GetUserByID(ctx, conflictUser.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetUserByID conflict: %v", err)
|
||||||
|
}
|
||||||
|
if stillConflict.Email != "conflict-google@example.com" {
|
||||||
|
t.Fatalf("conflict user email = %q; want conflict-google@example.com", stillConflict.Email)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func serveGoogleCallback(router http.Handler) *httptest.ResponseRecorder {
|
||||||
|
callback := "/auth/google/callback?" + url.Values{"state": {"state-value"}, "code": {"code"}}.Encode()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, callback, nil)
|
||||||
|
req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("google", auth.OAuthCookieState), Value: "state-value"})
|
||||||
|
req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("google", auth.OAuthCookieNonce), Value: "nonce-value"})
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
return rec
|
||||||
|
}
|
||||||
|
|
||||||
|
func findCookie(cookies []*http.Cookie, name string) *http.Cookie {
|
||||||
|
for _, c := range cookies {
|
||||||
|
if c.Name == name {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -32,7 +32,9 @@ type Pinger interface {
|
||||||
// 6. auth.Mount (gorilla/csrf — MUST come after ResolveSession, before routes) — D-24, Pitfall 7
|
// 6. auth.Mount (gorilla/csrf — MUST come after ResolveSession, before routes) — D-24, Pitfall 7
|
||||||
//
|
//
|
||||||
// Routes: GET / · GET /healthz (liveness) · GET /readyz (readiness) · GET /demo/time · GET /static/*
|
// Routes: GET / · GET /healthz (liveness) · GET /readyz (readiness) · GET /demo/time · GET /static/*
|
||||||
// GET /signup (auth pages, behind RedirectIfAuthed) · POST /signup.
|
//
|
||||||
|
// GET /signup (auth pages, behind RedirectIfAuthed) · POST /signup.
|
||||||
|
//
|
||||||
// staticFS is the embedded FS (or os.DirFS in tests) served at /static/*; the
|
// staticFS is the embedded FS (or os.DirFS in tests) served at /static/*; the
|
||||||
// embedded FS pattern blocks path traversal at the http.FS layer (T-01-08).
|
// embedded FS pattern blocks path traversal at the http.FS layer (T-01-08).
|
||||||
//
|
//
|
||||||
|
|
@ -69,6 +71,8 @@ func NewRouter(pinger Pinger, staticFS fs.FS, deps AuthDeps, tabloDeps TablosDep
|
||||||
// response; the GET guard handles the common case.
|
// response; the GET guard handles the common case.
|
||||||
r.Post("/signup", SignupPostHandler(deps))
|
r.Post("/signup", SignupPostHandler(deps))
|
||||||
r.Post("/login", LoginPostHandler(deps))
|
r.Post("/login", LoginPostHandler(deps))
|
||||||
|
r.Get("/auth/google/start", GoogleStartHandler(deps))
|
||||||
|
r.Get("/auth/google/callback", GoogleCallbackHandler(deps))
|
||||||
|
|
||||||
// Protected routes — require an authenticated session (D-23, AUTH-05).
|
// Protected routes — require an authenticated session (D-23, AUTH-05).
|
||||||
// RequireAuth checks the context set by ResolveSession above and redirects
|
// RequireAuth checks the context set by ResolveSession above and redirects
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue