diff --git a/.github/workflows/winget.yml b/.github/workflows/winget.yml index 5c28615595..17b55762a9 100644 --- a/.github/workflows/winget.yml +++ b/.github/workflows/winget.yml @@ -9,6 +9,7 @@ jobs: update: name: Update Winget Package runs-on: ubuntu-latest + if: ${{ github.repository.owner.login == 'ggml-org' }} steps: - name: Install cargo binstall diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a54cce887b..8ddb6d04cd 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2842,6 +2842,10 @@ class Mistral3Model(LlamaModel): self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): + # TODO: probably not worth supporting quantized weight, as official BF16 is also available + if name.endswith("weight_scale_inv"): + raise ValueError("This is a quantized weight, please use BF16 weight instead") + name = name.replace("language_model.", "") if "multi_modal_projector" in name or "vision_tower" in name: return [] diff --git a/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp b/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp index 67369147ce..c460c54911 100644 --- a/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +++ b/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp @@ -8,6 +8,10 @@ #include #endif +#if !defined(HWCAP2_SVE2) +#define HWCAP2_SVE2 (1 << 1) +#endif + #if !defined(HWCAP2_I8MM) #define HWCAP2_I8MM (1 << 13) #endif diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 637b6a788e..3050a86445 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 0bbc4e858f..f48ea5b62a 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -1263,7 +1263,11 @@ json convert_anthropic_to_oai(const json & body) { return oai_body; } -json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64) { +json format_embeddings_response_oaicompat( + const json & request, + const std::string & model_name, + const json & embeddings, + bool use_base64) { json data = json::array(); int32_t n_tokens = 0; int i = 0; @@ -1293,7 +1297,7 @@ json format_embeddings_response_oaicompat(const json & request, const json & emb } json res = json { - {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"model", json_value(request, "model", model_name)}, {"object", "list"}, {"usage", json { {"prompt_tokens", n_tokens}, @@ -1307,6 +1311,7 @@ json format_embeddings_response_oaicompat(const json & request, const json & emb json format_response_rerank( const json & request, + const std::string & model_name, const json & ranks, bool is_tei_format, std::vector & texts, @@ -1338,7 +1343,7 @@ json format_response_rerank( if (is_tei_format) return results; json res = json{ - {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"model", json_value(request, "model", model_name)}, {"object", "list"}, {"usage", json{ {"prompt_tokens", n_tokens}, diff --git a/tools/server/server-common.h b/tools/server/server-common.h index ab8aabbad0..51ae9ea8a9 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -13,8 +13,6 @@ #include #include -#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" - const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); using json = nlohmann::ordered_json; @@ -298,11 +296,16 @@ json oaicompat_chat_params_parse( json convert_anthropic_to_oai(const json & body); // TODO: move it to server-task.cpp -json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false); +json format_embeddings_response_oaicompat( + const json & request, + const std::string & model_name, + const json & embeddings, + bool use_base64 = false); // TODO: move it to server-task.cpp json format_response_rerank( const json & request, + const std::string & model_name, const json & ranks, bool is_tei_format, std::vector & texts, diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 85961c078e..0e8a5dee85 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -17,6 +17,7 @@ #include #include #include +#include // fix problem with std::min and std::max #if defined(_WIN32) @@ -518,6 +519,8 @@ struct server_context_impl { // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; + std::string model_name; // name of the loaded model, to be used by API + common_chat_templates_ptr chat_templates; oaicompat_parser_options oai_parser_opt; @@ -755,6 +758,18 @@ struct server_context_impl { } SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n"); + if (!params_base.model_alias.empty()) { + // user explicitly specified model name + model_name = params_base.model_alias; + } else if (!params_base.model.name.empty()) { + // use model name in registry format (for models in cache) + model_name = params_base.model.name; + } else { + // fallback: derive model name from file name + auto model_path = std::filesystem::path(params_base.model.path); + model_name = model_path.filename().string(); + } + // thinking is enabled if: // 1. It's not explicitly disabled (reasoning_budget == 0) // 2. The chat template supports it @@ -2608,7 +2623,7 @@ static std::unique_ptr handle_completions_impl( // OAI-compat task.params.res_type = res_type; task.params.oaicompat_cmpl_id = completion_id; - // oaicompat_model is already populated by params_from_json_cmpl + task.params.oaicompat_model = ctx_server.model_name; tasks.push_back(std::move(task)); } @@ -2936,7 +2951,7 @@ void server_routes::init_routes() { json data = { { "default_generation_settings", default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, - { "model_alias", ctx_server.params_base.model_alias }, + { "model_alias", ctx_server.model_name }, { "model_path", ctx_server.params_base.model.path }, { "modalities", json { {"vision", ctx_server.oai_parser_opt.allow_image}, @@ -3178,8 +3193,8 @@ void server_routes::init_routes() { json models = { {"models", { { - {"name", params.model_alias.empty() ? params.model.path : params.model_alias}, - {"model", params.model_alias.empty() ? params.model.path : params.model_alias}, + {"name", ctx_server.model_name}, + {"model", ctx_server.model_name}, {"modified_at", ""}, {"size", ""}, {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash @@ -3201,7 +3216,7 @@ void server_routes::init_routes() { {"object", "list"}, {"data", { { - {"id", params.model_alias.empty() ? params.model.path : params.model_alias}, + {"id", ctx_server.model_name}, {"object", "model"}, {"created", std::time(0)}, {"owned_by", "llamacpp"}, @@ -3348,6 +3363,7 @@ void server_routes::init_routes() { // write JSON response json root = format_response_rerank( body, + ctx_server.model_name, responses, is_tei_format, documents, @@ -3610,7 +3626,7 @@ std::unique_ptr server_routes::handle_embeddings_impl(cons // write JSON response json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD - ? format_embeddings_response_oaicompat(body, responses, use_base64) + ? format_embeddings_response_oaicompat(body, ctx_server.model_name, responses, use_base64) : json(responses); res->ok(root); return res; diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 0f812ed411..ac7f6b86bf 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -24,8 +24,55 @@ #include #endif +#if defined(__APPLE__) && defined(__MACH__) +// macOS: use _NSGetExecutablePath to get the executable path +#include +#include +#endif + #define CMD_EXIT "exit" +static std::filesystem::path get_server_exec_path() { +#if defined(_WIN32) + wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths + DWORD len = GetModuleFileNameW(nullptr, buf, _countof(buf)); + if (len == 0 || len >= _countof(buf)) { + throw std::runtime_error("GetModuleFileNameW failed or path too long"); + } + return std::filesystem::path(buf); +#elif defined(__APPLE__) && defined(__MACH__) + char small_path[PATH_MAX]; + uint32_t size = sizeof(small_path); + + if (_NSGetExecutablePath(small_path, &size) == 0) { + // resolve any symlinks to get absolute path + try { + return std::filesystem::canonical(std::filesystem::path(small_path)); + } catch (...) { + return std::filesystem::path(small_path); + } + } else { + // buffer was too small, allocate required size and call again + std::vector buf(size); + if (_NSGetExecutablePath(buf.data(), &size) == 0) { + try { + return std::filesystem::canonical(std::filesystem::path(buf.data())); + } catch (...) { + return std::filesystem::path(buf.data()); + } + } + throw std::runtime_error("_NSGetExecutablePath failed after buffer resize"); + } +#else + char path[FILENAME_MAX]; + ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX); + if (count <= 0) { + throw std::runtime_error("failed to resolve /proc/self/exe"); + } + return std::filesystem::path(std::string(path, count)); +#endif +} + struct local_model { std::string name; std::string path; @@ -99,6 +146,14 @@ server_models::server_models( for (char ** env = envp; *env != nullptr; env++) { base_env.push_back(std::string(*env)); } + GGML_ASSERT(!base_args.empty()); + // set binary path + try { + base_args[0] = get_server_exec_path().string(); + } catch (const std::exception & e) { + LOG_WRN("failed to get server executable path: %s\n", e.what()); + LOG_WRN("using original argv[0] as fallback: %s\n", base_args[0].c_str()); + } // TODO: allow refreshing cached model list // add cached models auto cached_models = common_list_cached_models(); @@ -587,26 +642,26 @@ static void res_ok(std::unique_ptr & res, const json & response res->data = safe_json_to_str(response_data); } -static void res_error(std::unique_ptr & res, const json & error_data) { +static void res_err(std::unique_ptr & res, const json & error_data) { res->status = json_value(error_data, "code", 500); res->data = safe_json_to_str({{ "error", error_data }}); } static bool router_validate_model(const std::string & name, server_models & models, bool models_autoload, std::unique_ptr & res) { if (name.empty()) { - res_error(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST)); + res_err(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST)); return false; } auto meta = models.get_meta(name); if (!meta.has_value()) { - res_error(res, format_error_response("model not found", ERROR_TYPE_INVALID_REQUEST)); + res_err(res, format_error_response("model not found", ERROR_TYPE_INVALID_REQUEST)); return false; } if (models_autoload) { models.ensure_model_loaded(name); } else { if (meta->status != SERVER_MODEL_STATUS_LOADED) { - res_error(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); + res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); return false; } } @@ -706,11 +761,11 @@ void server_models_routes::init_routes() { std::string name = json_value(body, "model", std::string()); auto model = models.get_meta(name); if (!model.has_value()) { - res_error(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND)); + res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND)); return res; } if (model->status == SERVER_MODEL_STATUS_LOADED) { - res_error(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST)); + res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST)); return res; } models.load(name, false); @@ -768,11 +823,11 @@ void server_models_routes::init_routes() { std::string name = json_value(body, "model", std::string()); auto model = models.get_meta(name); if (!model.has_value()) { - res_error(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST)); + res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST)); return res; } if (model->status != SERVER_MODEL_STATUS_LOADED) { - res_error(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); + res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); return res; } models.unload(name); diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index a0210c42ef..b6a5d8b799 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -455,9 +455,6 @@ task_params server_task::params_from_json_cmpl( } } - std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; - params.oaicompat_model = json_value(data, "model", model_name); - return params; } diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 392e0efecd..093cec9155 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -41,7 +41,8 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte assert res.status_code == 200 assert "cmpl" in res.body["id"] # make sure the completion id has the expected format assert res.body["system_fingerprint"].startswith("b") - assert res.body["model"] == model if model is not None else server.model_alias + # we no longer reflect back the model name, see https://github.com/ggml-org/llama.cpp/pull/17668 + # assert res.body["model"] == model if model is not None else server.model_alias assert res.body["usage"]["prompt_tokens"] == n_prompt assert res.body["usage"]["completion_tokens"] == n_predicted choice = res.body["choices"][0] @@ -59,7 +60,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte ) def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): global server - server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL + server.model_alias = "llama-test-model" server.start() res = server.make_stream_request("POST", "/chat/completions", data={ "max_tokens": max_tokens, @@ -81,7 +82,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte else: assert "role" not in choice["delta"] assert data["system_fingerprint"].startswith("b") - assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future + assert data["model"] == "llama-test-model" if last_cmpl_id is None: last_cmpl_id = data["id"] assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream diff --git a/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte b/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte index 2d2f515317..57a2edac58 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte @@ -575,6 +575,7 @@ void; } - let { open = $bindable(), type, message, onOpenChange }: Props = $props(); + let { open = $bindable(), type, message, contextInfo, onOpenChange }: Props = $props(); const isTimeout = $derived(type === 'timeout'); const title = $derived(isTimeout ? 'TCP Timeout' : 'Server Error'); @@ -51,6 +52,15 @@

