Merge remote-tracking branch 'upstream/master' into backend-sampling
This commit is contained in:
commit
2595818a68
|
|
@ -9,6 +9,7 @@ jobs:
|
|||
update:
|
||||
name: Update Winget Package
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.repository.owner.login == 'ggml-org' }}
|
||||
|
||||
steps:
|
||||
- name: Install cargo binstall
|
||||
|
|
|
|||
|
|
@ -2842,6 +2842,10 @@ class Mistral3Model(LlamaModel):
|
|||
self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
# TODO: probably not worth supporting quantized weight, as official BF16 is also available
|
||||
if name.endswith("weight_scale_inv"):
|
||||
raise ValueError("This is a quantized weight, please use BF16 weight instead")
|
||||
|
||||
name = name.replace("language_model.", "")
|
||||
if "multi_modal_projector" in name or "vision_tower" in name:
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -8,6 +8,10 @@
|
|||
#include <sys/sysctl.h>
|
||||
#endif
|
||||
|
||||
#if !defined(HWCAP2_SVE2)
|
||||
#define HWCAP2_SVE2 (1 << 1)
|
||||
#endif
|
||||
|
||||
#if !defined(HWCAP2_I8MM)
|
||||
#define HWCAP2_I8MM (1 << 13)
|
||||
#endif
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -1263,7 +1263,11 @@ json convert_anthropic_to_oai(const json & body) {
|
|||
return oai_body;
|
||||
}
|
||||
|
||||
json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64) {
|
||||
json format_embeddings_response_oaicompat(
|
||||
const json & request,
|
||||
const std::string & model_name,
|
||||
const json & embeddings,
|
||||
bool use_base64) {
|
||||
json data = json::array();
|
||||
int32_t n_tokens = 0;
|
||||
int i = 0;
|
||||
|
|
@ -1293,7 +1297,7 @@ json format_embeddings_response_oaicompat(const json & request, const json & emb
|
|||
}
|
||||
|
||||
json res = json {
|
||||
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||
{"model", json_value(request, "model", model_name)},
|
||||
{"object", "list"},
|
||||
{"usage", json {
|
||||
{"prompt_tokens", n_tokens},
|
||||
|
|
@ -1307,6 +1311,7 @@ json format_embeddings_response_oaicompat(const json & request, const json & emb
|
|||
|
||||
json format_response_rerank(
|
||||
const json & request,
|
||||
const std::string & model_name,
|
||||
const json & ranks,
|
||||
bool is_tei_format,
|
||||
std::vector<std::string> & texts,
|
||||
|
|
@ -1338,7 +1343,7 @@ json format_response_rerank(
|
|||
if (is_tei_format) return results;
|
||||
|
||||
json res = json{
|
||||
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||
{"model", json_value(request, "model", model_name)},
|
||||
{"object", "list"},
|
||||
{"usage", json{
|
||||
{"prompt_tokens", n_tokens},
|
||||
|
|
|
|||
|
|
@ -13,8 +13,6 @@
|
|||
#include <vector>
|
||||
#include <cinttypes>
|
||||
|
||||
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo"
|
||||
|
||||
const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
|
@ -298,11 +296,16 @@ json oaicompat_chat_params_parse(
|
|||
json convert_anthropic_to_oai(const json & body);
|
||||
|
||||
// TODO: move it to server-task.cpp
|
||||
json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false);
|
||||
json format_embeddings_response_oaicompat(
|
||||
const json & request,
|
||||
const std::string & model_name,
|
||||
const json & embeddings,
|
||||
bool use_base64 = false);
|
||||
|
||||
// TODO: move it to server-task.cpp
|
||||
json format_response_rerank(
|
||||
const json & request,
|
||||
const std::string & model_name,
|
||||
const json & ranks,
|
||||
bool is_tei_format,
|
||||
std::vector<std::string> & texts,
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
#include <cinttypes>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
#include <filesystem>
|
||||
|
||||
// fix problem with std::min and std::max
|
||||
#if defined(_WIN32)
|
||||
|
|
@ -518,6 +519,8 @@ struct server_context_impl {
|
|||
// Necessary similarity of prompt for slot selection
|
||||
float slot_prompt_similarity = 0.0f;
|
||||
|
||||
std::string model_name; // name of the loaded model, to be used by API
|
||||
|
||||
common_chat_templates_ptr chat_templates;
|
||||
oaicompat_parser_options oai_parser_opt;
|
||||
|
||||
|
|
@ -755,6 +758,18 @@ struct server_context_impl {
|
|||
}
|
||||
SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n");
|
||||
|
||||
if (!params_base.model_alias.empty()) {
|
||||
// user explicitly specified model name
|
||||
model_name = params_base.model_alias;
|
||||
} else if (!params_base.model.name.empty()) {
|
||||
// use model name in registry format (for models in cache)
|
||||
model_name = params_base.model.name;
|
||||
} else {
|
||||
// fallback: derive model name from file name
|
||||
auto model_path = std::filesystem::path(params_base.model.path);
|
||||
model_name = model_path.filename().string();
|
||||
}
|
||||
|
||||
// thinking is enabled if:
|
||||
// 1. It's not explicitly disabled (reasoning_budget == 0)
|
||||
// 2. The chat template supports it
|
||||
|
|
@ -2608,7 +2623,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|||
// OAI-compat
|
||||
task.params.res_type = res_type;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
// oaicompat_model is already populated by params_from_json_cmpl
|
||||
task.params.oaicompat_model = ctx_server.model_name;
|
||||
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
|
@ -2936,7 +2951,7 @@ void server_routes::init_routes() {
|
|||
json data = {
|
||||
{ "default_generation_settings", default_generation_settings_for_props },
|
||||
{ "total_slots", ctx_server.params_base.n_parallel },
|
||||
{ "model_alias", ctx_server.params_base.model_alias },
|
||||
{ "model_alias", ctx_server.model_name },
|
||||
{ "model_path", ctx_server.params_base.model.path },
|
||||
{ "modalities", json {
|
||||
{"vision", ctx_server.oai_parser_opt.allow_image},
|
||||
|
|
@ -3178,8 +3193,8 @@ void server_routes::init_routes() {
|
|||
json models = {
|
||||
{"models", {
|
||||
{
|
||||
{"name", params.model_alias.empty() ? params.model.path : params.model_alias},
|
||||
{"model", params.model_alias.empty() ? params.model.path : params.model_alias},
|
||||
{"name", ctx_server.model_name},
|
||||
{"model", ctx_server.model_name},
|
||||
{"modified_at", ""},
|
||||
{"size", ""},
|
||||
{"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash
|
||||
|
|
@ -3201,7 +3216,7 @@ void server_routes::init_routes() {
|
|||
{"object", "list"},
|
||||
{"data", {
|
||||
{
|
||||
{"id", params.model_alias.empty() ? params.model.path : params.model_alias},
|
||||
{"id", ctx_server.model_name},
|
||||
{"object", "model"},
|
||||
{"created", std::time(0)},
|
||||
{"owned_by", "llamacpp"},
|
||||
|
|
@ -3348,6 +3363,7 @@ void server_routes::init_routes() {
|
|||
// write JSON response
|
||||
json root = format_response_rerank(
|
||||
body,
|
||||
ctx_server.model_name,
|
||||
responses,
|
||||
is_tei_format,
|
||||
documents,
|
||||
|
|
@ -3610,7 +3626,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_embeddings_impl(cons
|
|||
|
||||
// write JSON response
|
||||
json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD
|
||||
? format_embeddings_response_oaicompat(body, responses, use_base64)
|
||||
? format_embeddings_response_oaicompat(body, ctx_server.model_name, responses, use_base64)
|
||||
: json(responses);
|
||||
res->ok(root);
|
||||
return res;
|
||||
|
|
|
|||
|
|
@ -24,8 +24,55 @@
|
|||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#if defined(__APPLE__) && defined(__MACH__)
|
||||
// macOS: use _NSGetExecutablePath to get the executable path
|
||||
#include <mach-o/dyld.h>
|
||||
#include <limits.h>
|
||||
#endif
|
||||
|
||||
#define CMD_EXIT "exit"
|
||||
|
||||
static std::filesystem::path get_server_exec_path() {
|
||||
#if defined(_WIN32)
|
||||
wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths
|
||||
DWORD len = GetModuleFileNameW(nullptr, buf, _countof(buf));
|
||||
if (len == 0 || len >= _countof(buf)) {
|
||||
throw std::runtime_error("GetModuleFileNameW failed or path too long");
|
||||
}
|
||||
return std::filesystem::path(buf);
|
||||
#elif defined(__APPLE__) && defined(__MACH__)
|
||||
char small_path[PATH_MAX];
|
||||
uint32_t size = sizeof(small_path);
|
||||
|
||||
if (_NSGetExecutablePath(small_path, &size) == 0) {
|
||||
// resolve any symlinks to get absolute path
|
||||
try {
|
||||
return std::filesystem::canonical(std::filesystem::path(small_path));
|
||||
} catch (...) {
|
||||
return std::filesystem::path(small_path);
|
||||
}
|
||||
} else {
|
||||
// buffer was too small, allocate required size and call again
|
||||
std::vector<char> buf(size);
|
||||
if (_NSGetExecutablePath(buf.data(), &size) == 0) {
|
||||
try {
|
||||
return std::filesystem::canonical(std::filesystem::path(buf.data()));
|
||||
} catch (...) {
|
||||
return std::filesystem::path(buf.data());
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("_NSGetExecutablePath failed after buffer resize");
|
||||
}
|
||||
#else
|
||||
char path[FILENAME_MAX];
|
||||
ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX);
|
||||
if (count <= 0) {
|
||||
throw std::runtime_error("failed to resolve /proc/self/exe");
|
||||
}
|
||||
return std::filesystem::path(std::string(path, count));
|
||||
#endif
|
||||
}
|
||||
|
||||
struct local_model {
|
||||
std::string name;
|
||||
std::string path;
|
||||
|
|
@ -99,6 +146,14 @@ server_models::server_models(
|
|||
for (char ** env = envp; *env != nullptr; env++) {
|
||||
base_env.push_back(std::string(*env));
|
||||
}
|
||||
GGML_ASSERT(!base_args.empty());
|
||||
// set binary path
|
||||
try {
|
||||
base_args[0] = get_server_exec_path().string();
|
||||
} catch (const std::exception & e) {
|
||||
LOG_WRN("failed to get server executable path: %s\n", e.what());
|
||||
LOG_WRN("using original argv[0] as fallback: %s\n", base_args[0].c_str());
|
||||
}
|
||||
// TODO: allow refreshing cached model list
|
||||
// add cached models
|
||||
auto cached_models = common_list_cached_models();
|
||||
|
|
@ -587,26 +642,26 @@ static void res_ok(std::unique_ptr<server_http_res> & res, const json & response
|
|||
res->data = safe_json_to_str(response_data);
|
||||
}
|
||||
|
||||
static void res_error(std::unique_ptr<server_http_res> & res, const json & error_data) {
|
||||
static void res_err(std::unique_ptr<server_http_res> & res, const json & error_data) {
|
||||
res->status = json_value(error_data, "code", 500);
|
||||
res->data = safe_json_to_str({{ "error", error_data }});
|
||||
}
|
||||
|
||||
static bool router_validate_model(const std::string & name, server_models & models, bool models_autoload, std::unique_ptr<server_http_res> & res) {
|
||||
if (name.empty()) {
|
||||
res_error(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST));
|
||||
return false;
|
||||
}
|
||||
auto meta = models.get_meta(name);
|
||||
if (!meta.has_value()) {
|
||||
res_error(res, format_error_response("model not found", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("model not found", ERROR_TYPE_INVALID_REQUEST));
|
||||
return false;
|
||||
}
|
||||
if (models_autoload) {
|
||||
models.ensure_model_loaded(name);
|
||||
} else {
|
||||
if (meta->status != SERVER_MODEL_STATUS_LOADED) {
|
||||
res_error(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -706,11 +761,11 @@ void server_models_routes::init_routes() {
|
|||
std::string name = json_value(body, "model", std::string());
|
||||
auto model = models.get_meta(name);
|
||||
if (!model.has_value()) {
|
||||
res_error(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND));
|
||||
res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND));
|
||||
return res;
|
||||
}
|
||||
if (model->status == SERVER_MODEL_STATUS_LOADED) {
|
||||
res_error(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
models.load(name, false);
|
||||
|
|
@ -768,11 +823,11 @@ void server_models_routes::init_routes() {
|
|||
std::string name = json_value(body, "model", std::string());
|
||||
auto model = models.get_meta(name);
|
||||
if (!model.has_value()) {
|
||||
res_error(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
if (model->status != SERVER_MODEL_STATUS_LOADED) {
|
||||
res_error(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
models.unload(name);
|
||||
|
|
|
|||
|
|
@ -455,9 +455,6 @@ task_params server_task::params_from_json_cmpl(
|
|||
}
|
||||
}
|
||||
|
||||
std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
|
||||
params.oaicompat_model = json_value(data, "model", model_name);
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -41,7 +41,8 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
|
|||
assert res.status_code == 200
|
||||
assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
|
||||
assert res.body["system_fingerprint"].startswith("b")
|
||||
assert res.body["model"] == model if model is not None else server.model_alias
|
||||
# we no longer reflect back the model name, see https://github.com/ggml-org/llama.cpp/pull/17668
|
||||
# assert res.body["model"] == model if model is not None else server.model_alias
|
||||
assert res.body["usage"]["prompt_tokens"] == n_prompt
|
||||
assert res.body["usage"]["completion_tokens"] == n_predicted
|
||||
choice = res.body["choices"][0]
|
||||
|
|
@ -59,7 +60,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
|
|||
)
|
||||
def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
|
||||
global server
|
||||
server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL
|
||||
server.model_alias = "llama-test-model"
|
||||
server.start()
|
||||
res = server.make_stream_request("POST", "/chat/completions", data={
|
||||
"max_tokens": max_tokens,
|
||||
|
|
@ -81,7 +82,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
|
|||
else:
|
||||
assert "role" not in choice["delta"]
|
||||
assert data["system_fingerprint"].startswith("b")
|
||||
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
|
||||
assert data["model"] == "llama-test-model"
|
||||
if last_cmpl_id is None:
|
||||
last_cmpl_id = data["id"]
|
||||
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
|
||||
|
|
|
|||
|
|
@ -575,6 +575,7 @@
|
|||
|
||||
<DialogChatError
|
||||
message={activeErrorDialog?.message ?? ''}
|
||||
contextInfo={activeErrorDialog?.contextInfo}
|
||||
onOpenChange={handleErrorDialogOpenChange}
|
||||
open={Boolean(activeErrorDialog)}
|
||||
type={activeErrorDialog?.type ?? 'server'}
|
||||
|
|
|
|||
|
|
@ -6,10 +6,11 @@
|
|||
open: boolean;
|
||||
type: 'timeout' | 'server';
|
||||
message: string;
|
||||
contextInfo?: { n_prompt_tokens: number; n_ctx: number };
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
}
|
||||
|
||||
let { open = $bindable(), type, message, onOpenChange }: Props = $props();
|
||||
let { open = $bindable(), type, message, contextInfo, onOpenChange }: Props = $props();
|
||||
|
||||
const isTimeout = $derived(type === 'timeout');
|
||||
const title = $derived(isTimeout ? 'TCP Timeout' : 'Server Error');
|
||||
|
|
@ -51,6 +52,15 @@
|
|||
|
||||
<div class={`rounded-lg border px-4 py-3 text-sm ${badgeClass}`}>
|
||||
<p class="font-medium">{message}</p>
|
||||
{#if contextInfo}
|
||||
<div class="mt-2 space-y-1 text-xs opacity-80">
|
||||
<p>
|
||||
<span class="font-medium">Prompt tokens:</span>
|
||||
{contextInfo.n_prompt_tokens.toLocaleString()}
|
||||
</p>
|
||||
<p><span class="font-medium">Context size:</span> {contextInfo.n_ctx.toLocaleString()}</p>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<AlertDialog.Footer>
|
||||
|
|
|
|||
|
|
@ -767,18 +767,33 @@ export class ChatService {
|
|||
* @param response - HTTP response object
|
||||
* @returns Promise<Error> - Parsed error with context info if available
|
||||
*/
|
||||
private static async parseErrorResponse(response: Response): Promise<Error> {
|
||||
private static async parseErrorResponse(
|
||||
response: Response
|
||||
): Promise<Error & { contextInfo?: { n_prompt_tokens: number; n_ctx: number } }> {
|
||||
try {
|
||||
const errorText = await response.text();
|
||||
const errorData: ApiErrorResponse = JSON.parse(errorText);
|
||||
|
||||
const message = errorData.error?.message || 'Unknown server error';
|
||||
const error = new Error(message);
|
||||
const error = new Error(message) as Error & {
|
||||
contextInfo?: { n_prompt_tokens: number; n_ctx: number };
|
||||
};
|
||||
error.name = response.status === 400 ? 'ServerError' : 'HttpError';
|
||||
|
||||
if (errorData.error && 'n_prompt_tokens' in errorData.error && 'n_ctx' in errorData.error) {
|
||||
error.contextInfo = {
|
||||
n_prompt_tokens: errorData.error.n_prompt_tokens,
|
||||
n_ctx: errorData.error.n_ctx
|
||||
};
|
||||
}
|
||||
|
||||
return error;
|
||||
} catch {
|
||||
const fallback = new Error(`Server error (${response.status}): ${response.statusText}`);
|
||||
const fallback = new Error(
|
||||
`Server error (${response.status}): ${response.statusText}`
|
||||
) as Error & {
|
||||
contextInfo?: { n_prompt_tokens: number; n_ctx: number };
|
||||
};
|
||||
fallback.name = 'HttpError';
|
||||
return fallback;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -58,7 +58,11 @@ class ChatStore {
|
|||
|
||||
activeProcessingState = $state<ApiProcessingState | null>(null);
|
||||
currentResponse = $state('');
|
||||
errorDialogState = $state<{ type: 'timeout' | 'server'; message: string } | null>(null);
|
||||
errorDialogState = $state<{
|
||||
type: 'timeout' | 'server';
|
||||
message: string;
|
||||
contextInfo?: { n_prompt_tokens: number; n_ctx: number };
|
||||
} | null>(null);
|
||||
isLoading = $state(false);
|
||||
chatLoadingStates = new SvelteMap<string, boolean>();
|
||||
chatStreamingStates = new SvelteMap<string, { response: string; messageId: string }>();
|
||||
|
|
@ -335,8 +339,12 @@ class ChatStore {
|
|||
return error instanceof Error && (error.name === 'AbortError' || error instanceof DOMException);
|
||||
}
|
||||
|
||||
private showErrorDialog(type: 'timeout' | 'server', message: string): void {
|
||||
this.errorDialogState = { type, message };
|
||||
private showErrorDialog(
|
||||
type: 'timeout' | 'server',
|
||||
message: string,
|
||||
contextInfo?: { n_prompt_tokens: number; n_ctx: number }
|
||||
): void {
|
||||
this.errorDialogState = { type, message, contextInfo };
|
||||
}
|
||||
|
||||
dismissErrorDialog(): void {
|
||||
|
|
@ -347,6 +355,23 @@ class ChatStore {
|
|||
// Message Operations
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Finds a message by ID and optionally validates its role.
|
||||
* Returns message and index, or null if not found or role doesn't match.
|
||||
*/
|
||||
private getMessageByIdWithRole(
|
||||
messageId: string,
|
||||
expectedRole?: ChatRole
|
||||
): { message: DatabaseMessage; index: number } | null {
|
||||
const index = conversationsStore.findMessageIndex(messageId);
|
||||
if (index === -1) return null;
|
||||
|
||||
const message = conversationsStore.activeMessages[index];
|
||||
if (expectedRole && message.role !== expectedRole) return null;
|
||||
|
||||
return { message, index };
|
||||
}
|
||||
|
||||
async addMessage(
|
||||
role: ChatRole,
|
||||
content: string,
|
||||
|
|
@ -508,7 +533,6 @@ class ChatStore {
|
|||
) => {
|
||||
this.stopStreaming();
|
||||
|
||||
// Build update data - only include model if not already persisted
|
||||
const updateData: Record<string, unknown> = {
|
||||
content: finalContent || streamedContent,
|
||||
thinking: reasoningContent || streamedReasoningContent,
|
||||
|
|
@ -520,7 +544,6 @@ class ChatStore {
|
|||
}
|
||||
await DatabaseService.updateMessage(assistantMessage.id, updateData);
|
||||
|
||||
// Update UI state - always include model and timings if available
|
||||
const idx = conversationsStore.findMessageIndex(assistantMessage.id);
|
||||
const uiUpdate: Partial<DatabaseMessage> = {
|
||||
content: updateData.content as string,
|
||||
|
|
@ -543,22 +566,38 @@ class ChatStore {
|
|||
},
|
||||
onError: (error: Error) => {
|
||||
this.stopStreaming();
|
||||
|
||||
if (this.isAbortError(error)) {
|
||||
this.setChatLoading(assistantMessage.convId, false);
|
||||
this.clearChatStreaming(assistantMessage.convId);
|
||||
this.clearProcessingState(assistantMessage.convId);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
console.error('Streaming error:', error);
|
||||
|
||||
this.setChatLoading(assistantMessage.convId, false);
|
||||
this.clearChatStreaming(assistantMessage.convId);
|
||||
this.clearProcessingState(assistantMessage.convId);
|
||||
|
||||
const idx = conversationsStore.findMessageIndex(assistantMessage.id);
|
||||
|
||||
if (idx !== -1) {
|
||||
const failedMessage = conversationsStore.removeMessageAtIndex(idx);
|
||||
if (failedMessage) DatabaseService.deleteMessage(failedMessage.id).catch(console.error);
|
||||
}
|
||||
this.showErrorDialog(error.name === 'TimeoutError' ? 'timeout' : 'server', error.message);
|
||||
|
||||
const contextInfo = (
|
||||
error as Error & { contextInfo?: { n_prompt_tokens: number; n_ctx: number } }
|
||||
).contextInfo;
|
||||
|
||||
this.showErrorDialog(
|
||||
error.name === 'TimeoutError' ? 'timeout' : 'server',
|
||||
error.message,
|
||||
contextInfo
|
||||
);
|
||||
|
||||
if (onError) onError(error);
|
||||
}
|
||||
},
|
||||
|
|
@ -591,7 +630,9 @@ class ChatStore {
|
|||
await conversationsStore.updateConversationName(currentConv.id, content.trim());
|
||||
|
||||
const assistantMessage = await this.createAssistantMessage(userMessage.id);
|
||||
|
||||
if (!assistantMessage) throw new Error('Failed to create assistant message');
|
||||
|
||||
conversationsStore.addMessageToActive(assistantMessage);
|
||||
await this.streamChatCompletion(
|
||||
conversationsStore.activeMessages.slice(0, -1),
|
||||
|
|
@ -607,15 +648,26 @@ class ChatStore {
|
|||
if (!this.errorDialogState) {
|
||||
const dialogType =
|
||||
error instanceof Error && error.name === 'TimeoutError' ? 'timeout' : 'server';
|
||||
this.showErrorDialog(dialogType, error instanceof Error ? error.message : 'Unknown error');
|
||||
const contextInfo = (
|
||||
error as Error & { contextInfo?: { n_prompt_tokens: number; n_ctx: number } }
|
||||
).contextInfo;
|
||||
|
||||
this.showErrorDialog(
|
||||
dialogType,
|
||||
error instanceof Error ? error.message : 'Unknown error',
|
||||
contextInfo
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async stopGeneration(): Promise<void> {
|
||||
const activeConv = conversationsStore.activeConversation;
|
||||
|
||||
if (!activeConv) return;
|
||||
|
||||
await this.savePartialResponseIfNeeded(activeConv.id);
|
||||
|
||||
this.stopStreaming();
|
||||
this.abortRequest(activeConv.id);
|
||||
this.setChatLoading(activeConv.id, false);
|
||||
|
|
@ -655,17 +707,22 @@ class ChatStore {
|
|||
|
||||
private async savePartialResponseIfNeeded(convId?: string): Promise<void> {
|
||||
const conversationId = convId || conversationsStore.activeConversation?.id;
|
||||
|
||||
if (!conversationId) return;
|
||||
|
||||
const streamingState = this.chatStreamingStates.get(conversationId);
|
||||
|
||||
if (!streamingState || !streamingState.response.trim()) return;
|
||||
|
||||
const messages =
|
||||
conversationId === conversationsStore.activeConversation?.id
|
||||
? conversationsStore.activeMessages
|
||||
: await conversationsStore.getConversationMessages(conversationId);
|
||||
|
||||
if (!messages.length) return;
|
||||
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
|
||||
if (lastMessage?.role === 'assistant') {
|
||||
try {
|
||||
const updateData: { content: string; thinking?: string; timings?: ChatMessageTimings } = {
|
||||
|
|
@ -684,9 +741,13 @@ class ChatStore {
|
|||
: undefined
|
||||
};
|
||||
}
|
||||
|
||||
await DatabaseService.updateMessage(lastMessage.id, updateData);
|
||||
|
||||
lastMessage.content = this.currentResponse;
|
||||
|
||||
if (updateData.thinking) lastMessage.thinking = updateData.thinking;
|
||||
|
||||
if (updateData.timings) lastMessage.timings = updateData.timings;
|
||||
} catch (error) {
|
||||
lastMessage.content = this.currentResponse;
|
||||
|
|
@ -700,14 +761,12 @@ class ChatStore {
|
|||
if (!activeConv) return;
|
||||
if (this.isLoading) this.stopGeneration();
|
||||
|
||||
try {
|
||||
const messageIndex = conversationsStore.findMessageIndex(messageId);
|
||||
if (messageIndex === -1) return;
|
||||
|
||||
const messageToUpdate = conversationsStore.activeMessages[messageIndex];
|
||||
const result = this.getMessageByIdWithRole(messageId, 'user');
|
||||
if (!result) return;
|
||||
const { message: messageToUpdate, index: messageIndex } = result;
|
||||
const originalContent = messageToUpdate.content;
|
||||
if (messageToUpdate.role !== 'user') return;
|
||||
|
||||
try {
|
||||
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
|
||||
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
|
||||
const isFirstUserMessage = rootMessage && messageToUpdate.parent === rootMessage.id;
|
||||
|
|
@ -724,7 +783,9 @@ class ChatStore {
|
|||
}
|
||||
|
||||
const messagesToRemove = conversationsStore.activeMessages.slice(messageIndex + 1);
|
||||
|
||||
for (const message of messagesToRemove) await DatabaseService.deleteMessage(message.id);
|
||||
|
||||
conversationsStore.sliceActiveMessages(messageIndex + 1);
|
||||
conversationsStore.updateConversationTimestamp();
|
||||
|
||||
|
|
@ -732,8 +793,11 @@ class ChatStore {
|
|||
this.clearChatStreaming(activeConv.id);
|
||||
|
||||
const assistantMessage = await this.createAssistantMessage();
|
||||
|
||||
if (!assistantMessage) throw new Error('Failed to create assistant message');
|
||||
|
||||
conversationsStore.addMessageToActive(assistantMessage);
|
||||
|
||||
await conversationsStore.updateCurrentNode(assistantMessage.id);
|
||||
await this.streamChatCompletion(
|
||||
conversationsStore.activeMessages.slice(0, -1),
|
||||
|
|
@ -758,12 +822,11 @@ class ChatStore {
|
|||
const activeConv = conversationsStore.activeConversation;
|
||||
if (!activeConv || this.isLoading) return;
|
||||
|
||||
try {
|
||||
const messageIndex = conversationsStore.findMessageIndex(messageId);
|
||||
if (messageIndex === -1) return;
|
||||
const messageToRegenerate = conversationsStore.activeMessages[messageIndex];
|
||||
if (messageToRegenerate.role !== 'assistant') return;
|
||||
const result = this.getMessageByIdWithRole(messageId, 'assistant');
|
||||
if (!result) return;
|
||||
const { index: messageIndex } = result;
|
||||
|
||||
try {
|
||||
const messagesToRemove = conversationsStore.activeMessages.slice(messageIndex);
|
||||
for (const message of messagesToRemove) await DatabaseService.deleteMessage(message.id);
|
||||
conversationsStore.sliceActiveMessages(messageIndex);
|
||||
|
|
@ -832,6 +895,7 @@ class ChatStore {
|
|||
const siblings = allMessages.filter(
|
||||
(m) => m.parent === messageToDelete.parent && m.id !== messageId
|
||||
);
|
||||
|
||||
if (siblings.length > 0) {
|
||||
const latestSibling = siblings.reduce((latest, sibling) =>
|
||||
sibling.timestamp > latest.timestamp ? sibling : latest
|
||||
|
|
@ -845,6 +909,7 @@ class ChatStore {
|
|||
}
|
||||
await DatabaseService.deleteMessageCascading(activeConv.id, messageId);
|
||||
await conversationsStore.refreshActiveMessages();
|
||||
|
||||
conversationsStore.updateConversationTimestamp();
|
||||
} catch (error) {
|
||||
console.error('Failed to delete message:', error);
|
||||
|
|
@ -862,12 +927,12 @@ class ChatStore {
|
|||
): Promise<void> {
|
||||
const activeConv = conversationsStore.activeConversation;
|
||||
if (!activeConv || this.isLoading) return;
|
||||
try {
|
||||
const idx = conversationsStore.findMessageIndex(messageId);
|
||||
if (idx === -1) return;
|
||||
const msg = conversationsStore.activeMessages[idx];
|
||||
if (msg.role !== 'assistant') return;
|
||||
|
||||
const result = this.getMessageByIdWithRole(messageId, 'assistant');
|
||||
if (!result) return;
|
||||
const { message: msg, index: idx } = result;
|
||||
|
||||
try {
|
||||
if (shouldBranch) {
|
||||
const newMessage = await DatabaseService.createMessageBranch(
|
||||
{
|
||||
|
|
@ -902,12 +967,12 @@ class ChatStore {
|
|||
async editUserMessagePreserveResponses(messageId: string, newContent: string): Promise<void> {
|
||||
const activeConv = conversationsStore.activeConversation;
|
||||
if (!activeConv) return;
|
||||
try {
|
||||
const idx = conversationsStore.findMessageIndex(messageId);
|
||||
if (idx === -1) return;
|
||||
const msg = conversationsStore.activeMessages[idx];
|
||||
if (msg.role !== 'user') return;
|
||||
|
||||
const result = this.getMessageByIdWithRole(messageId, 'user');
|
||||
if (!result) return;
|
||||
const { message: msg, index: idx } = result;
|
||||
|
||||
try {
|
||||
await DatabaseService.updateMessage(messageId, {
|
||||
content: newContent,
|
||||
timestamp: Date.now()
|
||||
|
|
@ -916,6 +981,7 @@ class ChatStore {
|
|||
|
||||
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
|
||||
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
|
||||
|
||||
if (rootMessage && msg.parent === rootMessage.id && newContent.trim()) {
|
||||
await conversationsStore.updateConversationTitleWithConfirmation(
|
||||
activeConv.id,
|
||||
|
|
@ -932,15 +998,16 @@ class ChatStore {
|
|||
async editMessageWithBranching(messageId: string, newContent: string): Promise<void> {
|
||||
const activeConv = conversationsStore.activeConversation;
|
||||
if (!activeConv || this.isLoading) return;
|
||||
try {
|
||||
const idx = conversationsStore.findMessageIndex(messageId);
|
||||
if (idx === -1) return;
|
||||
const msg = conversationsStore.activeMessages[idx];
|
||||
if (msg.role !== 'user') return;
|
||||
|
||||
const result = this.getMessageByIdWithRole(messageId, 'user');
|
||||
if (!result) return;
|
||||
const { message: msg } = result;
|
||||
|
||||
try {
|
||||
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
|
||||
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
|
||||
const isFirstUserMessage = rootMessage && msg.parent === rootMessage.id;
|
||||
|
||||
const parentId = msg.parent || rootMessage?.id;
|
||||
if (!parentId) return;
|
||||
|
||||
|
|
@ -1034,7 +1101,9 @@ class ChatStore {
|
|||
|
||||
private async generateResponseForMessage(userMessageId: string): Promise<void> {
|
||||
const activeConv = conversationsStore.activeConversation;
|
||||
|
||||
if (!activeConv) return;
|
||||
|
||||
this.errorDialogState = null;
|
||||
this.setChatLoading(activeConv.id, true);
|
||||
this.clearChatStreaming(activeConv.id);
|
||||
|
|
@ -1071,26 +1140,30 @@ class ChatStore {
|
|||
async continueAssistantMessage(messageId: string): Promise<void> {
|
||||
const activeConv = conversationsStore.activeConversation;
|
||||
if (!activeConv || this.isLoading) return;
|
||||
try {
|
||||
const idx = conversationsStore.findMessageIndex(messageId);
|
||||
if (idx === -1) return;
|
||||
const msg = conversationsStore.activeMessages[idx];
|
||||
if (msg.role !== 'assistant') return;
|
||||
|
||||
const result = this.getMessageByIdWithRole(messageId, 'assistant');
|
||||
if (!result) return;
|
||||
const { message: msg, index: idx } = result;
|
||||
|
||||
if (this.isChatLoading(activeConv.id)) return;
|
||||
|
||||
try {
|
||||
this.errorDialogState = null;
|
||||
this.setChatLoading(activeConv.id, true);
|
||||
this.clearChatStreaming(activeConv.id);
|
||||
|
||||
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
|
||||
const dbMessage = allMessages.find((m) => m.id === messageId);
|
||||
|
||||
if (!dbMessage) {
|
||||
this.setChatLoading(activeConv.id, false);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const originalContent = dbMessage.content;
|
||||
const originalThinking = dbMessage.thinking || '';
|
||||
|
||||
const conversationContext = conversationsStore.activeMessages.slice(0, idx);
|
||||
const contextWithContinue = [
|
||||
...conversationContext,
|
||||
|
|
@ -1107,6 +1180,7 @@ class ChatStore {
|
|||
contextWithContinue,
|
||||
{
|
||||
...this.getApiOptions(),
|
||||
|
||||
onChunk: (chunk: string) => {
|
||||
hasReceivedContent = true;
|
||||
appendedContent += chunk;
|
||||
|
|
@ -1114,6 +1188,7 @@ class ChatStore {
|
|||
this.setChatStreaming(msg.convId, fullContent, msg.id);
|
||||
conversationsStore.updateMessageAtIndex(idx, { content: fullContent });
|
||||
},
|
||||
|
||||
onReasoningChunk: (reasoningChunk: string) => {
|
||||
hasReceivedContent = true;
|
||||
appendedThinking += reasoningChunk;
|
||||
|
|
@ -1121,6 +1196,7 @@ class ChatStore {
|
|||
thinking: originalThinking + appendedThinking
|
||||
});
|
||||
},
|
||||
|
||||
onTimings: (timings: ChatMessageTimings, promptProgress?: ChatMessagePromptProgress) => {
|
||||
const tokensPerSecond =
|
||||
timings?.predicted_ms && timings?.predicted_n
|
||||
|
|
@ -1137,6 +1213,7 @@ class ChatStore {
|
|||
msg.convId
|
||||
);
|
||||
},
|
||||
|
||||
onComplete: async (
|
||||
finalContent?: string,
|
||||
reasoningContent?: string,
|
||||
|
|
@ -1161,6 +1238,7 @@ class ChatStore {
|
|||
this.clearChatStreaming(msg.convId);
|
||||
this.clearProcessingState(msg.convId);
|
||||
},
|
||||
|
||||
onError: async (error: Error) => {
|
||||
if (this.isAbortError(error)) {
|
||||
if (hasReceivedContent && appendedContent) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue