From 1a9bd32cf1e074a6a06e7e3482660c956aa5cdf6 Mon Sep 17 00:00:00 2001 From: Johnny Date: Mon, 1 Dec 2025 00:04:26 +0800 Subject: [PATCH] feat(auth): add PKCE support and enhance OAuth security MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements critical OAuth 2.0 security improvements to protect against authorization code interception attacks and improve provider compatibility: - Add PKCE (RFC 7636) support with SHA-256 code challenge/verifier - Fix access token extraction to use standard field instead of Extra() - Add OAuth error parameter handling (access_denied, invalid_scope, etc.) - Maintain backward compatibility for non-PKCE flows This brings the OAuth implementation up to modern security standards as recommended by Auth0, Okta, and the OAuth 2.0 Security Best Current Practice (RFC 8252). Backend changes: - Add code_verifier parameter to ExchangeToken with PKCE support - Use token.AccessToken for better provider compatibility - Update proto definition with optional code_verifier field Frontend changes: - Generate cryptographically secure PKCE parameters - Include code_challenge in authorization requests - Handle and display OAuth provider errors gracefully - Pass code_verifier during token exchange 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- plugin/idp/oauth2/oauth2.go | 22 ++++++++---- plugin/idp/oauth2/oauth2_test.go | 3 +- proto/api/v1/auth_service.proto | 4 +++ proto/gen/api/v1/auth_service.pb.go | 19 ++++++++--- proto/gen/openapi.yaml | 5 +++ server/router/api/v1/auth_service.go | 3 +- web/src/pages/AuthCallback.tsx | 27 +++++++++++++-- web/src/pages/SignIn.tsx | 8 +++-- web/src/types/proto/api/v1/auth_service.ts | 19 ++++++++++- web/src/utils/oauth.ts | 39 +++++++++++++++++++--- 10 files changed, 127 insertions(+), 22 deletions(-) diff --git a/plugin/idp/oauth2/oauth2.go b/plugin/idp/oauth2/oauth2.go index 6d10075ed..651bd8431 100644 --- a/plugin/idp/oauth2/oauth2.go +++ b/plugin/idp/oauth2/oauth2.go @@ -41,7 +41,8 @@ func NewIdentityProvider(config *storepb.OAuth2Config) (*IdentityProvider, error } // ExchangeToken returns the exchanged OAuth2 token using the given authorization code. -func (p *IdentityProvider) ExchangeToken(ctx context.Context, redirectURL, code string) (string, error) { +// If codeVerifier is provided, it will be used for PKCE (Proof Key for Code Exchange) validation. +func (p *IdentityProvider) ExchangeToken(ctx context.Context, redirectURL, code, codeVerifier string) (string, error) { conf := &oauth2.Config{ ClientID: p.config.ClientId, ClientSecret: p.config.ClientSecret, @@ -54,17 +55,26 @@ func (p *IdentityProvider) ExchangeToken(ctx context.Context, redirectURL, code }, } - token, err := conf.Exchange(ctx, code) + // Prepare token exchange options + opts := []oauth2.AuthCodeOption{} + + // Add PKCE code_verifier if provided + if codeVerifier != "" { + opts = append(opts, oauth2.SetAuthURLParam("code_verifier", codeVerifier)) + } + + token, err := conf.Exchange(ctx, code, opts...) if err != nil { return "", errors.Wrap(err, "failed to exchange access token") } - accessToken, ok := token.Extra("access_token").(string) - if !ok { - return "", errors.New(`missing "access_token" from authorization response`) + // Use the standard AccessToken field instead of Extra() + // This is more reliable across different OAuth providers + if token.AccessToken == "" { + return "", errors.New("missing access token from authorization response") } - return accessToken, nil + return token.AccessToken, nil } // UserInfo returns the parsed user information using the given OAuth2 token. diff --git a/plugin/idp/oauth2/oauth2_test.go b/plugin/idp/oauth2/oauth2_test.go index b91f03725..cd7fd640f 100644 --- a/plugin/idp/oauth2/oauth2_test.go +++ b/plugin/idp/oauth2/oauth2_test.go @@ -147,7 +147,8 @@ func TestIdentityProvider(t *testing.T) { require.NoError(t, err) redirectURL := "https://example.com/oauth/callback" - oauthToken, err := oauth2.ExchangeToken(ctx, redirectURL, testCode) + // Test without PKCE (backward compatibility) + oauthToken, err := oauth2.ExchangeToken(ctx, redirectURL, testCode, "") require.NoError(t, err) require.Equal(t, testAccessToken, oauthToken) diff --git a/proto/api/v1/auth_service.proto b/proto/api/v1/auth_service.proto index eb3477f99..941e307b5 100644 --- a/proto/api/v1/auth_service.proto +++ b/proto/api/v1/auth_service.proto @@ -68,6 +68,10 @@ message CreateSessionRequest { // The redirect URI used in the SSO flow. // Required field for security validation. string redirect_uri = 3 [(google.api.field_behavior) = REQUIRED]; + + // The PKCE code verifier for enhanced security (RFC 7636). + // Optional field - if provided, enables PKCE flow protection against authorization code interception. + string code_verifier = 4 [(google.api.field_behavior) = OPTIONAL]; } // Provide one authentication method (username/password or SSO). diff --git a/proto/gen/api/v1/auth_service.pb.go b/proto/gen/api/v1/auth_service.pb.go index 8dfa299aa..4753d5b98 100644 --- a/proto/gen/api/v1/auth_service.pb.go +++ b/proto/gen/api/v1/auth_service.pb.go @@ -360,7 +360,10 @@ type CreateSessionRequest_SSOCredentials struct { Code string `protobuf:"bytes,2,opt,name=code,proto3" json:"code,omitempty"` // The redirect URI used in the SSO flow. // Required field for security validation. - RedirectUri string `protobuf:"bytes,3,opt,name=redirect_uri,json=redirectUri,proto3" json:"redirect_uri,omitempty"` + RedirectUri string `protobuf:"bytes,3,opt,name=redirect_uri,json=redirectUri,proto3" json:"redirect_uri,omitempty"` + // The PKCE code verifier for enhanced security (RFC 7636). + // Optional field - if provided, enables PKCE flow protection against authorization code interception. + CodeVerifier string `protobuf:"bytes,4,opt,name=code_verifier,json=codeVerifier,proto3" json:"code_verifier,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -416,6 +419,13 @@ func (x *CreateSessionRequest_SSOCredentials) GetRedirectUri() string { return "" } +func (x *CreateSessionRequest_SSOCredentials) GetCodeVerifier() string { + if x != nil { + return x.CodeVerifier + } + return "" +} + var File_api_v1_auth_service_proto protoreflect.FileDescriptor const file_api_v1_auth_service_proto_rawDesc = "" + @@ -424,17 +434,18 @@ const file_api_v1_auth_service_proto_rawDesc = "" + "\x18GetCurrentSessionRequest\"\x89\x01\n" + "\x19GetCurrentSessionResponse\x12&\n" + "\x04user\x18\x01 \x01(\v2\x12.memos.api.v1.UserR\x04user\x12D\n" + - "\x10last_accessed_at\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\x0elastAccessedAt\"\xb8\x03\n" + + "\x10last_accessed_at\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\x0elastAccessedAt\"\xe3\x03\n" + "\x14CreateSessionRequest\x12k\n" + "\x14password_credentials\x18\x01 \x01(\v26.memos.api.v1.CreateSessionRequest.PasswordCredentialsH\x00R\x13passwordCredentials\x12\\\n" + "\x0fsso_credentials\x18\x02 \x01(\v21.memos.api.v1.CreateSessionRequest.SSOCredentialsH\x00R\x0essoCredentials\x1aW\n" + "\x13PasswordCredentials\x12\x1f\n" + "\busername\x18\x01 \x01(\tB\x03\xe0A\x02R\busername\x12\x1f\n" + - "\bpassword\x18\x02 \x01(\tB\x03\xe0A\x02R\bpassword\x1am\n" + + "\bpassword\x18\x02 \x01(\tB\x03\xe0A\x02R\bpassword\x1a\x97\x01\n" + "\x0eSSOCredentials\x12\x1a\n" + "\x06idp_id\x18\x01 \x01(\x05B\x03\xe0A\x02R\x05idpId\x12\x17\n" + "\x04code\x18\x02 \x01(\tB\x03\xe0A\x02R\x04code\x12&\n" + - "\fredirect_uri\x18\x03 \x01(\tB\x03\xe0A\x02R\vredirectUriB\r\n" + + "\fredirect_uri\x18\x03 \x01(\tB\x03\xe0A\x02R\vredirectUri\x12(\n" + + "\rcode_verifier\x18\x04 \x01(\tB\x03\xe0A\x01R\fcodeVerifierB\r\n" + "\vcredentials\"\x85\x01\n" + "\x15CreateSessionResponse\x12&\n" + "\x04user\x18\x01 \x01(\v2\x12.memos.api.v1.UserR\x04user\x12D\n" + diff --git a/proto/gen/openapi.yaml b/proto/gen/openapi.yaml index c8bcbc4bd..ca0b288e9 100644 --- a/proto/gen/openapi.yaml +++ b/proto/gen/openapi.yaml @@ -2159,6 +2159,11 @@ components: description: |- The redirect URI used in the SSO flow. Required field for security validation. + codeVerifier: + type: string + description: |- + The PKCE code verifier for enhanced security (RFC 7636). + Optional field - if provided, enables PKCE flow protection against authorization code interception. description: Nested message for SSO authentication credentials. CreateSessionResponse: type: object diff --git a/server/router/api/v1/auth_service.go b/server/router/api/v1/auth_service.go index 461bf8b72..c98967139 100644 --- a/server/router/api/v1/auth_service.go +++ b/server/router/api/v1/auth_service.go @@ -126,7 +126,8 @@ func (s *APIV1Service) CreateSession(ctx context.Context, request *v1pb.CreateSe if err != nil { return nil, status.Errorf(codes.Internal, "failed to create oauth2 identity provider, error: %v", err) } - token, err := oauth2IdentityProvider.ExchangeToken(ctx, ssoCredentials.RedirectUri, ssoCredentials.Code) + // Pass code_verifier for PKCE support (empty string if not provided for backward compatibility) + token, err := oauth2IdentityProvider.ExchangeToken(ctx, ssoCredentials.RedirectUri, ssoCredentials.Code, ssoCredentials.CodeVerifier) if err != nil { return nil, status.Errorf(codes.Internal, "failed to exchange token, error: %v", err) } diff --git a/web/src/pages/AuthCallback.tsx b/web/src/pages/AuthCallback.tsx index fc54b3897..6f61e6086 100644 --- a/web/src/pages/AuthCallback.tsx +++ b/web/src/pages/AuthCallback.tsx @@ -23,6 +23,28 @@ const AuthCallback = observer(() => { }); useEffect(() => { + // Check for OAuth error response first (e.g., user denied access) + const error = searchParams.get("error"); + const errorDescription = searchParams.get("error_description"); + const errorUri = searchParams.get("error_uri"); + + if (error) { + // OAuth provider returned an error + let errorMessage = `OAuth error: ${error}`; + if (errorDescription) { + errorMessage += `\n${decodeURIComponent(errorDescription)}`; + } + if (errorUri) { + errorMessage += `\nMore info: ${errorUri}`; + } + + setState({ + loading: false, + errorMessage, + }); + return; + } + const code = searchParams.get("code"); const state = searchParams.get("state"); @@ -34,7 +56,7 @@ const AuthCallback = observer(() => { return; } - // Validate OAuth state (CSRF protection) + // Validate OAuth state (CSRF protection) and retrieve PKCE code_verifier const validatedState = validateOAuthState(state); if (!validatedState) { setState({ @@ -44,7 +66,7 @@ const AuthCallback = observer(() => { return; } - const { identityProviderId, returnUrl } = validatedState; + const { identityProviderId, returnUrl, codeVerifier } = validatedState; const redirectUri = absolutifyLink("/auth/callback"); (async () => { @@ -54,6 +76,7 @@ const AuthCallback = observer(() => { idpId: identityProviderId, code, redirectUri, + codeVerifier: codeVerifier || "", // Pass PKCE code_verifier for token exchange }, }); setState({ diff --git a/web/src/pages/SignIn.tsx b/web/src/pages/SignIn.tsx index cd44a95e4..d1b29bcf5 100644 --- a/web/src/pages/SignIn.tsx +++ b/web/src/pages/SignIn.tsx @@ -49,15 +49,17 @@ const SignIn = observer(() => { try { // Generate and store secure state parameter with CSRF protection + // Also generate PKCE parameters (code_challenge) for enhanced security const identityProviderId = extractIdentityProviderIdFromName(identityProvider.name); - const state = storeOAuthState(identityProviderId); + const { state, codeChallenge } = await storeOAuthState(identityProviderId); - // Build OAuth authorization URL with secure state + // Build OAuth authorization URL with secure state and PKCE + // Using S256 (SHA-256) as the code_challenge_method per RFC 7636 const authUrl = `${oauth2Config.authUrl}?client_id=${ oauth2Config.clientId }&redirect_uri=${encodeURIComponent(redirectUri)}&state=${state}&response_type=code&scope=${encodeURIComponent( oauth2Config.scopes.join(" "), - )}`; + )}&code_challenge=${codeChallenge}&code_challenge_method=S256`; window.location.href = authUrl; } catch (error) { diff --git a/web/src/types/proto/api/v1/auth_service.ts b/web/src/types/proto/api/v1/auth_service.ts index 618670ba7..37f94dc40 100644 --- a/web/src/types/proto/api/v1/auth_service.ts +++ b/web/src/types/proto/api/v1/auth_service.ts @@ -66,6 +66,11 @@ export interface CreateSessionRequest_SSOCredentials { * Required field for security validation. */ redirectUri: string; + /** + * The PKCE code verifier for enhanced security (RFC 7636). + * Optional field - if provided, enables PKCE flow protection against authorization code interception. + */ + codeVerifier: string; } export interface CreateSessionResponse { @@ -296,7 +301,7 @@ export const CreateSessionRequest_PasswordCredentials: MessageFns = { @@ -310,6 +315,9 @@ export const CreateSessionRequest_SSOCredentials: MessageFns byte.toString(16).padStart(2, "0")).join(""); } -// Store OAuth state in sessionStorage -export function storeOAuthState(identityProviderId: number, returnUrl?: string): string { +// Generate a cryptographically secure random code_verifier for PKCE (RFC 7636) +// Returns a URL-safe base64 string (43-128 characters) +function generateCodeVerifier(): string { + const array = new Uint8Array(32); // 256 bits = 32 bytes + crypto.getRandomValues(array); + // Convert to base64url (URL-safe base64 without padding) + return base64UrlEncode(array); +} + +// Generate code_challenge from code_verifier using SHA-256 +async function generateCodeChallenge(codeVerifier: string): Promise { + const encoder = new TextEncoder(); + const data = encoder.encode(codeVerifier); + const hash = await crypto.subtle.digest("SHA-256", data); + return base64UrlEncode(new Uint8Array(hash)); +} + +// Base64URL encoding (RFC 4648 base64url without padding) +function base64UrlEncode(buffer: Uint8Array): string { + const base64 = btoa(String.fromCharCode(...buffer)); + return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, ""); +} + +// Store OAuth state and PKCE parameters in sessionStorage +// Returns both state and codeChallenge for use in authorization URL +export async function storeOAuthState(identityProviderId: number, returnUrl?: string): Promise<{ state: string; codeChallenge: string }> { const state = generateSecureState(); + const codeVerifier = generateCodeVerifier(); + const codeChallenge = await generateCodeChallenge(codeVerifier); + const stateData: OAuthState = { state, identityProviderId, timestamp: Date.now(), returnUrl, + codeVerifier, // Store for later retrieval in callback }; try { @@ -32,11 +61,12 @@ export function storeOAuthState(identityProviderId: number, returnUrl?: string): throw new Error("Failed to initialize OAuth flow"); } - return state; + return { state, codeChallenge }; } // Validate and retrieve OAuth state from storage (CSRF protection) -export function validateOAuthState(stateParam: string): { identityProviderId: number; returnUrl?: string } | null { +// Returns identityProviderId, returnUrl, and codeVerifier for PKCE +export function validateOAuthState(stateParam: string): { identityProviderId: number; returnUrl?: string; codeVerifier?: string } | null { try { const storedData = sessionStorage.getItem(STATE_STORAGE_KEY); if (!storedData) { @@ -65,6 +95,7 @@ export function validateOAuthState(stateParam: string): { identityProviderId: nu return { identityProviderId: stateData.identityProviderId, returnUrl: stateData.returnUrl, + codeVerifier: stateData.codeVerifier, // Return PKCE code_verifier }; } catch (error) { console.error("Failed to validate OAuth state:", error);