mirror of https://github.com/usememos/memos.git
feat(auth): add PKCE support and enhance OAuth security
Implements critical OAuth 2.0 security improvements to protect against authorization code interception attacks and improve provider compatibility: - Add PKCE (RFC 7636) support with SHA-256 code challenge/verifier - Fix access token extraction to use standard field instead of Extra() - Add OAuth error parameter handling (access_denied, invalid_scope, etc.) - Maintain backward compatibility for non-PKCE flows This brings the OAuth implementation up to modern security standards as recommended by Auth0, Okta, and the OAuth 2.0 Security Best Current Practice (RFC 8252). Backend changes: - Add code_verifier parameter to ExchangeToken with PKCE support - Use token.AccessToken for better provider compatibility - Update proto definition with optional code_verifier field Frontend changes: - Generate cryptographically secure PKCE parameters - Include code_challenge in authorization requests - Handle and display OAuth provider errors gracefully - Pass code_verifier during token exchange 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
a6a8997f4c
commit
1a9bd32cf1
|
|
@ -41,7 +41,8 @@ func NewIdentityProvider(config *storepb.OAuth2Config) (*IdentityProvider, error
|
|||
}
|
||||
|
||||
// ExchangeToken returns the exchanged OAuth2 token using the given authorization code.
|
||||
func (p *IdentityProvider) ExchangeToken(ctx context.Context, redirectURL, code string) (string, error) {
|
||||
// If codeVerifier is provided, it will be used for PKCE (Proof Key for Code Exchange) validation.
|
||||
func (p *IdentityProvider) ExchangeToken(ctx context.Context, redirectURL, code, codeVerifier string) (string, error) {
|
||||
conf := &oauth2.Config{
|
||||
ClientID: p.config.ClientId,
|
||||
ClientSecret: p.config.ClientSecret,
|
||||
|
|
@ -54,17 +55,26 @@ func (p *IdentityProvider) ExchangeToken(ctx context.Context, redirectURL, code
|
|||
},
|
||||
}
|
||||
|
||||
token, err := conf.Exchange(ctx, code)
|
||||
// Prepare token exchange options
|
||||
opts := []oauth2.AuthCodeOption{}
|
||||
|
||||
// Add PKCE code_verifier if provided
|
||||
if codeVerifier != "" {
|
||||
opts = append(opts, oauth2.SetAuthURLParam("code_verifier", codeVerifier))
|
||||
}
|
||||
|
||||
token, err := conf.Exchange(ctx, code, opts...)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to exchange access token")
|
||||
}
|
||||
|
||||
accessToken, ok := token.Extra("access_token").(string)
|
||||
if !ok {
|
||||
return "", errors.New(`missing "access_token" from authorization response`)
|
||||
// Use the standard AccessToken field instead of Extra()
|
||||
// This is more reliable across different OAuth providers
|
||||
if token.AccessToken == "" {
|
||||
return "", errors.New("missing access token from authorization response")
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
return token.AccessToken, nil
|
||||
}
|
||||
|
||||
// UserInfo returns the parsed user information using the given OAuth2 token.
|
||||
|
|
|
|||
|
|
@ -147,7 +147,8 @@ func TestIdentityProvider(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
redirectURL := "https://example.com/oauth/callback"
|
||||
oauthToken, err := oauth2.ExchangeToken(ctx, redirectURL, testCode)
|
||||
// Test without PKCE (backward compatibility)
|
||||
oauthToken, err := oauth2.ExchangeToken(ctx, redirectURL, testCode, "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testAccessToken, oauthToken)
|
||||
|
||||
|
|
|
|||
|
|
@ -68,6 +68,10 @@ message CreateSessionRequest {
|
|||
// The redirect URI used in the SSO flow.
|
||||
// Required field for security validation.
|
||||
string redirect_uri = 3 [(google.api.field_behavior) = REQUIRED];
|
||||
|
||||
// The PKCE code verifier for enhanced security (RFC 7636).
|
||||
// Optional field - if provided, enables PKCE flow protection against authorization code interception.
|
||||
string code_verifier = 4 [(google.api.field_behavior) = OPTIONAL];
|
||||
}
|
||||
|
||||
// Provide one authentication method (username/password or SSO).
|
||||
|
|
|
|||
|
|
@ -361,6 +361,9 @@ type CreateSessionRequest_SSOCredentials struct {
|
|||
// The redirect URI used in the SSO flow.
|
||||
// Required field for security validation.
|
||||
RedirectUri string `protobuf:"bytes,3,opt,name=redirect_uri,json=redirectUri,proto3" json:"redirect_uri,omitempty"`
|
||||
// The PKCE code verifier for enhanced security (RFC 7636).
|
||||
// Optional field - if provided, enables PKCE flow protection against authorization code interception.
|
||||
CodeVerifier string `protobuf:"bytes,4,opt,name=code_verifier,json=codeVerifier,proto3" json:"code_verifier,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
|
@ -416,6 +419,13 @@ func (x *CreateSessionRequest_SSOCredentials) GetRedirectUri() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func (x *CreateSessionRequest_SSOCredentials) GetCodeVerifier() string {
|
||||
if x != nil {
|
||||
return x.CodeVerifier
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var File_api_v1_auth_service_proto protoreflect.FileDescriptor
|
||||
|
||||
const file_api_v1_auth_service_proto_rawDesc = "" +
|
||||
|
|
@ -424,17 +434,18 @@ const file_api_v1_auth_service_proto_rawDesc = "" +
|
|||
"\x18GetCurrentSessionRequest\"\x89\x01\n" +
|
||||
"\x19GetCurrentSessionResponse\x12&\n" +
|
||||
"\x04user\x18\x01 \x01(\v2\x12.memos.api.v1.UserR\x04user\x12D\n" +
|
||||
"\x10last_accessed_at\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\x0elastAccessedAt\"\xb8\x03\n" +
|
||||
"\x10last_accessed_at\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\x0elastAccessedAt\"\xe3\x03\n" +
|
||||
"\x14CreateSessionRequest\x12k\n" +
|
||||
"\x14password_credentials\x18\x01 \x01(\v26.memos.api.v1.CreateSessionRequest.PasswordCredentialsH\x00R\x13passwordCredentials\x12\\\n" +
|
||||
"\x0fsso_credentials\x18\x02 \x01(\v21.memos.api.v1.CreateSessionRequest.SSOCredentialsH\x00R\x0essoCredentials\x1aW\n" +
|
||||
"\x13PasswordCredentials\x12\x1f\n" +
|
||||
"\busername\x18\x01 \x01(\tB\x03\xe0A\x02R\busername\x12\x1f\n" +
|
||||
"\bpassword\x18\x02 \x01(\tB\x03\xe0A\x02R\bpassword\x1am\n" +
|
||||
"\bpassword\x18\x02 \x01(\tB\x03\xe0A\x02R\bpassword\x1a\x97\x01\n" +
|
||||
"\x0eSSOCredentials\x12\x1a\n" +
|
||||
"\x06idp_id\x18\x01 \x01(\x05B\x03\xe0A\x02R\x05idpId\x12\x17\n" +
|
||||
"\x04code\x18\x02 \x01(\tB\x03\xe0A\x02R\x04code\x12&\n" +
|
||||
"\fredirect_uri\x18\x03 \x01(\tB\x03\xe0A\x02R\vredirectUriB\r\n" +
|
||||
"\fredirect_uri\x18\x03 \x01(\tB\x03\xe0A\x02R\vredirectUri\x12(\n" +
|
||||
"\rcode_verifier\x18\x04 \x01(\tB\x03\xe0A\x01R\fcodeVerifierB\r\n" +
|
||||
"\vcredentials\"\x85\x01\n" +
|
||||
"\x15CreateSessionResponse\x12&\n" +
|
||||
"\x04user\x18\x01 \x01(\v2\x12.memos.api.v1.UserR\x04user\x12D\n" +
|
||||
|
|
|
|||
|
|
@ -2159,6 +2159,11 @@ components:
|
|||
description: |-
|
||||
The redirect URI used in the SSO flow.
|
||||
Required field for security validation.
|
||||
codeVerifier:
|
||||
type: string
|
||||
description: |-
|
||||
The PKCE code verifier for enhanced security (RFC 7636).
|
||||
Optional field - if provided, enables PKCE flow protection against authorization code interception.
|
||||
description: Nested message for SSO authentication credentials.
|
||||
CreateSessionResponse:
|
||||
type: object
|
||||
|
|
|
|||
|
|
@ -126,7 +126,8 @@ func (s *APIV1Service) CreateSession(ctx context.Context, request *v1pb.CreateSe
|
|||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create oauth2 identity provider, error: %v", err)
|
||||
}
|
||||
token, err := oauth2IdentityProvider.ExchangeToken(ctx, ssoCredentials.RedirectUri, ssoCredentials.Code)
|
||||
// Pass code_verifier for PKCE support (empty string if not provided for backward compatibility)
|
||||
token, err := oauth2IdentityProvider.ExchangeToken(ctx, ssoCredentials.RedirectUri, ssoCredentials.Code, ssoCredentials.CodeVerifier)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to exchange token, error: %v", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,28 @@ const AuthCallback = observer(() => {
|
|||
});
|
||||
|
||||
useEffect(() => {
|
||||
// Check for OAuth error response first (e.g., user denied access)
|
||||
const error = searchParams.get("error");
|
||||
const errorDescription = searchParams.get("error_description");
|
||||
const errorUri = searchParams.get("error_uri");
|
||||
|
||||
if (error) {
|
||||
// OAuth provider returned an error
|
||||
let errorMessage = `OAuth error: ${error}`;
|
||||
if (errorDescription) {
|
||||
errorMessage += `\n${decodeURIComponent(errorDescription)}`;
|
||||
}
|
||||
if (errorUri) {
|
||||
errorMessage += `\nMore info: ${errorUri}`;
|
||||
}
|
||||
|
||||
setState({
|
||||
loading: false,
|
||||
errorMessage,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const code = searchParams.get("code");
|
||||
const state = searchParams.get("state");
|
||||
|
||||
|
|
@ -34,7 +56,7 @@ const AuthCallback = observer(() => {
|
|||
return;
|
||||
}
|
||||
|
||||
// Validate OAuth state (CSRF protection)
|
||||
// Validate OAuth state (CSRF protection) and retrieve PKCE code_verifier
|
||||
const validatedState = validateOAuthState(state);
|
||||
if (!validatedState) {
|
||||
setState({
|
||||
|
|
@ -44,7 +66,7 @@ const AuthCallback = observer(() => {
|
|||
return;
|
||||
}
|
||||
|
||||
const { identityProviderId, returnUrl } = validatedState;
|
||||
const { identityProviderId, returnUrl, codeVerifier } = validatedState;
|
||||
const redirectUri = absolutifyLink("/auth/callback");
|
||||
|
||||
(async () => {
|
||||
|
|
@ -54,6 +76,7 @@ const AuthCallback = observer(() => {
|
|||
idpId: identityProviderId,
|
||||
code,
|
||||
redirectUri,
|
||||
codeVerifier: codeVerifier || "", // Pass PKCE code_verifier for token exchange
|
||||
},
|
||||
});
|
||||
setState({
|
||||
|
|
|
|||
|
|
@ -49,15 +49,17 @@ const SignIn = observer(() => {
|
|||
|
||||
try {
|
||||
// Generate and store secure state parameter with CSRF protection
|
||||
// Also generate PKCE parameters (code_challenge) for enhanced security
|
||||
const identityProviderId = extractIdentityProviderIdFromName(identityProvider.name);
|
||||
const state = storeOAuthState(identityProviderId);
|
||||
const { state, codeChallenge } = await storeOAuthState(identityProviderId);
|
||||
|
||||
// Build OAuth authorization URL with secure state
|
||||
// Build OAuth authorization URL with secure state and PKCE
|
||||
// Using S256 (SHA-256) as the code_challenge_method per RFC 7636
|
||||
const authUrl = `${oauth2Config.authUrl}?client_id=${
|
||||
oauth2Config.clientId
|
||||
}&redirect_uri=${encodeURIComponent(redirectUri)}&state=${state}&response_type=code&scope=${encodeURIComponent(
|
||||
oauth2Config.scopes.join(" "),
|
||||
)}`;
|
||||
)}&code_challenge=${codeChallenge}&code_challenge_method=S256`;
|
||||
|
||||
window.location.href = authUrl;
|
||||
} catch (error) {
|
||||
|
|
|
|||
|
|
@ -66,6 +66,11 @@ export interface CreateSessionRequest_SSOCredentials {
|
|||
* Required field for security validation.
|
||||
*/
|
||||
redirectUri: string;
|
||||
/**
|
||||
* The PKCE code verifier for enhanced security (RFC 7636).
|
||||
* Optional field - if provided, enables PKCE flow protection against authorization code interception.
|
||||
*/
|
||||
codeVerifier: string;
|
||||
}
|
||||
|
||||
export interface CreateSessionResponse {
|
||||
|
|
@ -296,7 +301,7 @@ export const CreateSessionRequest_PasswordCredentials: MessageFns<CreateSessionR
|
|||
};
|
||||
|
||||
function createBaseCreateSessionRequest_SSOCredentials(): CreateSessionRequest_SSOCredentials {
|
||||
return { idpId: 0, code: "", redirectUri: "" };
|
||||
return { idpId: 0, code: "", redirectUri: "", codeVerifier: "" };
|
||||
}
|
||||
|
||||
export const CreateSessionRequest_SSOCredentials: MessageFns<CreateSessionRequest_SSOCredentials> = {
|
||||
|
|
@ -310,6 +315,9 @@ export const CreateSessionRequest_SSOCredentials: MessageFns<CreateSessionReques
|
|||
if (message.redirectUri !== "") {
|
||||
writer.uint32(26).string(message.redirectUri);
|
||||
}
|
||||
if (message.codeVerifier !== "") {
|
||||
writer.uint32(34).string(message.codeVerifier);
|
||||
}
|
||||
return writer;
|
||||
},
|
||||
|
||||
|
|
@ -344,6 +352,14 @@ export const CreateSessionRequest_SSOCredentials: MessageFns<CreateSessionReques
|
|||
message.redirectUri = reader.string();
|
||||
continue;
|
||||
}
|
||||
case 4: {
|
||||
if (tag !== 34) {
|
||||
break;
|
||||
}
|
||||
|
||||
message.codeVerifier = reader.string();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if ((tag & 7) === 4 || tag === 0) {
|
||||
break;
|
||||
|
|
@ -361,6 +377,7 @@ export const CreateSessionRequest_SSOCredentials: MessageFns<CreateSessionReques
|
|||
message.idpId = object.idpId ?? 0;
|
||||
message.code = object.code ?? "";
|
||||
message.redirectUri = object.redirectUri ?? "";
|
||||
message.codeVerifier = object.codeVerifier ?? "";
|
||||
return message;
|
||||
},
|
||||
};
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ interface OAuthState {
|
|||
identityProviderId: number;
|
||||
timestamp: number;
|
||||
returnUrl?: string;
|
||||
codeVerifier?: string; // PKCE code_verifier
|
||||
}
|
||||
|
||||
// Generate a cryptographically secure random state value
|
||||
|
|
@ -15,14 +16,42 @@ function generateSecureState(): string {
|
|||
return Array.from(array, (byte) => byte.toString(16).padStart(2, "0")).join("");
|
||||
}
|
||||
|
||||
// Store OAuth state in sessionStorage
|
||||
export function storeOAuthState(identityProviderId: number, returnUrl?: string): string {
|
||||
// Generate a cryptographically secure random code_verifier for PKCE (RFC 7636)
|
||||
// Returns a URL-safe base64 string (43-128 characters)
|
||||
function generateCodeVerifier(): string {
|
||||
const array = new Uint8Array(32); // 256 bits = 32 bytes
|
||||
crypto.getRandomValues(array);
|
||||
// Convert to base64url (URL-safe base64 without padding)
|
||||
return base64UrlEncode(array);
|
||||
}
|
||||
|
||||
// Generate code_challenge from code_verifier using SHA-256
|
||||
async function generateCodeChallenge(codeVerifier: string): Promise<string> {
|
||||
const encoder = new TextEncoder();
|
||||
const data = encoder.encode(codeVerifier);
|
||||
const hash = await crypto.subtle.digest("SHA-256", data);
|
||||
return base64UrlEncode(new Uint8Array(hash));
|
||||
}
|
||||
|
||||
// Base64URL encoding (RFC 4648 base64url without padding)
|
||||
function base64UrlEncode(buffer: Uint8Array): string {
|
||||
const base64 = btoa(String.fromCharCode(...buffer));
|
||||
return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, "");
|
||||
}
|
||||
|
||||
// Store OAuth state and PKCE parameters in sessionStorage
|
||||
// Returns both state and codeChallenge for use in authorization URL
|
||||
export async function storeOAuthState(identityProviderId: number, returnUrl?: string): Promise<{ state: string; codeChallenge: string }> {
|
||||
const state = generateSecureState();
|
||||
const codeVerifier = generateCodeVerifier();
|
||||
const codeChallenge = await generateCodeChallenge(codeVerifier);
|
||||
|
||||
const stateData: OAuthState = {
|
||||
state,
|
||||
identityProviderId,
|
||||
timestamp: Date.now(),
|
||||
returnUrl,
|
||||
codeVerifier, // Store for later retrieval in callback
|
||||
};
|
||||
|
||||
try {
|
||||
|
|
@ -32,11 +61,12 @@ export function storeOAuthState(identityProviderId: number, returnUrl?: string):
|
|||
throw new Error("Failed to initialize OAuth flow");
|
||||
}
|
||||
|
||||
return state;
|
||||
return { state, codeChallenge };
|
||||
}
|
||||
|
||||
// Validate and retrieve OAuth state from storage (CSRF protection)
|
||||
export function validateOAuthState(stateParam: string): { identityProviderId: number; returnUrl?: string } | null {
|
||||
// Returns identityProviderId, returnUrl, and codeVerifier for PKCE
|
||||
export function validateOAuthState(stateParam: string): { identityProviderId: number; returnUrl?: string; codeVerifier?: string } | null {
|
||||
try {
|
||||
const storedData = sessionStorage.getItem(STATE_STORAGE_KEY);
|
||||
if (!storedData) {
|
||||
|
|
@ -65,6 +95,7 @@ export function validateOAuthState(stateParam: string): { identityProviderId: nu
|
|||
return {
|
||||
identityProviderId: stateData.identityProviderId,
|
||||
returnUrl: stateData.returnUrl,
|
||||
codeVerifier: stateData.codeVerifier, // Return PKCE code_verifier
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("Failed to validate OAuth state:", error);
|
||||
|
|
|
|||
Loading…
Reference in New Issue