388 lines
12 KiB
Go
388 lines
12 KiB
Go
package web
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
|
|
"backend/internal/auth"
|
|
"backend/internal/db/sqlc"
|
|
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
type fakeCodeExchanger struct {
|
|
token *oauth2.Token
|
|
err error
|
|
}
|
|
|
|
func (f fakeCodeExchanger) Exchange(ctx context.Context, code string) (*oauth2.Token, error) {
|
|
return f.token, f.err
|
|
}
|
|
|
|
type fakeIDTokenVerifier struct {
|
|
claims auth.ProviderClaims
|
|
err error
|
|
}
|
|
|
|
func (f fakeIDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (auth.ProviderClaims, error) {
|
|
return f.claims, f.err
|
|
}
|
|
|
|
func newGoogleAuthDeps(q *sqlc.Queries, store *auth.Store) AuthDeps {
|
|
return AuthDeps{
|
|
Queries: q,
|
|
Store: store,
|
|
Secure: false,
|
|
OAuth: auth.OAuthConfig{
|
|
Google: auth.GoogleProviderConfig{
|
|
ClientID: "google-client",
|
|
ClientSecret: "google-secret",
|
|
RedirectURL: "https://xtablo.test/auth/google/callback",
|
|
AuthURL: "https://accounts.google.test/o/oauth2/v2/auth",
|
|
TokenURL: "https://oauth2.google.test/token",
|
|
Issuer: "https://accounts.google.test",
|
|
},
|
|
},
|
|
GoogleTokenExchanger: fakeCodeExchanger{
|
|
token: (&oauth2.Token{AccessToken: "access"}).WithExtra(map[string]any{"id_token": "raw-id-token"}),
|
|
},
|
|
GoogleVerifier: fakeIDTokenVerifier{
|
|
claims: auth.ProviderClaims{
|
|
Provider: "google",
|
|
Subject: "google-subject-1",
|
|
Email: "google@example.com",
|
|
EmailVerified: true,
|
|
DisplayName: "Google User",
|
|
AvatarURL: "https://example.com/avatar.png",
|
|
Nonce: "nonce-value",
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func withGoogleClaims(deps AuthDeps, claims auth.ProviderClaims) AuthDeps {
|
|
deps.GoogleVerifier = fakeIDTokenVerifier{claims: claims}
|
|
return deps
|
|
}
|
|
|
|
func newSocialRouter(t *testing.T, deps AuthDeps) http.Handler {
|
|
t.Helper()
|
|
router, err := NewRouter(stubPinger{}, os.DirFS("./static"), deps, TablosDeps{}, TasksDeps{}, EtapesDeps{}, EventsDeps{}, DiscussionDeps{}, PlanningDeps{}, FilesDeps{}, testCSRFKey, "dev", "localhost")
|
|
if err != nil {
|
|
t.Fatalf("NewRouter: %v", err)
|
|
}
|
|
return router
|
|
}
|
|
|
|
func TestGoogleStartRedirectsAndSetsStateNonceCookies(t *testing.T) {
|
|
router := newSocialRouter(t, newGoogleAuthDeps(nil, nil))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/auth/google/start", nil)
|
|
rec := httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusSeeOther {
|
|
t.Fatalf("status = %d; want 303", rec.Code)
|
|
}
|
|
loc := rec.Header().Get("Location")
|
|
if !strings.Contains(loc, "https://accounts.google.test/o/oauth2/v2/auth") {
|
|
t.Fatalf("Location = %q; want Google auth URL", loc)
|
|
}
|
|
if !strings.Contains(loc, "client_id=google-client") {
|
|
t.Fatalf("Location = %q; missing client_id", loc)
|
|
}
|
|
if !strings.Contains(loc, "scope=openid+email+profile") {
|
|
t.Fatalf("Location = %q; missing openid email profile scope", loc)
|
|
}
|
|
|
|
cookies := rec.Result().Cookies()
|
|
if findCookie(cookies, auth.OAuthCookieName("google", auth.OAuthCookieState)) == nil {
|
|
t.Fatal("missing Google state cookie")
|
|
}
|
|
if findCookie(cookies, auth.OAuthCookieName("google", auth.OAuthCookieNonce)) == nil {
|
|
t.Fatal("missing Google nonce cookie")
|
|
}
|
|
}
|
|
|
|
func TestGoogleCallbackInvalidStateRejectedBeforeExchange(t *testing.T) {
|
|
router := newSocialRouter(t, newGoogleAuthDeps(nil, nil))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/auth/google/callback?state=wrong&code=code", nil)
|
|
req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("google", auth.OAuthCookieState), Value: "expected"})
|
|
rec := httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusBadRequest {
|
|
t.Fatalf("status = %d; want 400", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestGoogleCallbackUnverifiedEmailRejected(t *testing.T) {
|
|
deps := newGoogleAuthDeps(nil, nil)
|
|
deps.GoogleVerifier = fakeIDTokenVerifier{claims: auth.ProviderClaims{
|
|
Provider: "google",
|
|
Subject: "google-subject-1",
|
|
Email: "google@example.com",
|
|
EmailVerified: false,
|
|
Nonce: "nonce-value",
|
|
}}
|
|
router := newSocialRouter(t, deps)
|
|
|
|
callback := "/auth/google/callback?" + url.Values{"state": {"state-value"}, "code": {"code"}}.Encode()
|
|
req := httptest.NewRequest(http.MethodGet, callback, nil)
|
|
req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("google", auth.OAuthCookieState), Value: "state-value"})
|
|
req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("google", auth.OAuthCookieNonce), Value: "nonce-value"})
|
|
rec := httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusUnauthorized {
|
|
t.Fatalf("status = %d; want 401", rec.Code)
|
|
}
|
|
if !strings.Contains(rec.Body.String(), "This provider did not return a verified email. Try another sign-in method.") {
|
|
t.Fatalf("body missing unverified email copy; got: %s", rec.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestGoogleCallbackVerifiedEmailLinksExistingUserAndSetsSession(t *testing.T) {
|
|
pool, cleanup := setupTestDB(t)
|
|
defer cleanup()
|
|
|
|
ctx := context.Background()
|
|
q := sqlc.New(pool)
|
|
store := auth.NewStore(q)
|
|
user := preInsertUser(t, ctx, q, "google@example.com", "correct-horse-12chars")
|
|
deps := newGoogleAuthDeps(q, store)
|
|
deps.DB = pool
|
|
router := newSocialRouter(t, deps)
|
|
|
|
callback := "/auth/google/callback?" + url.Values{"state": {"state-value"}, "code": {"code"}}.Encode()
|
|
req := httptest.NewRequest(http.MethodGet, callback, nil)
|
|
req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("google", auth.OAuthCookieState), Value: "state-value"})
|
|
req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("google", auth.OAuthCookieNonce), Value: "nonce-value"})
|
|
rec := httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusSeeOther {
|
|
t.Fatalf("status = %d; want 303; body: %s", rec.Code, rec.Body.String())
|
|
}
|
|
if loc := rec.Header().Get("Location"); loc != "/" {
|
|
t.Fatalf("Location = %q; want /", loc)
|
|
}
|
|
if c := getSessionCookie(rec); c == nil {
|
|
t.Fatal("session cookie not set")
|
|
}
|
|
|
|
var count int
|
|
if err := pool.QueryRow(ctx, `
|
|
SELECT COUNT(*)
|
|
FROM user_identities
|
|
WHERE user_id = $1 AND provider = 'google' AND provider_subject = 'google-subject-1'
|
|
`, user.ID).Scan(&count); err != nil {
|
|
t.Fatalf("count identity: %v", err)
|
|
}
|
|
if count != 1 {
|
|
t.Fatalf("identity count = %d; want 1", count)
|
|
}
|
|
}
|
|
|
|
func TestGoogleCallbackNewVerifiedEmailCreatesSocialOnlyUserAndSession(t *testing.T) {
|
|
pool, cleanup := setupTestDB(t)
|
|
defer cleanup()
|
|
|
|
ctx := context.Background()
|
|
q := sqlc.New(pool)
|
|
store := auth.NewStore(q)
|
|
deps := newGoogleAuthDeps(q, store)
|
|
deps.DB = pool
|
|
deps = withGoogleClaims(deps, auth.ProviderClaims{
|
|
Provider: "google",
|
|
Subject: "google-new-subject",
|
|
Email: "new-google@example.com",
|
|
EmailVerified: true,
|
|
DisplayName: "New Google",
|
|
AvatarURL: "https://example.com/new.png",
|
|
Nonce: "nonce-value",
|
|
})
|
|
router := newSocialRouter(t, deps)
|
|
|
|
rec := serveGoogleCallback(router)
|
|
|
|
if rec.Code != http.StatusSeeOther {
|
|
t.Fatalf("status = %d; want 303; body: %s", rec.Code, rec.Body.String())
|
|
}
|
|
if c := getSessionCookie(rec); c == nil {
|
|
t.Fatal("session cookie not set")
|
|
}
|
|
user, err := q.GetUserByEmail(ctx, "new-google@example.com")
|
|
if err != nil {
|
|
t.Fatalf("GetUserByEmail: %v", err)
|
|
}
|
|
if user.PasswordHash.Valid {
|
|
t.Fatalf("new social user password hash Valid = true; want false")
|
|
}
|
|
identity, err := q.GetUserIdentityByProviderSubject(ctx, sqlc.GetUserIdentityByProviderSubjectParams{
|
|
Provider: "google",
|
|
ProviderSubject: "google-new-subject",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("GetUserIdentityByProviderSubject: %v", err)
|
|
}
|
|
if identity.UserID != user.ID {
|
|
t.Fatalf("identity user id = %s; want %s", identity.UserID, user.ID)
|
|
}
|
|
}
|
|
|
|
func TestGoogleCallbackExistingSubjectWinsWhenEmailChanges(t *testing.T) {
|
|
pool, cleanup := setupTestDB(t)
|
|
defer cleanup()
|
|
|
|
ctx := context.Background()
|
|
q := sqlc.New(pool)
|
|
store := auth.NewStore(q)
|
|
user := preInsertSocialOnlyUser(t, ctx, q, "old-google@example.com")
|
|
if _, err := q.InsertUserIdentity(ctx, sqlc.InsertUserIdentityParams{
|
|
UserID: user.ID,
|
|
Provider: "google",
|
|
ProviderSubject: "stable-google-subject",
|
|
Email: "old-google@example.com",
|
|
EmailVerified: true,
|
|
}); err != nil {
|
|
t.Fatalf("InsertUserIdentity: %v", err)
|
|
}
|
|
deps := newGoogleAuthDeps(q, store)
|
|
deps.DB = pool
|
|
deps = withGoogleClaims(deps, auth.ProviderClaims{
|
|
Provider: "google",
|
|
Subject: "stable-google-subject",
|
|
Email: "changed-google@example.com",
|
|
EmailVerified: true,
|
|
Nonce: "nonce-value",
|
|
})
|
|
router := newSocialRouter(t, deps)
|
|
|
|
rec := serveGoogleCallback(router)
|
|
|
|
if rec.Code != http.StatusSeeOther {
|
|
t.Fatalf("status = %d; want 303; body: %s", rec.Code, rec.Body.String())
|
|
}
|
|
identity, err := q.GetUserIdentityByProviderSubject(ctx, sqlc.GetUserIdentityByProviderSubjectParams{
|
|
Provider: "google",
|
|
ProviderSubject: "stable-google-subject",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("GetUserIdentityByProviderSubject: %v", err)
|
|
}
|
|
if identity.UserID != user.ID {
|
|
t.Fatalf("identity relinked to %s; want original %s", identity.UserID, user.ID)
|
|
}
|
|
if identity.Email != "changed-google@example.com" {
|
|
t.Fatalf("identity email = %q; want changed-google@example.com", identity.Email)
|
|
}
|
|
updatedUser, err := q.GetUserByID(ctx, user.ID)
|
|
if err != nil {
|
|
t.Fatalf("GetUserByID: %v", err)
|
|
}
|
|
if updatedUser.Email != "changed-google@example.com" {
|
|
t.Fatalf("user email = %q; want changed-google@example.com", updatedUser.Email)
|
|
}
|
|
}
|
|
|
|
func TestGoogleCallbackEmailUpdateConflictDoesNotRelinkSubject(t *testing.T) {
|
|
pool, cleanup := setupTestDB(t)
|
|
defer cleanup()
|
|
|
|
ctx := context.Background()
|
|
q := sqlc.New(pool)
|
|
store := auth.NewStore(q)
|
|
linkedUser := preInsertSocialOnlyUser(t, ctx, q, "linked-google@example.com")
|
|
conflictUser := preInsertUser(t, ctx, q, "conflict-google@example.com", "correct-horse-12chars")
|
|
if _, err := q.InsertUserIdentity(ctx, sqlc.InsertUserIdentityParams{
|
|
UserID: linkedUser.ID,
|
|
Provider: "google",
|
|
ProviderSubject: "conflict-google-subject",
|
|
Email: "linked-google@example.com",
|
|
EmailVerified: true,
|
|
}); err != nil {
|
|
t.Fatalf("InsertUserIdentity: %v", err)
|
|
}
|
|
deps := newGoogleAuthDeps(q, store)
|
|
deps.DB = pool
|
|
deps = withGoogleClaims(deps, auth.ProviderClaims{
|
|
Provider: "google",
|
|
Subject: "conflict-google-subject",
|
|
Email: "conflict-google@example.com",
|
|
EmailVerified: true,
|
|
Nonce: "nonce-value",
|
|
})
|
|
router := newSocialRouter(t, deps)
|
|
|
|
rec := serveGoogleCallback(router)
|
|
|
|
if rec.Code != http.StatusSeeOther {
|
|
t.Fatalf("status = %d; want 303; body: %s", rec.Code, rec.Body.String())
|
|
}
|
|
identity, err := q.GetUserIdentityByProviderSubject(ctx, sqlc.GetUserIdentityByProviderSubjectParams{
|
|
Provider: "google",
|
|
ProviderSubject: "conflict-google-subject",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("GetUserIdentityByProviderSubject: %v", err)
|
|
}
|
|
if identity.UserID != linkedUser.ID {
|
|
t.Fatalf("identity user id = %s; want linked user %s", identity.UserID, linkedUser.ID)
|
|
}
|
|
if identity.Email != "conflict-google@example.com" {
|
|
t.Fatalf("identity email = %q; want conflict-google@example.com", identity.Email)
|
|
}
|
|
stillLinked, err := q.GetUserByID(ctx, linkedUser.ID)
|
|
if err != nil {
|
|
t.Fatalf("GetUserByID linked: %v", err)
|
|
}
|
|
if stillLinked.Email != "linked-google@example.com" {
|
|
t.Fatalf("linked user email = %q; want linked-google@example.com", stillLinked.Email)
|
|
}
|
|
stillConflict, err := q.GetUserByID(ctx, conflictUser.ID)
|
|
if err != nil {
|
|
t.Fatalf("GetUserByID conflict: %v", err)
|
|
}
|
|
if stillConflict.Email != "conflict-google@example.com" {
|
|
t.Fatalf("conflict user email = %q; want conflict-google@example.com", stillConflict.Email)
|
|
}
|
|
}
|
|
|
|
func TestAppleRoutesAreDisabled(t *testing.T) {
|
|
router := newSocialRouter(t, AuthDeps{})
|
|
|
|
for _, path := range []string{"/auth/apple/start", "/auth/apple/callback"} {
|
|
req := httptest.NewRequest(http.MethodGet, path, nil)
|
|
rec := httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusNotFound {
|
|
t.Fatalf("%s status = %d; want 404", path, rec.Code)
|
|
}
|
|
}
|
|
}
|
|
|
|
func serveGoogleCallback(router http.Handler) *httptest.ResponseRecorder {
|
|
callback := "/auth/google/callback?" + url.Values{"state": {"state-value"}, "code": {"code"}}.Encode()
|
|
req := httptest.NewRequest(http.MethodGet, callback, nil)
|
|
req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("google", auth.OAuthCookieState), Value: "state-value"})
|
|
req.AddCookie(&http.Cookie{Name: auth.OAuthCookieName("google", auth.OAuthCookieNonce), Value: "nonce-value"})
|
|
rec := httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req)
|
|
return rec
|
|
}
|
|
|
|
func findCookie(cookies []*http.Cookie, name string) *http.Cookie {
|
|
for _, c := range cookies {
|
|
if c.Name == name {
|
|
return c
|
|
}
|
|
}
|
|
return nil
|
|
}
|