package auth import ( "context" "crypto/rand" "crypto/subtle" "encoding/base64" "errors" "fmt" "net/http" "time" "github.com/coreos/go-oidc/v3/oidc" "golang.org/x/oauth2" ) const ( OAuthCookieState = "state" OAuthCookieNonce = "nonce" ) const oauthCookieTTL = 5 * time.Minute type GoogleProviderConfig struct { ClientID string ClientSecret string RedirectURL string AuthURL string TokenURL string Issuer string } func (c GoogleProviderConfig) Configured() bool { return c.ClientID != "" && c.ClientSecret != "" && c.RedirectURL != "" } func (c GoogleProviderConfig) withDefaults() GoogleProviderConfig { if c.AuthURL == "" { c.AuthURL = "https://accounts.google.com/o/oauth2/v2/auth" } if c.TokenURL == "" { c.TokenURL = "https://oauth2.googleapis.com/token" } if c.Issuer == "" { c.Issuer = "https://accounts.google.com" } return c } func (c GoogleProviderConfig) OAuth2Config() oauth2.Config { c = c.withDefaults() return oauth2.Config{ ClientID: c.ClientID, ClientSecret: c.ClientSecret, RedirectURL: c.RedirectURL, Scopes: []string{oidc.ScopeOpenID, "email", "profile"}, Endpoint: oauth2.Endpoint{ AuthURL: c.AuthURL, TokenURL: c.TokenURL, }, } } type OAuthConfig struct { Google GoogleProviderConfig } type ProviderClaims struct { Provider string Subject string Email string EmailVerified bool DisplayName string AvatarURL string Nonce string } type CodeExchanger interface { Exchange(ctx context.Context, code string) (*oauth2.Token, error) } type OAuth2CodeExchanger struct { Config oauth2.Config } func (e OAuth2CodeExchanger) Exchange(ctx context.Context, code string) (*oauth2.Token, error) { return e.Config.Exchange(ctx, code) } type IDTokenVerifier interface { Verify(ctx context.Context, rawIDToken string) (ProviderClaims, error) } type OIDCVerifier struct { Provider string Issuer string ClientID string } func (v OIDCVerifier) Verify(ctx context.Context, rawIDToken string) (ProviderClaims, error) { provider, err := oidc.NewProvider(ctx, v.Issuer) if err != nil { return ProviderClaims{}, fmt.Errorf("auth: oidc provider discovery: %w", err) } verified, err := provider.Verifier(&oidc.Config{ClientID: v.ClientID}).Verify(ctx, rawIDToken) if err != nil { return ProviderClaims{}, fmt.Errorf("auth: verify id token: %w", err) } var claims struct { Subject string `json:"sub"` Email string `json:"email"` EmailVerified bool `json:"email_verified"` Name string `json:"name"` Picture string `json:"picture"` Nonce string `json:"nonce"` } if err := verified.Claims(&claims); err != nil { return ProviderClaims{}, fmt.Errorf("auth: decode id token claims: %w", err) } if claims.Subject == "" { return ProviderClaims{}, errors.New("auth: id token missing subject") } return ProviderClaims{ Provider: v.Provider, Subject: claims.Subject, Email: claims.Email, EmailVerified: claims.EmailVerified, DisplayName: claims.Name, AvatarURL: claims.Picture, Nonce: claims.Nonce, }, nil } func GenerateOAuthValue() (string, error) { raw := make([]byte, 32) if _, err := rand.Read(raw); err != nil { return "", fmt.Errorf("auth: generate oauth value: %w", err) } return base64.RawURLEncoding.EncodeToString(raw), nil } func OAuthCookieName(provider, kind string) string { return "xtablo_oauth_" + provider + "_" + kind } func SetOAuthCookie(w http.ResponseWriter, provider, kind, value string, secure bool) { http.SetCookie(w, &http.Cookie{ Name: OAuthCookieName(provider, kind), Value: value, Path: "/", MaxAge: int(oauthCookieTTL.Seconds()), Expires: time.Now().Add(oauthCookieTTL), HttpOnly: true, Secure: secure, SameSite: http.SameSiteLaxMode, }) } func ClearOAuthCookie(w http.ResponseWriter, provider, kind string, secure bool) { http.SetCookie(w, &http.Cookie{ Name: OAuthCookieName(provider, kind), Value: "", Path: "/", MaxAge: -1, Expires: time.Unix(0, 0), HttpOnly: true, Secure: secure, SameSite: http.SameSiteLaxMode, }) } func ValidateOAuthCookie(r *http.Request, provider, kind, value string) bool { c, err := r.Cookie(OAuthCookieName(provider, kind)) if err != nil || c.Value == "" || value == "" { return false } if len(c.Value) != len(value) { return false } return subtle.ConstantTimeCompare([]byte(c.Value), []byte(value)) == 1 }