Merge remote-tracking branch 'upstream/master' into backend-sampling

This commit is contained in:
Daniel Bevenius 2025-12-02 12:07:01 +01:00
commit 2595818a68
No known key found for this signature in database
14 changed files with 260 additions and 70 deletions

View File

@ -9,6 +9,7 @@ jobs:
update: update:
name: Update Winget Package name: Update Winget Package
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: ${{ github.repository.owner.login == 'ggml-org' }}
steps: steps:
- name: Install cargo binstall - name: Install cargo binstall

View File

@ -2842,6 +2842,10 @@ class Mistral3Model(LlamaModel):
self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"]) self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
# TODO: probably not worth supporting quantized weight, as official BF16 is also available
if name.endswith("weight_scale_inv"):
raise ValueError("This is a quantized weight, please use BF16 weight instead")
name = name.replace("language_model.", "") name = name.replace("language_model.", "")
if "multi_modal_projector" in name or "vision_tower" in name: if "multi_modal_projector" in name or "vision_tower" in name:
return [] return []

View File

@ -8,6 +8,10 @@
#include <sys/sysctl.h> #include <sys/sysctl.h>
#endif #endif
#if !defined(HWCAP2_SVE2)
#define HWCAP2_SVE2 (1 << 1)
#endif
#if !defined(HWCAP2_I8MM) #if !defined(HWCAP2_I8MM)
#define HWCAP2_I8MM (1 << 13) #define HWCAP2_I8MM (1 << 13)
#endif #endif

Binary file not shown.

View File