{message}

+ {#if contextInfo} +
+

+ Prompt tokens: + {contextInfo.n_prompt_tokens.toLocaleString()} +

+

Context size: {contextInfo.n_ctx.toLocaleString()}

+
+ {/if}
diff --git a/tools/server/webui/src/lib/services/chat.ts b/tools/server/webui/src/lib/services/chat.ts index 1525885db8..f47ac14a3c 100644 --- a/tools/server/webui/src/lib/services/chat.ts +++ b/tools/server/webui/src/lib/services/chat.ts @@ -767,18 +767,33 @@ export class ChatService { * @param response - HTTP response object * @returns Promise - Parsed error with context info if available */ - private static async parseErrorResponse(response: Response): Promise { + private static async parseErrorResponse( + response: Response + ): Promise { try { const errorText = await response.text(); const errorData: ApiErrorResponse = JSON.parse(errorText); const message = errorData.error?.message || 'Unknown server error'; - const error = new Error(message); + const error = new Error(message) as Error & { + contextInfo?: { n_prompt_tokens: number; n_ctx: number }; + }; error.name = response.status === 400 ? 'ServerError' : 'HttpError'; + if (errorData.error && 'n_prompt_tokens' in errorData.error && 'n_ctx' in errorData.error) { + error.contextInfo = { + n_prompt_tokens: errorData.error.n_prompt_tokens, + n_ctx: errorData.error.n_ctx + }; + } + return error; } catch { - const fallback = new Error(`Server error (${response.status}): ${response.statusText}`); + const fallback = new Error( + `Server error (${response.status}): ${response.statusText}` + ) as Error & { + contextInfo?: { n_prompt_tokens: number; n_ctx: number }; + }; fallback.name = 'HttpError'; return fallback; } diff --git a/tools/server/webui/src/lib/stores/chat.svelte.ts b/tools/server/webui/src/lib/stores/chat.svelte.ts index 0c17b06bc1..f21e291163 100644 --- a/tools/server/webui/src/lib/stores/chat.svelte.ts +++ b/tools/server/webui/src/lib/stores/chat.svelte.ts @@ -58,7 +58,11 @@ class ChatStore { activeProcessingState = $state(null); currentResponse = $state(''); - errorDialogState = $state<{ type: 'timeout' | 'server'; message: string } | null>(null); + errorDialogState = $state<{ + type: 'timeout' | 'server'; + message: string; + contextInfo?: { n_prompt_tokens: number; n_ctx: number }; + } | null>(null); isLoading = $state(false); chatLoadingStates = new SvelteMap(); chatStreamingStates = new SvelteMap(); @@ -335,8 +339,12 @@ class ChatStore { return error instanceof Error && (error.name === 'AbortError' || error instanceof DOMException); } - private showErrorDialog(type: 'timeout' | 'server', message: string): void { - this.errorDialogState = { type, message }; + private showErrorDialog( + type: 'timeout' | 'server', + message: string, + contextInfo?: { n_prompt_tokens: number; n_ctx: number } + ): void { + this.errorDialogState = { type, message, contextInfo }; } dismissErrorDialog(): void { @@ -347,6 +355,23 @@ class ChatStore { // Message Operations // ───────────────────────────────────────────────────────────────────────────── + /** + * Finds a message by ID and optionally validates its role. + * Returns message and index, or null if not found or role doesn't match. + */ + private getMessageByIdWithRole( + messageId: string, + expectedRole?: ChatRole + ): { message: DatabaseMessage; index: number } | null { + const index = conversationsStore.findMessageIndex(messageId); + if (index === -1) return null; + + const message = conversationsStore.activeMessages[index]; + if (expectedRole && message.role !== expectedRole) return null; + + return { message, index }; + } + async addMessage( role: ChatRole, content: string, @@ -508,7 +533,6 @@ class ChatStore { ) => { this.stopStreaming(); - // Build update data - only include model if not already persisted const updateData: Record = { content: finalContent || streamedContent, thinking: reasoningContent || streamedReasoningContent, @@ -520,7 +544,6 @@ class ChatStore { } await DatabaseService.updateMessage(assistantMessage.id, updateData); - // Update UI state - always include model and timings if available const idx = conversationsStore.findMessageIndex(assistantMessage.id); const uiUpdate: Partial = { content: updateData.content as string, @@ -543,22 +566,38 @@ class ChatStore { }, onError: (error: Error) => { this.stopStreaming(); + if (this.isAbortError(error)) { this.setChatLoading(assistantMessage.convId, false); this.clearChatStreaming(assistantMessage.convId); this.clearProcessingState(assistantMessage.convId); + return; } + console.error('Streaming error:', error); + this.setChatLoading(assistantMessage.convId, false); this.clearChatStreaming(assistantMessage.convId); this.clearProcessingState(assistantMessage.convId); + const idx = conversationsStore.findMessageIndex(assistantMessage.id); + if (idx !== -1) { const failedMessage = conversationsStore.removeMessageAtIndex(idx); if (failedMessage) DatabaseService.deleteMessage(failedMessage.id).catch(console.error); } - this.showErrorDialog(error.name === 'TimeoutError' ? 'timeout' : 'server', error.message); + + const contextInfo = ( + error as Error & { contextInfo?: { n_prompt_tokens: number; n_ctx: number } } + ).contextInfo; + + this.showErrorDialog( + error.name === 'TimeoutError' ? 'timeout' : 'server', + error.message, + contextInfo + ); + if (onError) onError(error); } }, @@ -591,7 +630,9 @@ class ChatStore { await conversationsStore.updateConversationName(currentConv.id, content.trim()); const assistantMessage = await this.createAssistantMessage(userMessage.id); + if (!assistantMessage) throw new Error('Failed to create assistant message'); + conversationsStore.addMessageToActive(assistantMessage); await this.streamChatCompletion( conversationsStore.activeMessages.slice(0, -1), @@ -607,15 +648,26 @@ class ChatStore { if (!this.errorDialogState) { const dialogType = error instanceof Error && error.name === 'TimeoutError' ? 'timeout' : 'server'; - this.showErrorDialog(dialogType, error instanceof Error ? error.message : 'Unknown error'); + const contextInfo = ( + error as Error & { contextInfo?: { n_prompt_tokens: number; n_ctx: number } } + ).contextInfo; + + this.showErrorDialog( + dialogType, + error instanceof Error ? error.message : 'Unknown error', + contextInfo + ); } } } async stopGeneration(): Promise { const activeConv = conversationsStore.activeConversation; + if (!activeConv) return; + await this.savePartialResponseIfNeeded(activeConv.id); + this.stopStreaming(); this.abortRequest(activeConv.id); this.setChatLoading(activeConv.id, false); @@ -655,17 +707,22 @@ class ChatStore { private async savePartialResponseIfNeeded(convId?: string): Promise { const conversationId = convId || conversationsStore.activeConversation?.id; + if (!conversationId) return; + const streamingState = this.chatStreamingStates.get(conversationId); + if (!streamingState || !streamingState.response.trim()) return; const messages = conversationId === conversationsStore.activeConversation?.id ? conversationsStore.activeMessages : await conversationsStore.getConversationMessages(conversationId); + if (!messages.length) return; const lastMessage = messages[messages.length - 1]; + if (lastMessage?.role === 'assistant') { try { const updateData: { content: string; thinking?: string; timings?: ChatMessageTimings } = { @@ -684,9 +741,13 @@ class ChatStore { : undefined }; } + await DatabaseService.updateMessage(lastMessage.id, updateData); + lastMessage.content = this.currentResponse; + if (updateData.thinking) lastMessage.thinking = updateData.thinking; + if (updateData.timings) lastMessage.timings = updateData.timings; } catch (error) { lastMessage.content = this.currentResponse; @@ -700,14 +761,12 @@ class ChatStore { if (!activeConv) return; if (this.isLoading) this.stopGeneration(); + const result = this.getMessageByIdWithRole(messageId, 'user'); + if (!result) return; + const { message: messageToUpdate, index: messageIndex } = result; + const originalContent = messageToUpdate.content; + try { - const messageIndex = conversationsStore.findMessageIndex(messageId); - if (messageIndex === -1) return; - - const messageToUpdate = conversationsStore.activeMessages[messageIndex]; - const originalContent = messageToUpdate.content; - if (messageToUpdate.role !== 'user') return; - const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); const isFirstUserMessage = rootMessage && messageToUpdate.parent === rootMessage.id; @@ -724,7 +783,9 @@ class ChatStore { } const messagesToRemove = conversationsStore.activeMessages.slice(messageIndex + 1); + for (const message of messagesToRemove) await DatabaseService.deleteMessage(message.id); + conversationsStore.sliceActiveMessages(messageIndex + 1); conversationsStore.updateConversationTimestamp(); @@ -732,8 +793,11 @@ class ChatStore { this.clearChatStreaming(activeConv.id); const assistantMessage = await this.createAssistantMessage(); + if (!assistantMessage) throw new Error('Failed to create assistant message'); + conversationsStore.addMessageToActive(assistantMessage); + await conversationsStore.updateCurrentNode(assistantMessage.id); await this.streamChatCompletion( conversationsStore.activeMessages.slice(0, -1), @@ -758,12 +822,11 @@ class ChatStore { const activeConv = conversationsStore.activeConversation; if (!activeConv || this.isLoading) return; - try { - const messageIndex = conversationsStore.findMessageIndex(messageId); - if (messageIndex === -1) return; - const messageToRegenerate = conversationsStore.activeMessages[messageIndex]; - if (messageToRegenerate.role !== 'assistant') return; + const result = this.getMessageByIdWithRole(messageId, 'assistant'); + if (!result) return; + const { index: messageIndex } = result; + try { const messagesToRemove = conversationsStore.activeMessages.slice(messageIndex); for (const message of messagesToRemove) await DatabaseService.deleteMessage(message.id); conversationsStore.sliceActiveMessages(messageIndex); @@ -832,6 +895,7 @@ class ChatStore { const siblings = allMessages.filter( (m) => m.parent === messageToDelete.parent && m.id !== messageId ); + if (siblings.length > 0) { const latestSibling = siblings.reduce((latest, sibling) => sibling.timestamp > latest.timestamp ? sibling : latest @@ -845,6 +909,7 @@ class ChatStore { } await DatabaseService.deleteMessageCascading(activeConv.id, messageId); await conversationsStore.refreshActiveMessages(); + conversationsStore.updateConversationTimestamp(); } catch (error) { console.error('Failed to delete message:', error); @@ -862,12 +927,12 @@ class ChatStore { ): Promise { const activeConv = conversationsStore.activeConversation; if (!activeConv || this.isLoading) return; - try { - const idx = conversationsStore.findMessageIndex(messageId); - if (idx === -1) return; - const msg = conversationsStore.activeMessages[idx]; - if (msg.role !== 'assistant') return; + const result = this.getMessageByIdWithRole(messageId, 'assistant'); + if (!result) return; + const { message: msg, index: idx } = result; + + try { if (shouldBranch) { const newMessage = await DatabaseService.createMessageBranch( { @@ -902,12 +967,12 @@ class ChatStore { async editUserMessagePreserveResponses(messageId: string, newContent: string): Promise { const activeConv = conversationsStore.activeConversation; if (!activeConv) return; - try { - const idx = conversationsStore.findMessageIndex(messageId); - if (idx === -1) return; - const msg = conversationsStore.activeMessages[idx]; - if (msg.role !== 'user') return; + const result = this.getMessageByIdWithRole(messageId, 'user'); + if (!result) return; + const { message: msg, index: idx } = result; + + try { await DatabaseService.updateMessage(messageId, { content: newContent, timestamp: Date.now() @@ -916,6 +981,7 @@ class ChatStore { const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); + if (rootMessage && msg.parent === rootMessage.id && newContent.trim()) { await conversationsStore.updateConversationTitleWithConfirmation( activeConv.id, @@ -932,15 +998,16 @@ class ChatStore { async editMessageWithBranching(messageId: string, newContent: string): Promise { const activeConv = conversationsStore.activeConversation; if (!activeConv || this.isLoading) return; - try { - const idx = conversationsStore.findMessageIndex(messageId); - if (idx === -1) return; - const msg = conversationsStore.activeMessages[idx]; - if (msg.role !== 'user') return; + const result = this.getMessageByIdWithRole(messageId, 'user'); + if (!result) return; + const { message: msg } = result; + + try { const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); const isFirstUserMessage = rootMessage && msg.parent === rootMessage.id; + const parentId = msg.parent || rootMessage?.id; if (!parentId) return; @@ -1034,7 +1101,9 @@ class ChatStore { private async generateResponseForMessage(userMessageId: string): Promise { const activeConv = conversationsStore.activeConversation; + if (!activeConv) return; + this.errorDialogState = null; this.setChatLoading(activeConv.id, true); this.clearChatStreaming(activeConv.id); @@ -1071,26 +1140,30 @@ class ChatStore { async continueAssistantMessage(messageId: string): Promise { const activeConv = conversationsStore.activeConversation; if (!activeConv || this.isLoading) return; - try { - const idx = conversationsStore.findMessageIndex(messageId); - if (idx === -1) return; - const msg = conversationsStore.activeMessages[idx]; - if (msg.role !== 'assistant') return; - if (this.isChatLoading(activeConv.id)) return; + const result = this.getMessageByIdWithRole(messageId, 'assistant'); + if (!result) return; + const { message: msg, index: idx } = result; + + if (this.isChatLoading(activeConv.id)) return; + + try { this.errorDialogState = null; this.setChatLoading(activeConv.id, true); this.clearChatStreaming(activeConv.id); const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const dbMessage = allMessages.find((m) => m.id === messageId); + if (!dbMessage) { this.setChatLoading(activeConv.id, false); + return; } const originalContent = dbMessage.content; const originalThinking = dbMessage.thinking || ''; + const conversationContext = conversationsStore.activeMessages.slice(0, idx); const contextWithContinue = [ ...conversationContext, @@ -1107,6 +1180,7 @@ class ChatStore { contextWithContinue, { ...this.getApiOptions(), + onChunk: (chunk: string) => { hasReceivedContent = true; appendedContent += chunk; @@ -1114,6 +1188,7 @@ class ChatStore { this.setChatStreaming(msg.convId, fullContent, msg.id); conversationsStore.updateMessageAtIndex(idx, { content: fullContent }); }, + onReasoningChunk: (reasoningChunk: string) => { hasReceivedContent = true; appendedThinking += reasoningChunk; @@ -1121,6 +1196,7 @@ class ChatStore { thinking: originalThinking + appendedThinking }); }, + onTimings: (timings: ChatMessageTimings, promptProgress?: ChatMessagePromptProgress) => { const tokensPerSecond = timings?.predicted_ms && timings?.predicted_n @@ -1137,6 +1213,7 @@ class ChatStore { msg.convId ); }, + onComplete: async ( finalContent?: string, reasoningContent?: string, @@ -1161,6 +1238,7 @@ class ChatStore { this.clearChatStreaming(msg.convId); this.clearProcessingState(msg.convId); }, + onError: async (error: Error) => { if (this.isAbortError(error)) { if (hasReceivedContent && appendedContent) {