feat: Refactor MCP client to use official SDK
This commit is contained in:
parent
4dbcb5cdfd
commit
9ab2326e79
|
|
@ -1,255 +0,0 @@
|
|||
import type {
|
||||
ApiChatCompletionRequest,
|
||||
ApiChatMessageData,
|
||||
ApiChatCompletionToolCall
|
||||
} from '$lib/types/api';
|
||||
import type { ChatMessagePromptProgress, ChatMessageTimings } from '$lib/types/chat';
|
||||
import type { MCPToolCall } from '$lib/mcp';
|
||||
import { MCPClient } from '$lib/mcp';
|
||||
import { OpenAISseClient, type OpenAISseTurnResult } from './openai-sse-client';
|
||||
import type { AgenticChatCompletionRequest, AgenticMessage, AgenticToolCallList } from './types';
|
||||
import { toAgenticMessages } from './types';
|
||||
|
||||
export type AgenticOrchestratorCallbacks = {
|
||||
onChunk?: (chunk: string) => void;
|
||||
onReasoningChunk?: (chunk: string) => void;
|
||||
onToolCallChunk?: (serializedToolCalls: string) => void;
|
||||
onModel?: (model: string) => void;
|
||||
onFirstValidChunk?: () => void;
|
||||
onComplete?: () => void;
|
||||
onError?: (error: Error) => void;
|
||||
};
|
||||
|
||||
export type AgenticRunParams = {
|
||||
initialMessages: ApiChatMessageData[];
|
||||
requestTemplate: ApiChatCompletionRequest;
|
||||
callbacks: AgenticOrchestratorCallbacks;
|
||||
abortSignal?: AbortSignal;
|
||||
onProcessingUpdate?: (timings?: ChatMessageTimings, progress?: ChatMessagePromptProgress) => void;
|
||||
maxTurns?: number;
|
||||
filterReasoningAfterFirstTurn?: boolean;
|
||||
};
|
||||
|
||||
export type AgenticOrchestratorOptions = {
|
||||
mcpClient: MCPClient;
|
||||
llmClient: OpenAISseClient;
|
||||
maxTurns: number;
|
||||
maxToolPreviewLines: number;
|
||||
};
|
||||
|
||||
export class AgenticOrchestrator {
|
||||
private readonly mcpClient: MCPClient;
|
||||
private readonly llmClient: OpenAISseClient;
|
||||
private readonly maxTurns: number;
|
||||
private readonly maxToolPreviewLines: number;
|
||||
|
||||
constructor(options: AgenticOrchestratorOptions) {
|
||||
this.mcpClient = options.mcpClient;
|
||||
this.llmClient = options.llmClient;
|
||||
this.maxTurns = options.maxTurns;
|
||||
this.maxToolPreviewLines = options.maxToolPreviewLines;
|
||||
}
|
||||
|
||||
async run(params: AgenticRunParams): Promise<void> {
|
||||
const baseMessages = toAgenticMessages(params.initialMessages);
|
||||
const sessionMessages: AgenticMessage[] = [...baseMessages];
|
||||
const tools = await this.mcpClient.getToolsDefinition();
|
||||
|
||||
const requestWithoutMessages = { ...params.requestTemplate };
|
||||
delete (requestWithoutMessages as Partial<ApiChatCompletionRequest>).messages;
|
||||
const requestBase: AgenticChatCompletionRequest = {
|
||||
...(requestWithoutMessages as Omit<ApiChatCompletionRequest, 'messages'>),
|
||||
stream: true,
|
||||
messages: []
|
||||
};
|
||||
|
||||
const maxTurns = params.maxTurns ?? this.maxTurns;
|
||||
|
||||
// Accumulate tool_calls across all turns (not per-turn)
|
||||
const allToolCalls: ApiChatCompletionToolCall[] = [];
|
||||
|
||||
for (let turn = 0; turn < maxTurns; turn++) {
|
||||
if (params.abortSignal?.aborted) {
|
||||
params.callbacks.onComplete?.();
|
||||
return;
|
||||
}
|
||||
|
||||
const llmRequest: AgenticChatCompletionRequest = {
|
||||
...requestBase,
|
||||
messages: sessionMessages,
|
||||
tools: tools.length > 0 ? tools : undefined
|
||||
};
|
||||
|
||||
const shouldFilterReasoningChunks = params.filterReasoningAfterFirstTurn === true && turn > 0;
|
||||
|
||||
let turnResult: OpenAISseTurnResult;
|
||||
try {
|
||||
turnResult = await this.llmClient.stream(
|
||||
llmRequest,
|
||||
{
|
||||
onChunk: params.callbacks.onChunk,
|
||||
onReasoningChunk: shouldFilterReasoningChunks
|
||||
? undefined
|
||||
: params.callbacks.onReasoningChunk,
|
||||
onModel: params.callbacks.onModel,
|
||||
onFirstValidChunk: params.callbacks.onFirstValidChunk,
|
||||
onProcessingUpdate: (timings, progress) =>
|
||||
params.onProcessingUpdate?.(timings, progress)
|
||||
},
|
||||
params.abortSignal
|
||||
);
|
||||
} catch (error) {
|
||||
// Check if error is due to abort signal (stop button)
|
||||
if (params.abortSignal?.aborted) {
|
||||
params.callbacks.onComplete?.();
|
||||
return;
|
||||
}
|
||||
|
||||
const normalizedError = error instanceof Error ? error : new Error('LLM stream error');
|
||||
params.callbacks.onError?.(normalizedError);
|
||||
const errorChunk = `\n\n\`\`\`\nUpstream LLM error:\n${normalizedError.message}\n\`\`\`\n`;
|
||||
params.callbacks.onChunk?.(errorChunk);
|
||||
params.callbacks.onComplete?.();
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
turnResult.toolCalls.length === 0 ||
|
||||
(turnResult.finishReason && turnResult.finishReason !== 'tool_calls')
|
||||
) {
|
||||
params.callbacks.onComplete?.();
|
||||
return;
|
||||
}
|
||||
|
||||
const normalizedCalls = this.normalizeToolCalls(turnResult.toolCalls);
|
||||
if (normalizedCalls.length === 0) {
|
||||
params.callbacks.onComplete?.();
|
||||
return;
|
||||
}
|
||||
|
||||
// Accumulate tool_calls from this turn
|
||||
for (const call of normalizedCalls) {
|
||||
allToolCalls.push({
|
||||
id: call.id,
|
||||
type: call.type,
|
||||
function: call.function ? { ...call.function } : undefined
|
||||
});
|
||||
}
|
||||
|
||||
// Forward the complete accumulated list
|
||||
params.callbacks.onToolCallChunk?.(JSON.stringify(allToolCalls));
|
||||
|
||||
sessionMessages.push({
|
||||
role: 'assistant',
|
||||
content: turnResult.content || undefined,
|
||||
tool_calls: normalizedCalls
|
||||
});
|
||||
|
||||
for (const toolCall of normalizedCalls) {
|
||||
if (params.abortSignal?.aborted) {
|
||||
params.callbacks.onComplete?.();
|
||||
return;
|
||||
}
|
||||
|
||||
const result = await this.executeTool(toolCall, params.abortSignal).catch(
|
||||
(error: Error) => {
|
||||
// Don't show error for AbortError
|
||||
if (error.name !== 'AbortError') {
|
||||
params.callbacks.onError?.(error);
|
||||
}
|
||||
return `Error: ${error.message}`;
|
||||
}
|
||||
);
|
||||
|
||||
// Stop silently if aborted during tool execution
|
||||
if (params.abortSignal?.aborted) {
|
||||
params.callbacks.onComplete?.();
|
||||
return;
|
||||
}
|
||||
|
||||
this.emitToolPreview(result, params.callbacks.onChunk);
|
||||
|
||||
const contextValue = this.sanitizeToolContent(result);
|
||||
sessionMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: contextValue
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
params.callbacks.onChunk?.('\n\n```\nTurn limit reached\n```\n');
|
||||
params.callbacks.onComplete?.();
|
||||
}
|
||||
|
||||
private normalizeToolCalls(toolCalls: ApiChatCompletionToolCall[]): AgenticToolCallList {
|
||||
if (!toolCalls) {
|
||||
return [];
|
||||
}
|
||||
|
||||
return toolCalls.map((call, index) => ({
|
||||
id: call?.id ?? `tool_${index}`,
|
||||
type: (call?.type as 'function') ?? 'function',
|
||||
function: {
|
||||
name: call?.function?.name ?? '',
|
||||
arguments: call?.function?.arguments ?? ''
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
private async executeTool(
|
||||
toolCall: AgenticToolCallList[number],
|
||||
abortSignal?: AbortSignal
|
||||
): Promise<string> {
|
||||
const mcpCall: MCPToolCall = {
|
||||
id: toolCall.id,
|
||||
function: {
|
||||
name: toolCall.function.name,
|
||||
arguments: toolCall.function.arguments
|
||||
}
|
||||
};
|
||||
|
||||
const result = await this.mcpClient.execute(mcpCall, abortSignal);
|
||||
return result;
|
||||
}
|
||||
|
||||
private emitToolPreview(result: string, emit?: (chunk: string) => void): void {
|
||||
if (!emit) return;
|
||||
const preview = this.createPreview(result);
|
||||
emit(preview);
|
||||
}
|
||||
|
||||
private createPreview(result: string): string {
|
||||
if (this.isBase64Image(result)) {
|
||||
return `\n})\n`;
|
||||
}
|
||||
|
||||
const lines = result.split('\n');
|
||||
const trimmedLines =
|
||||
lines.length > this.maxToolPreviewLines ? lines.slice(-this.maxToolPreviewLines) : lines;
|
||||
const preview = trimmedLines.join('\n');
|
||||
return `\n\`\`\`\n${preview}\n\`\`\`\n`;
|
||||
}
|
||||
|
||||
private sanitizeToolContent(result: string): string {
|
||||
if (this.isBase64Image(result)) {
|
||||
return '[Image displayed to user]';
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private isBase64Image(content: string): boolean {
|
||||
const trimmed = content.trim();
|
||||
if (!trimmed.startsWith('data:image/')) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const match = trimmed.match(/^data:image\/(png|jpe?g|gif|webp);base64,([A-Za-z0-9+/]+=*)$/);
|
||||
if (!match) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const base64Payload = match[2];
|
||||
return base64Payload.length > 0 && base64Payload.length % 4 === 0;
|
||||
}
|
||||
}
|
||||
|
|
@ -1,39 +1,46 @@
|
|||
import { getDefaultMcpConfig } from '$lib/config/mcp';
|
||||
import { JsonRpcProtocol } from './protocol';
|
||||
import type {
|
||||
JsonRpcMessage,
|
||||
MCPClientConfig,
|
||||
MCPServerCapabilities,
|
||||
MCPServerConfig,
|
||||
MCPToolCall,
|
||||
MCPToolDefinition,
|
||||
MCPToolsCallResult
|
||||
} from './types';
|
||||
import { MCPError } from './types';
|
||||
import type { MCPTransport } from './transports/types';
|
||||
import { WebSocketTransport } from './transports/websocket';
|
||||
import { StreamableHttpTransport } from './transports/streamable-http';
|
||||
/**
|
||||
* MCP Client implementation using the official @modelcontextprotocol/sdk
|
||||
*
|
||||
* This module provides a wrapper around the SDK's Client class that maintains
|
||||
* backward compatibility with our existing MCPClient API.
|
||||
*/
|
||||
|
||||
const MCP_DEFAULTS = getDefaultMcpConfig();
|
||||
import { Client } from '@modelcontextprotocol/sdk/client';
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import type { Tool } from '@modelcontextprotocol/sdk/types.js';
|
||||
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||
import type { MCPClientConfig, MCPServerConfig, MCPToolCall } from '$lib/types/mcp';
|
||||
import { MCPError } from '$lib/types/mcp';
|
||||
import { DEFAULT_MCP_CONFIG } from '$lib/constants/mcp';
|
||||
|
||||
interface PendingRequest {
|
||||
resolve: (value: Record<string, unknown>) => void;
|
||||
reject: (reason?: unknown) => void;
|
||||
timeout: ReturnType<typeof setTimeout>;
|
||||
// Type for tool call result content item
|
||||
interface ToolResultContentItem {
|
||||
type: string;
|
||||
text?: string;
|
||||
data?: string;
|
||||
mimeType?: string;
|
||||
resource?: { text?: string; blob?: string; uri?: string };
|
||||
}
|
||||
|
||||
interface ServerState {
|
||||
transport: MCPTransport;
|
||||
pending: Map<number, PendingRequest>;
|
||||
requestId: number;
|
||||
tools: MCPToolDefinition[];
|
||||
requestTimeoutMs?: number;
|
||||
capabilities?: MCPServerCapabilities;
|
||||
protocolVersion?: string;
|
||||
// Type for tool call result (SDK uses complex union type)
|
||||
interface ToolCallResult {
|
||||
content?: ToolResultContentItem[];
|
||||
isError?: boolean;
|
||||
_meta?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
interface ServerConnection {
|
||||
client: Client;
|
||||
transport: Transport;
|
||||
tools: Tool[];
|
||||
}
|
||||
|
||||
/**
|
||||
* MCP Client using the official @modelcontextprotocol/sdk.
|
||||
*/
|
||||
export class MCPClient {
|
||||
private readonly servers: Map<string, ServerState> = new Map();
|
||||
private readonly servers: Map<string, ServerConnection> = new Map();
|
||||
private readonly toolsToServer: Map<string, string> = new Map();
|
||||
private readonly config: MCPClientConfig;
|
||||
|
||||
|
|
@ -46,9 +53,25 @@ export class MCPClient {
|
|||
|
||||
async initialize(): Promise<void> {
|
||||
const entries = Object.entries(this.config.servers);
|
||||
await Promise.all(
|
||||
const results = await Promise.allSettled(
|
||||
entries.map(([name, serverConfig]) => this.initializeServer(name, serverConfig))
|
||||
);
|
||||
|
||||
// Log any failures but don't throw if at least one server connected
|
||||
const failures = results.filter((r) => r.status === 'rejected');
|
||||
if (failures.length > 0) {
|
||||
for (const failure of failures) {
|
||||
console.error(
|
||||
'[MCP] Server initialization failed:',
|
||||
(failure as PromiseRejectedResult).reason
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const successes = results.filter((r) => r.status === 'fulfilled');
|
||||
if (successes.length === 0) {
|
||||
throw new Error('All MCP server connections failed');
|
||||
}
|
||||
}
|
||||
|
||||
listTools(): string[] {
|
||||
|
|
@ -73,7 +96,7 @@ export class MCPClient {
|
|||
function: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.inputSchema ?? {
|
||||
parameters: (tool.inputSchema as Record<string, unknown>) ?? {
|
||||
type: 'object',
|
||||
properties: {},
|
||||
required: []
|
||||
|
|
@ -93,10 +116,16 @@ export class MCPClient {
|
|||
throw new MCPError(`Unknown tool: ${toolName}`, -32601);
|
||||
}
|
||||
|
||||
const connection = this.servers.get(serverName);
|
||||
if (!connection) {
|
||||
throw new MCPError(`Server ${serverName} is not connected`, -32000);
|
||||
}
|
||||
|
||||
if (abortSignal?.aborted) {
|
||||
throw new DOMException('Aborted', 'AbortError');
|
||||
}
|
||||
|
||||
// Parse arguments
|
||||
let args: Record<string, unknown>;
|
||||
const originalArgs = toolCall.function.arguments;
|
||||
if (typeof originalArgs === 'string') {
|
||||
|
|
@ -133,234 +162,128 @@ export class MCPClient {
|
|||
throw new MCPError(`Invalid tool arguments type: ${typeof originalArgs}`, -32602);
|
||||
}
|
||||
|
||||
const response = await this.call(
|
||||
serverName,
|
||||
'tools/call',
|
||||
{
|
||||
name: toolName,
|
||||
arguments: args
|
||||
},
|
||||
abortSignal
|
||||
);
|
||||
try {
|
||||
const result = await connection.client.callTool(
|
||||
{ name: toolName, arguments: args },
|
||||
undefined,
|
||||
{ signal: abortSignal }
|
||||
);
|
||||
|
||||
return MCPClient.formatToolResult(response as MCPToolsCallResult);
|
||||
return MCPClient.formatToolResult(result as ToolCallResult);
|
||||
} catch (error) {
|
||||
if (error instanceof DOMException && error.name === 'AbortError') {
|
||||
throw error;
|
||||
}
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
throw new MCPError(`Tool execution failed: ${message}`, -32603);
|
||||
}
|
||||
}
|
||||
|
||||
async shutdown(): Promise<void> {
|
||||
for (const [, state] of this.servers) {
|
||||
await state.transport.stop();
|
||||
const closePromises: Promise<void>[] = [];
|
||||
|
||||
for (const [name, connection] of this.servers) {
|
||||
console.log(`[MCP][${name}] Closing connection...`);
|
||||
closePromises.push(
|
||||
connection.client.close().catch((error) => {
|
||||
console.warn(`[MCP][${name}] Error closing client:`, error);
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
await Promise.allSettled(closePromises);
|
||||
this.servers.clear();
|
||||
this.toolsToServer.clear();
|
||||
}
|
||||
|
||||
private async initializeServer(name: string, config: MCPServerConfig): Promise<void> {
|
||||
const protocolVersion = this.config.protocolVersion ?? MCP_DEFAULTS.protocolVersion;
|
||||
const transport = this.createTransport(config, protocolVersion);
|
||||
await transport.start();
|
||||
console.log(`[MCP][${name}] Starting server initialization...`);
|
||||
|
||||
const state: ServerState = {
|
||||
transport,
|
||||
pending: new Map(),
|
||||
requestId: 0,
|
||||
tools: [],
|
||||
requestTimeoutMs: config.requestTimeoutMs
|
||||
};
|
||||
|
||||
transport.onMessage((message) => this.handleMessage(name, message));
|
||||
this.servers.set(name, state);
|
||||
|
||||
const clientInfo = this.config.clientInfo ?? MCP_DEFAULTS.clientInfo;
|
||||
const clientInfo = this.config.clientInfo ?? DEFAULT_MCP_CONFIG.clientInfo;
|
||||
const capabilities =
|
||||
config.capabilities ?? this.config.capabilities ?? MCP_DEFAULTS.capabilities;
|
||||
config.capabilities ?? this.config.capabilities ?? DEFAULT_MCP_CONFIG.capabilities;
|
||||
|
||||
const initResult = await this.call(name, 'initialize', {
|
||||
protocolVersion,
|
||||
capabilities,
|
||||
clientInfo
|
||||
});
|
||||
// Create SDK client
|
||||
const client = new Client(
|
||||
{ name: clientInfo.name, version: clientInfo.version ?? '1.0.0' },
|
||||
{ capabilities }
|
||||
);
|
||||
|
||||
const negotiatedVersion = (initResult?.protocolVersion as string) ?? protocolVersion;
|
||||
// Create transport with fallback
|
||||
const transport = await this.createTransportWithFallback(name, config);
|
||||
|
||||
state.capabilities = (initResult?.capabilities as MCPServerCapabilities) ?? {};
|
||||
state.protocolVersion = negotiatedVersion;
|
||||
console.log(`[MCP][${name}] Connecting to server...`);
|
||||
await client.connect(transport);
|
||||
console.log(`[MCP][${name}] Connected, listing tools...`);
|
||||
|
||||
const notification = JsonRpcProtocol.createNotification('notifications/initialized');
|
||||
await state.transport.send(notification as JsonRpcMessage);
|
||||
// List available tools
|
||||
const toolsResult = await client.listTools();
|
||||
const tools = toolsResult.tools ?? [];
|
||||
console.log(`[MCP][${name}] Found ${tools.length} tools`);
|
||||
|
||||
await this.refreshTools(name);
|
||||
// Store connection
|
||||
const connection: ServerConnection = {
|
||||
client,
|
||||
transport,
|
||||
tools
|
||||
};
|
||||
this.servers.set(name, connection);
|
||||
|
||||
// Map tools to server
|
||||
for (const tool of tools) {
|
||||
this.toolsToServer.set(tool.name, name);
|
||||
}
|
||||
|
||||
// Note: Tool list changes will be handled by re-calling listTools when needed
|
||||
// The SDK's listChanged handler requires server capability support
|
||||
|
||||
console.log(`[MCP][${name}] Server initialization complete`);
|
||||
}
|
||||
|
||||
private createTransport(config: MCPServerConfig, protocolVersion: string): MCPTransport {
|
||||
private async createTransportWithFallback(
|
||||
name: string,
|
||||
config: MCPServerConfig
|
||||
): Promise<Transport> {
|
||||
if (!config.url) {
|
||||
throw new Error('MCP server configuration is missing url');
|
||||
}
|
||||
|
||||
const transportType = config.transport ?? 'websocket';
|
||||
const url = new URL(config.url);
|
||||
const requestInit: RequestInit = {};
|
||||
|
||||
if (transportType === 'streamable_http') {
|
||||
return new StreamableHttpTransport({
|
||||
url: config.url,
|
||||
headers: config.headers,
|
||||
credentials: config.credentials,
|
||||
protocolVersion,
|
||||
if (config.headers) {
|
||||
requestInit.headers = config.headers;
|
||||
}
|
||||
if (config.credentials) {
|
||||
requestInit.credentials = config.credentials;
|
||||
}
|
||||
|
||||
// Try StreamableHTTP first (modern), fall back to SSE (legacy)
|
||||
try {
|
||||
console.log(`[MCP][${name}] Trying StreamableHTTP transport...`);
|
||||
const transport = new StreamableHTTPClientTransport(url, {
|
||||
requestInit,
|
||||
sessionId: config.sessionId
|
||||
});
|
||||
}
|
||||
|
||||
if (transportType !== 'websocket') {
|
||||
throw new Error(`Unsupported transport "${transportType}" in webui environment`);
|
||||
}
|
||||
|
||||
return new WebSocketTransport({
|
||||
url: config.url,
|
||||
protocols: config.protocols,
|
||||
handshakeTimeoutMs: config.handshakeTimeoutMs
|
||||
});
|
||||
}
|
||||
|
||||
private async refreshTools(serverName: string): Promise<void> {
|
||||
const state = this.servers.get(serverName);
|
||||
if (!state) return;
|
||||
|
||||
const response = await this.call(serverName, 'tools/list');
|
||||
const tools = (response.tools as MCPToolDefinition[]) ?? [];
|
||||
state.tools = tools;
|
||||
|
||||
for (const [tool, owner] of Array.from(this.toolsToServer.entries())) {
|
||||
if (owner === serverName && !tools.find((t) => t.name === tool)) {
|
||||
this.toolsToServer.delete(tool);
|
||||
}
|
||||
}
|
||||
|
||||
for (const tool of tools) {
|
||||
this.toolsToServer.set(tool.name, serverName);
|
||||
}
|
||||
}
|
||||
|
||||
private call(
|
||||
serverName: string,
|
||||
method: string,
|
||||
params?: Record<string, unknown>,
|
||||
abortSignal?: AbortSignal
|
||||
): Promise<Record<string, unknown>> {
|
||||
const state = this.servers.get(serverName);
|
||||
if (!state) {
|
||||
return Promise.reject(new MCPError(`Server ${serverName} is not connected`, -32000));
|
||||
}
|
||||
|
||||
const id = ++state.requestId;
|
||||
const message = JsonRpcProtocol.createRequest(id, method, params);
|
||||
|
||||
const timeoutDuration =
|
||||
state.requestTimeoutMs ??
|
||||
this.config.requestTimeoutMs ??
|
||||
MCP_DEFAULTS.requestTimeoutSeconds * 1000;
|
||||
|
||||
if (abortSignal?.aborted) {
|
||||
return Promise.reject(new DOMException('Aborted', 'AbortError'));
|
||||
}
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
const cleanupTasks: Array<() => void> = [];
|
||||
const cleanup = () => {
|
||||
for (const task of cleanupTasks.splice(0)) {
|
||||
task();
|
||||
}
|
||||
};
|
||||
|
||||
const timeout = setTimeout(() => {
|
||||
cleanup();
|
||||
reject(new Error(`Timeout while waiting for ${method} response from ${serverName}`));
|
||||
}, timeoutDuration);
|
||||
cleanupTasks.push(() => clearTimeout(timeout));
|
||||
cleanupTasks.push(() => state.pending.delete(id));
|
||||
|
||||
if (abortSignal) {
|
||||
const abortHandler = () => {
|
||||
cleanup();
|
||||
reject(new DOMException('Aborted', 'AbortError'));
|
||||
};
|
||||
abortSignal.addEventListener('abort', abortHandler, { once: true });
|
||||
cleanupTasks.push(() => abortSignal.removeEventListener('abort', abortHandler));
|
||||
}
|
||||
|
||||
state.pending.set(id, {
|
||||
resolve: (value) => {
|
||||
cleanup();
|
||||
resolve(value);
|
||||
},
|
||||
reject: (reason) => {
|
||||
cleanup();
|
||||
reject(reason);
|
||||
},
|
||||
timeout
|
||||
});
|
||||
|
||||
const handleSendError = (error: unknown) => {
|
||||
cleanup();
|
||||
reject(error);
|
||||
};
|
||||
return transport;
|
||||
} catch (httpError) {
|
||||
console.warn(`[MCP][${name}] StreamableHTTP failed, trying SSE transport...`, httpError);
|
||||
|
||||
try {
|
||||
void state.transport
|
||||
.send(message as JsonRpcMessage)
|
||||
.catch((error) => handleSendError(error));
|
||||
} catch (error) {
|
||||
handleSendError(error);
|
||||
const transport = new SSEClientTransport(url, {
|
||||
requestInit
|
||||
});
|
||||
return transport;
|
||||
} catch (sseError) {
|
||||
// Both failed, throw combined error
|
||||
const httpMsg = httpError instanceof Error ? httpError.message : String(httpError);
|
||||
const sseMsg = sseError instanceof Error ? sseError.message : String(sseError);
|
||||
throw new Error(`Failed to create transport. StreamableHTTP: ${httpMsg}; SSE: ${sseMsg}`);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private handleMessage(serverName: string, message: JsonRpcMessage): void {
|
||||
const state = this.servers.get(serverName);
|
||||
if (!state) {
|
||||
return;
|
||||
}
|
||||
|
||||
if ('method' in message && !('id' in message)) {
|
||||
this.handleNotification(serverName, message.method, message.params);
|
||||
return;
|
||||
}
|
||||
|
||||
const response = JsonRpcProtocol.parseResponse(message);
|
||||
if (!response) {
|
||||
return;
|
||||
}
|
||||
|
||||
const pending = state.pending.get(response.id as number);
|
||||
if (!pending) {
|
||||
return;
|
||||
}
|
||||
|
||||
state.pending.delete(response.id as number);
|
||||
clearTimeout(pending.timeout);
|
||||
|
||||
if (response.error) {
|
||||
pending.reject(
|
||||
new MCPError(response.error.message, response.error.code, response.error.data)
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
pending.resolve(response.result ?? {});
|
||||
}
|
||||
|
||||
private handleNotification(
|
||||
serverName: string,
|
||||
method: string,
|
||||
params?: Record<string, unknown>
|
||||
): void {
|
||||
if (method === 'notifications/tools/list_changed') {
|
||||
void this.refreshTools(serverName).catch((error) => {
|
||||
console.error(`[MCP] Failed to refresh tools for ${serverName}:`, error);
|
||||
});
|
||||
} else if (method === 'notifications/logging/message' && params) {
|
||||
console.debug(`[MCP][${serverName}]`, params);
|
||||
}
|
||||
}
|
||||
|
||||
private static formatToolResult(result: MCPToolsCallResult): string {
|
||||
private static formatToolResult(result: ToolCallResult): string {
|
||||
const content = result.content;
|
||||
if (Array.isArray(content)) {
|
||||
return content
|
||||
|
|
@ -368,46 +291,30 @@ export class MCPClient {
|
|||
.filter(Boolean)
|
||||
.join('\n');
|
||||
}
|
||||
if (content) {
|
||||
return MCPClient.formatSingleContent(content);
|
||||
}
|
||||
if (result.result !== undefined) {
|
||||
return typeof result.result === 'string' ? result.result : JSON.stringify(result.result);
|
||||
}
|
||||
return '';
|
||||
}
|
||||
|
||||
private static formatSingleContent(content: unknown): string {
|
||||
if (content === null || content === undefined) {
|
||||
return '';
|
||||
private static formatSingleContent(content: ToolResultContentItem): string {
|
||||
if (content.type === 'text' && content.text) {
|
||||
return content.text;
|
||||
}
|
||||
|
||||
if (typeof content === 'string') {
|
||||
return content;
|
||||
if (content.type === 'image' && content.data) {
|
||||
return `data:${content.mimeType ?? 'image/png'};base64,${content.data}`;
|
||||
}
|
||||
|
||||
if (typeof content === 'object') {
|
||||
const typed = content as {
|
||||
type?: string;
|
||||
text?: string;
|
||||
data?: string;
|
||||
mimeType?: string;
|
||||
resource?: unknown;
|
||||
};
|
||||
if (typed.type === 'text' && typeof typed.text === 'string') {
|
||||
return typed.text;
|
||||
if (content.type === 'resource' && content.resource) {
|
||||
const resource = content.resource;
|
||||
if (resource.text) {
|
||||
return resource.text;
|
||||
}
|
||||
if (typed.type === 'image' && typeof typed.data === 'string' && typed.mimeType) {
|
||||
return `data:${typed.mimeType};base64,${typed.data}`;
|
||||
}
|
||||
if (typed.type === 'resource' && typed.resource) {
|
||||
return JSON.stringify(typed.resource);
|
||||
}
|
||||
if (typeof typed.text === 'string') {
|
||||
return typed.text;
|
||||
if (resource.blob) {
|
||||
return resource.blob;
|
||||
}
|
||||
return JSON.stringify(resource);
|
||||
}
|
||||
// audio type
|
||||
if (content.data && content.mimeType) {
|
||||
return `data:${content.mimeType};base64,${content.data}`;
|
||||
}
|
||||
|
||||
return JSON.stringify(content);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
export { MCPClient } from './client';
|
||||
export { MCPError } from './types';
|
||||
export type { MCPClientConfig, MCPServerConfig, MCPToolCall } from './types';
|
||||
export { MCPError } from '$lib/types/mcp';
|
||||
export type { MCPClientConfig, MCPServerConfig, MCPToolCall, IMCPClient } from '$lib/types/mcp';
|
||||
|
|
|
|||
|
|
@ -1,46 +0,0 @@
|
|||
import type {
|
||||
JsonRpcId,
|
||||
JsonRpcMessage,
|
||||
JsonRpcNotification,
|
||||
JsonRpcRequest,
|
||||
JsonRpcResponse
|
||||
} from './types';
|
||||
|
||||
export class JsonRpcProtocol {
|
||||
static createRequest(
|
||||
id: JsonRpcId,
|
||||
method: string,
|
||||
params?: Record<string, unknown>
|
||||
): JsonRpcRequest {
|
||||
return {
|
||||
jsonrpc: '2.0',
|
||||
id,
|
||||
method,
|
||||
...(params ? { params } : {})
|
||||
};
|
||||
}
|
||||
|
||||
static createNotification(method: string, params?: Record<string, unknown>): JsonRpcNotification {
|
||||
return {
|
||||
jsonrpc: '2.0',
|
||||
method,
|
||||
...(params ? { params } : {})
|
||||
};
|
||||
}
|
||||
|
||||
static parseResponse(message: JsonRpcMessage): JsonRpcResponse | null {
|
||||
if (!message || typeof message !== 'object') {
|
||||
return null;
|
||||
}
|
||||
|
||||
if ((message as JsonRpcResponse).jsonrpc !== '2.0') {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!('id' in message)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return message as JsonRpcResponse;
|
||||
}
|
||||
}
|
||||
|
|
@ -1,129 +0,0 @@
|
|||
import type { JsonRpcMessage } from '$lib/mcp/types';
|
||||
import type { MCPTransport } from './types';
|
||||
|
||||
export type StreamableHttpTransportOptions = {
|
||||
url: string;
|
||||
headers?: Record<string, string>;
|
||||
credentials?: RequestCredentials;
|
||||
protocolVersion?: string;
|
||||
sessionId?: string;
|
||||
};
|
||||
|
||||
export class StreamableHttpTransport implements MCPTransport {
|
||||
private handler: ((message: JsonRpcMessage) => void) | null = null;
|
||||
private activeSessionId: string | undefined;
|
||||
|
||||
constructor(private readonly options: StreamableHttpTransportOptions) {}
|
||||
|
||||
async start(): Promise<void> {
|
||||
this.activeSessionId = this.options.sessionId ?? undefined;
|
||||
}
|
||||
|
||||
async stop(): Promise<void> {}
|
||||
|
||||
async send(message: JsonRpcMessage): Promise<void> {
|
||||
return this.dispatch(message);
|
||||
}
|
||||
|
||||
onMessage(handler: (message: JsonRpcMessage) => void): void {
|
||||
this.handler = handler;
|
||||
}
|
||||
|
||||
private async dispatch(message: JsonRpcMessage): Promise<void> {
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
Accept: 'application/json, text/event-stream',
|
||||
...(this.options.headers ?? {})
|
||||
};
|
||||
|
||||
if (this.activeSessionId) {
|
||||
headers['Mcp-Session-Id'] = this.activeSessionId;
|
||||
}
|
||||
|
||||
if (this.options.protocolVersion) {
|
||||
headers['MCP-Protocol-Version'] = this.options.protocolVersion;
|
||||
}
|
||||
|
||||
const credentialsOption =
|
||||
this.options.credentials ?? (this.activeSessionId ? 'include' : 'same-origin');
|
||||
const response = await fetch(this.options.url, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify(message),
|
||||
credentials: credentialsOption
|
||||
});
|
||||
|
||||
const sessionHeader = response.headers.get('mcp-session-id');
|
||||
if (sessionHeader) {
|
||||
this.activeSessionId = sessionHeader;
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
const errorBody = await response.text().catch(() => '');
|
||||
throw new Error(
|
||||
`Failed to send MCP request over Streamable HTTP (${response.status} ${response.statusText}): ${errorBody}`
|
||||
);
|
||||
}
|
||||
|
||||
const contentType = response.headers.get('content-type') ?? '';
|
||||
|
||||
if (contentType.includes('application/json')) {
|
||||
const payload = (await response.json()) as JsonRpcMessage;
|
||||
this.handler?.(payload);
|
||||
return;
|
||||
}
|
||||
|
||||
if (contentType.includes('text/event-stream') && response.body) {
|
||||
const reader = response.body.getReader();
|
||||
await this.consume(reader);
|
||||
return;
|
||||
}
|
||||
|
||||
if (response.status >= 400) {
|
||||
const bodyText = await response.text().catch(() => '');
|
||||
throw new Error(
|
||||
`Unexpected MCP Streamable HTTP response (${response.status}): ${bodyText || 'no body'}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private async consume(reader: ReadableStreamDefaultReader<Uint8Array>): Promise<void> {
|
||||
const decoder = new TextDecoder('utf-8');
|
||||
let buffer = '';
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
|
||||
const parts = buffer.split('\n\n');
|
||||
buffer = parts.pop() ?? '';
|
||||
|
||||
for (const part of parts) {
|
||||
if (!part.startsWith('data: ')) {
|
||||
continue;
|
||||
}
|
||||
const payload = part.slice(6);
|
||||
if (!payload || payload === '[DONE]') {
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
const message = JSON.parse(payload) as JsonRpcMessage;
|
||||
this.handler?.(message);
|
||||
} catch (error) {
|
||||
console.error('[MCP][Streamable HTTP] Failed to parse JSON payload:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
if ((error as Error)?.name === 'AbortError') {
|
||||
return;
|
||||
}
|
||||
throw error;
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,8 +0,0 @@
|
|||
import type { JsonRpcMessage } from '../types';
|
||||
|
||||
export interface MCPTransport {
|
||||
start(): Promise<void>;
|
||||
stop(): Promise<void>;
|
||||
send(message: JsonRpcMessage): Promise<void>;
|
||||
onMessage(handler: (message: JsonRpcMessage) => void): void;
|
||||
}
|
||||
|
|
@ -1,238 +0,0 @@
|
|||
import type { JsonRpcMessage } from '$lib/mcp/types';
|
||||
import type { MCPTransport } from './types';
|
||||
|
||||
export type WebSocketTransportOptions = {
|
||||
url: string;
|
||||
protocols?: string | string[];
|
||||
handshakeTimeoutMs?: number;
|
||||
};
|
||||
|
||||
export type TransportMessageHandler = (message: JsonRpcMessage) => void;
|
||||
|
||||
function ensureWebSocket(): typeof WebSocket | null {
|
||||
if (typeof WebSocket !== 'undefined') {
|
||||
return WebSocket;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function arrayBufferToString(buffer: ArrayBufferLike): string {
|
||||
return new TextDecoder('utf-8').decode(new Uint8Array(buffer));
|
||||
}
|
||||
|
||||
async function normalizePayload(data: unknown): Promise<string> {
|
||||
if (typeof data === 'string') {
|
||||
return data;
|
||||
}
|
||||
|
||||
if (data instanceof ArrayBuffer) {
|
||||
return arrayBufferToString(data);
|
||||
}
|
||||
|
||||
if (ArrayBuffer.isView(data)) {
|
||||
return arrayBufferToString(data.buffer);
|
||||
}
|
||||
|
||||
if (typeof Blob !== 'undefined' && data instanceof Blob) {
|
||||
return await data.text();
|
||||
}
|
||||
|
||||
throw new Error('Unsupported WebSocket message payload type');
|
||||
}
|
||||
|
||||
export class WebSocketTransport implements MCPTransport {
|
||||
private socket: WebSocket | null = null;
|
||||
private handler: TransportMessageHandler | null = null;
|
||||
private openPromise: Promise<void> | null = null;
|
||||
private reconnectAttempts = 0;
|
||||
private readonly maxReconnectAttempts = 5;
|
||||
private readonly reconnectDelay = 1_000;
|
||||
private isReconnecting = false;
|
||||
private shouldAttemptReconnect = true;
|
||||
|
||||
constructor(private readonly options: WebSocketTransportOptions) {}
|
||||
|
||||
start(): Promise<void> {
|
||||
if (this.openPromise) {
|
||||
return this.openPromise;
|
||||
}
|
||||
|
||||
this.shouldAttemptReconnect = true;
|
||||
|
||||
this.openPromise = new Promise((resolve, reject) => {
|
||||
const WebSocketImpl = ensureWebSocket();
|
||||
if (!WebSocketImpl) {
|
||||
this.openPromise = null;
|
||||
reject(new Error('WebSocket is not available in this environment'));
|
||||
return;
|
||||
}
|
||||
|
||||
let handshakeTimeout: ReturnType<typeof setTimeout> | undefined;
|
||||
const socket = this.options.protocols
|
||||
? new WebSocketImpl(this.options.url, this.options.protocols)
|
||||
: new WebSocketImpl(this.options.url);
|
||||
|
||||
const cleanup = () => {
|
||||
if (!socket) return;
|
||||
socket.onopen = null;
|
||||
socket.onclose = null;
|
||||
socket.onerror = null;
|
||||
socket.onmessage = null;
|
||||
if (handshakeTimeout) {
|
||||
clearTimeout(handshakeTimeout);
|
||||
handshakeTimeout = undefined;
|
||||
}
|
||||
};
|
||||
|
||||
const fail = (error: unknown) => {
|
||||
cleanup();
|
||||
this.openPromise = null;
|
||||
reject(error instanceof Error ? error : new Error('WebSocket connection error'));
|
||||
};
|
||||
|
||||
socket.onopen = () => {
|
||||
cleanup();
|
||||
this.socket = socket;
|
||||
this.reconnectAttempts = 0;
|
||||
this.attachMessageHandler();
|
||||
this.attachCloseHandler(socket);
|
||||
resolve();
|
||||
this.openPromise = null;
|
||||
};
|
||||
|
||||
socket.onerror = (event) => {
|
||||
const error = event instanceof Event ? new Error('WebSocket connection error') : event;
|
||||
fail(error);
|
||||
};
|
||||
|
||||
socket.onclose = (event) => {
|
||||
if (!this.socket) {
|
||||
fail(new Error(`WebSocket closed before opening (code: ${event.code})`));
|
||||
}
|
||||
};
|
||||
|
||||
if (this.options.handshakeTimeoutMs) {
|
||||
handshakeTimeout = setTimeout(() => {
|
||||
if (!this.socket) {
|
||||
try {
|
||||
socket.close();
|
||||
} catch (error) {
|
||||
console.warn('[MCP][Transport] Failed to close socket after timeout:', error);
|
||||
}
|
||||
fail(new Error('WebSocket handshake timed out'));
|
||||
}
|
||||
}, this.options.handshakeTimeoutMs);
|
||||
}
|
||||
});
|
||||
|
||||
return this.openPromise;
|
||||
}
|
||||
|
||||
async send(message: JsonRpcMessage): Promise<void> {
|
||||
if (!this.socket || this.socket.readyState !== WebSocket.OPEN) {
|
||||
throw new Error('WebSocket transport is not connected');
|
||||
}
|
||||
this.socket.send(JSON.stringify(message));
|
||||
}
|
||||
|
||||
async stop(): Promise<void> {
|
||||
this.shouldAttemptReconnect = false;
|
||||
this.reconnectAttempts = 0;
|
||||
this.isReconnecting = false;
|
||||
|
||||
const socket = this.socket;
|
||||
if (!socket) {
|
||||
this.openPromise = null;
|
||||
return;
|
||||
}
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
const onClose = () => {
|
||||
socket.removeEventListener('close', onClose);
|
||||
resolve();
|
||||
};
|
||||
socket.addEventListener('close', onClose);
|
||||
try {
|
||||
socket.close();
|
||||
} catch (error) {
|
||||
socket.removeEventListener('close', onClose);
|
||||
console.warn('[MCP][Transport] Failed to close WebSocket:', error);
|
||||
resolve();
|
||||
}
|
||||
});
|
||||
|
||||
this.socket = null;
|
||||
this.openPromise = null;
|
||||
}
|
||||
|
||||
onMessage(handler: TransportMessageHandler): void {
|
||||
this.handler = handler;
|
||||
this.attachMessageHandler();
|
||||
}
|
||||
|
||||
private attachMessageHandler(): void {
|
||||
if (!this.socket) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.socket.onmessage = (event: MessageEvent) => {
|
||||
const payload = event.data;
|
||||
void (async () => {
|
||||
try {
|
||||
const text = await normalizePayload(payload);
|
||||
const parsed = JSON.parse(text);
|
||||
this.handler?.(parsed as JsonRpcMessage);
|
||||
} catch (error) {
|
||||
console.error('[MCP][Transport] Failed to handle message:', error);
|
||||
}
|
||||
})();
|
||||
};
|
||||
}
|
||||
|
||||
private attachCloseHandler(socket: WebSocket): void {
|
||||
socket.onclose = (event) => {
|
||||
this.socket = null;
|
||||
|
||||
if (event.code === 1000 || !this.shouldAttemptReconnect) {
|
||||
return;
|
||||
}
|
||||
|
||||
console.warn('[MCP][WebSocket] Connection closed unexpectedly, attempting reconnect');
|
||||
void this.reconnect();
|
||||
};
|
||||
}
|
||||
|
||||
private async reconnect(): Promise<void> {
|
||||
if (
|
||||
this.isReconnecting ||
|
||||
this.reconnectAttempts >= this.maxReconnectAttempts ||
|
||||
!this.shouldAttemptReconnect
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.isReconnecting = true;
|
||||
this.reconnectAttempts++;
|
||||
|
||||
const delay = this.reconnectDelay * Math.pow(2, this.reconnectAttempts - 1);
|
||||
await new Promise((resolve) => setTimeout(resolve, delay));
|
||||
|
||||
try {
|
||||
this.openPromise = null;
|
||||
await this.start();
|
||||
this.reconnectAttempts = 0;
|
||||
console.log('[MCP][WebSocket] Reconnected successfully');
|
||||
} catch (error) {
|
||||
console.error('[MCP][WebSocket] Reconnection failed:', error);
|
||||
} finally {
|
||||
this.isReconnecting = false;
|
||||
if (
|
||||
!this.socket &&
|
||||
this.shouldAttemptReconnect &&
|
||||
this.reconnectAttempts < this.maxReconnectAttempts
|
||||
) {
|
||||
void this.reconnect();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,140 +0,0 @@
|
|||
import { browser } from '$app/environment';
|
||||
import { MCPClient } from '$lib/mcp';
|
||||
import { buildMcpClientConfig } from '$lib/config/mcp';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
|
||||
const globalState = globalThis as typeof globalThis & {
|
||||
__llamaMcpClient?: MCPClient;
|
||||
__llamaMcpInitPromise?: Promise<MCPClient | undefined>;
|
||||
__llamaMcpConfigSignature?: string;
|
||||
__llamaMcpInitConfigSignature?: string;
|
||||
};
|
||||
|
||||
function serializeConfigSignature(): string | undefined {
|
||||
const mcpConfig = buildMcpClientConfig(config());
|
||||
return mcpConfig ? JSON.stringify(mcpConfig) : undefined;
|
||||
}
|
||||
|
||||
async function shutdownClient(): Promise<void> {
|
||||
if (!globalState.__llamaMcpClient) return;
|
||||
|
||||
const clientToShutdown = globalState.__llamaMcpClient;
|
||||
globalState.__llamaMcpClient = undefined;
|
||||
globalState.__llamaMcpConfigSignature = undefined;
|
||||
|
||||
try {
|
||||
await clientToShutdown.shutdown();
|
||||
} catch (error) {
|
||||
console.error('[MCP] Failed to shutdown client:', error);
|
||||
}
|
||||
}
|
||||
|
||||
async function bootstrapClient(
|
||||
signature: string,
|
||||
mcpConfig: ReturnType<typeof buildMcpClientConfig>
|
||||
): Promise<MCPClient | undefined> {
|
||||
if (!browser || !mcpConfig) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const client = new MCPClient(mcpConfig);
|
||||
globalState.__llamaMcpInitConfigSignature = signature;
|
||||
|
||||
const initPromise = client
|
||||
.initialize()
|
||||
.then(() => {
|
||||
// Ignore initialization if config changed during bootstrap
|
||||
if (globalState.__llamaMcpInitConfigSignature !== signature) {
|
||||
void client.shutdown().catch((shutdownError) => {
|
||||
console.error(
|
||||
'[MCP] Failed to shutdown stale client after config change:',
|
||||
shutdownError
|
||||
);
|
||||
});
|
||||
return undefined;
|
||||
}
|
||||
|
||||
globalState.__llamaMcpClient = client;
|
||||
globalState.__llamaMcpConfigSignature = signature;
|
||||
return client;
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('[MCP] Failed to initialize client:', error);
|
||||
|
||||
// Cleanup global references on error
|
||||
if (globalState.__llamaMcpClient === client) {
|
||||
globalState.__llamaMcpClient = undefined;
|
||||
}
|
||||
if (globalState.__llamaMcpConfigSignature === signature) {
|
||||
globalState.__llamaMcpConfigSignature = undefined;
|
||||
}
|
||||
|
||||
void client.shutdown().catch((shutdownError) => {
|
||||
console.error('[MCP] Failed to shutdown client after init error:', shutdownError);
|
||||
});
|
||||
return undefined;
|
||||
})
|
||||
.finally(() => {
|
||||
// Clear init promise only if it's OUR promise
|
||||
if (globalState.__llamaMcpInitPromise === initPromise) {
|
||||
globalState.__llamaMcpInitPromise = undefined;
|
||||
// Clear init signature only if it's still ours
|
||||
if (globalState.__llamaMcpInitConfigSignature === signature) {
|
||||
globalState.__llamaMcpInitConfigSignature = undefined;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
globalState.__llamaMcpInitPromise = initPromise;
|
||||
return initPromise;
|
||||
}
|
||||
|
||||
export function getMcpClient(): MCPClient | undefined {
|
||||
return globalState.__llamaMcpClient;
|
||||
}
|
||||
|
||||
export async function ensureMcpClient(): Promise<MCPClient | undefined> {
|
||||
const signature = serializeConfigSignature();
|
||||
|
||||
// Configuration removed: shut down active client if present
|
||||
if (!signature) {
|
||||
// Wait for any in-flight init to complete before shutdown
|
||||
if (globalState.__llamaMcpInitPromise) {
|
||||
await globalState.__llamaMcpInitPromise;
|
||||
}
|
||||
await shutdownClient();
|
||||
globalState.__llamaMcpInitPromise = undefined;
|
||||
globalState.__llamaMcpInitConfigSignature = undefined;
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Client already initialized with correct config
|
||||
if (globalState.__llamaMcpClient && globalState.__llamaMcpConfigSignature === signature) {
|
||||
return globalState.__llamaMcpClient;
|
||||
}
|
||||
|
||||
// Init in progress with correct config
|
||||
if (
|
||||
globalState.__llamaMcpInitPromise &&
|
||||
globalState.__llamaMcpInitConfigSignature === signature
|
||||
) {
|
||||
return globalState.__llamaMcpInitPromise;
|
||||
}
|
||||
|
||||
// Config changed - wait for in-flight init before shutdown
|
||||
if (
|
||||
globalState.__llamaMcpInitPromise &&
|
||||
globalState.__llamaMcpInitConfigSignature !== signature
|
||||
) {
|
||||
await globalState.__llamaMcpInitPromise;
|
||||
}
|
||||
|
||||
// Shutdown if config changed
|
||||
if (globalState.__llamaMcpConfigSignature !== signature) {
|
||||
await shutdownClient();
|
||||
}
|
||||
|
||||
// Bootstrap new client
|
||||
const mcpConfig = buildMcpClientConfig(config());
|
||||
return bootstrapClient(signature, mcpConfig);
|
||||
}
|
||||
Loading…
Reference in New Issue