diff --git a/web/src/App.tsx b/web/src/App.tsx index 9de74c605..b9569234e 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -4,6 +4,7 @@ import { useTranslation } from "react-i18next"; import { Outlet } from "react-router-dom"; import useNavigateTo from "./hooks/useNavigateTo"; import { userStore, workspaceStore } from "./store"; +import { cleanupExpiredOAuthState } from "./utils/oauth"; import { loadTheme } from "./utils/theme"; const App = observer(() => { @@ -13,6 +14,11 @@ const App = observer(() => { const userGeneralSetting = userStore.state.userGeneralSetting; const workspaceGeneralSetting = workspaceStore.state.generalSetting; + // Clean up expired OAuth states on app initialization + useEffect(() => { + cleanupExpiredOAuthState(); + }, []); + // Redirect to sign up page if no instance owner. useEffect(() => { if (!workspaceProfile.owner) { diff --git a/web/src/pages/AuthCallback.tsx b/web/src/pages/AuthCallback.tsx index 539750187..fc54b3897 100644 --- a/web/src/pages/AuthCallback.tsx +++ b/web/src/pages/AuthCallback.tsx @@ -1,4 +1,3 @@ -import { last } from "lodash-es"; import { LoaderIcon } from "lucide-react"; import { observer } from "mobx-react-lite"; import { ClientError } from "nice-grpc-web"; @@ -8,6 +7,7 @@ import { authServiceClient } from "@/grpcweb"; import { absolutifyLink } from "@/helpers/utils"; import useNavigateTo from "@/hooks/useNavigateTo"; import { initialUserStore } from "@/store/user"; +import { validateOAuthState } from "@/utils/oauth"; interface State { loading: boolean; @@ -29,21 +29,24 @@ const AuthCallback = observer(() => { if (!code || !state) { setState({ loading: false, - errorMessage: "Failed to authorize. Invalid state passed to the auth callback.", + errorMessage: "Failed to authorize. Missing authorization code or state parameter.", }); return; } - const identityProviderId = Number(last(state.split("-"))); - if (!identityProviderId) { + // Validate OAuth state (CSRF protection) + const validatedState = validateOAuthState(state); + if (!validatedState) { setState({ loading: false, - errorMessage: "No identity provider ID found in the state parameter.", + errorMessage: "Failed to authorize. Invalid or expired state parameter. This may indicate a CSRF attack attempt.", }); return; } + const { identityProviderId, returnUrl } = validatedState; const redirectUri = absolutifyLink("/auth/callback"); + (async () => { try { await authServiceClient.createSession({ @@ -58,7 +61,8 @@ const AuthCallback = observer(() => { errorMessage: "", }); await initialUserStore(); - navigateTo("/"); + // Redirect to return URL if specified, otherwise home + navigateTo(returnUrl || "/"); } catch (error: any) { console.error(error); setState({ diff --git a/web/src/pages/SignIn.tsx b/web/src/pages/SignIn.tsx index 00e968559..aa369f9cf 100644 --- a/web/src/pages/SignIn.tsx +++ b/web/src/pages/SignIn.tsx @@ -14,6 +14,7 @@ import { workspaceStore } from "@/store"; import { extractIdentityProviderIdFromName } from "@/store/common"; import { IdentityProvider, IdentityProvider_Type } from "@/types/proto/api/v1/idp_service"; import { useTranslate } from "@/utils/i18n"; +import { storeOAuthState } from "@/utils/oauth"; const SignIn = observer(() => { const t = useTranslate(); @@ -38,7 +39,6 @@ const SignIn = observer(() => { }, []); const handleSignInWithIdentityProvider = async (identityProvider: IdentityProvider) => { - const stateQueryParameter = `auth.signin.${identityProvider.title}-${extractIdentityProviderIdFromName(identityProvider.name)}`; if (identityProvider.type === IdentityProvider_Type.OAUTH2) { const redirectUri = absolutifyLink("/auth/callback"); const oauth2Config = identityProvider.config?.oauth2Config; @@ -46,12 +46,24 @@ const SignIn = observer(() => { toast.error("Identity provider configuration is invalid."); return; } - const authUrl = `${oauth2Config.authUrl}?client_id=${ - oauth2Config.clientId - }&redirect_uri=${encodeURIComponent(redirectUri)}&state=${stateQueryParameter}&response_type=code&scope=${encodeURIComponent( - oauth2Config.scopes.join(" "), - )}`; - window.location.href = authUrl; + + try { + // Generate and store secure state parameter with CSRF protection + const identityProviderId = extractIdentityProviderIdFromName(identityProvider.name); + const state = storeOAuthState(identityProviderId); + + // Build OAuth authorization URL with secure state + const authUrl = `${oauth2Config.authUrl}?client_id=${ + oauth2Config.clientId + }&redirect_uri=${encodeURIComponent(redirectUri)}&state=${state}&response_type=code&scope=${encodeURIComponent( + oauth2Config.scopes.join(" "), + )}`; + + window.location.href = authUrl; + } catch (error) { + console.error("Failed to initiate OAuth flow:", error); + toast.error("Failed to initiate sign-in. Please try again."); + } } }; diff --git a/web/src/utils/oauth.ts b/web/src/utils/oauth.ts new file mode 100644 index 000000000..045fddc82 --- /dev/null +++ b/web/src/utils/oauth.ts @@ -0,0 +1,111 @@ +/** + * OAuth state management utilities + * Implements secure state parameter handling following Auth0 best practices + * @see https://auth0.com/docs/secure/attack-protection/state-parameters + */ + +const STATE_STORAGE_KEY = "oauth_state"; +const STATE_EXPIRY_MS = 10 * 60 * 1000; // 10 minutes + +interface OAuthState { + state: string; + identityProviderId: number; + timestamp: number; + returnUrl?: string; +} + +/** + * Generate a cryptographically secure random state value + * Uses Web Crypto API for strong randomness + */ +function generateSecureState(): string { + const array = new Uint8Array(32); + crypto.getRandomValues(array); + return Array.from(array, (byte) => byte.toString(16).padStart(2, "0")).join(""); +} + +/** + * Store OAuth state in sessionStorage with metadata + * State is stored temporarily and will be validated on callback + */ +export function storeOAuthState(identityProviderId: number, returnUrl?: string): string { + const state = generateSecureState(); + const stateData: OAuthState = { + state, + identityProviderId, + timestamp: Date.now(), + returnUrl, + }; + + try { + sessionStorage.setItem(STATE_STORAGE_KEY, JSON.stringify(stateData)); + } catch (error) { + console.error("Failed to store OAuth state:", error); + throw new Error("Failed to initialize OAuth flow"); + } + + return state; +} + +/** + * Validate and retrieve OAuth state from storage + * Implements CSRF protection by verifying state matches + * Cleans up expired or used states + */ +export function validateOAuthState(stateParam: string): { identityProviderId: number; returnUrl?: string } | null { + try { + const storedData = sessionStorage.getItem(STATE_STORAGE_KEY); + if (!storedData) { + console.error("No OAuth state found in storage"); + return null; + } + + const stateData: OAuthState = JSON.parse(storedData); + + // Check if state has expired + if (Date.now() - stateData.timestamp > STATE_EXPIRY_MS) { + console.error("OAuth state has expired"); + sessionStorage.removeItem(STATE_STORAGE_KEY); + return null; + } + + // Validate state matches (CSRF protection) + if (stateData.state !== stateParam) { + console.error("OAuth state mismatch - possible CSRF attack"); + sessionStorage.removeItem(STATE_STORAGE_KEY); + return null; + } + + // State is valid, clean up and return data + sessionStorage.removeItem(STATE_STORAGE_KEY); + return { + identityProviderId: stateData.identityProviderId, + returnUrl: stateData.returnUrl, + }; + } catch (error) { + console.error("Failed to validate OAuth state:", error); + sessionStorage.removeItem(STATE_STORAGE_KEY); + return null; + } +} + +/** + * Clean up expired OAuth states + * Should be called on app initialization + */ +export function cleanupExpiredOAuthState(): void { + try { + const storedData = sessionStorage.getItem(STATE_STORAGE_KEY); + if (!storedData) { + return; + } + + const stateData: OAuthState = JSON.parse(storedData); + if (Date.now() - stateData.timestamp > STATE_EXPIRY_MS) { + sessionStorage.removeItem(STATE_STORAGE_KEY); + } + } catch { + // If parsing fails, remove the corrupted data + sessionStorage.removeItem(STATE_STORAGE_KEY); + } +}