xtablo-source/backend/internal/auth/oauth.go
2026-05-15 21:03:30 +02:00

181 lines
4.4 KiB
Go

package auth
import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"errors"
"fmt"
"net/http"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"
)
const (
OAuthCookieState = "state"
OAuthCookieNonce = "nonce"
)
const oauthCookieTTL = 5 * time.Minute
type GoogleProviderConfig struct {
ClientID string
ClientSecret string
RedirectURL string
AuthURL string
TokenURL string
Issuer string
}
func (c GoogleProviderConfig) Configured() bool {
return c.ClientID != "" && c.ClientSecret != "" && c.RedirectURL != ""
}
func (c GoogleProviderConfig) withDefaults() GoogleProviderConfig {
if c.AuthURL == "" {
c.AuthURL = "https://accounts.google.com/o/oauth2/v2/auth"
}
if c.TokenURL == "" {
c.TokenURL = "https://oauth2.googleapis.com/token"
}
if c.Issuer == "" {
c.Issuer = "https://accounts.google.com"
}
return c
}
func (c GoogleProviderConfig) OAuth2Config() oauth2.Config {
c = c.withDefaults()
return oauth2.Config{
ClientID: c.ClientID,
ClientSecret: c.ClientSecret,
RedirectURL: c.RedirectURL,
Scopes: []string{oidc.ScopeOpenID, "email", "profile"},
Endpoint: oauth2.Endpoint{
AuthURL: c.AuthURL,
TokenURL: c.TokenURL,
},
}
}
type OAuthConfig struct {
Google GoogleProviderConfig
}
type ProviderClaims struct {
Provider string
Subject string
Email string
EmailVerified bool
DisplayName string
AvatarURL string
Nonce string
}
type CodeExchanger interface {
Exchange(ctx context.Context, code string) (*oauth2.Token, error)
}
type OAuth2CodeExchanger struct {
Config oauth2.Config
}
func (e OAuth2CodeExchanger) Exchange(ctx context.Context, code string) (*oauth2.Token, error) {
return e.Config.Exchange(ctx, code)
}
type IDTokenVerifier interface {
Verify(ctx context.Context, rawIDToken string) (ProviderClaims, error)
}
type OIDCVerifier struct {
Provider string
Issuer string
ClientID string
}
func (v OIDCVerifier) Verify(ctx context.Context, rawIDToken string) (ProviderClaims, error) {
provider, err := oidc.NewProvider(ctx, v.Issuer)
if err != nil {
return ProviderClaims{}, fmt.Errorf("auth: oidc provider discovery: %w", err)
}
verified, err := provider.Verifier(&oidc.Config{ClientID: v.ClientID}).Verify(ctx, rawIDToken)
if err != nil {
return ProviderClaims{}, fmt.Errorf("auth: verify id token: %w", err)
}
var claims struct {
Subject string `json:"sub"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Name string `json:"name"`
Picture string `json:"picture"`
Nonce string `json:"nonce"`
}
if err := verified.Claims(&claims); err != nil {
return ProviderClaims{}, fmt.Errorf("auth: decode id token claims: %w", err)
}
if claims.Subject == "" {
return ProviderClaims{}, errors.New("auth: id token missing subject")
}
return ProviderClaims{
Provider: v.Provider,
Subject: claims.Subject,
Email: claims.Email,
EmailVerified: claims.EmailVerified,
DisplayName: claims.Name,
AvatarURL: claims.Picture,
Nonce: claims.Nonce,
}, nil
}
func GenerateOAuthValue() (string, error) {
raw := make([]byte, 32)
if _, err := rand.Read(raw); err != nil {
return "", fmt.Errorf("auth: generate oauth value: %w", err)
}
return base64.RawURLEncoding.EncodeToString(raw), nil
}
func OAuthCookieName(provider, kind string) string {
return "xtablo_oauth_" + provider + "_" + kind
}
func SetOAuthCookie(w http.ResponseWriter, provider, kind, value string, secure bool) {
http.SetCookie(w, &http.Cookie{
Name: OAuthCookieName(provider, kind),
Value: value,
Path: "/",
MaxAge: int(oauthCookieTTL.Seconds()),
Expires: time.Now().Add(oauthCookieTTL),
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func ClearOAuthCookie(w http.ResponseWriter, provider, kind string, secure bool) {
http.SetCookie(w, &http.Cookie{
Name: OAuthCookieName(provider, kind),
Value: "",
Path: "/",
MaxAge: -1,
Expires: time.Unix(0, 0),
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func ValidateOAuthCookie(r *http.Request, provider, kind, value string) bool {
c, err := r.Cookie(OAuthCookieName(provider, kind))
if err != nil || c.Value == "" || value == "" {
return false
}
if len(c.Value) != len(value) {
return false
}
return subtle.ConstantTimeCompare([]byte(c.Value), []byte(value)) == 1
}