@ -1263,7 +1263,11 @@ json convert_anthropic_to_oai(const json & body) {
return oai_body; return oai_body;
} }
json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64) { json format_embeddings_response_oaicompat(
const json & request,
const std::string & model_name,
const json & embeddings,
bool use_base64) {
json data = json::array(); json data = json::array();
int32_t n_tokens = 0; int32_t n_tokens = 0;
int i = 0; int i = 0;
@ -1293,7 +1297,7 @@ json format_embeddings_response_oaicompat(const json & request, const json & emb
} }
json res = json { json res = json {
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, {"model", json_value(request, "model", model_name)},
{"object", "list"}, {"object", "list"},
{"usage", json { {"usage", json {
{"prompt_tokens", n_tokens}, {"prompt_tokens", n_tokens},
@ -1307,6 +1311,7 @@ json format_embeddings_response_oaicompat(const json & request, const json & emb
json format_response_rerank( json format_response_rerank(
const json & request, const json & request,
const std::string & model_name,
const json & ranks, const json & ranks,
bool is_tei_format, bool is_tei_format,
std::vector<std::string> & texts, std::vector<std::string> & texts,
@ -1338,7 +1343,7 @@ json format_response_rerank(
if (is_tei_format) return results; if (is_tei_format) return results;
json res = json{ json res = json{
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, {"model", json_value(request, "model", model_name)},
{"object", "list"}, {"object", "list"},
{"usage", json{ {"usage", json{
{"prompt_tokens", n_tokens}, {"prompt_tokens", n_tokens},

View File

@ -13,8 +13,6 @@
#include <vector> #include <vector>
#include <cinttypes> #include <cinttypes>
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo"
const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
@ -298,11 +296,16 @@ json oaicompat_chat_params_parse(
json convert_anthropic_to_oai(const json & body); json convert_anthropic_to_oai(const json & body);
// TODO: move it to server-task.cpp // TODO: move it to server-task.cpp
json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false); json format_embeddings_response_oaicompat(
const json & request,
const std::string & model_name,
const json & embeddings,
bool use_base64 = false);
// TODO: move it to server-task.cpp // TODO: move it to server-task.cpp
json format_response_rerank( json format_response_rerank(
const json & request, const json & request,
const std::string & model_name,
const json & ranks, const json & ranks,
bool is_tei_format, bool is_tei_format,
std::vector<std::string> & texts, std::vector<std::string> & texts,

View File

@ -17,6 +17,7 @@
#include <cinttypes> #include <cinttypes>
#include <memory> #include <memory>
#include <unordered_set> #include <unordered_set>
#include <filesystem>
// fix problem with std::min and std::max // fix problem with std::min and std::max
#if defined(_WIN32) #if defined(_WIN32)
@ -518,6 +519,8 @@ struct server_context_impl {
// Necessary similarity of prompt for slot selection // Necessary similarity of prompt for slot selection
float slot_prompt_similarity = 0.0f; float slot_prompt_similarity = 0.0f;
std::string model_name; // name of the loaded model, to be used by API
common_chat_templates_ptr chat_templates; common_chat_templates_ptr chat_templates;
oaicompat_parser_options oai_parser_opt; oaicompat_parser_options oai_parser_opt;
@ -755,6 +758,18 @@ struct server_context_impl {
} }
SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n"); SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n");
if (!params_base.model_alias.empty()) {
// user explicitly specified model name
model_name = params_base.model_alias;
} else if (!params_base.model.name.empty()) {
// use model name in registry format (for models in cache)
model_name = params_base.model.name;
} else {
// fallback: derive model name from file name
auto model_path = std::filesystem::path(params_base.model.path);
model_name = model_path.filename().string();
}
// thinking is enabled if: // thinking is enabled if:
// 1. It's not explicitly disabled (reasoning_budget == 0) // 1. It's not explicitly disabled (reasoning_budget == 0)
// 2. The chat template supports it // 2. The chat template supports it
@ -2608,7 +2623,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
// OAI-compat // OAI-compat
task.params.res_type = res_type; task.params.res_type = res_type;
task.params.oaicompat_cmpl_id = completion_id; task.params.oaicompat_cmpl_id = completion_id;
// oaicompat_model is already populated by params_from_json_cmpl task.params.oaicompat_model = ctx_server.model_name;
tasks.push_back(std::move(task)); tasks.push_back(std::move(task));
} }
@ -2936,7 +2951,7 @@ void server_routes::init_routes() {
json data = { json data = {
{ "default_generation_settings", default_generation_settings_for_props }, { "default_generation_settings", default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel }, { "total_slots", ctx_server.params_base.n_parallel },
{ "model_alias", ctx_server.params_base.model_alias }, { "model_alias", ctx_server.model_name },
{ "model_path", ctx_server.params_base.model.path }, { "model_path", ctx_server.params_base.model.path },
{ "modalities", json { { "modalities", json {
{"vision", ctx_server.oai_parser_opt.allow_image}, {"vision", ctx_server.oai_parser_opt.allow_image},
@ -3178,8 +3193,8 @@ void server_routes::init_routes() {
json models = { json models = {
{"models", { {"models", {
{ {
{"name", params.model_alias.empty() ? params.model.path : params.model_alias}, {"name", ctx_server.model_name},
{"model", params.model_alias.empty() ? params.model.path : params.model_alias}, {"model", ctx_server.model_name},
{"modified_at", ""}, {"modified_at", ""},
{"size", ""}, {"size", ""},
{"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash
@ -3201,7 +3216,7 @@ void server_routes::init_routes() {
{"object", "list"}, {"object", "list"},
{"data", { {"data", {
{ {
{"id", params.model_alias.empty() ? params.model.path : params.model_alias}, {"id", ctx_server.model_name},
{"object", "model"}, {"object", "model"},
{"created", std::time(0)}, {"created", std::time(0)},
{"owned_by", "llamacpp"}, {"owned_by", "llamacpp"},
@ -3348,6 +3363,7 @@ void server_routes::init_routes() {
// write JSON response // write JSON response
json root = format_response_rerank( json root = format_response_rerank(
body, body,
ctx_server.model_name,
responses, responses,
is_tei_format, is_tei_format,
documents, documents,
@ -3610,7 +3626,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_embeddings_impl(cons
// write JSON response // write JSON response
json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD
? format_embeddings_response_oaicompat(body, responses, use_base64) ? format_embeddings_response_oaicompat(body, ctx_server.model_name, responses, use_base64)
: json(responses); : json(responses);
res->ok(root); res->ok(root);
return res; return res;

View File

@ -24,8 +24,55 @@
#include <unistd.h> #include <unistd.h>
#endif #endif
#if defined(__APPLE__) && defined(__MACH__)
// macOS: use _NSGetExecutablePath to get the executable path
#include <mach-o/dyld.h>
#include <limits.h>
#endif
#define CMD_EXIT "exit" #define CMD_EXIT "exit"
static std::filesystem::path get_server_exec_path() {
#if defined(_WIN32)
wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths
DWORD len = GetModuleFileNameW(nullptr, buf, _countof(buf));
if (len == 0 || len >= _countof(buf)) {
throw std::runtime_error("GetModuleFileNameW failed or path too long");
}
return std::filesystem::path(buf);
#elif defined(__APPLE__) && defined(__MACH__)
char small_path[PATH_MAX];
uint32_t size = sizeof(small_path);
if (_NSGetExecutablePath(small_path, &size) == 0) {
// resolve any symlinks to get absolute path
try {
return std::filesystem::canonical(std::filesystem::path(small_path));
} catch (...) {
return std::filesystem::path(small_path);
}
} else {
// buffer was too small, allocate required size and call again
std::vector<char> buf(size);
if (_NSGetExecutablePath(buf.data(), &size) == 0) {
try {
return std::filesystem::canonical(std::filesystem::path(buf.data()));
} catch (...) {
return std::filesystem::path(buf.data());
}
}
throw std::runtime_error("_NSGetExecutablePath failed after buffer resize");
}
#else
char path[FILENAME_MAX];
ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX);
if (count <= 0) {
throw std::runtime_error("failed to resolve /proc/self/exe");
}
return std::filesystem::path(std::string(path, count));
#endif
}
struct local_model { struct local_model {
std::string name; std::string name;
std::string path; std::string path;
@ -99,6 +146,14 @@ server_models::server_models(
for (char ** env = envp; *env != nullptr; env++) { for (char ** env = envp; *env != nullptr; env++) {
base_env.push_back(std::string(*env)); base_env.push_back(std::string(*env));
} }
GGML_ASSERT(!base_args.empty());
// set binary path
try {
base_args[0] = get_server_exec_path().string();
} catch (const std::exception & e) {
LOG_WRN("failed to get server executable path: %s\n", e.what());
LOG_WRN("using original argv[0] as fallback: %s\n", base_args[0].c_str());
}
// TODO: allow refreshing cached model list // TODO: allow refreshing cached model list
// add cached models // add cached models
auto cached_models = common_list_cached_models(); auto cached_models = common_list_cached_models();
@ -587,26 +642,26 @@ static void res_ok(std::unique_ptr<server_http_res> & res, const json & response
res->data = safe_json_to_str(response_data); res->data = safe_json_to_str(response_data);
} }
static void res_error(std::unique_ptr<server_http_res> & res, const json & error_data) { static void res_err(std::unique_ptr<server_http_res> & res, const json & error_data) {
res->status = json_value(error_data, "code", 500); res->status = json_value(error_data, "code", 500);
res->data = safe_json_to_str({{ "error", error_data }}); res->data = safe_json_to_str({{ "error", error_data }});
} }
static bool router_validate_model(const std::string & name, server_models & models, bool models_autoload, std::unique_ptr<server_http_res> & res) { static bool router_validate_model(const std::string & name, server_models & models, bool models_autoload, std::unique_ptr<server_http_res> & res) {
if (name.empty()) { if (name.empty()) {
res_error(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST)); res_err(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST));
return false; return false;
} }
auto meta = models.get_meta(name); auto meta = models.get_meta(name);
if (!meta.has_value()) { if (!meta.has_value()) {
res_error(res, format_error_response("model not found", ERROR_TYPE_INVALID_REQUEST)); res_err(res, format_error_response("model not found", ERROR_TYPE_INVALID_REQUEST));
return false; return false;
} }
if (models_autoload) { if (models_autoload) {
models.ensure_model_loaded(name); models.ensure_model_loaded(name);
} else { } else {
if (meta->status != SERVER_MODEL_STATUS_LOADED) { if (meta->status != SERVER_MODEL_STATUS_LOADED) {
res_error(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
return false; return false;
} }
} }
@ -706,11 +761,11 @@ void server_models_routes::init_routes() {
std::string name = json_value(body, "model", std::string()); std::string name = json_value(body, "model", std::string());
auto model = models.get_meta(name); auto model = models.get_meta(name);
if (!model.has_value()) { if (!model.has_value()) {
res_error(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND)); res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND));
return res; return res;
} }
if (model->status == SERVER_MODEL_STATUS_LOADED) { if (model->status == SERVER_MODEL_STATUS_LOADED) {
res_error(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST)); res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST));
return res; return res;
} }
models.load(name, false); models.load(name, false);
@ -768,11 +823,11 @@ void server_models_routes::init_routes() {
std::string name = json_value(body, "model", std::string()); std::string name = json_value(body, "model", std::string());
auto model = models.get_meta(name); auto model = models.get_meta(name);
if (!model.has_value()) { if (!model.has_value()) {
res_error(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST)); res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
return res; return res;
} }
if (model->status != SERVER_MODEL_STATUS_LOADED) { if (model->status != SERVER_MODEL_STATUS_LOADED) {
res_error(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
return res; return res;
} }
models.unload(name); models.unload(name);

View File

@ -455,9 +455,6 @@ task_params server_task::params_from_json_cmpl(
} }
} }
std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
params.oaicompat_model = json_value(data, "model", model_name);
return params; return params;
} }

View File

@ -41,7 +41,8 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
assert res.status_code == 200 assert res.status_code == 200
assert "cmpl" in res.body["id"] # make sure the completion id has the expected format assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
assert res.body["system_fingerprint"].startswith("b") assert res.body["system_fingerprint"].startswith("b")
assert res.body["model"] == model if model is not None else server.model_alias # we no longer reflect back the model name, see https://github.com/ggml-org/llama.cpp/pull/17668
# assert res.body["model"] == model if model is not None else server.model_alias
assert res.body["usage"]["prompt_tokens"] == n_prompt assert res.body["usage"]["prompt_tokens"] == n_prompt
assert res.body["usage"]["completion_tokens"] == n_predicted assert res.body["usage"]["completion_tokens"] == n_predicted
choice = res.body["choices"][0] choice = res.body["choices"][0]
@ -59,7 +60,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
) )
def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
global server global server
server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL server.model_alias = "llama-test-model"
server.start() server.start()
res = server.make_stream_request("POST", "/chat/completions", data={ res = server.make_stream_request("POST", "/chat/completions", data={
"max_tokens": max_tokens, "max_tokens": max_tokens,
@ -81,7 +82,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
else: else:
assert "role" not in choice["delta"] assert "role" not in choice["delta"]
assert data["system_fingerprint"].startswith("b") assert data["system_fingerprint"].startswith("b")
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future assert data["model"] == "llama-test-model"
if last_cmpl_id is None: if last_cmpl_id is None:
last_cmpl_id = data["id"] last_cmpl_id = data["id"]
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream

View File

@ -575,6 +575,7 @@
<DialogChatError <DialogChatError
message={activeErrorDialog?.message ?? ''} message={activeErrorDialog?.message ?? ''}
contextInfo={activeErrorDialog?.contextInfo}
onOpenChange={handleErrorDialogOpenChange} onOpenChange={handleErrorDialogOpenChange}
open={Boolean(activeErrorDialog)} open={Boolean(activeErrorDialog)}
type={activeErrorDialog?.type ?? 'server'} type={activeErrorDialog?.type ?? 'server'}

View File

@ -6,10 +6,11 @@
open: boolean; open: boolean;
type: 'timeout' | 'server'; type: 'timeout' | 'server';
message: string; message: string;
contextInfo?: { n_prompt_tokens: number; n_ctx: number };
onOpenChange?: (open: boolean) => void; onOpenChange?: (open: boolean) => void;
} }
let { open = $bindable(), type, message, onOpenChange }: Props = $props(); let { open = $bindable(), type, message, contextInfo, onOpenChange }: Props = $props();
const isTimeout = $derived(type === 'timeout'); const isTimeout = $derived(type === 'timeout');
const title = $derived(isTimeout ? 'TCP Timeout' : 'Server Error'); const title = $derived(isTimeout ? 'TCP Timeout' : 'Server Error');
@ -51,6 +52,15 @@
<div class={`rounded-lg border px-4 py-3 text-sm ${badgeClass}`}> <div class={`rounded-lg border px-4 py-3 text-sm ${badgeClass}`}>
<p class="font-medium">{message}</p> <p class="font-medium">{message}</p>
{#if contextInfo}
<div class="mt-2 space-y-1 text-xs opacity-80">
<p>
<span class="font-medium">Prompt tokens:</span>
{contextInfo.n_prompt_tokens.toLocaleString()}
</p>
<p><span class="font-medium">Context size:</span> {contextInfo.n_ctx.toLocaleString()}</p>
</div>
{/if}
</div> </div>
<AlertDialog.Footer> <AlertDialog.Footer>

View File

@ -767,18 +767,33 @@ export class ChatService {
* @param response - HTTP response object * @param response - HTTP response object
* @returns Promise<Error> - Parsed error with context info if available * @returns Promise<Error> - Parsed error with context info if available
*/ */
private static async parseErrorResponse(response: Response): Promise<Error> { private static async parseErrorResponse(
response: Response
): Promise<Error & { contextInfo?: { n_prompt_tokens: number; n_ctx: number } }> {
try { try {
const errorText = await response.text(); const errorText = await response.text();
const errorData: ApiErrorResponse = JSON.parse(errorText); const errorData: ApiErrorResponse = JSON.parse(errorText);
const message = errorData.error?.message || 'Unknown server error'; const message = errorData.error?.message || 'Unknown server error';
const error = new Error(message); const error = new Error(message) as Error & {
contextInfo?: { n_prompt_tokens: number; n_ctx: number };
};
error.name = response.status === 400 ? 'ServerError' : 'HttpError'; error.name = response.status === 400 ? 'ServerError' : 'HttpError';
if (errorData.error && 'n_prompt_tokens' in errorData.error && 'n_ctx' in errorData.error) {
error.contextInfo = {
n_prompt_tokens: errorData.error.n_prompt_tokens,
n_ctx: errorData.error.n_ctx
};
}
return error; return error;
} catch { } catch {
const fallback = new Error(`Server error (${response.status}): ${response.statusText}`); const fallback = new Error(
`Server error (${response.status}): ${response.statusText}`
) as Error & {
contextInfo?: { n_prompt_tokens: number; n_ctx: number };
};
fallback.name = 'HttpError'; fallback.name = 'HttpError';
return fallback; return fallback;
} }

View File

@ -58,7 +58,11 @@ class ChatStore {
activeProcessingState = $state<ApiProcessingState | null>(null); activeProcessingState = $state<ApiProcessingState | null>(null);
currentResponse = $state(''); currentResponse = $state('');
errorDialogState = $state<{ type: 'timeout' | 'server'; message: string } | null>(null); errorDialogState = $state<{
type: 'timeout' | 'server';
message: string;
contextInfo?: { n_prompt_tokens: number; n_ctx: number };
} | null>(null);
isLoading = $state(false); isLoading = $state(false);
chatLoadingStates = new SvelteMap<string, boolean>(); chatLoadingStates = new SvelteMap<string, boolean>();
chatStreamingStates = new SvelteMap<string, { response: string; messageId: string }>(); chatStreamingStates = new SvelteMap<string, { response: string; messageId: string }>();
@ -335,8 +339,12 @@ class ChatStore {
return error instanceof Error && (error.name === 'AbortError' || error instanceof DOMException); return error instanceof Error && (error.name === 'AbortError' || error instanceof DOMException);
} }
private showErrorDialog(type: 'timeout' | 'server', message: string): void { private showErrorDialog(
this.errorDialogState = { type, message }; type: 'timeout' | 'server',
message: string,
contextInfo?: { n_prompt_tokens: number; n_ctx: number }
): void {
this.errorDialogState = { type, message, contextInfo };
} }
dismissErrorDialog(): void { dismissErrorDialog(): void {
@ -347,6 +355,23 @@ class ChatStore {
// Message Operations // Message Operations
// ───────────────────────────────────────────────────────────────────────────── // ─────────────────────────────────────────────────────────────────────────────
/**
* Finds a message by ID and optionally validates its role.
* Returns message and index, or null if not found or role doesn't match.
*/
private getMessageByIdWithRole(
messageId: string,
expectedRole?: ChatRole
): { message: DatabaseMessage; index: number } | null {
const index = conversationsStore.findMessageIndex(messageId);
if (index === -1) return null;
const message = conversationsStore.activeMessages[index];
if (expectedRole && message.role !== expectedRole) return null;
return { message, index };
}
async addMessage( async addMessage(
role: ChatRole, role: ChatRole,
content: string, content: string,
@ -508,7 +533,6 @@ class ChatStore {
) => { ) => {
this.stopStreaming(); this.stopStreaming();
// Build update data - only include model if not already persisted
const updateData: Record<string, unknown> = { const updateData: Record<string, unknown> = {
content: finalContent || streamedContent, content: finalContent || streamedContent,
thinking: reasoningContent || streamedReasoningContent, thinking: reasoningContent || streamedReasoningContent,
@ -520,7 +544,6 @@ class ChatStore {
} }
await DatabaseService.updateMessage(assistantMessage.id, updateData); await DatabaseService.updateMessage(assistantMessage.id, updateData);
// Update UI state - always include model and timings if available
const idx = conversationsStore.findMessageIndex(assistantMessage.id); const idx = conversationsStore.findMessageIndex(assistantMessage.id);
const uiUpdate: Partial<DatabaseMessage> = { const uiUpdate: Partial<DatabaseMessage> = {
content: updateData.content as string, content: updateData.content as string,
@ -543,22 +566,38 @@ class ChatStore {
}, },
onError: (error: Error) => { onError: (error: Error) => {
this.stopStreaming(); this.stopStreaming();
if (this.isAbortError(error)) { if (this.isAbortError(error)) {
this.setChatLoading(assistantMessage.convId, false); this.setChatLoading(assistantMessage.convId, false);
this.clearChatStreaming(assistantMessage.convId); this.clearChatStreaming(assistantMessage.convId);
this.clearProcessingState(assistantMessage.convId); this.clearProcessingState(assistantMessage.convId);
return; return;
} }
console.error('Streaming error:', error); console.error('Streaming error:', error);
this.setChatLoading(assistantMessage.convId, false); this.setChatLoading(assistantMessage.convId, false);
this.clearChatStreaming(assistantMessage.convId); this.clearChatStreaming(assistantMessage.convId);
this.clearProcessingState(assistantMessage.convId); this.clearProcessingState(assistantMessage.convId);
const idx = conversationsStore.findMessageIndex(assistantMessage.id); const idx = conversationsStore.findMessageIndex(assistantMessage.id);
if (idx !== -1) { if (idx !== -1) {
const failedMessage = conversationsStore.removeMessageAtIndex(idx); const failedMessage = conversationsStore.removeMessageAtIndex(idx);
if (failedMessage) DatabaseService.deleteMessage(failedMessage.id).catch(console.error); if (failedMessage) DatabaseService.deleteMessage(failedMessage.id).catch(console.error);
} }
this.showErrorDialog(error.name === 'TimeoutError' ? 'timeout' : 'server', error.message);
const contextInfo = (
error as Error & { contextInfo?: { n_prompt_tokens: number; n_ctx: number } }
).contextInfo;
this.showErrorDialog(
error.name === 'TimeoutError' ? 'timeout' : 'server',
error.message,
contextInfo
);
if (onError) onError(error); if (onError) onError(error);
} }
}, },
@ -591,7 +630,9 @@ class ChatStore {
await conversationsStore.updateConversationName(currentConv.id, content.trim()); await conversationsStore.updateConversationName(currentConv.id, content.trim());
const assistantMessage = await this.createAssistantMessage(userMessage.id); const assistantMessage = await this.createAssistantMessage(userMessage.id);
if (!assistantMessage) throw new Error('Failed to create assistant message'); if (!assistantMessage) throw new Error('Failed to create assistant message');
conversationsStore.addMessageToActive(assistantMessage); conversationsStore.addMessageToActive(assistantMessage);
await this.streamChatCompletion( await this.streamChatCompletion(
conversationsStore.activeMessages.slice(0, -1), conversationsStore.activeMessages.slice(0, -1),
@ -607,15 +648,26 @@ class ChatStore {
if (!this.errorDialogState) { if (!this.errorDialogState) {
const dialogType = const dialogType =
error instanceof Error && error.name === 'TimeoutError' ? 'timeout' : 'server'; error instanceof Error && error.name === 'TimeoutError' ? 'timeout' : 'server';
this.showErrorDialog(dialogType, error instanceof Error ? error.message : 'Unknown error'); const contextInfo = (
error as Error & { contextInfo?: { n_prompt_tokens: number; n_ctx: number } }
).contextInfo;
this.showErrorDialog(
dialogType,
error instanceof Error ? error.message : 'Unknown error',
contextInfo
);
} }
} }
} }
async stopGeneration(): Promise<void> { async stopGeneration(): Promise<void> {
const activeConv = conversationsStore.activeConversation; const activeConv = conversationsStore.activeConversation;
if (!activeConv) return; if (!activeConv) return;
await this.savePartialResponseIfNeeded(activeConv.id); await this.savePartialResponseIfNeeded(activeConv.id);
this.stopStreaming(); this.stopStreaming();
this.abortRequest(activeConv.id); this.abortRequest(activeConv.id);
this.setChatLoading(activeConv.id, false); this.setChatLoading(activeConv.id, false);
@ -655,17 +707,22 @@ class ChatStore {
private async savePartialResponseIfNeeded(convId?: string): Promise<void> { private async savePartialResponseIfNeeded(convId?: string): Promise<void> {
const conversationId = convId || conversationsStore.activeConversation?.id; const conversationId = convId || conversationsStore.activeConversation?.id;
if (!conversationId) return; if (!conversationId) return;
const streamingState = this.chatStreamingStates.get(conversationId); const streamingState = this.chatStreamingStates.get(conversationId);
if (!streamingState || !streamingState.response.trim()) return; if (!streamingState || !streamingState.response.trim()) return;
const messages = const messages =
conversationId === conversationsStore.activeConversation?.id conversationId === conversationsStore.activeConversation?.id
? conversationsStore.activeMessages ? conversationsStore.activeMessages
: await conversationsStore.getConversationMessages(conversationId); : await conversationsStore.getConversationMessages(conversationId);
if (!messages.length) return; if (!messages.length) return;
const lastMessage = messages[messages.length - 1]; const lastMessage = messages[messages.length - 1];
if (lastMessage?.role === 'assistant') { if (lastMessage?.role === 'assistant') {
try { try {
const updateData: { content: string; thinking?: string; timings?: ChatMessageTimings } = { const updateData: { content: string; thinking?: string; timings?: ChatMessageTimings } = {
@ -684,9 +741,13 @@ class ChatStore {
: undefined : undefined
}; };
} }
await DatabaseService.updateMessage(lastMessage.id, updateData); await DatabaseService.updateMessage(lastMessage.id, updateData);
lastMessage.content = this.currentResponse; lastMessage.content = this.currentResponse;
if (updateData.thinking) lastMessage.thinking = updateData.thinking; if (updateData.thinking) lastMessage.thinking = updateData.thinking;
if (updateData.timings) lastMessage.timings = updateData.timings; if (updateData.timings) lastMessage.timings = updateData.timings;
} catch (error) { } catch (error) {
lastMessage.content = this.currentResponse; lastMessage.content = this.currentResponse;
@ -700,14 +761,12 @@ class ChatStore {
if (!activeConv) return; if (!activeConv) return;
if (this.isLoading) this.stopGeneration(); if (this.isLoading) this.stopGeneration();
const result = this.getMessageByIdWithRole(messageId, 'user');
if (!result) return;
const { message: messageToUpdate, index: messageIndex } = result;
const originalContent = messageToUpdate.content;
try { try {
const messageIndex = conversationsStore.findMessageIndex(messageId);
if (messageIndex === -1) return;
const messageToUpdate = conversationsStore.activeMessages[messageIndex];
const originalContent = messageToUpdate.content;
if (messageToUpdate.role !== 'user') return;
const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
const isFirstUserMessage = rootMessage && messageToUpdate.parent === rootMessage.id; const isFirstUserMessage = rootMessage && messageToUpdate.parent === rootMessage.id;
@ -724,7 +783,9 @@ class ChatStore {
} }
const messagesToRemove = conversationsStore.activeMessages.slice(messageIndex + 1); const messagesToRemove = conversationsStore.activeMessages.slice(messageIndex + 1);
for (const message of messagesToRemove) await DatabaseService.deleteMessage(message.id); for (const message of messagesToRemove) await DatabaseService.deleteMessage(message.id);
conversationsStore.sliceActiveMessages(messageIndex + 1); conversationsStore.sliceActiveMessages(messageIndex + 1);
conversationsStore.updateConversationTimestamp(); conversationsStore.updateConversationTimestamp();
@ -732,8 +793,11 @@ class ChatStore {
this.clearChatStreaming(activeConv.id); this.clearChatStreaming(activeConv.id);
const assistantMessage = await this.createAssistantMessage(); const assistantMessage = await this.createAssistantMessage();
if (!assistantMessage) throw new Error('Failed to create assistant message'); if (!assistantMessage) throw new Error('Failed to create assistant message');
conversationsStore.addMessageToActive(assistantMessage); conversationsStore.addMessageToActive(assistantMessage);
await conversationsStore.updateCurrentNode(assistantMessage.id); await conversationsStore.updateCurrentNode(assistantMessage.id);
await this.streamChatCompletion( await this.streamChatCompletion(
conversationsStore.activeMessages.slice(0, -1), conversationsStore.activeMessages.slice(0, -1),
@ -758,12 +822,11 @@ class ChatStore {
const activeConv = conversationsStore.activeConversation; const activeConv = conversationsStore.activeConversation;
if (!activeConv || this.isLoading) return; if (!activeConv || this.isLoading) return;
try { const result = this.getMessageByIdWithRole(messageId, 'assistant');
const messageIndex = conversationsStore.findMessageIndex(messageId); if (!result) return;
if (messageIndex === -1) return; const { index: messageIndex } = result;
const messageToRegenerate = conversationsStore.activeMessages[messageIndex];
if (messageToRegenerate.role !== 'assistant') return;
try {
const messagesToRemove = conversationsStore.activeMessages.slice(messageIndex); const messagesToRemove = conversationsStore.activeMessages.slice(messageIndex);
for (const message of messagesToRemove) await DatabaseService.deleteMessage(message.id); for (const message of messagesToRemove) await DatabaseService.deleteMessage(message.id);
conversationsStore.sliceActiveMessages(messageIndex); conversationsStore.sliceActiveMessages(messageIndex);
@ -832,6 +895,7 @@ class ChatStore {
const siblings = allMessages.filter( const siblings = allMessages.filter(
(m) => m.parent === messageToDelete.parent && m.id !== messageId (m) => m.parent === messageToDelete.parent && m.id !== messageId
); );
if (siblings.length > 0) { if (siblings.length > 0) {
const latestSibling = siblings.reduce((latest, sibling) => const latestSibling = siblings.reduce((latest, sibling) =>
sibling.timestamp > latest.timestamp ? sibling : latest sibling.timestamp > latest.timestamp ? sibling : latest
@ -845,6 +909,7 @@ class ChatStore {
} }
await DatabaseService.deleteMessageCascading(activeConv.id, messageId); await DatabaseService.deleteMessageCascading(activeConv.id, messageId);
await conversationsStore.refreshActiveMessages(); await conversationsStore.refreshActiveMessages();
conversationsStore.updateConversationTimestamp(); conversationsStore.updateConversationTimestamp();
} catch (error) { } catch (error) {
console.error('Failed to delete message:', error); console.error('Failed to delete message:', error);
@ -862,12 +927,12 @@ class ChatStore {
): Promise<void> { ): Promise<void> {
const activeConv = conversationsStore.activeConversation; const activeConv = conversationsStore.activeConversation;
if (!activeConv || this.isLoading) return; if (!activeConv || this.isLoading) return;
try {
const idx = conversationsStore.findMessageIndex(messageId);
if (idx === -1) return;
const msg = conversationsStore.activeMessages[idx];
if (msg.role !== 'assistant') return;
const result = this.getMessageByIdWithRole(messageId, 'assistant');
if (!result) return;
const { message: msg, index: idx } = result;
try {
if (shouldBranch) { if (shouldBranch) {
const newMessage = await DatabaseService.createMessageBranch( const newMessage = await DatabaseService.createMessageBranch(
{ {
@ -902,12 +967,12 @@ class ChatStore {
async editUserMessagePreserveResponses(messageId: string, newContent: string): Promise<void> { async editUserMessagePreserveResponses(messageId: string, newContent: string): Promise<void> {
const activeConv = conversationsStore.activeConversation; const activeConv = conversationsStore.activeConversation;
if (!activeConv) return; if (!activeConv) return;
try {
const idx = conversationsStore.findMessageIndex(messageId);
if (idx === -1) return;
const msg = conversationsStore.activeMessages[idx];
if (msg.role !== 'user') return;
const result = this.getMessageByIdWithRole(messageId, 'user');
if (!result) return;
const { message: msg, index: idx } = result;
try {
await DatabaseService.updateMessage(messageId, { await DatabaseService.updateMessage(messageId, {
content: newContent, content: newContent,
timestamp: Date.now() timestamp: Date.now()
@ -916,6 +981,7 @@ class ChatStore {
const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
if (rootMessage && msg.parent === rootMessage.id && newContent.trim()) { if (rootMessage && msg.parent === rootMessage.id && newContent.trim()) {
await conversationsStore.updateConversationTitleWithConfirmation( await conversationsStore.updateConversationTitleWithConfirmation(
activeConv.id, activeConv.id,
@ -932,15 +998,16 @@ class ChatStore {
async editMessageWithBranching(messageId: string, newContent: string): Promise<void> { async editMessageWithBranching(messageId: string, newContent: string): Promise<void> {
const activeConv = conversationsStore.activeConversation; const activeConv = conversationsStore.activeConversation;
if (!activeConv || this.isLoading) return; if (!activeConv || this.isLoading) return;
try {
const idx = conversationsStore.findMessageIndex(messageId);
if (idx === -1) return;
const msg = conversationsStore.activeMessages[idx];
if (msg.role !== 'user') return;
const result = this.getMessageByIdWithRole(messageId, 'user');
if (!result) return;
const { message: msg } = result;
try {
const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
const isFirstUserMessage = rootMessage && msg.parent === rootMessage.id; const isFirstUserMessage = rootMessage && msg.parent === rootMessage.id;
const parentId = msg.parent || rootMessage?.id; const parentId = msg.parent || rootMessage?.id;
if (!parentId) return; if (!parentId) return;
@ -1034,7 +1101,9 @@ class ChatStore {
private async generateResponseForMessage(userMessageId: string): Promise<void> { private async generateResponseForMessage(userMessageId: string): Promise<void> {
const activeConv = conversationsStore.activeConversation; const activeConv = conversationsStore.activeConversation;
if (!activeConv) return; if (!activeConv) return;
this.errorDialogState = null; this.errorDialogState = null;
this.setChatLoading(activeConv.id, true); this.setChatLoading(activeConv.id, true);
this.clearChatStreaming(activeConv.id); this.clearChatStreaming(activeConv.id);
@ -1071,26 +1140,30 @@ class ChatStore {
async continueAssistantMessage(messageId: string): Promise<void> { async continueAssistantMessage(messageId: string): Promise<void> {
const activeConv = conversationsStore.activeConversation; const activeConv = conversationsStore.activeConversation;
if (!activeConv || this.isLoading) return; if (!activeConv || this.isLoading) return;
try {
const idx = conversationsStore.findMessageIndex(messageId);
if (idx === -1) return;
const msg = conversationsStore.activeMessages[idx];
if (msg.role !== 'assistant') return;
if (this.isChatLoading(activeConv.id)) return;
const result = this.getMessageByIdWithRole(messageId, 'assistant');
if (!result) return;
const { message: msg, index: idx } = result;
if (this.isChatLoading(activeConv.id)) return;
try {
this.errorDialogState = null; this.errorDialogState = null;
this.setChatLoading(activeConv.id, true); this.setChatLoading(activeConv.id, true);
this.clearChatStreaming(activeConv.id); this.clearChatStreaming(activeConv.id);
const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
const dbMessage = allMessages.find((m) => m.id === messageId); const dbMessage = allMessages.find((m) => m.id === messageId);
if (!dbMessage) { if (!dbMessage) {
this.setChatLoading(activeConv.id, false); this.setChatLoading(activeConv.id, false);
return; return;
} }
const originalContent = dbMessage.content; const originalContent = dbMessage.content;
const originalThinking = dbMessage.thinking || ''; const originalThinking = dbMessage.thinking || '';
const conversationContext = conversationsStore.activeMessages.slice(0, idx); const conversationContext = conversationsStore.activeMessages.slice(0, idx);
const contextWithContinue = [ const contextWithContinue = [
...conversationContext, ...conversationContext,
@ -1107,6 +1180,7 @@ class ChatStore {
contextWithContinue, contextWithContinue,
{ {
...this.getApiOptions(), ...this.getApiOptions(),
onChunk: (chunk: string) => { onChunk: (chunk: string) => {
hasReceivedContent = true; hasReceivedContent = true;
appendedContent += chunk; appendedContent += chunk;
@ -1114,6 +1188,7 @@ class ChatStore {
this.setChatStreaming(msg.convId, fullContent, msg.id); this.setChatStreaming(msg.convId, fullContent, msg.id);
conversationsStore.updateMessageAtIndex(idx, { content: fullContent }); conversationsStore.updateMessageAtIndex(idx, { content: fullContent });
}, },
onReasoningChunk: (reasoningChunk: string) => { onReasoningChunk: (reasoningChunk: string) => {
hasReceivedContent = true; hasReceivedContent = true;
appendedThinking += reasoningChunk; appendedThinking += reasoningChunk;
@ -1121,6 +1196,7 @@ class ChatStore {
thinking: originalThinking + appendedThinking thinking: originalThinking + appendedThinking
}); });
}, },
onTimings: (timings: ChatMessageTimings, promptProgress?: ChatMessagePromptProgress) => { onTimings: (timings: ChatMessageTimings, promptProgress?: ChatMessagePromptProgress) => {
const tokensPerSecond = const tokensPerSecond =
timings?.predicted_ms && timings?.predicted_n timings?.predicted_ms && timings?.predicted_n
@ -1137,6 +1213,7 @@ class ChatStore {
msg.convId msg.convId
); );
}, },
onComplete: async ( onComplete: async (
finalContent?: string, finalContent?: string,
reasoningContent?: string, reasoningContent?: string,
@ -1161,6 +1238,7 @@ class ChatStore {
this.clearChatStreaming(msg.convId); this.clearChatStreaming(msg.convId);
this.clearProcessingState(msg.convId); this.clearProcessingState(msg.convId);
}, },
onError: async (error: Error) => { onError: async (error: Error) => {
if (this.isAbortError(error)) { if (this.isAbortError(error)) {
if (hasReceivedContent && appendedContent) { if (hasReceivedContent && appendedContent) {