mirror of https://github.com/usememos/memos.git
feat: implement OAuth state management with CSRF protection and cleanup functionality
This commit is contained in:
parent
fb01b49ecf
commit
dc9470f71c
|
|
@ -4,6 +4,7 @@ import { useTranslation } from "react-i18next";
|
|||
import { Outlet } from "react-router-dom";
|
||||
import useNavigateTo from "./hooks/useNavigateTo";
|
||||
import { userStore, workspaceStore } from "./store";
|
||||
import { cleanupExpiredOAuthState } from "./utils/oauth";
|
||||
import { loadTheme } from "./utils/theme";
|
||||
|
||||
const App = observer(() => {
|
||||
|
|
@ -13,6 +14,11 @@ const App = observer(() => {
|
|||
const userGeneralSetting = userStore.state.userGeneralSetting;
|
||||
const workspaceGeneralSetting = workspaceStore.state.generalSetting;
|
||||
|
||||
// Clean up expired OAuth states on app initialization
|
||||
useEffect(() => {
|
||||
cleanupExpiredOAuthState();
|
||||
}, []);
|
||||
|
||||
// Redirect to sign up page if no instance owner.
|
||||
useEffect(() => {
|
||||
if (!workspaceProfile.owner) {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import { last } from "lodash-es";
|
||||
import { LoaderIcon } from "lucide-react";
|
||||
import { observer } from "mobx-react-lite";
|
||||
import { ClientError } from "nice-grpc-web";
|
||||
|
|
@ -8,6 +7,7 @@ import { authServiceClient } from "@/grpcweb";
|
|||
import { absolutifyLink } from "@/helpers/utils";
|
||||
import useNavigateTo from "@/hooks/useNavigateTo";
|
||||
import { initialUserStore } from "@/store/user";
|
||||
import { validateOAuthState } from "@/utils/oauth";
|
||||
|
||||
interface State {
|
||||
loading: boolean;
|
||||
|
|
@ -29,21 +29,24 @@ const AuthCallback = observer(() => {
|
|||
if (!code || !state) {
|
||||
setState({
|
||||
loading: false,
|
||||
errorMessage: "Failed to authorize. Invalid state passed to the auth callback.",
|
||||
errorMessage: "Failed to authorize. Missing authorization code or state parameter.",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const identityProviderId = Number(last(state.split("-")));
|
||||
if (!identityProviderId) {
|
||||
// Validate OAuth state (CSRF protection)
|
||||
const validatedState = validateOAuthState(state);
|
||||
if (!validatedState) {
|
||||
setState({
|
||||
loading: false,
|
||||
errorMessage: "No identity provider ID found in the state parameter.",
|
||||
errorMessage: "Failed to authorize. Invalid or expired state parameter. This may indicate a CSRF attack attempt.",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const { identityProviderId, returnUrl } = validatedState;
|
||||
const redirectUri = absolutifyLink("/auth/callback");
|
||||
|
||||
(async () => {
|
||||
try {
|
||||
await authServiceClient.createSession({
|
||||
|
|
@ -58,7 +61,8 @@ const AuthCallback = observer(() => {
|
|||
errorMessage: "",
|
||||
});
|
||||
await initialUserStore();
|
||||
navigateTo("/");
|
||||
// Redirect to return URL if specified, otherwise home
|
||||
navigateTo(returnUrl || "/");
|
||||
} catch (error: any) {
|
||||
console.error(error);
|
||||
setState({
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import { workspaceStore } from "@/store";
|
|||
import { extractIdentityProviderIdFromName } from "@/store/common";
|
||||
import { IdentityProvider, IdentityProvider_Type } from "@/types/proto/api/v1/idp_service";
|
||||
import { useTranslate } from "@/utils/i18n";
|
||||
import { storeOAuthState } from "@/utils/oauth";
|
||||
|
||||
const SignIn = observer(() => {
|
||||
const t = useTranslate();
|
||||
|
|
@ -38,7 +39,6 @@ const SignIn = observer(() => {
|
|||
}, []);
|
||||
|
||||
const handleSignInWithIdentityProvider = async (identityProvider: IdentityProvider) => {
|
||||
const stateQueryParameter = `auth.signin.${identityProvider.title}-${extractIdentityProviderIdFromName(identityProvider.name)}`;
|
||||
if (identityProvider.type === IdentityProvider_Type.OAUTH2) {
|
||||
const redirectUri = absolutifyLink("/auth/callback");
|
||||
const oauth2Config = identityProvider.config?.oauth2Config;
|
||||
|
|
@ -46,12 +46,24 @@ const SignIn = observer(() => {
|
|||
toast.error("Identity provider configuration is invalid.");
|
||||
return;
|
||||
}
|
||||
const authUrl = `${oauth2Config.authUrl}?client_id=${
|
||||
oauth2Config.clientId
|
||||
}&redirect_uri=${encodeURIComponent(redirectUri)}&state=${stateQueryParameter}&response_type=code&scope=${encodeURIComponent(
|
||||
oauth2Config.scopes.join(" "),
|
||||
)}`;
|
||||
window.location.href = authUrl;
|
||||
|
||||
try {
|
||||
// Generate and store secure state parameter with CSRF protection
|
||||
const identityProviderId = extractIdentityProviderIdFromName(identityProvider.name);
|
||||
const state = storeOAuthState(identityProviderId);
|
||||
|
||||
// Build OAuth authorization URL with secure state
|
||||
const authUrl = `${oauth2Config.authUrl}?client_id=${
|
||||
oauth2Config.clientId
|
||||
}&redirect_uri=${encodeURIComponent(redirectUri)}&state=${state}&response_type=code&scope=${encodeURIComponent(
|
||||
oauth2Config.scopes.join(" "),
|
||||
)}`;
|
||||
|
||||
window.location.href = authUrl;
|
||||
} catch (error) {
|
||||
console.error("Failed to initiate OAuth flow:", error);
|
||||
toast.error("Failed to initiate sign-in. Please try again.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,111 @@
|
|||
/**
|
||||
* OAuth state management utilities
|
||||
* Implements secure state parameter handling following Auth0 best practices
|
||||
* @see https://auth0.com/docs/secure/attack-protection/state-parameters
|
||||
*/
|
||||
|
||||
const STATE_STORAGE_KEY = "oauth_state";
|
||||
const STATE_EXPIRY_MS = 10 * 60 * 1000; // 10 minutes
|
||||
|
||||
interface OAuthState {
|
||||
state: string;
|
||||
identityProviderId: number;
|
||||
timestamp: number;
|
||||
returnUrl?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a cryptographically secure random state value
|
||||
* Uses Web Crypto API for strong randomness
|
||||
*/
|
||||
function generateSecureState(): string {
|
||||
const array = new Uint8Array(32);
|
||||
crypto.getRandomValues(array);
|
||||
return Array.from(array, (byte) => byte.toString(16).padStart(2, "0")).join("");
|
||||
}
|
||||
|
||||
/**
|
||||
* Store OAuth state in sessionStorage with metadata
|
||||
* State is stored temporarily and will be validated on callback
|
||||
*/
|
||||
export function storeOAuthState(identityProviderId: number, returnUrl?: string): string {
|
||||
const state = generateSecureState();
|
||||
const stateData: OAuthState = {
|
||||
state,
|
||||
identityProviderId,
|
||||
timestamp: Date.now(),
|
||||
returnUrl,
|
||||
};
|
||||
|
||||
try {
|
||||
sessionStorage.setItem(STATE_STORAGE_KEY, JSON.stringify(stateData));
|
||||
} catch (error) {
|
||||
console.error("Failed to store OAuth state:", error);
|
||||
throw new Error("Failed to initialize OAuth flow");
|
||||
}
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate and retrieve OAuth state from storage
|
||||
* Implements CSRF protection by verifying state matches
|
||||
* Cleans up expired or used states
|
||||
*/
|
||||
export function validateOAuthState(stateParam: string): { identityProviderId: number; returnUrl?: string } | null {
|
||||
try {
|
||||
const storedData = sessionStorage.getItem(STATE_STORAGE_KEY);
|
||||
if (!storedData) {
|
||||
console.error("No OAuth state found in storage");
|
||||
return null;
|
||||
}
|
||||
|
||||
const stateData: OAuthState = JSON.parse(storedData);
|
||||
|
||||
// Check if state has expired
|
||||
if (Date.now() - stateData.timestamp > STATE_EXPIRY_MS) {
|
||||
console.error("OAuth state has expired");
|
||||
sessionStorage.removeItem(STATE_STORAGE_KEY);
|
||||
return null;
|
||||
}
|
||||
|
||||
// Validate state matches (CSRF protection)
|
||||
if (stateData.state !== stateParam) {
|
||||
console.error("OAuth state mismatch - possible CSRF attack");
|
||||
sessionStorage.removeItem(STATE_STORAGE_KEY);
|
||||
return null;
|
||||
}
|
||||
|
||||
// State is valid, clean up and return data
|
||||
sessionStorage.removeItem(STATE_STORAGE_KEY);
|
||||
return {
|
||||
identityProviderId: stateData.identityProviderId,
|
||||
returnUrl: stateData.returnUrl,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("Failed to validate OAuth state:", error);
|
||||
sessionStorage.removeItem(STATE_STORAGE_KEY);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up expired OAuth states
|
||||
* Should be called on app initialization
|
||||
*/
|
||||
export function cleanupExpiredOAuthState(): void {
|
||||
try {
|
||||
const storedData = sessionStorage.getItem(STATE_STORAGE_KEY);
|
||||
if (!storedData) {
|
||||
return;
|
||||
}
|
||||
|
||||
const stateData: OAuthState = JSON.parse(storedData);
|
||||
if (Date.now() - stateData.timestamp > STATE_EXPIRY_MS) {
|
||||
sessionStorage.removeItem(STATE_STORAGE_KEY);
|
||||
}
|
||||
} catch {
|
||||
// If parsing fails, remove the corrupted data
|
||||
sessionStorage.removeItem(STATE_STORAGE_KEY);
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue