feat(auth): add PKCE support and enhance OAuth security

Implements critical OAuth 2.0 security improvements to protect against authorization code interception attacks and improve provider compatibility:

- Add PKCE (RFC 7636) support with SHA-256 code challenge/verifier
- Fix access token extraction to use standard field instead of Extra()
- Add OAuth error parameter handling (access_denied, invalid_scope, etc.)
- Maintain backward compatibility for non-PKCE flows

This brings the OAuth implementation up to modern security standards as recommended by Auth0, Okta, and the OAuth 2.0 Security Best Current Practice (RFC 8252).

Backend changes:
- Add code_verifier parameter to ExchangeToken with PKCE support
- Use token.AccessToken for better provider compatibility
- Update proto definition with optional code_verifier field

Frontend changes:
- Generate cryptographically secure PKCE parameters
- Include code_challenge in authorization requests
- Handle and display OAuth provider errors gracefully
- Pass code_verifier during token exchange

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Johnny 2025-12-01 00:04:26 +08:00
parent a6a8997f4c
commit 1a9bd32cf1
10 changed files with 127 additions and 22 deletions

View File

@ -41,7 +41,8 @@ func NewIdentityProvider(config *storepb.OAuth2Config) (*IdentityProvider, error
}
// ExchangeToken returns the exchanged OAuth2 token using the given authorization code.
func (p *IdentityProvider) ExchangeToken(ctx context.Context, redirectURL, code string) (string, error) {
// If codeVerifier is provided, it will be used for PKCE (Proof Key for Code Exchange) validation.
func (p *IdentityProvider) ExchangeToken(ctx context.Context, redirectURL, code, codeVerifier string) (string, error) {
conf := &oauth2.Config{
ClientID: p.config.ClientId,
ClientSecret: p.config.ClientSecret,
@ -54,17 +55,26 @@ func (p *IdentityProvider) ExchangeToken(ctx context.Context, redirectURL, code
},
}
token, err := conf.Exchange(ctx, code)
// Prepare token exchange options
opts := []oauth2.AuthCodeOption{}
// Add PKCE code_verifier if provided
if codeVerifier != "" {
opts = append(opts, oauth2.SetAuthURLParam("code_verifier", codeVerifier))
}
token, err := conf.Exchange(ctx, code, opts...)
if err != nil {
return "", errors.Wrap(err, "failed to exchange access token")
}
accessToken, ok := token.Extra("access_token").(string)
if !ok {
return "", errors.New(`missing "access_token" from authorization response`)
// Use the standard AccessToken field instead of Extra()
// This is more reliable across different OAuth providers
if token.AccessToken == "" {
return "", errors.New("missing access token from authorization response")
}
return accessToken, nil
return token.AccessToken, nil
}
// UserInfo returns the parsed user information using the given OAuth2 token.

View File

@ -147,7 +147,8 @@ func TestIdentityProvider(t *testing.T) {
require.NoError(t, err)
redirectURL := "https://example.com/oauth/callback"
oauthToken, err := oauth2.ExchangeToken(ctx, redirectURL, testCode)
// Test without PKCE (backward compatibility)
oauthToken, err := oauth2.ExchangeToken(ctx, redirectURL, testCode, "")
require.NoError(t, err)
require.Equal(t, testAccessToken, oauthToken)

View File

@ -68,6 +68,10 @@ message CreateSessionRequest {
// The redirect URI used in the SSO flow.
// Required field for security validation.
string redirect_uri = 3 [(google.api.field_behavior) = REQUIRED];
// The PKCE code verifier for enhanced security (RFC 7636).
// Optional field - if provided, enables PKCE flow protection against authorization code interception.
string code_verifier = 4 [(google.api.field_behavior) = OPTIONAL];
}
// Provide one authentication method (username/password or SSO).

View File

@ -360,7 +360,10 @@ type CreateSessionRequest_SSOCredentials struct {
Code string `protobuf:"bytes,2,opt,name=code,proto3" json:"code,omitempty"`
// The redirect URI used in the SSO flow.
// Required field for security validation.
RedirectUri string `protobuf:"bytes,3,opt,name=redirect_uri,json=redirectUri,proto3" json:"redirect_uri,omitempty"`
RedirectUri string `protobuf:"bytes,3,opt,name=redirect_uri,json=redirectUri,proto3" json:"redirect_uri,omitempty"`
// The PKCE code verifier for enhanced security (RFC 7636).
// Optional field - if provided, enables PKCE flow protection against authorization code interception.
CodeVerifier string `protobuf:"bytes,4,opt,name=code_verifier,json=codeVerifier,proto3" json:"code_verifier,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@ -416,6 +419,13 @@ func (x *CreateSessionRequest_SSOCredentials) GetRedirectUri() string {
return ""
}
func (x *CreateSessionRequest_SSOCredentials) GetCodeVerifier() string {
if x != nil {
return x.CodeVerifier
}
return ""
}
var File_api_v1_auth_service_proto protoreflect.FileDescriptor
const file_api_v1_auth_service_proto_rawDesc = "" +
@ -424,17 +434,18 @@ const file_api_v1_auth_service_proto_rawDesc = "" +
"\x18GetCurrentSessionRequest\"\x89\x01\n" +
"\x19GetCurrentSessionResponse\x12&\n" +
"\x04user\x18\x01 \x01(\v2\x12.memos.api.v1.UserR\x04user\x12D\n" +
"\x10last_accessed_at\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\x0elastAccessedAt\"\xb8\x03\n" +
"\x10last_accessed_at\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\x0elastAccessedAt\"\xe3\x03\n" +
"\x14CreateSessionRequest\x12k\n" +
"\x14password_credentials\x18\x01 \x01(\v26.memos.api.v1.CreateSessionRequest.PasswordCredentialsH\x00R\x13passwordCredentials\x12\\\n" +
"\x0fsso_credentials\x18\x02 \x01(\v21.memos.api.v1.CreateSessionRequest.SSOCredentialsH\x00R\x0essoCredentials\x1aW\n" +
"\x13PasswordCredentials\x12\x1f\n" +
"\busername\x18\x01 \x01(\tB\x03\xe0A\x02R\busername\x12\x1f\n" +
"\bpassword\x18\x02 \x01(\tB\x03\xe0A\x02R\bpassword\x1am\n" +
"\bpassword\x18\x02 \x01(\tB\x03\xe0A\x02R\bpassword\x1a\x97\x01\n" +
"\x0eSSOCredentials\x12\x1a\n" +
"\x06idp_id\x18\x01 \x01(\x05B\x03\xe0A\x02R\x05idpId\x12\x17\n" +
"\x04code\x18\x02 \x01(\tB\x03\xe0A\x02R\x04code\x12&\n" +
"\fredirect_uri\x18\x03 \x01(\tB\x03\xe0A\x02R\vredirectUriB\r\n" +
"\fredirect_uri\x18\x03 \x01(\tB\x03\xe0A\x02R\vredirectUri\x12(\n" +
"\rcode_verifier\x18\x04 \x01(\tB\x03\xe0A\x01R\fcodeVerifierB\r\n" +
"\vcredentials\"\x85\x01\n" +
"\x15CreateSessionResponse\x12&\n" +
"\x04user\x18\x01 \x01(\v2\x12.memos.api.v1.UserR\x04user\x12D\n" +

View File

@ -2159,6 +2159,11 @@ components:
description: |-
The redirect URI used in the SSO flow.
Required field for security validation.
codeVerifier:
type: string
description: |-
The PKCE code verifier for enhanced security (RFC 7636).
Optional field - if provided, enables PKCE flow protection against authorization code interception.
description: Nested message for SSO authentication credentials.
CreateSessionResponse:
type: object

View File

@ -126,7 +126,8 @@ func (s *APIV1Service) CreateSession(ctx context.Context, request *v1pb.CreateSe
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create oauth2 identity provider, error: %v", err)
}
token, err := oauth2IdentityProvider.ExchangeToken(ctx, ssoCredentials.RedirectUri, ssoCredentials.Code)
// Pass code_verifier for PKCE support (empty string if not provided for backward compatibility)
token, err := oauth2IdentityProvider.ExchangeToken(ctx, ssoCredentials.RedirectUri, ssoCredentials.Code, ssoCredentials.CodeVerifier)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to exchange token, error: %v", err)
}

View File

@ -23,6 +23,28 @@ const AuthCallback = observer(() => {
});
useEffect(() => {
// Check for OAuth error response first (e.g., user denied access)
const error = searchParams.get("error");
const errorDescription = searchParams.get("error_description");
const errorUri = searchParams.get("error_uri");
if (error) {
// OAuth provider returned an error
let errorMessage = `OAuth error: ${error}`;
if (errorDescription) {
errorMessage += `\n${decodeURIComponent(errorDescription)}`;
}
if (errorUri) {
errorMessage += `\nMore info: ${errorUri}`;
}
setState({
loading: false,
errorMessage,
});
return;
}
const code = searchParams.get("code");
const state = searchParams.get("state");
@ -34,7 +56,7 @@ const AuthCallback = observer(() => {
return;
}
// Validate OAuth state (CSRF protection)
// Validate OAuth state (CSRF protection) and retrieve PKCE code_verifier
const validatedState = validateOAuthState(state);
if (!validatedState) {
setState({
@ -44,7 +66,7 @@ const AuthCallback = observer(() => {
return;
}
const { identityProviderId, returnUrl } = validatedState;
const { identityProviderId, returnUrl, codeVerifier } = validatedState;
const redirectUri = absolutifyLink("/auth/callback");
(async () => {
@ -54,6 +76,7 @@ const AuthCallback = observer(() => {
idpId: identityProviderId,
code,
redirectUri,
codeVerifier: codeVerifier || "", // Pass PKCE code_verifier for token exchange
},
});
setState({

View File

@ -49,15 +49,17 @@ const SignIn = observer(() => {
try {
// Generate and store secure state parameter with CSRF protection
// Also generate PKCE parameters (code_challenge) for enhanced security
const identityProviderId = extractIdentityProviderIdFromName(identityProvider.name);
const state = storeOAuthState(identityProviderId);
const { state, codeChallenge } = await storeOAuthState(identityProviderId);
// Build OAuth authorization URL with secure state
// Build OAuth authorization URL with secure state and PKCE
// Using S256 (SHA-256) as the code_challenge_method per RFC 7636
const authUrl = `${oauth2Config.authUrl}?client_id=${
oauth2Config.clientId
}&redirect_uri=${encodeURIComponent(redirectUri)}&state=${state}&response_type=code&scope=${encodeURIComponent(
oauth2Config.scopes.join(" "),
)}`;
)}&code_challenge=${codeChallenge}&code_challenge_method=S256`;
window.location.href = authUrl;
} catch (error) {

View File

@ -66,6 +66,11 @@ export interface CreateSessionRequest_SSOCredentials {
* Required field for security validation.
*/
redirectUri: string;
/**
* The PKCE code verifier for enhanced security (RFC 7636).
* Optional field - if provided, enables PKCE flow protection against authorization code interception.
*/
codeVerifier: string;
}
export interface CreateSessionResponse {
@ -296,7 +301,7 @@ export const CreateSessionRequest_PasswordCredentials: MessageFns<CreateSessionR
};
function createBaseCreateSessionRequest_SSOCredentials(): CreateSessionRequest_SSOCredentials {
return { idpId: 0, code: "", redirectUri: "" };
return { idpId: 0, code: "", redirectUri: "", codeVerifier: "" };
}
export const CreateSessionRequest_SSOCredentials: MessageFns<CreateSessionRequest_SSOCredentials> = {
@ -310,6 +315,9 @@ export const CreateSessionRequest_SSOCredentials: MessageFns<CreateSessionReques
if (message.redirectUri !== "") {
writer.uint32(26).string(message.redirectUri);
}
if (message.codeVerifier !== "") {
writer.uint32(34).string(message.codeVerifier);
}
return writer;
},
@ -344,6 +352,14 @@ export const CreateSessionRequest_SSOCredentials: MessageFns<CreateSessionReques
message.redirectUri = reader.string();
continue;
}
case 4: {
if (tag !== 34) {
break;
}
message.codeVerifier = reader.string();
continue;
}
}
if ((tag & 7) === 4 || tag === 0) {
break;
@ -361,6 +377,7 @@ export const CreateSessionRequest_SSOCredentials: MessageFns<CreateSessionReques
message.idpId = object.idpId ?? 0;
message.code = object.code ?? "";
message.redirectUri = object.redirectUri ?? "";
message.codeVerifier = object.codeVerifier ?? "";
return message;
},
};

View File

@ -6,6 +6,7 @@ interface OAuthState {
identityProviderId: number;
timestamp: number;
returnUrl?: string;
codeVerifier?: string; // PKCE code_verifier
}
// Generate a cryptographically secure random state value
@ -15,14 +16,42 @@ function generateSecureState(): string {
return Array.from(array, (byte) => byte.toString(16).padStart(2, "0")).join("");
}
// Store OAuth state in sessionStorage
export function storeOAuthState(identityProviderId: number, returnUrl?: string): string {
// Generate a cryptographically secure random code_verifier for PKCE (RFC 7636)
// Returns a URL-safe base64 string (43-128 characters)
function generateCodeVerifier(): string {
const array = new Uint8Array(32); // 256 bits = 32 bytes
crypto.getRandomValues(array);
// Convert to base64url (URL-safe base64 without padding)
return base64UrlEncode(array);
}
// Generate code_challenge from code_verifier using SHA-256
async function generateCodeChallenge(codeVerifier: string): Promise<string> {
const encoder = new TextEncoder();
const data = encoder.encode(codeVerifier);
const hash = await crypto.subtle.digest("SHA-256", data);
return base64UrlEncode(new Uint8Array(hash));
}
// Base64URL encoding (RFC 4648 base64url without padding)
function base64UrlEncode(buffer: Uint8Array): string {
const base64 = btoa(String.fromCharCode(...buffer));
return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, "");
}
// Store OAuth state and PKCE parameters in sessionStorage
// Returns both state and codeChallenge for use in authorization URL
export async function storeOAuthState(identityProviderId: number, returnUrl?: string): Promise<{ state: string; codeChallenge: string }> {
const state = generateSecureState();
const codeVerifier = generateCodeVerifier();
const codeChallenge = await generateCodeChallenge(codeVerifier);
const stateData: OAuthState = {
state,
identityProviderId,
timestamp: Date.now(),
returnUrl,
codeVerifier, // Store for later retrieval in callback
};
try {
@ -32,11 +61,12 @@ export function storeOAuthState(identityProviderId: number, returnUrl?: string):
throw new Error("Failed to initialize OAuth flow");
}
return state;
return { state, codeChallenge };
}
// Validate and retrieve OAuth state from storage (CSRF protection)
export function validateOAuthState(stateParam: string): { identityProviderId: number; returnUrl?: string } | null {
// Returns identityProviderId, returnUrl, and codeVerifier for PKCE
export function validateOAuthState(stateParam: string): { identityProviderId: number; returnUrl?: string; codeVerifier?: string } | null {
try {
const storedData = sessionStorage.getItem(STATE_STORAGE_KEY);
if (!storedData) {
@ -65,6 +95,7 @@ export function validateOAuthState(stateParam: string): { identityProviderId: nu
return {
identityProviderId: stateData.identityProviderId,
returnUrl: stateData.returnUrl,
codeVerifier: stateData.codeVerifier, // Return PKCE code_verifier
};
} catch (error) {
console.error("Failed to validate OAuth state:", error);