Merge branch 'master' into xsn/server_model_management_v1_2

This commit is contained in:
Xuan Son Nguyen 2025-11-20 14:19:16 +01:00
commit 919d3f8cbf
79 changed files with 2598 additions and 1204 deletions

View File

@ -26,7 +26,6 @@
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <thread> #include <thread>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
@ -60,6 +59,14 @@
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
#endif #endif
common_time_meas::common_time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
common_time_meas::~common_time_meas() {
if (t_start_us >= 0) {
t_acc += ggml_time_us() - t_start_us;
}
}
// //
// CPU utils // CPU utils
// //

View File

@ -2,17 +2,15 @@
#pragma once #pragma once
#include "ggml-opt.h"
#include "llama-cpp.h"
#include <set> #include <set>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <string_view> #include <string_view>
#include <vector> #include <vector>
#include <map> #include <map>
#include <sstream>
#include <cmath>
#include "ggml-opt.h"
#include "llama-cpp.h"
#ifdef _WIN32 #ifdef _WIN32
#define DIRECTORY_SEPARATOR '\\' #define DIRECTORY_SEPARATOR '\\'
@ -30,6 +28,15 @@
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
struct common_time_meas {
common_time_meas(int64_t & t_acc, bool disable = false);
~common_time_meas();
const int64_t t_start_us;
int64_t & t_acc;
};
struct common_adapter_lora_info { struct common_adapter_lora_info {
std::string path; std::string path;
float scale; float scale;

View File

@ -3,9 +3,10 @@
#include "common.h" #include "common.h"
#include "log.h" #include "log.h"
#include <cmath>
#include <unordered_map>
#include <algorithm> #include <algorithm>
#include <cmath>
#include <cstring>
#include <unordered_map>
// the ring buffer works similarly to std::deque, but with a fixed capacity // the ring buffer works similarly to std::deque, but with a fixed capacity
// TODO: deduplicate with llama-impl.h // TODO: deduplicate with llama-impl.h
@ -112,6 +113,13 @@ struct common_sampler {
llama_token_data_array cur_p; llama_token_data_array cur_p;
void reset() {
prev.clear();
llama_sampler_reset(grmr);
llama_sampler_reset(chain);
}
void set_logits(struct llama_context * ctx, int idx) { void set_logits(struct llama_context * ctx, int idx) {
const auto * logits = llama_get_logits_ith(ctx, idx); const auto * logits = llama_get_logits_ith(ctx, idx);
@ -128,6 +136,12 @@ struct common_sampler {
cur_p = { cur.data(), cur.size(), -1, false }; cur_p = { cur.data(), cur.size(), -1, false };
} }
common_time_meas tm() {
return common_time_meas(t_total_us, params.no_perf);
}
mutable int64_t t_total_us = 0;
}; };
std::string common_params_sampling::print() const { std::string common_params_sampling::print() const {
@ -298,6 +312,8 @@ void common_sampler_free(struct common_sampler * gsmpl) {
} }
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) { void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
const auto tm = gsmpl->tm();
if (accept_grammar) { if (accept_grammar) {
llama_sampler_accept(gsmpl->grmr, token); llama_sampler_accept(gsmpl->grmr, token);
} }
@ -308,9 +324,7 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo
} }
void common_sampler_reset(struct common_sampler * gsmpl) { void common_sampler_reset(struct common_sampler * gsmpl) {
llama_sampler_reset(gsmpl->grmr); gsmpl->reset();
llama_sampler_reset(gsmpl->chain);
} }
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) { struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
@ -327,16 +341,54 @@ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) { void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
// TODO: measure grammar performance // TODO: measure grammar performance
const double t_sampling_ms = gsmpl ? 1e-3*gsmpl->t_total_us : 0;
llama_perf_sampler_data data_smpl;
llama_perf_context_data data_ctx;
memset(&data_smpl, 0, sizeof(data_smpl));
memset(&data_ctx, 0, sizeof(data_ctx));
if (gsmpl) { if (gsmpl) {
llama_perf_sampler_print(gsmpl->chain); auto & data = data_smpl;
data = llama_perf_sampler(gsmpl->chain);
// note: the sampling time includes the samplers time + extra time spent in common/sampling
LOG_INF("%s: sampling time = %10.2f ms\n", __func__, t_sampling_ms);
LOG_INF("%s: samplers time = %10.2f ms / %5d tokens\n", __func__, data.t_sample_ms, data.n_sample);
} }
if (ctx) { if (ctx) {
llama_perf_context_print(ctx); auto & data = data_ctx;
data = llama_perf_context(ctx);
const double t_end_ms = 1e-3 * ggml_time_us();
const double t_total_ms = t_end_ms - data.t_start_ms;
const double t_unacc_ms = t_total_ms - (t_sampling_ms + data.t_p_eval_ms + data.t_eval_ms);
const double t_unacc_pc = 100.0 * t_unacc_ms / t_total_ms;
LOG_INF("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
LOG_INF("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
LOG_INF("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
LOG_INF("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
LOG_INF("%s: unaccounted time = %10.2f ms / %5.1f %% (total - sampling - prompt eval - eval) / (total)\n", __func__, t_unacc_ms, t_unacc_pc);
LOG_INF("%s: graphs reused = %10d\n", __func__, data.n_reused);
llama_memory_breakdown_print(ctx); llama_memory_breakdown_print(ctx);
} }
} }
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
llama_synchronize(ctx);
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
const auto tm = gsmpl->tm();
gsmpl->set_logits(ctx, idx); gsmpl->set_logits(ctx, idx);
auto & grmr = gsmpl->grmr; auto & grmr = gsmpl->grmr;
@ -428,6 +480,8 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
// helpers // helpers
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) { llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
const auto tm = gsmpl->tm();
auto * res = &gsmpl->cur_p; auto * res = &gsmpl->cur_p;
if (do_sort && !res->sorted) { if (do_sort && !res->sorted) {

View File

@ -1673,11 +1673,9 @@ class GPTNeoXModel(TextModel):
model_arch = gguf.MODEL_ARCH.GPTNEOX model_arch = gguf.MODEL_ARCH.GPTNEOX
def set_gguf_parameters(self): def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count( self.gguf_writer.add_rope_dimension_count(
int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])), int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
@ -1735,7 +1733,7 @@ class BloomModel(TextModel):
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed)) self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
self.gguf_writer.add_embedding_length(n_embed) self.gguf_writer.add_embedding_length(n_embed)
self.gguf_writer.add_feed_forward_length(4 * n_embed) self.gguf_writer.add_feed_forward_length(4 * n_embed)
self.gguf_writer.add_block_count(self.hparams["n_layer"]) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head) self.gguf_writer.add_head_count_kv(n_head)
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
@ -1798,10 +1796,9 @@ class MPTModel(TextModel):
self.gguf_writer.add_unk_token_id(0) self.gguf_writer.add_unk_token_id(0)
def set_gguf_parameters(self): def set_gguf_parameters(self):
block_count = self.hparams["n_layers"]
self.gguf_writer.add_context_length(self.hparams["max_seq_len"]) self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_embedding_length(self.hparams["d_model"])
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"]) self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"])
self.gguf_writer.add_head_count(self.hparams["n_heads"]) self.gguf_writer.add_head_count(self.hparams["n_heads"])
if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"): if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"):
@ -1834,7 +1831,6 @@ class OrionModel(TextModel):
self._set_vocab_sentencepiece() self._set_vocab_sentencepiece()
def set_gguf_parameters(self): def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
head_count = self.hparams["num_attention_heads"] head_count = self.hparams["num_attention_heads"]
head_count_kv = self.hparams.get("num_key_value_heads", head_count) head_count_kv = self.hparams.get("num_key_value_heads", head_count)
@ -1852,7 +1848,7 @@ class OrionModel(TextModel):
self.gguf_writer.add_tensor_data_layout("Meta AI original pth") self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
self.gguf_writer.add_context_length(ctx_length) self.gguf_writer.add_context_length(ctx_length)
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_head_count(head_count) self.gguf_writer.add_head_count(head_count)
self.gguf_writer.add_head_count_kv(head_count_kv) self.gguf_writer.add_head_count_kv(head_count_kv)
@ -1869,7 +1865,6 @@ class BaichuanModel(TextModel):
self._set_vocab_sentencepiece() self._set_vocab_sentencepiece()
def set_gguf_parameters(self): def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
head_count = self.hparams["num_attention_heads"] head_count = self.hparams["num_attention_heads"]
head_count_kv = self.hparams.get("num_key_value_heads", head_count) head_count_kv = self.hparams.get("num_key_value_heads", head_count)
@ -1886,7 +1881,7 @@ class BaichuanModel(TextModel):
self.gguf_writer.add_tensor_data_layout("Meta AI original pth") self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
self.gguf_writer.add_context_length(ctx_length) self.gguf_writer.add_context_length(ctx_length)
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count(head_count) self.gguf_writer.add_head_count(head_count)
@ -1993,7 +1988,6 @@ class XverseModel(TextModel):
special_vocab.add_to_gguf(self.gguf_writer) special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self): def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
head_count = self.hparams["num_attention_heads"] head_count = self.hparams["num_attention_heads"]
head_count_kv = self.hparams.get("num_key_value_heads", head_count) head_count_kv = self.hparams.get("num_key_value_heads", head_count)
@ -2010,7 +2004,7 @@ class XverseModel(TextModel):
self.gguf_writer.add_tensor_data_layout("Meta AI original pth") self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
self.gguf_writer.add_context_length(ctx_length) self.gguf_writer.add_context_length(ctx_length)
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count(head_count) self.gguf_writer.add_head_count(head_count)
@ -2053,10 +2047,6 @@ class FalconModel(TextModel):
model_arch = gguf.MODEL_ARCH.FALCON model_arch = gguf.MODEL_ARCH.FALCON
def set_gguf_parameters(self): def set_gguf_parameters(self):
block_count = self.hparams.get("num_hidden_layers")
if block_count is None:
block_count = self.hparams["n_layer"] # old name
n_head = self.hparams.get("num_attention_heads") n_head = self.hparams.get("num_attention_heads")
if n_head is None: if n_head is None:
n_head = self.hparams["n_head"] # old name n_head = self.hparams["n_head"] # old name
@ -2069,7 +2059,7 @@ class FalconModel(TextModel):
self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"]) self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head_kv) self.gguf_writer.add_head_count_kv(n_head_kv)
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
@ -2107,12 +2097,10 @@ class StarCoderModel(TextModel):
model_arch = gguf.MODEL_ARCH.STARCODER model_arch = gguf.MODEL_ARCH.STARCODER
def set_gguf_parameters(self): def set_gguf_parameters(self):
block_count = self.hparams["n_layer"]
self.gguf_writer.add_context_length(self.hparams["n_positions"]) self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(self.hparams["n_head"]) self.gguf_writer.add_head_count(self.hparams["n_head"])
self.gguf_writer.add_head_count_kv(1) self.gguf_writer.add_head_count_kv(1)
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
@ -2142,14 +2130,12 @@ class RefactModel(TextModel):
multiple_of = 256 multiple_of = 256
ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
block_count = self.hparams["n_layer"]
# refact uses Alibi. So this is from config.json which might be used by training. # refact uses Alibi. So this is from config.json which might be used by training.
self.gguf_writer.add_context_length(self.hparams["n_positions"]) self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(ff_dim) self.gguf_writer.add_feed_forward_length(ff_dim)
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(self.hparams["n_head"]) self.gguf_writer.add_head_count(self.hparams["n_head"])
self.gguf_writer.add_head_count_kv(1) self.gguf_writer.add_head_count_kv(1)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"]) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
@ -2196,11 +2182,10 @@ class StableLMModel(TextModel):
def set_gguf_parameters(self): def set_gguf_parameters(self):
hparams = self.hparams hparams = self.hparams
block_count = hparams["num_hidden_layers"]
self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(hparams["hidden_size"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"]) rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"])
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"]))) self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
@ -3151,7 +3136,7 @@ class DbrxModel(TextModel):
def set_gguf_parameters(self): def set_gguf_parameters(self):
ffn_config = self.hparams["ffn_config"] ffn_config = self.hparams["ffn_config"]
attn_config = self.hparams["attn_config"] attn_config = self.hparams["attn_config"]
self.gguf_writer.add_block_count(self.hparams["n_layers"]) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.hparams["max_seq_len"]) self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_embedding_length(self.hparams["d_model"])
@ -3353,7 +3338,7 @@ class QwenModel(TextModel):
def set_gguf_parameters(self): def set_gguf_parameters(self):
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"]) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"]) self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
@ -4384,7 +4369,7 @@ class GPT2Model(TextModel):
model_arch = gguf.MODEL_ARCH.GPT2 model_arch = gguf.MODEL_ARCH.GPT2
def set_gguf_parameters(self): def set_gguf_parameters(self):
self.gguf_writer.add_block_count(self.hparams["n_layer"]) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.hparams["n_ctx"]) self.gguf_writer.add_context_length(self.hparams["n_ctx"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
@ -4416,8 +4401,6 @@ class Phi2Model(TextModel):
model_arch = gguf.MODEL_ARCH.PHI2 model_arch = gguf.MODEL_ARCH.PHI2
def set_gguf_parameters(self): def set_gguf_parameters(self):
block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
rot_pct = self.find_hparam(["partial_rotary_factor"]) rot_pct = self.find_hparam(["partial_rotary_factor"])
n_embd = self.find_hparam(["hidden_size", "n_embd"]) n_embd = self.find_hparam(["hidden_size", "n_embd"])
n_head = self.find_hparam(["num_attention_heads", "n_head"]) n_head = self.find_hparam(["num_attention_heads", "n_head"])
@ -4426,7 +4409,7 @@ class Phi2Model(TextModel):
self.gguf_writer.add_embedding_length(n_embd) self.gguf_writer.add_embedding_length(n_embd)
self.gguf_writer.add_feed_forward_length(4 * n_embd) self.gguf_writer.add_feed_forward_length(4 * n_embd)
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head) self.gguf_writer.add_head_count_kv(n_head)
self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_epsilon", "layer_norm_eps"])) self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_epsilon", "layer_norm_eps"]))
@ -4544,8 +4527,6 @@ class Phi3MiniModel(TextModel):
special_vocab.add_to_gguf(self.gguf_writer) special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self): def set_gguf_parameters(self):
block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
n_embd = self.find_hparam(["hidden_size", "n_embd"]) n_embd = self.find_hparam(["hidden_size", "n_embd"])
n_head = self.find_hparam(["num_attention_heads", "n_head"]) n_head = self.find_hparam(["num_attention_heads", "n_head"])
n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"]) n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
@ -4559,7 +4540,7 @@ class Phi3MiniModel(TextModel):
self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds) self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds)
self.gguf_writer.add_embedding_length(n_embd) self.gguf_writer.add_embedding_length(n_embd)
self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"])) self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"]))
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head_kv) self.gguf_writer.add_head_count_kv(n_head_kv)
self.gguf_writer.add_layer_norm_rms_eps(rms_eps) self.gguf_writer.add_layer_norm_rms_eps(rms_eps)
@ -4679,12 +4660,11 @@ class PlamoModel(TextModel):
def set_gguf_parameters(self): def set_gguf_parameters(self):
hparams = self.hparams hparams = self.hparams
block_count = hparams["num_hidden_layers"]
self.gguf_writer.add_context_length(4096) # not in config.json self.gguf_writer.add_context_length(4096) # not in config.json
self.gguf_writer.add_embedding_length(hparams["hidden_size"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(hparams["num_attention_heads"]) self.gguf_writer.add_head_count(hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
@ -4807,7 +4787,6 @@ class Plamo2Model(TextModel):
def set_gguf_parameters(self): def set_gguf_parameters(self):
hparams = self.hparams hparams = self.hparams
block_count = hparams["num_hidden_layers"]
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
# Which layers are Mamba layers # Which layers are Mamba layers
@ -4819,10 +4798,10 @@ class Plamo2Model(TextModel):
num_attention_heads = [] num_attention_heads = []
if mamba_enabled: if mamba_enabled:
for i in range(block_count): for i in range(self.block_count):
if block_count <= (mamba_step // 2): if self.block_count <= (mamba_step // 2):
# use attention in last layer # use attention in last layer
is_mamba = (i != block_count - 1) is_mamba = (i != self.block_count - 1)
else: else:
is_mamba = (i % mamba_step) != (mamba_step // 2) is_mamba = (i % mamba_step) != (mamba_step // 2)
if is_mamba: if is_mamba:
@ -4840,7 +4819,7 @@ class Plamo2Model(TextModel):
self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096)) self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096))
self.gguf_writer.add_key_length(hparams.get("hidden_size_per_head", 128)) self.gguf_writer.add_key_length(hparams.get("hidden_size_per_head", 128))
self.gguf_writer.add_value_length(hparams.get("hidden_size_per_head", 128)) self.gguf_writer.add_value_length(hparams.get("hidden_size_per_head", 128))
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06)) self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000)) self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000))
@ -4897,12 +4876,10 @@ class CodeShellModel(TextModel):
model_arch = gguf.MODEL_ARCH.CODESHELL model_arch = gguf.MODEL_ARCH.CODESHELL
def set_gguf_parameters(self): def set_gguf_parameters(self):
block_count = self.hparams["n_layer"]
self.gguf_writer.add_context_length(self.hparams["n_positions"]) self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(self.hparams["n_head"]) self.gguf_writer.add_head_count(self.hparams["n_head"])
self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"]) self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
@ -5044,7 +5021,7 @@ class InternLM2Model(TextModel):
def set_gguf_parameters(self): def set_gguf_parameters(self):
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"]) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"]) self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
@ -5665,11 +5642,10 @@ class GemmaModel(TextModel):
def set_gguf_parameters(self): def set_gguf_parameters(self):
hparams = self.hparams hparams = self.hparams
block_count = hparams["num_hidden_layers"]
self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(hparams["hidden_size"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_head_count(hparams["num_attention_heads"]) self.gguf_writer.add_head_count(hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"]) self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
@ -5705,11 +5681,10 @@ class Gemma2Model(TextModel):
def set_gguf_parameters(self): def set_gguf_parameters(self):
hparams = self.hparams hparams = self.hparams
block_count = hparams["num_hidden_layers"]
self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(hparams["hidden_size"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_head_count(hparams["num_attention_heads"]) self.gguf_writer.add_head_count(hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"]) self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
@ -5753,12 +5728,11 @@ class Gemma3Model(TextModel):
def set_gguf_parameters(self): def set_gguf_parameters(self):
hparams = self.hparams hparams = self.hparams
block_count = hparams["num_hidden_layers"]
# some default values are not specified in the hparams # some default values are not specified in the hparams
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072)) self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072))
self.gguf_writer.add_embedding_length(hparams["hidden_size"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8)) self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8))
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6)) self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6))
@ -6034,7 +6008,6 @@ class Rwkv6Model(TextModel):
self._set_vocab_rwkv_world() self._set_vocab_rwkv_world()
def set_gguf_parameters(self): def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
head_size = self.hparams["head_size"] head_size = self.hparams["head_size"]
hidden_size = self.hparams["hidden_size"] hidden_size = self.hparams["hidden_size"]
layer_norm_eps = self.hparams["layer_norm_epsilon"] layer_norm_eps = self.hparams["layer_norm_epsilon"]
@ -6046,7 +6019,7 @@ class Rwkv6Model(TextModel):
# RWKV isn't context limited # RWKV isn't context limited
self.gguf_writer.add_context_length(1048576) self.gguf_writer.add_context_length(1048576)
self.gguf_writer.add_embedding_length(hidden_size) self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_layer_norm_eps(layer_norm_eps) self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers) self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers)
self.gguf_writer.add_wkv_head_size(head_size) self.gguf_writer.add_wkv_head_size(head_size)
@ -6110,7 +6083,6 @@ class RWKV6Qwen2Model(Rwkv6Model):
self._set_vocab_gpt2() self._set_vocab_gpt2()
def set_gguf_parameters(self): def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
num_attention_heads = self.hparams["num_attention_heads"] num_attention_heads = self.hparams["num_attention_heads"]
num_key_value_heads = self.hparams["num_key_value_heads"] num_key_value_heads = self.hparams["num_key_value_heads"]
hidden_size = self.hparams["hidden_size"] hidden_size = self.hparams["hidden_size"]
@ -6123,7 +6095,7 @@ class RWKV6Qwen2Model(Rwkv6Model):
# RWKV isn't context limited # RWKV isn't context limited
self.gguf_writer.add_context_length(1048576) self.gguf_writer.add_context_length(1048576)
self.gguf_writer.add_embedding_length(hidden_size) self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_wkv_head_size(head_size) self.gguf_writer.add_wkv_head_size(head_size)
self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim) self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim)
self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim) self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim)
@ -6164,7 +6136,6 @@ class Rwkv7Model(TextModel):
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32 return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
def set_gguf_parameters(self): def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
try: try:
head_size = self.hparams["head_size"] head_size = self.hparams["head_size"]
layer_norm_eps = self.hparams["layer_norm_epsilon"] layer_norm_eps = self.hparams["layer_norm_epsilon"]
@ -6189,7 +6160,7 @@ class Rwkv7Model(TextModel):
# RWKV isn't context limited # RWKV isn't context limited
self.gguf_writer.add_context_length(1048576) self.gguf_writer.add_context_length(1048576)
self.gguf_writer.add_embedding_length(hidden_size) self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_layer_norm_eps(layer_norm_eps) self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
self.gguf_writer.add_wkv_head_size(head_size) self.gguf_writer.add_wkv_head_size(head_size)
self.gguf_writer.add_decay_lora_rank(lora_rank_decay) self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
@ -6283,7 +6254,6 @@ class ARwkv7Model(Rwkv7Model):
self._set_vocab_gpt2() self._set_vocab_gpt2()
def set_gguf_parameters(self): def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
hidden_size = self.hparams["hidden_size"] hidden_size = self.hparams["hidden_size"]
head_size = self.hparams["head_size"] head_size = self.hparams["head_size"]
rms_norm_eps = self.hparams["rms_norm_eps"] rms_norm_eps = self.hparams["rms_norm_eps"]
@ -6300,7 +6270,7 @@ class ARwkv7Model(Rwkv7Model):
# RWKV isn't context limited # RWKV isn't context limited
self.gguf_writer.add_context_length(1048576) self.gguf_writer.add_context_length(1048576)
self.gguf_writer.add_embedding_length(hidden_size) self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_wkv_head_size(head_size) self.gguf_writer.add_wkv_head_size(head_size)
self.gguf_writer.add_decay_lora_rank(lora_rank_decay) self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
@ -7524,7 +7494,7 @@ class T5Model(TextModel):
self.gguf_writer.add_context_length(n_ctx) self.gguf_writer.add_context_length(n_ctx)
self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_embedding_length(self.hparams["d_model"])
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"]) self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
self.gguf_writer.add_block_count(self.hparams["num_layers"]) self.gguf_writer.add_block_count(self.block_count)
if (dec_n_layer := self.hparams.get("num_decoder_layers")) is not None: if (dec_n_layer := self.hparams.get("num_decoder_layers")) is not None:
self.gguf_writer.add_decoder_block_count(dec_n_layer) self.gguf_writer.add_decoder_block_count(dec_n_layer)
self.gguf_writer.add_head_count(self.hparams["num_heads"]) self.gguf_writer.add_head_count(self.hparams["num_heads"])
@ -7663,7 +7633,7 @@ class T5EncoderModel(TextModel):
self.gguf_writer.add_context_length(n_ctx) self.gguf_writer.add_context_length(n_ctx)
self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_embedding_length(self.hparams["d_model"])
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"]) self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
self.gguf_writer.add_block_count(self.hparams["num_layers"]) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(self.hparams["num_heads"]) self.gguf_writer.add_head_count(self.hparams["num_heads"])
self.gguf_writer.add_key_length(self.hparams["d_kv"]) self.gguf_writer.add_key_length(self.hparams["d_kv"])
self.gguf_writer.add_value_length(self.hparams["d_kv"]) self.gguf_writer.add_value_length(self.hparams["d_kv"])
@ -7726,7 +7696,7 @@ class JaisModel(TextModel):
self._set_vocab_gpt2() self._set_vocab_gpt2()
def set_gguf_parameters(self): def set_gguf_parameters(self):
self.gguf_writer.add_block_count(self.hparams["n_layer"]) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.hparams["n_positions"]) self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(self.hparams["n_inner"]) self.gguf_writer.add_feed_forward_length(self.hparams["n_inner"])
@ -8068,7 +8038,7 @@ class ChatGLMModel(TextModel):
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed)) self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
self.gguf_writer.add_embedding_length(n_embed) self.gguf_writer.add_embedding_length(n_embed)
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", self.hparams.get("intermediate_size", 4 * n_embed))) self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", self.hparams.get("intermediate_size", 4 * n_embed)))
self.gguf_writer.add_block_count(self.hparams.get("num_layers", self.hparams["num_hidden_layers"])) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head_kv) self.gguf_writer.add_head_count_kv(n_head_kv)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon",1e-5)) self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon",1e-5))
@ -8150,7 +8120,6 @@ class ExaoneModel(TextModel):
num_kv_heads = hparams.get("num_key_value_heads", num_heads) num_kv_heads = hparams.get("num_key_value_heads", num_heads)
layer_norm_eps = hparams["layer_norm_epsilon"] layer_norm_eps = hparams["layer_norm_epsilon"]
intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim
num_layers = hparams["num_layers"]
# ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0 # ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0
# attention_dropout_rate = hparams["attention_dropout"] # attention_dropout_rate = hparams["attention_dropout"]
# ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0 # ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0
@ -8161,7 +8130,7 @@ class ExaoneModel(TextModel):
self.gguf_writer.add_context_length(max_position_embeddings) self.gguf_writer.add_context_length(max_position_embeddings)
self.gguf_writer.add_layer_norm_rms_eps(layer_norm_eps) self.gguf_writer.add_layer_norm_rms_eps(layer_norm_eps)
self.gguf_writer.add_feed_forward_length(intermediate_size) self.gguf_writer.add_feed_forward_length(intermediate_size)
self.gguf_writer.add_block_count(num_layers) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_file_type(self.ftype)
if (rope_theta := self.hparams.get("rope_theta")) is not None: if (rope_theta := self.hparams.get("rope_theta")) is not None:

View File

@ -277,10 +277,15 @@ def parse_args() -> argparse.Namespace:
return parser.parse_args() return parser.parse_args()
def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]: def load_hparams_from_hf(hf_model_id: str) -> tuple[dict[str, Any], Path | None]:
from huggingface_hub import try_to_load_from_cache
# normally, adapter does not come with base model config, we need to load it from AutoConfig # normally, adapter does not come with base model config, we need to load it from AutoConfig
config = AutoConfig.from_pretrained(hf_model_id) config = AutoConfig.from_pretrained(hf_model_id)
return config.to_dict() cache_dir = try_to_load_from_cache(hf_model_id, "config.json")
cache_dir = Path(cache_dir).parent if isinstance(cache_dir, str) else None
return config.to_dict(), cache_dir
if __name__ == '__main__': if __name__ == '__main__':
@ -325,13 +330,13 @@ if __name__ == '__main__':
# load base model # load base model
if base_model_id is not None: if base_model_id is not None:
logger.info(f"Loading base model from Hugging Face: {base_model_id}") logger.info(f"Loading base model from Hugging Face: {base_model_id}")
hparams = load_hparams_from_hf(base_model_id) hparams, dir_base_model = load_hparams_from_hf(base_model_id)
elif dir_base_model is None: elif dir_base_model is None:
if "base_model_name_or_path" in lparams: if "base_model_name_or_path" in lparams:
model_id = lparams["base_model_name_or_path"] model_id = lparams["base_model_name_or_path"]
logger.info(f"Loading base model from Hugging Face: {model_id}") logger.info(f"Loading base model from Hugging Face: {model_id}")
try: try:
hparams = load_hparams_from_hf(model_id) hparams, dir_base_model = load_hparams_from_hf(model_id)
except OSError as e: except OSError as e:
logger.error(f"Failed to load base model config: {e}") logger.error(f"Failed to load base model config: {e}")
logger.error("Please try downloading the base model and add its path to --base") logger.error("Please try downloading the base model and add its path to --base")
@ -480,6 +485,7 @@ if __name__ == '__main__':
dir_lora_model=dir_lora, dir_lora_model=dir_lora,
lora_alpha=alpha, lora_alpha=alpha,
hparams=hparams, hparams=hparams,
remote_hf_model_id=base_model_id,
) )
logger.info("Exporting model...") logger.info("Exporting model...")

View File

@ -17,12 +17,12 @@ Legend:
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ | | ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | | ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | | ❌ | | ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | | ❌ |
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | | ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | | ❌ | | ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | | ❌ |
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | | ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | | ❌ | | CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | | CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | | CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ | | CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
@ -43,9 +43,9 @@ Legend:
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ | | ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ | | EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | | EXPM1 | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
| FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | | ❌ | | FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | | FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | | ❌ | | FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | | GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | | GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | | GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
@ -87,7 +87,7 @@ Legend:
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | | ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | | ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | | ❌ | | ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
@ -99,7 +99,7 @@ Legend:
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | | SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | | SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | | ❌ | | SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | | SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
@ -107,7 +107,7 @@ Legend:
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | | SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | | SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | | ❌ | | STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | | SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| SUM | ❌ | ✅ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | | SUM | ❌ | ✅ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | | SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
@ -116,6 +116,6 @@ Legend:
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ | | TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | | ❌ | | TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | | UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |

View File

@ -5,8 +5,8 @@
"Vulkan0","SGN","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" "Vulkan0","SGN","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
"Vulkan0","NEG","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","NEG","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","NEG","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","NEG","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","STEP","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" "Vulkan0","STEP","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","STEP","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" "Vulkan0","STEP","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","TANH","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","TANH","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","TANH","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","TANH","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","ELU","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" "Vulkan0","ELU","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
@ -29,18 +29,18 @@
"Vulkan0","EXP","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","EXP","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","EXPM1","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" "Vulkan0","EXPM1","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
"Vulkan0","EXPM1","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" "Vulkan0","EXPM1","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
"Vulkan0","SOFTPLUS","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" "Vulkan0","SOFTPLUS","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","SOFTPLUS","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" "Vulkan0","SOFTPLUS","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","GELU_ERF","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","GELU_ERF","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" "Vulkan0","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" "Vulkan0","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" "Vulkan0","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" "Vulkan0","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" "Vulkan0","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" "Vulkan0","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" "Vulkan0","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" "Vulkan0","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","ABS","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan" "Vulkan0","ABS","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan"
"Vulkan0","ABS","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","Vulkan" "Vulkan0","ABS","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","Vulkan"
"Vulkan0","SGN","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan" "Vulkan0","SGN","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan"
@ -89,8 +89,8 @@
"Vulkan0","SGN","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" "Vulkan0","SGN","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
"Vulkan0","NEG","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","NEG","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","NEG","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","NEG","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","STEP","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" "Vulkan0","STEP","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","STEP","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" "Vulkan0","STEP","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","TANH","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","TANH","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","TANH","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","TANH","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","ELU","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" "Vulkan0","ELU","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
@ -113,18 +113,18 @@
"Vulkan0","EXP","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","EXP","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","EXPM1","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" "Vulkan0","EXPM1","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
"Vulkan0","EXPM1","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" "Vulkan0","EXPM1","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
"Vulkan0","SOFTPLUS","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" "Vulkan0","SOFTPLUS","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","SOFTPLUS","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" "Vulkan0","SOFTPLUS","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","GELU_ERF","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","GELU_ERF","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" "Vulkan0","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" "Vulkan0","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" "Vulkan0","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" "Vulkan0","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" "Vulkan0","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" "Vulkan0","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" "Vulkan0","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" "Vulkan0","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","ABS","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan" "Vulkan0","ABS","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan"
"Vulkan0","ABS","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","Vulkan" "Vulkan0","ABS","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","Vulkan"
"Vulkan0","SGN","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan" "Vulkan0","SGN","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan"
@ -5654,7 +5654,7 @@
"Vulkan0","SUB","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan" "Vulkan0","SUB","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan"
"Vulkan0","MUL","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan" "Vulkan0","MUL","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan"
"Vulkan0","DIV","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan" "Vulkan0","DIV","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan"
"Vulkan0","ADD1","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan" "Vulkan0","ADD1","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
"Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=0.000000,inplace=0","support","1","yes","Vulkan" "Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=0.000000,inplace=0","support","1","yes","Vulkan"
"Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=0","support","1","yes","Vulkan" "Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=0","support","1","yes","Vulkan"
"Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=1","support","1","yes","Vulkan" "Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=1","support","1","yes","Vulkan"
@ -8632,10 +8632,10 @@
"Vulkan0","COS","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" "Vulkan0","COS","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
"Vulkan0","CLAMP","type=f16,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","0","no","Vulkan" "Vulkan0","CLAMP","type=f16,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","0","no","Vulkan"
"Vulkan0","LEAKY_RELU","type=f16,ne_a=[10,5,4,3],negative_slope=0.100000","support","0","no","Vulkan" "Vulkan0","LEAKY_RELU","type=f16,ne_a=[10,5,4,3],negative_slope=0.100000","support","0","no","Vulkan"
"Vulkan0","FLOOR","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" "Vulkan0","FLOOR","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan"
"Vulkan0","CEIL","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" "Vulkan0","CEIL","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan"
"Vulkan0","ROUND","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" "Vulkan0","ROUND","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan"
"Vulkan0","TRUNC","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" "Vulkan0","TRUNC","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan"
"Vulkan0","SQR","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","SQR","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
"Vulkan0","SQRT","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","SQRT","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
"Vulkan0","LOG","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","LOG","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan"
@ -8643,10 +8643,10 @@
"Vulkan0","COS","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","COS","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
"Vulkan0","CLAMP","type=f16,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","0","no","Vulkan" "Vulkan0","CLAMP","type=f16,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","0","no","Vulkan"
"Vulkan0","LEAKY_RELU","type=f16,ne_a=[7,1,5,3],negative_slope=0.100000","support","0","no","Vulkan" "Vulkan0","LEAKY_RELU","type=f16,ne_a=[7,1,5,3],negative_slope=0.100000","support","0","no","Vulkan"
"Vulkan0","FLOOR","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","FLOOR","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan"
"Vulkan0","CEIL","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","CEIL","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan"
"Vulkan0","ROUND","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","ROUND","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan"
"Vulkan0","TRUNC","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","TRUNC","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan"
"Vulkan0","SQR","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" "Vulkan0","SQR","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
"Vulkan0","SQRT","type=f32,ne=[10,3,3,2]","support","1","yes","Vulkan" "Vulkan0","SQRT","type=f32,ne=[10,3,3,2]","support","1","yes","Vulkan"
"Vulkan0","LOG","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" "Vulkan0","LOG","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
@ -8654,10 +8654,10 @@
"Vulkan0","COS","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan" "Vulkan0","COS","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
"Vulkan0","CLAMP","type=f32,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan" "Vulkan0","CLAMP","type=f32,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan"
"Vulkan0","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","1","yes","Vulkan" "Vulkan0","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","1","yes","Vulkan"
"Vulkan0","FLOOR","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan" "Vulkan0","FLOOR","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
"Vulkan0","CEIL","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan" "Vulkan0","CEIL","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
"Vulkan0","ROUND","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan" "Vulkan0","ROUND","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
"Vulkan0","TRUNC","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan" "Vulkan0","TRUNC","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
"Vulkan0","SQR","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","SQR","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
"Vulkan0","SQRT","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","SQRT","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
"Vulkan0","LOG","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","LOG","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
@ -8665,10 +8665,10 @@
"Vulkan0","COS","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","COS","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
"Vulkan0","CLAMP","type=f32,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan" "Vulkan0","CLAMP","type=f32,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan"
"Vulkan0","LEAKY_RELU","type=f32,ne_a=[7,1,5,3],negative_slope=0.100000","support","1","yes","Vulkan" "Vulkan0","LEAKY_RELU","type=f32,ne_a=[7,1,5,3],negative_slope=0.100000","support","1","yes","Vulkan"
"Vulkan0","FLOOR","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","FLOOR","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
"Vulkan0","CEIL","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","CEIL","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
"Vulkan0","ROUND","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","ROUND","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
"Vulkan0","TRUNC","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","TRUNC","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
"Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,1,1],n_past=5","support","1","yes","Vulkan" "Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,1,1],n_past=5","support","1","yes","Vulkan"
"Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,3,1],n_past=5","support","1","yes","Vulkan" "Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,3,1],n_past=5","support","1","yes","Vulkan"
"Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,3,2],n_past=5","support","1","yes","Vulkan" "Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,3,2],n_past=5","support","1","yes","Vulkan"
@ -9478,7 +9478,7 @@
"Vulkan0","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","0","no","Vulkan" "Vulkan0","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","0","no","Vulkan"
"Vulkan0","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","0","no","Vulkan" "Vulkan0","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","0","no","Vulkan"
"Vulkan0","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","1","yes","Vulkan" "Vulkan0","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","1","yes","Vulkan"
"Vulkan0","ARANGE","type=f32,start=0.000000,stop=10.000000,step=1.000000","support","0","no","Vulkan" "Vulkan0","ARANGE","type=f32,start=0.000000,stop=10.000000,step=1.000000","support","1","yes","Vulkan"
"Vulkan0","TIMESTEP_EMBEDDING","type=f32,ne_a=[2,1,1,1],dim=320,max_period=10000","support","1","yes","Vulkan" "Vulkan0","TIMESTEP_EMBEDDING","type=f32,ne_a=[2,1,1,1],dim=320,max_period=10000","support","1","yes","Vulkan"
"Vulkan0","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","1","yes","Vulkan" "Vulkan0","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","1","yes","Vulkan"
"Vulkan0","CUMSUM","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan" "Vulkan0","CUMSUM","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan"
@ -9487,9 +9487,9 @@
"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","0","no","Vulkan" "Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","0","no","Vulkan"
"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","0","no","Vulkan" "Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","0","no","Vulkan"
"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","0","no","Vulkan" "Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","0","no","Vulkan"
"Vulkan0","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","0","no","Vulkan" "Vulkan0","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","1","yes","Vulkan"
"Vulkan0","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","0","no","Vulkan" "Vulkan0","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","1","yes","Vulkan"
"Vulkan0","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","0","no","Vulkan" "Vulkan0","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","1","yes","Vulkan"
"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","0","no","Vulkan" "Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","0","no","Vulkan"
"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","0","no","Vulkan" "Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","0","no","Vulkan"
"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","0","no","Vulkan" "Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","0","no","Vulkan"

Can't render this file because it is too large.

View File

@ -4,10 +4,10 @@
#include "llama.h" #include "llama.h"
#include "ggml.h" #include "ggml.h"
#include <cmath>
#include <cstdio> #include <cstdio>
#include <string> #include <string>
#include <vector> #include <vector>
#include <numeric>
/** /**
* This the arbitrary data which will be passed to each callback. * This the arbitrary data which will be passed to each callback.
@ -37,23 +37,23 @@ static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
return u.f; return u.f;
} }
static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) { static float ggml_get_float_value(const uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
float v; float v;
if (type == GGML_TYPE_F16) { if (type == GGML_TYPE_F16) {
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]); v = ggml_fp16_to_fp32(*(const ggml_fp16_t *) &data[i]);
} else if (type == GGML_TYPE_F32) { } else if (type == GGML_TYPE_F32) {
v = *(float *) &data[i]; v = *(const float *) &data[i];
} else if (type == GGML_TYPE_I64) { } else if (type == GGML_TYPE_I64) {
v = (float) *(int64_t *) &data[i]; v = (float) *(const int64_t *) &data[i];
} else if (type == GGML_TYPE_I32) { } else if (type == GGML_TYPE_I32) {
v = (float) *(int32_t *) &data[i]; v = (float) *(const int32_t *) &data[i];
} else if (type == GGML_TYPE_I16) { } else if (type == GGML_TYPE_I16) {
v = (float) *(int16_t *) &data[i]; v = (float) *(const int16_t *) &data[i];
} else if (type == GGML_TYPE_I8) { } else if (type == GGML_TYPE_I8) {
v = (float) *(int8_t *) &data[i]; v = (float) *(const int8_t *) &data[i];
} else if (type == GGML_TYPE_BF16) { } else if (type == GGML_TYPE_BF16) {
v = ggml_compute_bf16_to_fp32(*(ggml_bf16_t *) &data[i]); v = ggml_compute_bf16_to_fp32(*(const ggml_bf16_t *) &data[i]);
} else { } else {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }

View File

@ -392,9 +392,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}") string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
if (EXTRACTED_NUMBER GREATER_EQUAL 10) if (EXTRACTED_NUMBER GREATER_EQUAL 10)
list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64) list(APPEND ARCH_FLAGS -mcpu=power10)
elseif (EXTRACTED_NUMBER EQUAL 9) elseif (EXTRACTED_NUMBER EQUAL 9)
list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64) list(APPEND ARCH_FLAGS -mcpu=power9)
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le") elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native) list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native)
else() else()

View File

@ -39,7 +39,7 @@
#include "kernels.h" #include "kernels.h"
#define NELEMS(x) sizeof(x) / sizeof(*x) #define NELEMS(x) (sizeof(x) / sizeof(*x))
template<size_t(*Fn)(size_t,size_t,size_t)> template<size_t(*Fn)(size_t,size_t,size_t)>
static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) { static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) {
@ -635,6 +635,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
}, },
#endif #endif
#endif #endif
{ /* Sentinel */ }
}; };
static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = { static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
@ -803,6 +804,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
/* .op_type = */ GGML_TYPE_F32, /* .op_type = */ GGML_TYPE_F32,
}, },
#endif #endif
{ /* Sentinel */ }
}; };
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) { ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
@ -810,7 +812,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) { if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) #if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu && if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type && gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type && gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type &&
@ -820,7 +822,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
} }
} }
if (!kernel) { if (!kernel) {
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) { for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) {
if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu && if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu &&
gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type && gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type &&
gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type && gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type &&
@ -830,6 +832,10 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
} }
} }
} }
#else
GGML_UNUSED(gemm_gemv_kernels);
GGML_UNUSED(gemm_gemv_kernels_q8);
GGML_UNUSED(cpu_features);
#endif #endif
} }
@ -840,12 +846,14 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features)
ggml_kleidiai_kernels * kernels = nullptr; ggml_kleidiai_kernels * kernels = nullptr;
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) #if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) { if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
kernels = &gemm_gemv_kernels[i]; kernels = &gemm_gemv_kernels[i];
break; break;
} }
} }
#else
GGML_UNUSED(features);
#endif #endif
return kernels; return kernels;
@ -855,12 +863,14 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features)
ggml_kleidiai_kernels * kernels = nullptr; ggml_kleidiai_kernels * kernels = nullptr;
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) #if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) { for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) {
if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) { if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) {
kernels = &gemm_gemv_kernels_q8[i]; kernels = &gemm_gemv_kernels_q8[i];
break; break;
} }
} }
#else
GGML_UNUSED(features);
#endif #endif
return kernels; return kernels;

View File

@ -9696,13 +9696,12 @@ static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params
for (int64_t i00 = 0; i00 < n; ++i00) { for (int64_t i00 = 0; i00 < n; ++i00) {
float sum = 0.0f; float sum = 0.0f;
for (int64_t t = 0; t < i00; ++t) { for (int64_t t = 0; t < i00; ++t) {
sum += A_batch[i00 * n + t] * X_batch[i01 * n + t]; sum += A_batch[i00 * n + t] * X_batch[t * k + i01];
} }
const float diag = A_batch[i00 * n + i00]; const float diag = A_batch[i00 * n + i00];
GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix"); GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
X_batch[i01 * n + i00] = (B_batch[i00 * k + i01] - sum) / diag;
} }
} }
} }

View File

@ -160,18 +160,18 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
#define GGML_F32xt svfloat32_t #define GGML_F32xt svfloat32_t
#define GGML_F32xt_ZERO svdup_n_f32(0.0f) #define GGML_F32xt_ZERO svdup_n_f32(0.0f)
#define GGML_F32xt_SET1(x) svdup_n_f32(x) #define GGML_F32xt_SET1(x) svdup_n_f32(x)
#define GGML_F32xt_LOAD_IMPL(pg, a, ...) svld1_f32(pg, a) #define GGML_F32xt_LOAD_IMPL(pg, a) svld1_f32(pg, a)
#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_LOAD(a) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, a)
#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b) #define GGML_F32xt_STORE_IMPL(pg, a, b) svst1_f32(pg, a, b)
#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_STORE(a, b) GGML_F32xt_STORE_IMPL(DEFAULT_PG, a, b)
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a) #define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_FMA(a, b, c) GGML_F32xt_FMA_IMPL(DEFAULT_PG, a, b, c)
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b) #define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_ADD(a, b) GGML_F32xt_ADD_IMPL(DEFAULT_PG, a, b)
#define GGML_F32xt_MUL_IMPL(pg, a, b) svmul_f32_m(pg, a, b) #define GGML_F32xt_MUL_IMPL(pg, a, b) svmul_f32_m(pg, a, b)
#define GGML_F32xt_MUL(...) GGML_F32xt_MUL_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_MUL(a, b) GGML_F32xt_MUL_IMPL(DEFAULT_PG, a, b)
#define GGML_F32xt_REDUCE_ONE_IMPL(pg, a) svaddv(pg, a) #define GGML_F32xt_REDUCE_ONE_IMPL(pg, a) svaddv(pg, a)
#define GGML_F32xt_REDUCE_ONE(...) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_REDUCE_ONE(a) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, a)
#define GGML_F32xt_REDUCE_IMPL(pg, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \ #define GGML_F32xt_REDUCE_IMPL(pg, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \
{ \ { \
sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum2); \ sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum2); \
@ -183,7 +183,8 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum5); \ sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum5); \
(res) = (ggml_float) GGML_F32xt_REDUCE_ONE(sum1); \ (res) = (ggml_float) GGML_F32xt_REDUCE_ONE(sum1); \
} }
#define GGML_F32xt_REDUCE(...) GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_REDUCE(res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \
GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8)
#define GGML_F32_VEC GGML_F32xt #define GGML_F32_VEC GGML_F32xt
#define GGML_F32_VEC_ZERO GGML_F32xt_ZERO #define GGML_F32_VEC_ZERO GGML_F32xt_ZERO
@ -206,11 +207,11 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
#define GGML_F32Cxt_STORE(dst_ptr, src_vec) svst1_f16(DEFAULT_PG16, (__fp16 *)(dst_ptr), (src_vec)) #define GGML_F32Cxt_STORE(dst_ptr, src_vec) svst1_f16(DEFAULT_PG16, (__fp16 *)(dst_ptr), (src_vec))
#define GGML_F32Cxt_FMA_IMPL(pg, a, b, c) svmad_f16_x(pg, b, c, a) #define GGML_F32Cxt_FMA_IMPL(pg, a, b, c) svmad_f16_x(pg, b, c, a)
#define GGML_F32Cxt_FMA(...) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, __VA_ARGS__) #define GGML_F32Cxt_FMA(a, b, c) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, a, b, c)
#define GGML_F32Cxt_ADD_IMPL(pg, a, b) svadd_f16_x(pg, a, b) #define GGML_F32Cxt_ADD_IMPL(pg, a, b) svadd_f16_x(pg, a, b)
#define GGML_F32Cxt_ADD(...) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, __VA_ARGS__) #define GGML_F32Cxt_ADD(a, b) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, a, b)
#define GGML_F32Cxt_MUL_IMPL(pg, a, b) svmul_f16_x(pg, a, b) #define GGML_F32Cxt_MUL_IMPL(pg, a, b) svmul_f16_x(pg, a, b)
#define GGML_F32Cxt_MUL(...) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, __VA_ARGS__) #define GGML_F32Cxt_MUL(a, b) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, a, b)
#define GGML_F32Cxt_REDUCE GGML_F16xt_REDUCE_MIXED #define GGML_F32Cxt_REDUCE GGML_F16xt_REDUCE_MIXED
#define GGML_F16x_VEC GGML_F32Cxt #define GGML_F16x_VEC GGML_F32Cxt
@ -224,7 +225,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
#define GGML_F16x_VEC_REDUCE GGML_F32Cxt_REDUCE #define GGML_F16x_VEC_REDUCE GGML_F32Cxt_REDUCE
#define GGML_F16xt_REDUCE_ONE_IMPL(pg, a) svaddv_f16(pg, a) #define GGML_F16xt_REDUCE_ONE_IMPL(pg, a) svaddv_f16(pg, a)
#define GGML_F16xt_REDUCE_ONE(...) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, __VA_ARGS__) #define GGML_F16xt_REDUCE_ONE(a) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, a)
#define GGML_F16xt_REDUCE_MIXED_IMPL(pg16, res, sum1, sum2, sum3, sum4) \ #define GGML_F16xt_REDUCE_MIXED_IMPL(pg16, res, sum1, sum2, sum3, sum4) \
{ \ { \
@ -234,7 +235,8 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
__fp16 sum_f16 = svaddv_f16(pg16, sum1); \ __fp16 sum_f16 = svaddv_f16(pg16, sum1); \
(res) = (ggml_float) sum_f16; \ (res) = (ggml_float) sum_f16; \
} }
#define GGML_F16xt_REDUCE_MIXED(...) GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, __VA_ARGS__) #define GGML_F16xt_REDUCE_MIXED(res, sum1, sum2, sum3, sum4) \
GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, res, sum1, sum2, sum3, sum4)
// F16 NEON // F16 NEON

View File

@ -698,8 +698,7 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
} }
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
#if defined(GGML_SIMD) #if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
#if defined(__ARM_FEATURE_SVE)
const int sve_register_length = svcntb() * 8; const int sve_register_length = svcntb() * 8;
const int ggml_f16_epr = sve_register_length / 16; const int ggml_f16_epr = sve_register_length / 16;
const int ggml_f16_step = 2 * ggml_f16_epr; const int ggml_f16_step = 2 * ggml_f16_epr;
@ -725,13 +724,16 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
svfloat16_t out = svmul_f16_m(pg, hy, vx); svfloat16_t out = svmul_f16_m(pg, hy, vx);
svst1_f16(pg, (__fp16 *)(y + np), out); svst1_f16(pg, (__fp16 *)(y + np), out);
} }
#elif defined(__riscv_v_intrinsic) #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
// todo: RVV impl for (int i = 0, vl; i < n; i += vl) {
// scalar vl = __riscv_vsetvl_e16m2(n - i);
for (int i = 0; i < n; ++i) { vfloat16m2_t vy = __riscv_vle16_v_f16m2((_Float16 *)&y[i], vl);
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); vfloat32m4_t vy32 = __riscv_vfwcvt_f_f_v_f32m4(vy, vl);
vy32 = __riscv_vfmul_vf_f32m4(vy32, v, vl);
vy = __riscv_vfncvt_f_f_w_f16m2(vy32, vl);
__riscv_vse16_v_f16m2((_Float16 *)&y[i], vy, vl);
} }
#else #elif defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1)); const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
@ -751,7 +753,6 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
for (int i = np; i < n; ++i) { for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
} }
#endif
#else #else
// scalar // scalar
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {

View File

@ -384,7 +384,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
char * src1_ddc = (char *) src1->data; char * src1_ddc = (char *) src1->data;
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1); const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && src0->ne[3] == 1; const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) &&
src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0);
if (src0->type == src1->type && contiguous_srcs) { if (src0->type == src1->type && contiguous_srcs) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));

View File

@ -3001,6 +3001,10 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope, static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
const ggml_tensor * view, const ggml_tensor * view,
const ggml_tensor * set_rows) { const ggml_tensor * set_rows) {
if (rope->op != GGML_OP_ROPE || view->op != GGML_OP_VIEW || set_rows->op != GGML_OP_SET_ROWS) {
return false;
}
// ne3 not tested // ne3 not tested
if (rope->src[0]->ne[3] != 1) { if (rope->src[0]->ne[3] != 1) {
return false; return false;
@ -3744,10 +3748,110 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t
return ctx->description.c_str(); return ctx->description.c_str();
} }
#if defined(__linux__)
// Helper function to get available memory from /proc/meminfo for UMA systems
static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_kb, long * free_swap_kb) {
FILE * meminfo_file = nullptr;
// 2KB buffer for reading /proc/meminfo since it does not report size info, should be enough
const size_t BUFFER_SIZE = 2048;
auto file_buffer = std::make_unique<char[]>(BUFFER_SIZE);
size_t bytes_read = 0;
long huge_tlb_total_pages = -1;
long huge_tlb_free_pages = -1;
long huge_tlb_page_size = -1;
if (available_memory_kb == nullptr || free_swap_kb == nullptr) {
return false;
}
meminfo_file = fopen("/proc/meminfo", "r");
if (meminfo_file == nullptr) {
GGML_LOG_ERROR("%s: failed to open /proc/meminfo\n", __func__);
return false;
}
// Read file into buffer
bytes_read = fread(file_buffer.get(), 1, BUFFER_SIZE - 1, meminfo_file);
fclose(meminfo_file);
if (bytes_read == 0) {
GGML_LOG_ERROR("%s: failed to read from /proc/meminfo\n", __func__);
return false;
}
file_buffer[bytes_read] = '\0';
*available_memory_kb = -1;
*free_swap_kb = -1;
// Parse the file buffer line by line
char * line = file_buffer.get();
char * line_next;
while (line < file_buffer.get() + bytes_read) {
// Find the end of the current line
line_next = strchr(line, '\n');
if (line_next != nullptr) {
*line_next = '\0';
line_next++;
} else {
line_next = file_buffer.get() + bytes_read;
}
long value;
if (sscanf(line, "MemAvailable: %ld kB", &value) == 1) {
*available_memory_kb = value;
} else if (sscanf(line, "SwapFree: %ld kB", &value) == 1) {
*free_swap_kb = value;
} else if (sscanf(line, "HugePages_Total: %ld", &value) == 1) {
huge_tlb_total_pages = value;
} else if (sscanf(line, "HugePages_Free: %ld", &value) == 1) {
huge_tlb_free_pages = value;
} else if (sscanf(line, "Hugepagesize: %ld kB", &value) == 1) {
huge_tlb_page_size = value;
}
line = line_next;
}
if (huge_tlb_total_pages != 0 && huge_tlb_total_pages != -1) {
*available_memory_kb = huge_tlb_free_pages * huge_tlb_page_size;
// Hugetlbfs pages are not swappable.
*free_swap_kb = 0;
}
GGML_LOG_DEBUG("%s: final available_memory_kb: %ld\n", __func__, *available_memory_kb);
return true;
}
#endif // defined(__linux__)
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
ggml_cuda_set_device(ctx->device); ggml_cuda_set_device(ctx->device);
CUDA_CHECK(cudaMemGetInfo(free, total)); CUDA_CHECK(cudaMemGetInfo(free, total));
// ref: https://github.com/ggml-org/llama.cpp/pull/17368
#if defined(__linux__)
// Check if this is a UMA (Unified Memory Architecture) system
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, ctx->device));
// Check if UMA is explicitly enabled via environment variable
bool uma_env = getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr;
bool is_uma = prop.unifiedAddressing > 0 || uma_env;
if (is_uma) {
// For UMA systems (like DGX Spark), use system memory info
long available_memory_kb = 0;
long free_swap_kb = 0;
if (ggml_backend_cuda_get_available_uma_memory(&available_memory_kb, &free_swap_kb) && available_memory_kb > 0) {
*free = (size_t)available_memory_kb * 1024;
} else {
GGML_LOG_ERROR("%s: /proc/meminfo reading failed, using cudaMemGetInfo\n", __func__);
}
}
#endif // defined(__linux__)
} }
static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) { static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {

View File

@ -11,6 +11,7 @@
#include <cassert> #include <cassert>
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include <cmath>
static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) { static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) {
if (!t) { if (!t) {

View File

@ -406,8 +406,8 @@ enum shader_reduction_mode {
SHADER_REDUCTION_MODE_COUNT, SHADER_REDUCTION_MODE_COUNT,
}; };
// argsort pipelines for up to 1<<10 invocations per workgroup
static constexpr uint32_t num_argsort_pipelines = 11; static constexpr uint32_t num_argsort_pipelines = 11;
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
static constexpr uint32_t num_topk_moe_pipelines = 10; static constexpr uint32_t num_topk_moe_pipelines = 10;
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
@ -526,6 +526,7 @@ struct vk_device_struct {
bool multi_add; bool multi_add;
bool shader_int64; bool shader_int64;
bool buffer_device_address; bool buffer_device_address;
bool vulkan_memory_model;
bool add_rms_fusion; bool add_rms_fusion;
uint32_t partials_binding_alignment; uint32_t partials_binding_alignment;
@ -539,6 +540,9 @@ struct vk_device_struct {
uint32_t subgroup_max_size; uint32_t subgroup_max_size;
bool subgroup_require_full_support; bool subgroup_require_full_support;
// floor(log2(maxComputeWorkGroupInvocations))
uint32_t max_workgroup_size_log2 {};
bool coopmat_support; bool coopmat_support;
bool coopmat_acc_f32_support {}; bool coopmat_acc_f32_support {};
bool coopmat_acc_f16_support {}; bool coopmat_acc_f16_support {};
@ -638,6 +642,7 @@ struct vk_device_struct {
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32;
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_cpy_transpose_16, pipeline_cpy_transpose_32;
vk_pipeline pipeline_set_rows_i32[GGML_TYPE_COUNT]; vk_pipeline pipeline_set_rows_i32[GGML_TYPE_COUNT];
vk_pipeline pipeline_set_rows_i64[GGML_TYPE_COUNT]; vk_pipeline pipeline_set_rows_i64[GGML_TYPE_COUNT];
vk_pipeline pipeline_norm_f32; vk_pipeline pipeline_norm_f32;
@ -664,6 +669,20 @@ struct vk_device_struct {
vk_pipeline pipeline_hardsigmoid[2]; vk_pipeline pipeline_hardsigmoid[2];
vk_pipeline pipeline_hardswish[2]; vk_pipeline pipeline_hardswish[2];
vk_pipeline pipeline_abs[2]; vk_pipeline pipeline_abs[2];
vk_pipeline pipeline_softplus[2];
vk_pipeline pipeline_step[2];
vk_pipeline pipeline_round[2];
vk_pipeline pipeline_ceil[2];
vk_pipeline pipeline_floor[2];
vk_pipeline pipeline_trunc[2];
vk_pipeline pipeline_add1_f16_f16;
vk_pipeline pipeline_add1_f16_f32;
vk_pipeline pipeline_add1_f32_f32;
vk_pipeline pipeline_arange_f32;
vk_pipeline pipeline_fill_f32;
vk_pipeline pipeline_geglu[2]; vk_pipeline pipeline_geglu[2];
vk_pipeline pipeline_reglu[2]; vk_pipeline pipeline_reglu[2];
@ -683,6 +702,7 @@ struct vk_device_struct {
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines]; vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
vk_pipeline pipeline_sum_rows_f32; vk_pipeline pipeline_sum_rows_f32;
vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_argmax_f32;
vk_pipeline pipeline_count_equal_i32; vk_pipeline pipeline_count_equal_i32;
@ -1173,8 +1193,14 @@ struct vk_op_soft_max_push_constants {
struct vk_op_argsort_push_constants { struct vk_op_argsort_push_constants {
uint32_t ncols; uint32_t ncols;
uint32_t ncols_padded;
uint32_t ncols_padded_log2;
uint32_t nrows; uint32_t nrows;
int32_t order; uint32_t order;
uint32_t outer_start;
uint32_t outer_end;
uint32_t inner_start;
uint32_t inner_end;
}; };
struct vk_op_im2col_push_constants { struct vk_op_im2col_push_constants {
@ -2901,15 +2927,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (path == FAPATH) { \ if (path == FAPATH) { \
if (aligned) { \ if (aligned) { \
if (f32acc) { \ if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} else { \ } else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} \ } \
} else { \ } else { \
if (f32acc) { \ if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} else { \ } else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} \ } \
} \ } \
} \ } \
@ -3697,6 +3723,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_32, "cpy_transpose_32", cpy_transpose_32_len, cpy_transpose_32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_16, "cpy_transpose_16", cpy_transpose_16_len, cpy_transpose_16_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
if (device->float_controls_rte_fp16) { if (device->float_controls_rte_fp16) {
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
@ -3826,6 +3855,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_UNARY(hardsigmoid) CREATE_UNARY(hardsigmoid)
CREATE_UNARY(hardswish) CREATE_UNARY(hardswish)
CREATE_UNARY(abs) CREATE_UNARY(abs)
CREATE_UNARY(softplus)
CREATE_UNARY(step)
CREATE_UNARY(round)
CREATE_UNARY(ceil)
CREATE_UNARY(floor)
CREATE_UNARY(trunc)
#undef CREATE_UNARY #undef CREATE_UNARY
#define CREATE_UNARY_RTE(name) \ #define CREATE_UNARY_RTE(name) \
@ -3839,6 +3874,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_UNARY_RTE(exp) CREATE_UNARY_RTE(exp)
#undef CREATE_UNARY_RTE #undef CREATE_UNARY_RTE
ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f16, "add1_f16_f16", add1_f16_f16_len, add1_f16_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
#define CREATE_GLU(name) \ #define CREATE_GLU(name) \
if (device->float_controls_rte_fp16) { \ if (device->float_controls_rte_fp16) { \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
@ -3891,7 +3934,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
} }
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<<i, 1, 1}, {1u<<i, i}, 1, true); uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
if (i <= device->max_workgroup_size_log2 &&
2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
const uint32_t NCOLS_PADDED_LOG2 = i;
ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
}
const uint32_t WG_UNROLL_FACTOR = BLOCK_SIZE > 1 ? 2 : 1;
BLOCK_SIZE /= WG_UNROLL_FACTOR;
ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);
} }
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
@ -4292,6 +4343,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated; device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations)));
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties(); std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
// Try to find a non-graphics compute queue and transfer-focused queues // Try to find a non-graphics compute queue and transfer-focused queues
@ -4431,6 +4484,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->shader_int64 = device_features2.features.shaderInt64; device->shader_int64 = device_features2.features.shaderInt64;
device->buffer_device_address = vk12_features.bufferDeviceAddress; device->buffer_device_address = vk12_features.bufferDeviceAddress;
device->vulkan_memory_model = vk12_features.vulkanMemoryModel;
if (device->subgroup_size_control) { if (device->subgroup_size_control) {
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize; device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
@ -6247,6 +6301,17 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
// Choose "contiguous copy" shader if src/dst are contiguous // Choose "contiguous copy" shader if src/dst are contiguous
bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst)); bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst));
// Use optimized "transpose" shader if src dim1 is the innermost dimension.
bool transpose = dst && src->nb[1] == ggml_type_size(to) && ggml_are_same_shape(dst, src);
if (transpose && src->type == to) {
if (ggml_type_size(to) == 4) {
return ctx->device->pipeline_cpy_transpose_32;
} else if (ggml_type_size(to) == 2) {
return ctx->device->pipeline_cpy_transpose_16;
}
}
if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) { if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) {
if (contig) { if (contig) {
return ctx->device->pipeline_contig_cpy_f32_f32; return ctx->device->pipeline_contig_cpy_f32_f32;
@ -8242,6 +8307,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16]; return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_ABS: case GGML_UNARY_OP_ABS:
return ctx->device->pipeline_abs[dst->type == GGML_TYPE_F16]; return ctx->device->pipeline_abs[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_SOFTPLUS:
return ctx->device->pipeline_softplus[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_STEP:
return ctx->device->pipeline_step[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_ROUND:
return ctx->device->pipeline_round[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_CEIL:
return ctx->device->pipeline_ceil[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_FLOOR:
return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_TRUNC:
return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16];
default: default:
break; break;
} }
@ -8344,19 +8421,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
} }
return nullptr; return nullptr;
} }
case GGML_OP_ARGSORT:
if (ctx->num_additional_fused_ops) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
GGML_ASSERT(idx < num_topk_moe_pipelines);
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
return ctx->device->pipeline_topk_moe[idx][mode];
}
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
return ctx->device->pipeline_argsort_f32[idx];
}
return nullptr;
case GGML_OP_SUM: case GGML_OP_SUM:
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN: case GGML_OP_MEAN:
@ -8449,7 +8513,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_CONV_TRANSPOSE_2D:
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
std::array<uint32_t, 3> elements; std::array<uint32_t, 3> elements{};
if (op == GGML_OP_CONV_2D) elements = ggml_vk_get_conv_elements(dst); if (op == GGML_OP_CONV_2D) elements = ggml_vk_get_conv_elements(dst);
else if (op == GGML_OP_CONV_TRANSPOSE_2D) elements = ggml_vk_get_conv_transpose_2d_elements(dst); else if (op == GGML_OP_CONV_TRANSPOSE_2D) elements = ggml_vk_get_conv_transpose_2d_elements(dst);
vk_conv_shapes shape; vk_conv_shapes shape;
@ -8527,6 +8591,27 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
} }
} }
return nullptr; return nullptr;
case GGML_OP_ADD1:
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_add1_f16_f16;
}
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_add1_f16_f32;
}
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_add1_f32_f32;
}
return nullptr;
case GGML_OP_ARANGE:
if (dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_arange_f32;
}
return nullptr;
case GGML_OP_FILL:
if (dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_fill_f32;
}
return nullptr;
default: default:
return nullptr; return nullptr;
} }
@ -8748,8 +8833,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
break; break;
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 }; GGML_ASSERT(0);
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
break; break;
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
{ {
@ -8817,6 +8901,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_SUB: case GGML_OP_SUB:
case GGML_OP_DIV: case GGML_OP_DIV:
case GGML_OP_MUL: case GGML_OP_MUL:
case GGML_OP_ADD1:
case GGML_OP_ARANGE:
case GGML_OP_FILL:
case GGML_OP_SCALE: case GGML_OP_SCALE:
case GGML_OP_SQR: case GGML_OP_SQR:
case GGML_OP_SQRT: case GGML_OP_SQRT:
@ -8858,6 +8945,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
} else { } else {
elements = { ne, 1, 1 }; elements = { ne, 1, 1 };
} }
if (pipeline == ctx->device->pipeline_cpy_transpose_32 ||
pipeline == ctx->device->pipeline_cpy_transpose_16) {
// 32x32 tiles
elements[0] = (uint32_t)CEIL_DIV(dst->ne[0], 32);
elements[1] = (uint32_t)CEIL_DIV(dst->ne[1], 32);
elements[2] = (uint32_t)(dst->ne[2]*dst->ne[3]);
elements[0] = std::min(elements[0], ctx->device->properties.limits.maxComputeWorkGroupCount[0]);
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
}
} break; } break;
case GGML_OP_ADD_ID: case GGML_OP_ADD_ID:
{ {
@ -9423,6 +9521,63 @@ static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, cons
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst)); ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst));
} }
static void ggml_vk_add1(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ADD1, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
0.0f, 0.0f, 0,
});
}
static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
VK_LOG_DEBUG("ggml_vk_arange(dst=" << dst << ", ne=" << ggml_nelements(dst) << ")");
vk_op_push_constants pc = {
(uint32_t)ggml_nelements(dst),
1,
ggml_get_op_params_f32(dst, 0),
ggml_get_op_params_f32(dst, 2),
};
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_ARANGE);
GGML_ASSERT(pipeline != nullptr);
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false);
std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(dst), 1, 1 };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements);
}
static void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
VK_LOG_DEBUG("ggml_vk_fill(dst=" << dst << ", ne=" << ggml_nelements(dst) << ")");
vk_op_push_constants pc = {
(uint32_t)ggml_nelements(dst),
1,
ggml_get_op_params_f32(dst, 0),
0.0f,
};
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_FILL);
GGML_ASSERT(pipeline != nullptr);
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false);
std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(dst), 1, 1 };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements);
}
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst)); ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst));
} }
@ -9865,16 +10020,89 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
} }
static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
int32_t * op_params = (int32_t *)dst->op_params; const uint32_t * op_params = (const uint32_t *)dst->op_params;
uint32_t ncols = src0->ne[0]; uint32_t ncols = src0->ne[0];
uint32_t nrows = ggml_nrows(src0); uint32_t nrows = ggml_nrows(src0);
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, { uint32_t ncols_pad_log2 = (uint32_t)ceilf(log2f(float(ncols)));
ncols, uint32_t ncolsp2 = 1 << ncols_pad_log2;
nrows,
op_params[0], vk_op_argsort_push_constants pc { ncols, ncolsp2, ncols_pad_log2, nrows, op_params[0], 0, 0, 0, 0, };
});
// Pick the largest workgroup size <= ncolsp2
uint32_t pipeline_idx = std::min(ncols_pad_log2, num_argsort_pipelines - 1);
// Use the "small" argsort shader if the whole sort can be done by a single workgroup.
bool use_small = ncols_pad_log2 <= ctx->device->max_workgroup_size_log2 &&
ctx->device->pipeline_argsort_f32[pipeline_idx] != nullptr;
vk_pipeline pipeline = use_small ? ctx->device->pipeline_argsort_f32[pipeline_idx]
: ctx->device->pipeline_argsort_large_f32[pipeline_idx];
vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0);
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
vk_subbuffer subbuf1 = dst_buf;
// Reserve space for ivec2 per element, with rows padded to a power of two
if (!use_small) {
const size_t x_sz = size_t{ncolsp2} * nrows * 2 * sizeof(int);
if (ctx->prealloc_size_x < x_sz) {
ctx->prealloc_size_x = x_sz;
ggml_vk_preallocate_buffers(ctx, subctx);
}
if (ctx->prealloc_x_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
subbuf1 = { ctx->prealloc_x, 0, ctx->prealloc_x->size };
}
std::array<uint32_t, 3> elements;
elements[0] = ncolsp2;
elements[1] = std::min((uint32_t)ggml_nrows(src0), ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
elements[2] = 1;
// First dispatch initializes tmp_idx and does the first N passes where
// there is only communication between threads in the same workgroup.
{
vk_op_argsort_push_constants pc2 = pc;
pc2.outer_start = 0;
pc2.outer_end = std::min(ncols_pad_log2, ctx->device->max_workgroup_size_log2);
pc2.inner_start = 0;
pc2.inner_end = 100;
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
}
if (!use_small) {
ggml_vk_sync_buffers(ctx, subctx);
// Loop over outer/inner passes, synchronizing between each pass.
for (uint32_t outer = ctx->device->max_workgroup_size_log2; outer < ncols_pad_log2; ++outer) {
for (uint32_t inner = 0; inner < outer + 1; ++inner) {
vk_op_argsort_push_constants pc2 = pc;
pc2.outer_start = outer;
pc2.outer_end = outer + 1;
pc2.inner_start = inner;
pc2.inner_end = inner + 1;
// When the inner idx is large enough, there's only communication
// within a workgroup. So the remaining inner iterations can all
// run in the same dispatch.
if (outer - inner < pipeline_idx) {
pc2.inner_end = 100;
inner = outer;
pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx];
} else {
// Smaller workgroup empirically seems to perform better
pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx - 2];
}
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
ggml_vk_sync_buffers(ctx, subctx);
}
}
ctx->prealloc_x_need_sync = true;
}
} }
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@ -11182,6 +11410,12 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_ABS: case GGML_UNARY_OP_ABS:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_STEP:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_TRUNC:
break; break;
default: default:
return false; return false;
@ -11223,6 +11457,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_SUB: case GGML_OP_SUB:
case GGML_OP_MUL: case GGML_OP_MUL:
case GGML_OP_DIV: case GGML_OP_DIV:
case GGML_OP_ADD1:
case GGML_OP_ARANGE:
case GGML_OP_FILL:
case GGML_OP_CONCAT: case GGML_OP_CONCAT:
case GGML_OP_UPSCALE: case GGML_OP_UPSCALE:
case GGML_OP_SCALE: case GGML_OP_SCALE:
@ -11435,6 +11672,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_UPSCALE: case GGML_OP_UPSCALE:
ggml_vk_upscale(ctx, compute_ctx, src0, node); ggml_vk_upscale(ctx, compute_ctx, src0, node);
break;
case GGML_OP_ADD1:
ggml_vk_add1(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_ARANGE:
ggml_vk_arange(ctx, compute_ctx, node);
break;
case GGML_OP_FILL:
ggml_vk_fill(ctx, compute_ctx, node);
break; break;
case GGML_OP_SCALE: case GGML_OP_SCALE:
ggml_vk_scale(ctx, compute_ctx, src0, node); ggml_vk_scale(ctx, compute_ctx, src0, node);
@ -11519,6 +11768,12 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_ABS: case GGML_UNARY_OP_ABS:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_STEP:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_TRUNC:
ggml_vk_unary(ctx, compute_ctx, src0, node); ggml_vk_unary(ctx, compute_ctx, src0, node);
break; break;
default: default:
@ -11721,6 +11976,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_OP_SUB: case GGML_OP_SUB:
case GGML_OP_MUL: case GGML_OP_MUL:
case GGML_OP_DIV: case GGML_OP_DIV:
case GGML_OP_ADD1:
case GGML_OP_ARANGE:
case GGML_OP_FILL:
case GGML_OP_ADD_ID: case GGML_OP_ADD_ID:
case GGML_OP_CONCAT: case GGML_OP_CONCAT:
case GGML_OP_UPSCALE: case GGML_OP_UPSCALE:
@ -11792,6 +12050,12 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_ABS: case GGML_UNARY_OP_ABS:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_STEP:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_TRUNC:
buf = tensor->buffer; buf = tensor->buffer;
break; break;
default: default:
@ -13394,6 +13658,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_ABS: case GGML_UNARY_OP_ABS:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_STEP:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_TRUNC:
return ggml_is_contiguous(op->src[0]) && return ggml_is_contiguous(op->src[0]) &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@ -13695,10 +13965,25 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_LOG: case GGML_OP_LOG:
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
return op->ne[0] <= max_argsort_cols; {
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
return false;
}
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
// pipeline_argsort_large_f32 requires vulkan memory model.
if (device->vulkan_memory_model) {
return true;
} else {
return op->ne[0] <= (1 << device->max_workgroup_size_log2);
}
}
case GGML_OP_UPSCALE: case GGML_OP_UPSCALE:
case GGML_OP_ACC: case GGML_OP_ACC:
case GGML_OP_CONCAT: case GGML_OP_CONCAT:
case GGML_OP_ADD1:
case GGML_OP_ARANGE:
case GGML_OP_FILL:
case GGML_OP_SCALE: case GGML_OP_SCALE:
case GGML_OP_PAD: case GGML_OP_PAD:
case GGML_OP_ROLL: case GGML_OP_ROLL:
@ -14181,6 +14466,16 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
} else if (tensor->op == GGML_OP_SCALE) { } else if (tensor->op == GGML_OP_SCALE) {
const float * params = (const float *)tensor->op_params; const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]); tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]);
} else if (tensor->op == GGML_OP_ADD1) {
tensor_clone = ggml_add1(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_ARANGE) {
const float start = ggml_get_op_params_f32(tensor, 0);
const float stop = ggml_get_op_params_f32(tensor, 1);
const float step = ggml_get_op_params_f32(tensor, 2);
tensor_clone = ggml_arange(ggml_ctx, start, stop, step);
} else if (tensor->op == GGML_OP_FILL) {
const float value = ggml_get_op_params_f32(tensor, 0);
tensor_clone = ggml_fill(ggml_ctx, tensor_clone, value);
} else if (tensor->op == GGML_OP_SQR) { } else if (tensor->op == GGML_OP_SQR) {
tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_SQRT) { } else if (tensor->op == GGML_OP_SQRT) {
@ -14294,6 +14589,24 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_UNARY_OP_ABS: case GGML_UNARY_OP_ABS:
tensor_clone = ggml_abs(ggml_ctx, src_clone[0]); tensor_clone = ggml_abs(ggml_ctx, src_clone[0]);
break; break;
case GGML_UNARY_OP_SOFTPLUS:
tensor_clone = ggml_softplus(ggml_ctx, src_clone[0]);
break;
case GGML_UNARY_OP_STEP:
tensor_clone = ggml_step(ggml_ctx, src_clone[0]);
break;
case GGML_UNARY_OP_ROUND:
tensor_clone = ggml_round(ggml_ctx, src_clone[0]);
break;
case GGML_UNARY_OP_CEIL:
tensor_clone = ggml_ceil(ggml_ctx, src_clone[0]);
break;
case GGML_UNARY_OP_FLOOR:
tensor_clone = ggml_floor(ggml_ctx, src_clone[0]);
break;
case GGML_UNARY_OP_TRUNC:
tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]);
break;
default: default:
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");

View File

@ -0,0 +1,28 @@
#version 450
#extension GL_EXT_shader_16bit_storage : require
#include "types.glsl"
#include "generic_binary_head.glsl"
const uint num_threads = 256;
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
void main() {
uint idx = get_idx();
const uint num_iter = 2;
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= p.ne) {
continue;
}
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset()]));
idx += num_threads;
}
}

View File

@ -0,0 +1,20 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
// p.param1 = start, p.param2 = step
float value = p.param1 + p.param2 * float(i);
data_d[i] = D_TYPE(value);
}

View File

@ -4,28 +4,27 @@
#include "types.glsl" #include "types.glsl"
layout(constant_id = 0) const int BLOCK_SIZE = 1024; layout(constant_id = 0) const int BLOCK_SIZE = 1024;
layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10; layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10;
#define ASC 0 #define ASC 0
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) buffer D {int data_d[];}; layout (binding = 2) writeonly buffer D {int data_d[];};
layout (push_constant) uniform parameter { layout (push_constant) uniform parameter {
uint ncols; uint ncols;
uint ncols_padded;
uint ncols_padded_log2;
uint nrows; uint nrows;
uint order; uint order;
uint outer_start;
uint outer_end;
uint inner_start;
uint inner_end;
} p; } p;
shared int dst_row[BLOCK_SIZE]; shared ivec2 dst_row[BLOCK_SIZE];
shared A_TYPE a_sh[BLOCK_SIZE];
void swap(uint idx0, uint idx1) {
int tmp = dst_row[idx0];
dst_row[idx0] = dst_row[idx1];
dst_row[idx1] = tmp;
}
void argsort(bool needs_bounds_check, const uint row) { void argsort(bool needs_bounds_check, const uint row) {
// bitonic sort // bitonic sort
@ -34,11 +33,10 @@ void argsort(bool needs_bounds_check, const uint row) {
const uint row_offset = row * p.ncols; const uint row_offset = row * p.ncols;
// initialize indices // initialize indices
dst_row[col] = col; dst_row[col] = ivec2(col, floatBitsToInt(data_a[row_offset + col]));
a_sh[col] = data_a[row_offset + col];
barrier(); barrier();
uint num_outer_loop_iters = BLOCK_SIZE_LOG2; uint num_outer_loop_iters = NCOLS_PADDED_LOG2;
[[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) { [[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
uint num_inner_loop_iters = outer_idx + 1; uint num_inner_loop_iters = outer_idx + 1;
[[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) { [[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
@ -47,14 +45,15 @@ void argsort(bool needs_bounds_check, const uint row) {
int idx_0 = (col & k) == 0 ? col : ixj; int idx_0 = (col & k) == 0 ? col : ixj;
int idx_1 = (col & k) == 0 ? ixj : col; int idx_1 = (col & k) == 0 ? ixj : col;
int sh_idx_0 = dst_row[idx_0]; ivec2 sh_idx_0 = dst_row[idx_0];
int sh_idx_1 = dst_row[idx_1]; ivec2 sh_idx_1 = dst_row[idx_1];
bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false; bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false;
bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false; bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false;
if ((idx_0_oob || if ((idx_0_oob ||
(!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) { (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {
swap(idx_0, idx_1); dst_row[idx_0] = sh_idx_1;
dst_row[idx_1] = sh_idx_0;
} }
barrier(); barrier();
@ -63,9 +62,9 @@ void argsort(bool needs_bounds_check, const uint row) {
if (col < p.ncols) { if (col < p.ncols) {
if (p.order == ASC) { if (p.order == ASC) {
data_d[row_offset + col] = dst_row[col]; data_d[row_offset + col] = dst_row[col].x;
} else { } else {
data_d[row_offset + p.ncols - col - 1] = dst_row[col]; data_d[row_offset + p.ncols - col - 1] = dst_row[col].x;
} }
} }
} }

View File

@ -0,0 +1,114 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_KHR_memory_scope_semantics : enable
#pragma use_vulkan_memory_model
#include "types.glsl"
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
layout(constant_id = 1) const int WG_UNROLL_FACTOR = 2;
#define ASC 0
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) workgroupcoherent buffer B {ivec2 tmp_idx[];};
layout (binding = 2) workgroupcoherent buffer D {int data_d[];};
layout (push_constant) uniform parameter {
uint ncols;
uint ncols_padded;
uint ncols_padded_log2;
uint nrows;
uint order;
uint outer_start;
uint outer_end;
uint inner_start;
uint inner_end;
} p;
void argsort(bool needs_bounds_check, const uint row) {
// bitonic sort
int col = int(gl_GlobalInvocationID.x);
col = (col % BLOCK_SIZE) + (col / BLOCK_SIZE) * BLOCK_SIZE * WG_UNROLL_FACTOR;
const uint row_offset = row * p.ncols;
uint idx_offset = row * p.ncols_padded;
bool need_barrier = false;
// initialize indices
if (p.outer_start == 0 && p.inner_start == 0) {
[[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {
uint c = u*BLOCK_SIZE + col;
if (c < p.ncols_padded) {
ivec2 v = ivec2(c, floatBitsToInt(data_a[row_offset + c]));
tmp_idx[idx_offset + c] = v;
}
}
need_barrier = true;
}
[[unroll]] for (uint outer_idx = p.outer_start, k = (2 << outer_idx); outer_idx < p.outer_end; k *= 2, outer_idx++) {
uint inner_end = min(p.inner_end, outer_idx + 1);
for (uint j = k >> (p.inner_start + 1), inner_idx = p.inner_start; inner_idx < inner_end; j /= 2, inner_idx++) {
if (need_barrier) {
controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease);
}
need_barrier = true;
[[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {
int c = u*BLOCK_SIZE + col;
const int ixj = int(c ^ j);
if (ixj < c) {
continue;
}
int idx_0 = (c & k) == 0 ? c : ixj;
int idx_1 = (c & k) == 0 ? ixj : c;
ivec2 sh_idx_0 = tmp_idx[idx_offset + idx_0];
ivec2 sh_idx_1 = tmp_idx[idx_offset + idx_1];
bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false;
bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false;
if ((idx_0_oob ||
(!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y)))) {
tmp_idx[idx_offset + idx_0] = sh_idx_1;
tmp_idx[idx_offset + idx_1] = sh_idx_0;
}
}
}
}
if (p.outer_end == p.ncols_padded_log2 &&
p.inner_end >= p.ncols_padded_log2 + 1) {
controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease);
[[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {
uint c = u*BLOCK_SIZE + col;
if (c < p.ncols) {
if (p.order == ASC) {
data_d[row_offset + c] = tmp_idx[idx_offset + c].x;
} else {
data_d[row_offset + p.ncols - c - 1] = tmp_idx[idx_offset + c].x;
}
}
}
}
}
void main() {
if (p.ncols == p.ncols_padded) {
uint row = gl_WorkGroupID.y;
while (row < p.nrows) {
argsort(false, row);
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
} else {
uint row = gl_WorkGroupID.y;
while (row < p.nrows) {
argsort(true, row);
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
}
}

View File

@ -0,0 +1,22 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float x = float(data_a[i]);
data_d[i] = D_TYPE(ceil(x));
}

View File

@ -0,0 +1,67 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
// workgroup does 32x32 tile, but uses 32x8 threads
#define TILE_DIM 32
layout(local_size_x = 32, local_size_y = 8, local_size_z = 1) in;
shared uint sh[TILE_DIM][TILE_DIM + 1];
void iter(uvec3 wg_id) {
const uint tile_col = wg_id.x;
const uint tile_row = wg_id.y;
const uint tid_col = gl_LocalInvocationID.x;
const uint tid_row = gl_LocalInvocationID.y;
const uint i2 = wg_id.z % p.ne12;
const uint i3 = wg_id.z / p.ne12;
const uint i02 = i2;
const uint i03 = i3;
// The workgroup does TILE_DIM x TILE_DIM, but swaps the LSBs of the
// src coords to make memory accesses contiguous, dst has tid.x in i0,
// src has tid.x in i01
[[unroll]] for (uint y = 0; y < 4; ++y) {
const uint i00 = tile_col * TILE_DIM + tid_row + 8 * y;
const uint i01 = tile_row * TILE_DIM + tid_col;
if (i00 < p.ne00 && i01 < p.ne01 && i02 < p.ne02 && i03 < p.ne03) {
const uint src_idx = i00 * p.nb00 + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
sh[tid_row + 8 * y][tid_col] = uint(data_a[get_aoffset() + src_idx]);
}
}
barrier();
[[unroll]] for (uint y = 0; y < 4; ++y) {
const uint i0 = tile_col * TILE_DIM + tid_col;
const uint i1 = tile_row * TILE_DIM + tid_row + 8 * y;
if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) {
const uint dst_idx = i0 * p.nb10 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
// load transposed
data_d[get_doffset() + dst_idx] = D_TYPE(sh[tid_col][tid_row + 8 * y]);
}
}
}
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
void main() {
uint z = gl_WorkGroupID.z;
uint y = gl_WorkGroupID.y;
bool need_barrier = false;
for (uint z = gl_WorkGroupID.z; z < p.ne12 * p.ne13; z += gl_NumWorkGroups.z) {
for (uint y = gl_WorkGroupID.y; y < CEIL_DIV(p.ne11, TILE_DIM); y += gl_NumWorkGroups.y) {
for (uint x = gl_WorkGroupID.x; x < CEIL_DIV(p.ne10, TILE_DIM); x += gl_NumWorkGroups.x) {
if (need_barrier) {
barrier();
}
need_barrier = true;
iter(uvec3(x, y, z));
}
}
}
}

View File

@ -0,0 +1,19 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
// p.param1 = fill value
data_d[i] = D_TYPE(p.param1);
}

View File

@ -0,0 +1,22 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float x = float(data_a[i]);
data_d[i] = D_TYPE(floor(x));
}

View File

@ -0,0 +1,29 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float x = float(data_a[i]);
float result;
// Round halfway cases away from zero as roundf does.
if (x >= 0.0) {
result = floor(x + 0.5);
} else {
result = ceil(x - 0.5);
}
data_d[i] = D_TYPE(result);
}

View File

@ -0,0 +1,23 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float x = float(data_a[i]);
const float result = (x > 20.0f) ? x : log(1.0f + exp(x));
data_d[i] = D_TYPE(result);
}

View File

@ -0,0 +1,22 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float x = float(data_a[i]);
data_d[i] = D_TYPE(x >= 0.0f ? 1.0f : 0.0f);
}

View File

@ -0,0 +1,22 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float x = float(data_a[i]);
data_d[i] = D_TYPE(trunc(x));
}

View File

@ -734,6 +734,9 @@ void process_shaders() {
string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}});
string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}});
string_to_spv("cpy_transpose_16", "copy_transpose.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}});
string_to_spv("cpy_transpose_32", "copy_transpose.comp", {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}});
for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
@ -843,6 +846,25 @@ void process_shaders() {
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("add1_f16_f16", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
string_to_spv("add1_f16_f32", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
string_to_spv("add1_f32_f32", "add1.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("arange_f32", "arange.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("fill_f32", "fill.comp", {{"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("step_f16", "step.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("step_f32", "step.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("round_f16", "round.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("round_f32", "round.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("ceil_f16", "ceil.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("ceil_f32", "ceil.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("floor_f16", "floor.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("floor_f32", "floor.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("trunc_f16", "trunc.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("trunc_f32", "trunc.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
for (auto rte : {false, true}) { for (auto rte : {false, true}) {
std::string suffix = rte ? "_rte" : ""; std::string suffix = rte ? "_rte" : "";
string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
@ -889,6 +911,7 @@ void process_shaders() {
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));

View File

@ -1 +1 @@
7b6abb2b92fcef35cb01c6ce6ada9bd85306522d 781baf2a14d9e0aaee542b2e1bb918bfc4132199

View File

@ -6,8 +6,10 @@
#include <cmath> #include <cmath>
#include <algorithm> #include <algorithm>
#include <cstdint>
#include <stdexcept> #include <stdexcept>
#define MAX_REPETITION_THRESHOLD 2000
// //
// helpers // helpers
// //
@ -345,7 +347,9 @@ const char * llama_grammar_parser::parse_sequence(
size_t last_sym_start = rule.size(); size_t last_sym_start = rule.size();
const char * pos = src; const char * pos = src;
auto handle_repetitions = [&](int min_times, int max_times) { // use UINT64_MAX as the empty value because we aligned to the proper unsigned long type so -1 can't be used
// (though it's technically the same as -1 now)
auto handle_repetitions = [&](unsigned long min_times, unsigned long max_times) {
if (last_sym_start == rule.size()) { if (last_sym_start == rule.size()) {
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
@ -373,20 +377,20 @@ const char * llama_grammar_parser::parse_sequence(
rule.resize(last_sym_start); rule.resize(last_sym_start);
} else { } else {
// Repeat the previous elements (min_times - 1) times // Repeat the previous elements (min_times - 1) times
for (int i = 1; i < min_times; i++) { for (unsigned long i = 1; i < min_times; i++) {
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end()); rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
} }
} }
uint32_t last_rec_rule_id = 0; uint32_t last_rec_rule_id = 0;
auto n_opt = max_times < 0 ? 1 : max_times - min_times; auto n_opt = max_times == UINT64_MAX ? 1 : max_times - min_times;
llama_grammar_rule rec_rule(prev_rule); llama_grammar_rule rec_rule(prev_rule);
for (int i = 0; i < n_opt; i++) { for (unsigned long i = 0; i < n_opt; i++) {
rec_rule.resize(prev_rule.size()); rec_rule.resize(prev_rule.size());
uint32_t rec_rule_id = generate_symbol_id( rule_name); uint32_t rec_rule_id = generate_symbol_id( rule_name);
if (i > 0 || max_times < 0) { if (i > 0 || max_times == UINT64_MAX) {
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times == UINT64_MAX ? rec_rule_id : last_rec_rule_id});
} }
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
rec_rule.push_back({LLAMA_GRETYPE_END, 0}); rec_rule.push_back({LLAMA_GRETYPE_END, 0});
@ -478,10 +482,10 @@ const char * llama_grammar_parser::parse_sequence(
throw std::runtime_error(std::string("expecting an int at ") + pos); throw std::runtime_error(std::string("expecting an int at ") + pos);
} }
const char * int_end = parse_int(pos); const char * int_end = parse_int(pos);
int min_times = std::stoul(std::string(pos, int_end - pos)); unsigned long min_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested); pos = parse_space(int_end, is_nested);
int max_times = -1; unsigned long max_times = UINT64_MAX;
if (*pos == '}') { if (*pos == '}') {
max_times = min_times; max_times = min_times;
@ -502,6 +506,9 @@ const char * llama_grammar_parser::parse_sequence(
} else { } else {
throw std::runtime_error(std::string("expecting ',' at ") + pos); throw std::runtime_error(std::string("expecting ',' at ") + pos);
} }
if (min_times > MAX_REPETITION_THRESHOLD || (max_times != UINT64_MAX && max_times > MAX_REPETITION_THRESHOLD)) {
throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions"));
}
handle_repetitions(min_times, max_times); handle_repetitions(min_times, max_times);
} else { } else {
break; break;

View File

@ -23,7 +23,7 @@ time_meas::~time_meas() {
if (t_start_us >= 0) { if (t_start_us >= 0) {
t_acc += ggml_time_us() - t_start_us; t_acc += ggml_time_us() - t_start_us;
} }
} }
void llama_log_set(ggml_log_callback log_callback, void * user_data) { void llama_log_set(ggml_log_callback log_callback, void * user_data) {
ggml_log_set(log_callback, user_data); ggml_log_set(log_callback, user_data);

View File

@ -472,9 +472,6 @@ static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
for (auto * smpl : chain->samplers) { for (auto * smpl : chain->samplers) {
llama_sampler_reset(smpl); llama_sampler_reset(smpl);
} }
chain->t_sample_us = 0;
chain->n_sample = 0;
} }
static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
@ -2670,8 +2667,7 @@ struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * c
void llama_perf_sampler_print(const struct llama_sampler * chain) { void llama_perf_sampler_print(const struct llama_sampler * chain) {
const auto data = llama_perf_sampler(chain); const auto data = llama_perf_sampler(chain);
LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", LLAMA_LOG_INFO("%s: samplers time = %10.2f ms / %5d runs\n", __func__, data.t_sample_ms, data.n_sample);
__func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
} }
void llama_perf_sampler_reset(struct llama_sampler * chain) { void llama_perf_sampler_reset(struct llama_sampler * chain) {
@ -2681,5 +2677,6 @@ void llama_perf_sampler_reset(struct llama_sampler * chain) {
auto * ctx = (struct llama_sampler_chain *) chain->ctx; auto * ctx = (struct llama_sampler_chain *) chain->ctx;
ctx->t_sample_us = ctx->n_sample = 0; ctx->t_sample_us = 0;
ctx->n_sample = 0;
} }

View File

@ -2776,24 +2776,34 @@ struct test_cpy : public test_case {
struct test_cont : public test_case { struct test_cont : public test_case {
const ggml_type type; const ggml_type type;
const std::array<int64_t, 4> ne; const std::array<int64_t, 4> ne;
bool use_view_slice;
std::string vars() override { std::string vars() override {
return VARS_TO_STR2(type, ne); return VARS_TO_STR3(type, ne, use_view_slice);
} }
test_cont(ggml_type type = GGML_TYPE_F32, test_cont(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 10, 10, 1}) std::array<int64_t, 4> ne = {10, 10, 10, 1},
: type(type), ne(ne) {} bool use_view_slice = false)
: type(type), ne(ne), use_view_slice(use_view_slice) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(src); ggml_set_param(src);
ggml_set_name(src, "src"); ggml_set_name(src, "src");
src = ggml_transpose(ctx, src);
ggml_set_name(src, "src_transposed");
ggml_tensor * out = ggml_cont(ctx, src); ggml_tensor * dst;
if (use_view_slice) {
dst = ggml_view_4d(ctx, src, src->ne[0], 1, src->ne[2], src->ne[3],
src->nb[1], src->nb[2], src->nb[3], src->nb[0] * (src->ne[1] - 1));
ggml_set_name(dst, "src_view_slice");
} else {
dst = ggml_transpose(ctx, src);
ggml_set_name(dst, "src_transposed");
}
ggml_tensor * out = ggml_cont(ctx, dst);
ggml_set_name(out, "out"); ggml_set_name(out, "out");
return out; return out;
@ -6945,16 +6955,17 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0})); test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
test_cases.emplace_back(new test_cont()); for (ggml_type type_dst : { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16 }) {
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1})); for (bool use_view_slice : { true, false }) {
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 3 ,5})); for (std::array<int64_t, 4> ne : std::initializer_list<std::array<int64_t, 4>>{ {2, 1, 1, 1}, {2, 1, 3, 5},
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 3, 5 ,7})); {2, 3, 5, 7}, {1, 4, 4, 1}, {1, 8, 17, 1}, {10, 10, 10, 1} }) {
test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 1, 1 ,1})); if (use_view_slice && (type_dst == GGML_TYPE_F16 || type_dst == GGML_TYPE_BF16)) {
test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 1, 3 ,5})); continue; // TODO: add after WebGPU is fixed
test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 3, 5 ,7})); }
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 1, 1 ,1})); test_cases.emplace_back(new test_cont(type_dst, ne, use_view_slice));
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 1, 3 ,5})); }
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 3, 5 ,7})); }
}
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) { auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) { for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
@ -7015,6 +7026,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16)); test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
test_cases.emplace_back(new test_add1()); test_cases.emplace_back(new test_add1());
test_cases.emplace_back(new test_add1(GGML_TYPE_F32, {1024, 1024, 1, 1}));
test_cases.emplace_back(new test_scale()); test_cases.emplace_back(new test_scale());
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f)); test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f, true)); // inplace test test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f, true)); // inplace test
@ -7354,9 +7366,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_clamp (type, {7, 1, 5, 3})); test_cases.emplace_back(new test_clamp (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_leaky_relu(type, {7, 1, 5, 3})); test_cases.emplace_back(new test_leaky_relu(type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_floor (type, {7, 1, 5, 3})); test_cases.emplace_back(new test_floor (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_floor (type, { 1024, 1024, 1, 1 }));
test_cases.emplace_back(new test_ceil (type, {7, 1, 5, 3})); test_cases.emplace_back(new test_ceil (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_ceil (type, { 1024, 1024, 1, 1 }));
test_cases.emplace_back(new test_round (type, {7, 1, 5, 3})); test_cases.emplace_back(new test_round (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_round (type, { 1024, 1024, 1, 1 }));
test_cases.emplace_back(new test_trunc (type, {7, 1, 5, 3})); test_cases.emplace_back(new test_trunc (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_trunc (type, { 1024, 1024, 1, 1 }));
} }
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5)); test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));
@ -7501,13 +7517,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
} }
for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) { for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order)); for (uint32_t i = 4; i <= 1024*1024; i *= 2) {
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i-1, 1, 1, 1}));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i, 1, 1, 1}));
}
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1023, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1023, 2, 1, 3}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 2, 1, 3}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 2, 1, 3}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // many backends only handle up to 1024
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2047, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2047, 2, 1, 3}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order));
@ -7556,6 +7574,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1})); test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
test_cases.emplace_back(new test_roll()); test_cases.emplace_back(new test_roll());
test_cases.emplace_back(new test_arange()); test_cases.emplace_back(new test_arange());
test_cases.emplace_back(new test_arange(GGML_TYPE_F32, 0.0f, 1048576.0f, 1.0f));
test_cases.emplace_back(new test_timestep_embedding()); test_cases.emplace_back(new test_timestep_embedding());
test_cases.emplace_back(new test_leaky_relu()); test_cases.emplace_back(new test_leaky_relu());
@ -7583,6 +7602,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_fill(0.0f)); test_cases.emplace_back(new test_fill(0.0f));
test_cases.emplace_back(new test_fill(2.0f, GGML_TYPE_F32, { 303, 207, 11, 3 })); test_cases.emplace_back(new test_fill(2.0f, GGML_TYPE_F32, { 303, 207, 11, 3 }));
test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F32, { 800, 600, 4, 4 })); test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F32, { 800, 600, 4, 4 }));
test_cases.emplace_back(new test_fill(3.5f, GGML_TYPE_F32, { 2048, 512, 2, 2 }));
test_cases.emplace_back(new test_solve_tri()); test_cases.emplace_back(new test_solve_tri());
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 11, 11, 1, 1 }, { 5, 11, 1, 1 })); test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 11, 11, 1, 1 }, { 5, 11, 1, 1 }));

View File

@ -147,11 +147,15 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
auto * mem = llama_get_memory(ctx); llama_memory_t mem = llama_get_memory(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
// note: the time for chat template initialization is not negligible:
auto chat_templates = common_chat_templates_init(model, params.chat_template); auto chat_templates = common_chat_templates_init(model, params.chat_template);
// start measuring performance timings from here
llama_perf_context_reset(ctx);
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads); LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);

Binary file not shown.

View File

@ -25,3 +25,4 @@ vite.config.ts.timestamp-*
*storybook.log *storybook.log
storybook-static storybook-static
*.code-workspace

View File

@ -2109,9 +2109,9 @@
} }
}, },
"node_modules/@sveltejs/kit": { "node_modules/@sveltejs/kit": {
"version": "2.48.4", "version": "2.48.5",
"resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.48.4.tgz", "resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.48.5.tgz",
"integrity": "sha512-TGFX1pZUt9qqY20Cv5NyYvy0iLWHf2jXi8s+eCGsig7jQMdwZWKUFMR6TbvFNhfDSUpc1sH/Y5EHv20g3HHA3g==", "integrity": "sha512-/rnwfSWS3qwUSzvHynUTORF9xSJi7PCR9yXkxUOnRrNqyKmCmh3FPHH+E9BbgqxXfTevGXBqgnlh9kMb+9T5XA==",
"dev": true, "dev": true,
"license": "MIT", "license": "MIT",
"dependencies": { "dependencies": {
@ -5087,9 +5087,9 @@
"license": "MIT" "license": "MIT"
}, },
"node_modules/js-yaml": { "node_modules/js-yaml": {
"version": "4.1.0", "version": "4.1.1",
"resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz",
"integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", "integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==",
"dev": true, "dev": true,
"license": "MIT", "license": "MIT",
"dependencies": { "dependencies": {

View File

@ -0,0 +1,273 @@
<script lang="ts">
import { FileText, Image, Music, FileIcon, Eye } from '@lucide/svelte';
import { FileTypeCategory, MimeTypeApplication } from '$lib/enums/files';
import { convertPDFToImage } from '$lib/utils/pdf-processing';
import { Button } from '$lib/components/ui/button';
import { getFileTypeCategory } from '$lib/utils/file-type';
interface Props {
// Either an uploaded file or a stored attachment
uploadedFile?: ChatUploadedFile;
attachment?: DatabaseMessageExtra;
// For uploaded files
preview?: string;
name?: string;
type?: string;
textContent?: string;
}
let { uploadedFile, attachment, preview, name, type, textContent }: Props = $props();
let displayName = $derived(uploadedFile?.name || attachment?.name || name || 'Unknown File');
let displayPreview = $derived(
uploadedFile?.preview || (attachment?.type === 'imageFile' ? attachment.base64Url : preview)
);
let displayType = $derived(
uploadedFile?.type ||
(attachment?.type === 'imageFile'
? 'image'
: attachment?.type === 'textFile'
? 'text'
: attachment?.type === 'audioFile'
? attachment.mimeType || 'audio'
: attachment?.type === 'pdfFile'
? MimeTypeApplication.PDF
: type || 'unknown')
);
let displayTextContent = $derived(
uploadedFile?.textContent ||
(attachment?.type === 'textFile'
? attachment.content
: attachment?.type === 'pdfFile'
? attachment.content
: textContent)
);
let isAudio = $derived(
getFileTypeCategory(displayType) === FileTypeCategory.AUDIO || displayType === 'audio'
);
let isImage = $derived(
getFileTypeCategory(displayType) === FileTypeCategory.IMAGE || displayType === 'image'
);
let isPdf = $derived(displayType === MimeTypeApplication.PDF);
let isText = $derived(
getFileTypeCategory(displayType) === FileTypeCategory.TEXT || displayType === 'text'
);
let IconComponent = $derived(() => {
if (isImage) return Image;
if (isText || isPdf) return FileText;
if (isAudio) return Music;
return FileIcon;
});
let pdfViewMode = $state<'text' | 'pages'>('pages');
let pdfImages = $state<string[]>([]);
let pdfImagesLoading = $state(false);
let pdfImagesError = $state<string | null>(null);
async function loadPdfImages() {
if (!isPdf || pdfImages.length > 0 || pdfImagesLoading) return;
pdfImagesLoading = true;
pdfImagesError = null;
try {
let file: File | null = null;
if (uploadedFile?.file) {
file = uploadedFile.file;
} else if (attachment?.type === 'pdfFile') {
// Check if we have pre-processed images
if (attachment.images && Array.isArray(attachment.images)) {
pdfImages = attachment.images;
return;
}
// Convert base64 back to File for processing
if (attachment.base64Data) {
const base64Data = attachment.base64Data;
const byteCharacters = atob(base64Data);
const byteNumbers = new Array(byteCharacters.length);
for (let i = 0; i < byteCharacters.length; i++) {
byteNumbers[i] = byteCharacters.charCodeAt(i);
}
const byteArray = new Uint8Array(byteNumbers);
file = new File([byteArray], displayName, { type: MimeTypeApplication.PDF });
}
}
if (file) {
pdfImages = await convertPDFToImage(file);
} else {
throw new Error('No PDF file available for conversion');
}
} catch (error) {
pdfImagesError = error instanceof Error ? error.message : 'Failed to load PDF images';
} finally {
pdfImagesLoading = false;
}
}
export function reset() {
pdfImages = [];
pdfImagesLoading = false;
pdfImagesError = null;
pdfViewMode = 'pages';
}
$effect(() => {
if (isPdf && pdfViewMode === 'pages') {
loadPdfImages();
}
});
</script>
<div class="space-y-4">
<div class="flex items-center justify-end gap-6">
{#if isPdf}
<div class="flex items-center gap-2">
<Button
variant={pdfViewMode === 'text' ? 'default' : 'outline'}
size="sm"
onclick={() => (pdfViewMode = 'text')}
disabled={pdfImagesLoading}
>
<FileText class="mr-1 h-4 w-4" />
Text
</Button>
<Button
variant={pdfViewMode === 'pages' ? 'default' : 'outline'}
size="sm"
onclick={() => {
pdfViewMode = 'pages';
loadPdfImages();
}}
disabled={pdfImagesLoading}
>
{#if pdfImagesLoading}
<div
class="mr-1 h-4 w-4 animate-spin rounded-full border-2 border-current border-t-transparent"
></div>
{:else}
<Eye class="mr-1 h-4 w-4" />
{/if}
Pages
</Button>
</div>
{/if}
</div>
<div class="flex-1 overflow-auto">
{#if isImage && displayPreview}
<div class="flex items-center justify-center">
<img
src={displayPreview}
alt={displayName}
class="max-h-full rounded-lg object-contain shadow-lg"
/>
</div>
{:else if isPdf && pdfViewMode === 'pages'}
{#if pdfImagesLoading}
<div class="flex items-center justify-center p-8">
<div class="text-center">
<div
class="mx-auto mb-4 h-8 w-8 animate-spin rounded-full border-4 border-primary border-t-transparent"
></div>
<p class="text-muted-foreground">Converting PDF to images...</p>
</div>
</div>
{:else if pdfImagesError}
<div class="flex items-center justify-center p-8">
<div class="text-center">
<FileText class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
<p class="mb-4 text-muted-foreground">Failed to load PDF images</p>
<p class="text-sm text-muted-foreground">{pdfImagesError}</p>
<Button class="mt-4" onclick={() => (pdfViewMode = 'text')}>View as Text</Button>
</div>
</div>
{:else if pdfImages.length > 0}
<div class="max-h-[70vh] space-y-4 overflow-auto">
{#each pdfImages as image, index (image)}
<div class="text-center">
<p class="mb-2 text-sm text-muted-foreground">Page {index + 1}</p>
<img
src={image}
alt="PDF Page {index + 1}"
class="mx-auto max-w-full rounded-lg shadow-lg"
/>
</div>
{/each}
</div>
{:else}
<div class="flex items-center justify-center p-8">
<div class="text-center">
<FileText class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
<p class="mb-4 text-muted-foreground">No PDF pages available</p>
</div>
</div>
{/if}
{:else if (isText || (isPdf && pdfViewMode === 'text')) && displayTextContent}
<div
class="max-h-[60vh] overflow-auto rounded-lg bg-muted p-4 font-mono text-sm break-words whitespace-pre-wrap"
>
{displayTextContent}
</div>
{:else if isAudio}
<div class="flex items-center justify-center p-8">
<div class="w-full max-w-md text-center">
<Music class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
{#if attachment?.type === 'audioFile'}
<audio
controls
class="mb-4 w-full"
src="data:{attachment.mimeType};base64,{attachment.base64Data}"
>
Your browser does not support the audio element.
</audio>
{:else if uploadedFile?.preview}
<audio controls class="mb-4 w-full" src={uploadedFile.preview}>
Your browser does not support the audio element.
</audio>
{:else}
<p class="mb-4 text-muted-foreground">Audio preview not available</p>
{/if}
<p class="text-sm text-muted-foreground">
{displayName}
</p>
</div>
</div>
{:else}
<div class="flex items-center justify-center p-8">
<div class="text-center">
{#if IconComponent}
<IconComponent class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
{/if}
<p class="mb-4 text-muted-foreground">Preview not available for this file type</p>
</div>
</div>
{/if}
</div>
</div>

View File

@ -1,314 +0,0 @@
<script lang="ts">
import * as Dialog from '$lib/components/ui/dialog';
import { FileText, Image, Music, FileIcon, Eye } from '@lucide/svelte';
import { FileTypeCategory, MimeTypeApplication } from '$lib/enums/files';
import { convertPDFToImage } from '$lib/utils/pdf-processing';
import { Button } from '$lib/components/ui/button';
import { getFileTypeCategory } from '$lib/utils/file-type';
import { formatFileSize } from '$lib/utils/file-preview';
interface Props {
open: boolean;
// Either an uploaded file or a stored attachment
uploadedFile?: ChatUploadedFile;
attachment?: DatabaseMessageExtra;
// For uploaded files
preview?: string;
name?: string;
type?: string;
size?: number;
textContent?: string;
}
let {
open = $bindable(),
uploadedFile,
attachment,
preview,
name,
type,
size,
textContent
}: Props = $props();
let displayName = $derived(uploadedFile?.name || attachment?.name || name || 'Unknown File');
let displayPreview = $derived(
uploadedFile?.preview || (attachment?.type === 'imageFile' ? attachment.base64Url : preview)
);
let displayType = $derived(
uploadedFile?.type ||
(attachment?.type === 'imageFile'
? 'image'
: attachment?.type === 'textFile'
? 'text'
: attachment?.type === 'audioFile'
? attachment.mimeType || 'audio'
: attachment?.type === 'pdfFile'
? MimeTypeApplication.PDF
: type || 'unknown')
);
let displaySize = $derived(uploadedFile?.size || size);
let displayTextContent = $derived(
uploadedFile?.textContent ||
(attachment?.type === 'textFile'
? attachment.content
: attachment?.type === 'pdfFile'
? attachment.content
: textContent)
);
let isAudio = $derived(
getFileTypeCategory(displayType) === FileTypeCategory.AUDIO || displayType === 'audio'
);
let isImage = $derived(
getFileTypeCategory(displayType) === FileTypeCategory.IMAGE || displayType === 'image'
);
let isPdf = $derived(displayType === MimeTypeApplication.PDF);
let isText = $derived(
getFileTypeCategory(displayType) === FileTypeCategory.TEXT || displayType === 'text'
);
let IconComponent = $derived(() => {
if (isImage) return Image;
if (isText || isPdf) return FileText;
if (isAudio) return Music;
return FileIcon;
});
let pdfViewMode = $state<'text' | 'pages'>('pages');
let pdfImages = $state<string[]>([]);
let pdfImagesLoading = $state(false);
let pdfImagesError = $state<string | null>(null);
async function loadPdfImages() {
if (!isPdf || pdfImages.length > 0 || pdfImagesLoading) return;
pdfImagesLoading = true;
pdfImagesError = null;
try {
let file: File | null = null;
if (uploadedFile?.file) {
file = uploadedFile.file;
} else if (attachment?.type === 'pdfFile') {
// Check if we have pre-processed images
if (attachment.images && Array.isArray(attachment.images)) {
pdfImages = attachment.images;
return;
}
// Convert base64 back to File for processing
if (attachment.base64Data) {
const base64Data = attachment.base64Data;
const byteCharacters = atob(base64Data);
const byteNumbers = new Array(byteCharacters.length);
for (let i = 0; i < byteCharacters.length; i++) {
byteNumbers[i] = byteCharacters.charCodeAt(i);
}
const byteArray = new Uint8Array(byteNumbers);
file = new File([byteArray], displayName, { type: MimeTypeApplication.PDF });
}
}
if (file) {
pdfImages = await convertPDFToImage(file);
} else {
throw new Error('No PDF file available for conversion');
}
} catch (error) {
pdfImagesError = error instanceof Error ? error.message : 'Failed to load PDF images';
} finally {
pdfImagesLoading = false;
}
}
$effect(() => {
if (open) {
pdfImages = [];
pdfImagesLoading = false;
pdfImagesError = null;
pdfViewMode = 'pages';
}
});
$effect(() => {
if (open && isPdf && pdfViewMode === 'pages') {
loadPdfImages();
}
});
</script>
<Dialog.Root bind:open>
<Dialog.Content class="grid max-h-[90vh] max-w-5xl overflow-hidden !p-10 sm:w-auto sm:max-w-6xl">
<Dialog.Header class="flex-shrink-0">
<div class="flex items-center justify-between gap-6">
<div class="flex items-center gap-3">
{#if IconComponent}
<IconComponent class="h-5 w-5 text-muted-foreground" />
{/if}
<div>
<Dialog.Title class="text-left">{displayName}</Dialog.Title>
<div class="flex items-center gap-2 text-sm text-muted-foreground">
<span>{displayType}</span>
{#if displaySize}
<span></span>
<span>{formatFileSize(displaySize)}</span>
{/if}
</div>
</div>
</div>
{#if isPdf}
<div class="flex items-center gap-2">
<Button
variant={pdfViewMode === 'text' ? 'default' : 'outline'}
size="sm"
onclick={() => (pdfViewMode = 'text')}
disabled={pdfImagesLoading}
>
<FileText class="mr-1 h-4 w-4" />
Text
</Button>
<Button
variant={pdfViewMode === 'pages' ? 'default' : 'outline'}
size="sm"
onclick={() => {
pdfViewMode = 'pages';
loadPdfImages();
}}
disabled={pdfImagesLoading}
>
{#if pdfImagesLoading}
<div
class="mr-1 h-4 w-4 animate-spin rounded-full border-2 border-current border-t-transparent"
></div>
{:else}
<Eye class="mr-1 h-4 w-4" />
{/if}
Pages
</Button>
</div>
{/if}
</div>
</Dialog.Header>
<div class="flex-1 overflow-auto">
{#if isImage && displayPreview}
<div class="flex items-center justify-center">
<img
src={displayPreview}
alt={displayName}
class="max-h-full rounded-lg object-contain shadow-lg"
/>
</div>
{:else if isPdf && pdfViewMode === 'pages'}
{#if pdfImagesLoading}
<div class="flex items-center justify-center p-8">
<div class="text-center">
<div
class="mx-auto mb-4 h-8 w-8 animate-spin rounded-full border-4 border-primary border-t-transparent"
></div>
<p class="text-muted-foreground">Converting PDF to images...</p>
</div>
</div>
{:else if pdfImagesError}
<div class="flex items-center justify-center p-8">
<div class="text-center">
<FileText class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
<p class="mb-4 text-muted-foreground">Failed to load PDF images</p>
<p class="text-sm text-muted-foreground">{pdfImagesError}</p>
<Button class="mt-4" onclick={() => (pdfViewMode = 'text')}>View as Text</Button>
</div>
</div>
{:else if pdfImages.length > 0}
<div class="max-h-[70vh] space-y-4 overflow-auto">
{#each pdfImages as image, index (image)}
<div class="text-center">
<p class="mb-2 text-sm text-muted-foreground">Page {index + 1}</p>
<img
src={image}
alt="PDF Page {index + 1}"
class="mx-auto max-w-full rounded-lg shadow-lg"
/>
</div>
{/each}
</div>
{:else}
<div class="flex items-center justify-center p-8">
<div class="text-center">
<FileText class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
<p class="mb-4 text-muted-foreground">No PDF pages available</p>
</div>
</div>
{/if}
{:else if (isText || (isPdf && pdfViewMode === 'text')) && displayTextContent}
<div
class="max-h-[60vh] overflow-auto rounded-lg bg-muted p-4 font-mono text-sm break-words whitespace-pre-wrap"
>
{displayTextContent}
</div>
{:else if isAudio}
<div class="flex items-center justify-center p-8">
<div class="w-full max-w-md text-center">
<Music class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
{#if attachment?.type === 'audioFile'}
<audio
controls
class="mb-4 w-full"
src="data:{attachment.mimeType};base64,{attachment.base64Data}"
>
Your browser does not support the audio element.
</audio>
{:else if uploadedFile?.preview}
<audio controls class="mb-4 w-full" src={uploadedFile.preview}>
Your browser does not support the audio element.
</audio>
{:else}
<p class="mb-4 text-muted-foreground">Audio preview not available</p>
{/if}
<p class="text-sm text-muted-foreground">
{displayName}
</p>
</div>
</div>
{:else}
<div class="flex items-center justify-center p-8">
<div class="text-center">
{#if IconComponent}
<IconComponent class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
{/if}
<p class="mb-4 text-muted-foreground">Preview not available for this file type</p>
</div>
</div>
{/if}
</div>
</Dialog.Content>
</Dialog.Root>

View File

@ -1,11 +1,10 @@
<script lang="ts"> <script lang="ts">
import { ChatAttachmentImagePreview, ChatAttachmentFilePreview } from '$lib/components/app'; import { ChatAttachmentThumbnailImage, ChatAttachmentThumbnailFile } from '$lib/components/app';
import { Button } from '$lib/components/ui/button'; import { Button } from '$lib/components/ui/button';
import { ChevronLeft, ChevronRight } from '@lucide/svelte'; import { ChevronLeft, ChevronRight } from '@lucide/svelte';
import { FileTypeCategory } from '$lib/enums/files'; import { FileTypeCategory } from '$lib/enums/files';
import { getFileTypeCategory } from '$lib/utils/file-type'; import { getFileTypeCategory } from '$lib/utils/file-type';
import ChatAttachmentPreviewDialog from './ChatAttachmentPreviewDialog.svelte'; import { DialogChatAttachmentPreview, DialogChatAttachmentsViewAll } from '$lib/components/app';
import ChatAttachmentsViewAllDialog from './ChatAttachmentsViewAllDialog.svelte';
import type { ChatAttachmentDisplayItem, ChatAttachmentPreviewItem } from '$lib/types/chat'; import type { ChatAttachmentDisplayItem, ChatAttachmentPreviewItem } from '$lib/types/chat';
interface Props { interface Props {
@ -200,7 +199,7 @@
> >
{#each displayItems as item (item.id)} {#each displayItems as item (item.id)}
{#if item.isImage && item.preview} {#if item.isImage && item.preview}
<ChatAttachmentImagePreview <ChatAttachmentThumbnailImage
class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}" class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
id={item.id} id={item.id}
name={item.name} name={item.name}
@ -213,7 +212,7 @@
onClick={(event) => openPreview(item, event)} onClick={(event) => openPreview(item, event)}
/> />
{:else} {:else}
<ChatAttachmentFilePreview <ChatAttachmentThumbnailFile
class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}" class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
id={item.id} id={item.id}
name={item.name} name={item.name}
@ -256,7 +255,7 @@
{/if} {/if}
{#if previewItem} {#if previewItem}
<ChatAttachmentPreviewDialog <DialogChatAttachmentPreview
bind:open={previewDialogOpen} bind:open={previewDialogOpen}
uploadedFile={previewItem.uploadedFile} uploadedFile={previewItem.uploadedFile}
attachment={previewItem.attachment} attachment={previewItem.attachment}
@ -268,7 +267,7 @@
/> />
{/if} {/if}
<ChatAttachmentsViewAllDialog <DialogChatAttachmentsViewAll
bind:open={viewAllDialogOpen} bind:open={viewAllDialogOpen}
{uploadedFiles} {uploadedFiles}
{attachments} {attachments}

View File

@ -1,13 +1,14 @@
<script lang="ts"> <script lang="ts">
import * as Dialog from '$lib/components/ui/dialog'; import {
import { ChatAttachmentImagePreview, ChatAttachmentFilePreview } from '$lib/components/app'; ChatAttachmentThumbnailImage,
ChatAttachmentThumbnailFile,
DialogChatAttachmentPreview
} from '$lib/components/app';
import { FileTypeCategory } from '$lib/enums/files'; import { FileTypeCategory } from '$lib/enums/files';
import { getFileTypeCategory } from '$lib/utils/file-type'; import { getFileTypeCategory } from '$lib/utils/file-type';
import ChatAttachmentPreviewDialog from './ChatAttachmentPreviewDialog.svelte';
import type { ChatAttachmentDisplayItem, ChatAttachmentPreviewItem } from '$lib/types/chat'; import type { ChatAttachmentDisplayItem, ChatAttachmentPreviewItem } from '$lib/types/chat';
interface Props { interface Props {
open?: boolean;
uploadedFiles?: ChatUploadedFile[]; uploadedFiles?: ChatUploadedFile[];
attachments?: DatabaseMessageExtra[]; attachments?: DatabaseMessageExtra[];
readonly?: boolean; readonly?: boolean;
@ -18,7 +19,6 @@
} }
let { let {
open = $bindable(false),
uploadedFiles = [], uploadedFiles = [],
attachments = [], attachments = [],
readonly = false, readonly = false,
@ -127,25 +127,14 @@
} }
</script> </script>
<Dialog.Root bind:open> <div class="space-y-4">
<Dialog.Portal>
<Dialog.Overlay />
<Dialog.Content class="flex !max-h-[90vh] !max-w-6xl flex-col">
<Dialog.Header>
<Dialog.Title>All Attachments ({displayItems.length})</Dialog.Title>
<Dialog.Description class="text-sm text-muted-foreground">
View and manage all attached files
</Dialog.Description>
</Dialog.Header>
<div class="min-h-0 flex-1 space-y-6 overflow-y-auto px-1"> <div class="min-h-0 flex-1 space-y-6 overflow-y-auto px-1">
{#if fileItems.length > 0} {#if fileItems.length > 0}
<div> <div>
<h3 class="mb-3 text-sm font-medium text-foreground">Files ({fileItems.length})</h3> <h3 class="mb-3 text-sm font-medium text-foreground">Files ({fileItems.length})</h3>
<div class="flex flex-wrap items-start gap-3"> <div class="flex flex-wrap items-start gap-3">
{#each fileItems as item (item.id)} {#each fileItems as item (item.id)}
<ChatAttachmentFilePreview <ChatAttachmentThumbnailFile
class="cursor-pointer" class="cursor-pointer"
id={item.id} id={item.id}
name={item.name} name={item.name}
@ -167,7 +156,7 @@
<div class="flex flex-wrap items-start gap-3"> <div class="flex flex-wrap items-start gap-3">
{#each imageItems as item (item.id)} {#each imageItems as item (item.id)}
{#if item.preview} {#if item.preview}
<ChatAttachmentImagePreview <ChatAttachmentThumbnailImage
class="cursor-pointer" class="cursor-pointer"
id={item.id} id={item.id}
name={item.name} name={item.name}
@ -185,12 +174,10 @@
</div> </div>
{/if} {/if}
</div> </div>
</Dialog.Content> </div>
</Dialog.Portal>
</Dialog.Root>
{#if previewItem} {#if previewItem}
<ChatAttachmentPreviewDialog <DialogChatAttachmentPreview
bind:open={previewDialogOpen} bind:open={previewDialogOpen}
uploadedFile={previewItem.uploadedFile} uploadedFile={previewItem.uploadedFile}
attachment={previewItem.attachment} attachment={previewItem.attachment}

View File

@ -1,9 +1,11 @@
<script lang="ts"> <script lang="ts">
import { Square, ArrowUp } from '@lucide/svelte'; import { Square, ArrowUp } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button'; import { Button } from '$lib/components/ui/button';
import ChatFormActionFileAttachments from './ChatFormActionFileAttachments.svelte'; import {
import ChatFormActionRecord from './ChatFormActionRecord.svelte'; ChatFormActionFileAttachments,
import ChatFormModelSelector from './ChatFormModelSelector.svelte'; ChatFormActionRecord,
ChatFormModelSelector
} from '$lib/components/app';
import { config } from '$lib/stores/settings.svelte'; import { config } from '$lib/stores/settings.svelte';
import type { FileTypeCategory } from '$lib/enums/files'; import type { FileTypeCategory } from '$lib/enums/files';

View File

@ -10,6 +10,7 @@
class?: string; class?: string;
message: DatabaseMessage; message: DatabaseMessage;
onCopy?: (message: DatabaseMessage) => void; onCopy?: (message: DatabaseMessage) => void;
onContinueAssistantMessage?: (message: DatabaseMessage) => void;
onDelete?: (message: DatabaseMessage) => void; onDelete?: (message: DatabaseMessage) => void;
onEditWithBranching?: (message: DatabaseMessage, newContent: string) => void; onEditWithBranching?: (message: DatabaseMessage, newContent: string) => void;
onEditWithReplacement?: ( onEditWithReplacement?: (
@ -17,6 +18,7 @@
newContent: string, newContent: string,
shouldBranch: boolean shouldBranch: boolean
) => void; ) => void;
onEditUserMessagePreserveResponses?: (message: DatabaseMessage, newContent: string) => void;
onNavigateToSibling?: (siblingId: string) => void; onNavigateToSibling?: (siblingId: string) => void;
onRegenerateWithBranching?: (message: DatabaseMessage) => void; onRegenerateWithBranching?: (message: DatabaseMessage) => void;
siblingInfo?: ChatMessageSiblingInfo | null; siblingInfo?: ChatMessageSiblingInfo | null;
@ -26,9 +28,11 @@
class: className = '', class: className = '',
message, message,
onCopy, onCopy,
onContinueAssistantMessage,
onDelete, onDelete,
onEditWithBranching, onEditWithBranching,
onEditWithReplacement, onEditWithReplacement,
onEditUserMessagePreserveResponses,
onNavigateToSibling, onNavigateToSibling,
onRegenerateWithBranching, onRegenerateWithBranching,
siblingInfo = null siblingInfo = null
@ -133,17 +137,33 @@
onRegenerateWithBranching?.(message); onRegenerateWithBranching?.(message);
} }
function handleContinue() {
onContinueAssistantMessage?.(message);
}
function handleSaveEdit() { function handleSaveEdit() {
if (message.role === 'user') { if (message.role === 'user') {
// For user messages, trim to avoid accidental whitespace
onEditWithBranching?.(message, editedContent.trim()); onEditWithBranching?.(message, editedContent.trim());
} else { } else {
onEditWithReplacement?.(message, editedContent.trim(), shouldBranchAfterEdit); // For assistant messages, preserve exact content including trailing whitespace
// This is important for the Continue feature to work properly
onEditWithReplacement?.(message, editedContent, shouldBranchAfterEdit);
} }
isEditing = false; isEditing = false;
shouldBranchAfterEdit = false; shouldBranchAfterEdit = false;
} }
function handleSaveEditOnly() {
if (message.role === 'user') {
// For user messages, trim to avoid accidental whitespace
onEditUserMessagePreserveResponses?.(message, editedContent.trim());
}
isEditing = false;
}
function handleShowDeleteDialogChange(show: boolean) { function handleShowDeleteDialogChange(show: boolean) {
showDeleteDialog = show; showDeleteDialog = show;
} }
@ -166,6 +186,7 @@
onEditedContentChange={handleEditedContentChange} onEditedContentChange={handleEditedContentChange}
{onNavigateToSibling} {onNavigateToSibling}
onSaveEdit={handleSaveEdit} onSaveEdit={handleSaveEdit}
onSaveEditOnly={handleSaveEditOnly}
onShowDeleteDialogChange={handleShowDeleteDialogChange} onShowDeleteDialogChange={handleShowDeleteDialogChange}
{showDeleteDialog} {showDeleteDialog}
{siblingInfo} {siblingInfo}
@ -181,6 +202,7 @@
messageContent={message.content} messageContent={message.content}
onCancelEdit={handleCancelEdit} onCancelEdit={handleCancelEdit}
onConfirmDelete={handleConfirmDelete} onConfirmDelete={handleConfirmDelete}
onContinue={handleContinue}
onCopy={handleCopy} onCopy={handleCopy}
onDelete={handleDelete} onDelete={handleDelete}
onEdit={handleEdit} onEdit={handleEdit}

View File

@ -1,7 +1,10 @@
<script lang="ts"> <script lang="ts">
import { Edit, Copy, RefreshCw, Trash2 } from '@lucide/svelte'; import { Edit, Copy, RefreshCw, Trash2, ArrowRight } from '@lucide/svelte';
import { ActionButton, ConfirmationDialog } from '$lib/components/app'; import {
import ChatMessageBranchingControls from './ChatMessageBranchingControls.svelte'; ActionButton,
ChatMessageBranchingControls,
DialogConfirmation
} from '$lib/components/app';
interface Props { interface Props {
role: 'user' | 'assistant'; role: 'user' | 'assistant';
@ -18,6 +21,7 @@
onCopy: () => void; onCopy: () => void;
onEdit?: () => void; onEdit?: () => void;
onRegenerate?: () => void; onRegenerate?: () => void;
onContinue?: () => void;
onDelete: () => void; onDelete: () => void;
onConfirmDelete: () => void; onConfirmDelete: () => void;
onNavigateToSibling?: (siblingId: string) => void; onNavigateToSibling?: (siblingId: string) => void;
@ -31,6 +35,7 @@
onCopy, onCopy,
onEdit, onEdit,
onConfirmDelete, onConfirmDelete,
onContinue,
onDelete, onDelete,
onNavigateToSibling, onNavigateToSibling,
onShowDeleteDialogChange, onShowDeleteDialogChange,
@ -69,12 +74,16 @@
<ActionButton icon={RefreshCw} tooltip="Regenerate" onclick={onRegenerate} /> <ActionButton icon={RefreshCw} tooltip="Regenerate" onclick={onRegenerate} />
{/if} {/if}
{#if role === 'assistant' && onContinue}
<ActionButton icon={ArrowRight} tooltip="Continue" onclick={onContinue} />
{/if}
<ActionButton icon={Trash2} tooltip="Delete" onclick={onDelete} /> <ActionButton icon={Trash2} tooltip="Delete" onclick={onDelete} />
</div> </div>
</div> </div>
</div> </div>
<ConfirmationDialog <DialogConfirmation
bind:open={showDeleteDialog} bind:open={showDeleteDialog}
title="Delete Message" title="Delete Message"
description={deletionInfo && deletionInfo.totalCount > 1 description={deletionInfo && deletionInfo.totalCount > 1

View File

@ -2,6 +2,7 @@
import { ChatMessageThinkingBlock, MarkdownContent } from '$lib/components/app'; import { ChatMessageThinkingBlock, MarkdownContent } from '$lib/components/app';
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte'; import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
import { isLoading } from '$lib/stores/chat.svelte'; import { isLoading } from '$lib/stores/chat.svelte';
import autoResizeTextarea from '$lib/utils/autoresize-textarea';
import { fade } from 'svelte/transition'; import { fade } from 'svelte/transition';
import { import {
Check, Check,
@ -39,6 +40,7 @@
onCancelEdit?: () => void; onCancelEdit?: () => void;
onCopy: () => void; onCopy: () => void;
onConfirmDelete: () => void; onConfirmDelete: () => void;
onContinue?: () => void;
onDelete: () => void; onDelete: () => void;
onEdit?: () => void; onEdit?: () => void;
onEditKeydown?: (event: KeyboardEvent) => void; onEditKeydown?: (event: KeyboardEvent) => void;
@ -65,6 +67,7 @@
messageContent, messageContent,
onCancelEdit, onCancelEdit,
onConfirmDelete, onConfirmDelete,
onContinue,
onCopy, onCopy,
onDelete, onDelete,
onEdit, onEdit,
@ -107,6 +110,12 @@
void copyToClipboard(model ?? ''); void copyToClipboard(model ?? '');
} }
$effect(() => {
if (isEditing && textareaElement) {
autoResizeTextarea(textareaElement);
}
});
function formatToolCallBadge(toolCall: ApiChatCompletionToolCall, index: number) { function formatToolCallBadge(toolCall: ApiChatCompletionToolCall, index: number) {
const callNumber = index + 1; const callNumber = index + 1;
const functionName = toolCall.function?.name?.trim(); const functionName = toolCall.function?.name?.trim();
@ -190,7 +199,10 @@
bind:value={editedContent} bind:value={editedContent}
class="min-h-[50vh] w-full resize-y rounded-2xl px-3 py-2 text-sm {INPUT_CLASSES}" class="min-h-[50vh] w-full resize-y rounded-2xl px-3 py-2 text-sm {INPUT_CLASSES}"
onkeydown={onEditKeydown} onkeydown={onEditKeydown}
oninput={(e) => onEditedContentChange?.(e.currentTarget.value)} oninput={(e) => {
autoResizeTextarea(e.currentTarget);
onEditedContentChange?.(e.currentTarget.value);
}}
placeholder="Edit assistant message..." placeholder="Edit assistant message..."
></textarea> ></textarea>
@ -335,6 +347,9 @@
{onCopy} {onCopy}
{onEdit} {onEdit}
{onRegenerate} {onRegenerate}
onContinue={currentConfig.enableContinueGeneration && !thinkingContent
? onContinue
: undefined}
{onDelete} {onDelete}
{onConfirmDelete} {onConfirmDelete}
{onNavigateToSibling} {onNavigateToSibling}

View File

@ -1,10 +1,11 @@
<script lang="ts"> <script lang="ts">
import { Check, X } from '@lucide/svelte'; import { Check, X, Send } from '@lucide/svelte';
import { Card } from '$lib/components/ui/card'; import { Card } from '$lib/components/ui/card';
import { Button } from '$lib/components/ui/button'; import { Button } from '$lib/components/ui/button';
import { ChatAttachmentsList, MarkdownContent } from '$lib/components/app'; import { ChatAttachmentsList, MarkdownContent } from '$lib/components/app';
import { INPUT_CLASSES } from '$lib/constants/input-classes'; import { INPUT_CLASSES } from '$lib/constants/input-classes';
import { config } from '$lib/stores/settings.svelte'; import { config } from '$lib/stores/settings.svelte';
import autoResizeTextarea from '$lib/utils/autoresize-textarea';
import ChatMessageActions from './ChatMessageActions.svelte'; import ChatMessageActions from './ChatMessageActions.svelte';
interface Props { interface Props {
@ -22,6 +23,7 @@
} | null; } | null;
onCancelEdit: () => void; onCancelEdit: () => void;
onSaveEdit: () => void; onSaveEdit: () => void;
onSaveEditOnly?: () => void;
onEditKeydown: (event: KeyboardEvent) => void; onEditKeydown: (event: KeyboardEvent) => void;
onEditedContentChange: (content: string) => void; onEditedContentChange: (content: string) => void;
onCopy: () => void; onCopy: () => void;
@ -43,6 +45,7 @@
deletionInfo, deletionInfo,
onCancelEdit, onCancelEdit,
onSaveEdit, onSaveEdit,
onSaveEditOnly,
onEditKeydown, onEditKeydown,
onEditedContentChange, onEditedContentChange,
onCopy, onCopy,
@ -58,6 +61,12 @@
let messageElement: HTMLElement | undefined = $state(); let messageElement: HTMLElement | undefined = $state();
const currentConfig = config(); const currentConfig = config();
$effect(() => {
if (isEditing && textareaElement) {
autoResizeTextarea(textareaElement);
}
});
$effect(() => { $effect(() => {
if (!messageElement || !message.content.trim()) return; if (!messageElement || !message.content.trim()) return;
@ -95,20 +104,34 @@
bind:value={editedContent} bind:value={editedContent}
class="min-h-[60px] w-full resize-none rounded-2xl px-3 py-2 text-sm {INPUT_CLASSES}" class="min-h-[60px] w-full resize-none rounded-2xl px-3 py-2 text-sm {INPUT_CLASSES}"
onkeydown={onEditKeydown} onkeydown={onEditKeydown}
oninput={(e) => onEditedContentChange(e.currentTarget.value)} oninput={(e) => {
autoResizeTextarea(e.currentTarget);
onEditedContentChange(e.currentTarget.value);
}}
placeholder="Edit your message..." placeholder="Edit your message..."
></textarea> ></textarea>
<div class="mt-2 flex justify-end gap-2"> <div class="mt-2 flex justify-end gap-2">
<Button class="h-8 px-3" onclick={onCancelEdit} size="sm" variant="outline"> <Button class="h-8 px-3" onclick={onCancelEdit} size="sm" variant="ghost">
<X class="mr-1 h-3 w-3" /> <X class="mr-1 h-3 w-3" />
Cancel Cancel
</Button> </Button>
<Button class="h-8 px-3" onclick={onSaveEdit} disabled={!editedContent.trim()} size="sm"> {#if onSaveEditOnly}
<Button
class="h-8 px-3"
onclick={onSaveEditOnly}
disabled={!editedContent.trim()}
size="sm"
variant="outline"
>
<Check class="mr-1 h-3 w-3" /> <Check class="mr-1 h-3 w-3" />
Save
</Button>
{/if}
<Button class="h-8 px-3" onclick={onSaveEdit} disabled={!editedContent.trim()} size="sm">
<Send class="mr-1 h-3 w-3" />
Send Send
</Button> </Button>
</div> </div>

View File

@ -3,10 +3,12 @@
import { DatabaseStore } from '$lib/stores/database'; import { DatabaseStore } from '$lib/stores/database';
import { import {
activeConversation, activeConversation,
continueAssistantMessage,
deleteMessage, deleteMessage,
navigateToSibling,
editMessageWithBranching,
editAssistantMessage, editAssistantMessage,
editMessageWithBranching,
editUserMessagePreserveResponses,
navigateToSibling,
regenerateMessageWithBranching regenerateMessageWithBranching
} from '$lib/stores/chat.svelte'; } from '$lib/stores/chat.svelte';
import { getMessageSiblings } from '$lib/utils/branching'; import { getMessageSiblings } from '$lib/utils/branching';
@ -93,6 +95,26 @@
refreshAllMessages(); refreshAllMessages();
} }
async function handleContinueAssistantMessage(message: DatabaseMessage) {
onUserAction?.();
await continueAssistantMessage(message.id);
refreshAllMessages();
}
async function handleEditUserMessagePreserveResponses(
message: DatabaseMessage,
newContent: string
) {
onUserAction?.();
await editUserMessagePreserveResponses(message.id, newContent);
refreshAllMessages();
}
async function handleDeleteMessage(message: DatabaseMessage) { async function handleDeleteMessage(message: DatabaseMessage) {
await deleteMessage(message.id); await deleteMessage(message.id);
@ -110,7 +132,9 @@
onNavigateToSibling={handleNavigateToSibling} onNavigateToSibling={handleNavigateToSibling}
onEditWithBranching={handleEditWithBranching} onEditWithBranching={handleEditWithBranching}
onEditWithReplacement={handleEditWithReplacement} onEditWithReplacement={handleEditWithReplacement}
onEditUserMessagePreserveResponses={handleEditUserMessagePreserveResponses}
onRegenerateWithBranching={handleRegenerateWithBranching} onRegenerateWithBranching={handleRegenerateWithBranching}
onContinueAssistantMessage={handleContinueAssistantMessage}
/> />
{/each} {/each}
</div> </div>

View File

@ -5,13 +5,13 @@
ChatScreenHeader, ChatScreenHeader,
ChatScreenWarning, ChatScreenWarning,
ChatMessages, ChatMessages,
ChatProcessingInfo, ChatScreenProcessingInfo,
EmptyFileAlertDialog, DialogEmptyFileAlert,
ChatErrorDialog, DialogChatError,
ServerErrorSplash, ServerErrorSplash,
ServerInfo, ServerInfo,
ServerLoadingSplash, ServerLoadingSplash,
ConfirmationDialog DialogConfirmation
} from '$lib/components/app'; } from '$lib/components/app';
import * as AlertDialog from '$lib/components/ui/alert-dialog'; import * as AlertDialog from '$lib/components/ui/alert-dialog';
import { import {
@ -299,7 +299,7 @@
class="pointer-events-none sticky right-0 bottom-0 left-0 mt-auto" class="pointer-events-none sticky right-0 bottom-0 left-0 mt-auto"
in:slide={{ duration: 150, axis: 'y' }} in:slide={{ duration: 150, axis: 'y' }}
> >
<ChatProcessingInfo /> <ChatScreenProcessingInfo />
{#if serverWarning()} {#if serverWarning()}
<ChatScreenWarning class="pointer-events-auto mx-auto max-w-[48rem] px-4" /> <ChatScreenWarning class="pointer-events-auto mx-auto max-w-[48rem] px-4" />
@ -432,7 +432,7 @@
</AlertDialog.Portal> </AlertDialog.Portal>
</AlertDialog.Root> </AlertDialog.Root>
<ConfirmationDialog <DialogConfirmation
bind:open={showDeleteDialog} bind:open={showDeleteDialog}
title="Delete Conversation" title="Delete Conversation"
description="Are you sure you want to delete this conversation? This action cannot be undone and will permanently remove all messages in this conversation." description="Are you sure you want to delete this conversation? This action cannot be undone and will permanently remove all messages in this conversation."
@ -444,7 +444,7 @@
onCancel={() => (showDeleteDialog = false)} onCancel={() => (showDeleteDialog = false)}
/> />
<EmptyFileAlertDialog <DialogEmptyFileAlert
bind:open={showEmptyFileDialog} bind:open={showEmptyFileDialog}
emptyFiles={emptyFileNames} emptyFiles={emptyFileNames}
onOpenChange={(open) => { onOpenChange={(open) => {
@ -454,7 +454,7 @@
}} }}
/> />
<ChatErrorDialog <DialogChatError
message={activeErrorDialog?.message ?? ''} message={activeErrorDialog?.message ?? ''}
onOpenChange={handleErrorDialogOpenChange} onOpenChange={handleErrorDialogOpenChange}
open={Boolean(activeErrorDialog)} open={Boolean(activeErrorDialog)}

View File

@ -1,6 +1,6 @@
<script lang="ts"> <script lang="ts">
import { Settings } from '@lucide/svelte'; import { Settings } from '@lucide/svelte';
import { ChatSettingsDialog } from '$lib/components/app'; import { DialogChatSettings } from '$lib/components/app';
import { Button } from '$lib/components/ui/button'; import { Button } from '$lib/components/ui/button';
let settingsOpen = $state(false); let settingsOpen = $state(false);
@ -20,4 +20,4 @@
</div> </div>
</header> </header>
<ChatSettingsDialog open={settingsOpen} onOpenChange={(open) => (settingsOpen = open)} /> <DialogChatSettings open={settingsOpen} onOpenChange={(open) => (settingsOpen = open)} />

View File

@ -12,20 +12,21 @@
ChevronRight, ChevronRight,
Database Database
} from '@lucide/svelte'; } from '@lucide/svelte';
import { ChatSettingsFooter, ChatSettingsFields } from '$lib/components/app'; import {
import ImportExportTab from './ImportExportTab.svelte'; ChatSettingsFooter,
import * as Dialog from '$lib/components/ui/dialog'; ChatSettingsImportExportTab,
ChatSettingsFields
} from '$lib/components/app';
import { ScrollArea } from '$lib/components/ui/scroll-area'; import { ScrollArea } from '$lib/components/ui/scroll-area';
import { config, updateMultipleConfig } from '$lib/stores/settings.svelte'; import { config, updateMultipleConfig } from '$lib/stores/settings.svelte';
import { setMode } from 'mode-watcher'; import { setMode } from 'mode-watcher';
import type { Component } from 'svelte'; import type { Component } from 'svelte';
interface Props { interface Props {
onOpenChange?: (open: boolean) => void; onSave?: () => void;
open?: boolean;
} }
let { onOpenChange, open = false }: Props = $props(); let { onSave }: Props = $props();
const settingSections: Array<{ const settingSections: Array<{
fields: SettingsFieldConfig[]; fields: SettingsFieldConfig[];
@ -52,6 +53,11 @@
{ value: 'dark', label: 'Dark', icon: Moon } { value: 'dark', label: 'Dark', icon: Moon }
] ]
}, },
{
key: 'pasteLongTextToFileLen',
label: 'Paste long text to file length',
type: 'input'
},
{ {
key: 'showMessageStats', key: 'showMessageStats',
label: 'Show message generation statistics', label: 'Show message generation statistics',
@ -68,14 +74,15 @@
type: 'checkbox' type: 'checkbox'
}, },
{ {
key: 'askForTitleConfirmation', key: 'showModelInfo',
label: 'Ask for confirmation before changing conversation title', label: 'Show model information',
type: 'checkbox' type: 'checkbox'
}, },
{ {
key: 'pasteLongTextToFileLen', key: 'enableContinueGeneration',
label: 'Paste long text to file length', label: 'Enable "Continue" button',
type: 'input' type: 'checkbox',
isExperimental: true
}, },
{ {
key: 'pdfAsImage', key: 'pdfAsImage',
@ -83,13 +90,13 @@
type: 'checkbox' type: 'checkbox'
}, },
{ {
key: 'showModelInfo', key: 'renderUserContentAsMarkdown',
label: 'Show model information', label: 'Render user content as Markdown',
type: 'checkbox' type: 'checkbox'
}, },
{ {
key: 'renderUserContentAsMarkdown', key: 'askForTitleConfirmation',
label: 'Render user content as Markdown', label: 'Ask for confirmation before changing conversation title',
type: 'checkbox' type: 'checkbox'
} }
] ]
@ -263,7 +270,6 @@
settingSections.find((section) => section.title === activeSection) || settingSections[0] settingSections.find((section) => section.title === activeSection) || settingSections[0]
); );
let localConfig: SettingsConfigType = $state({ ...config() }); let localConfig: SettingsConfigType = $state({ ...config() });
let originalTheme: string = $state('');
let canScrollLeft = $state(false); let canScrollLeft = $state(false);
let canScrollRight = $state(false); let canScrollRight = $state(false);
@ -279,18 +285,10 @@
localConfig[key] = value; localConfig[key] = value;
} }
function handleClose() {
if (localConfig.theme !== originalTheme) {
setMode(originalTheme as 'light' | 'dark' | 'system');
}
onOpenChange?.(false);
}
function handleReset() { function handleReset() {
localConfig = { ...config() }; localConfig = { ...config() };
setMode(localConfig.theme as 'light' | 'dark' | 'system'); setMode(localConfig.theme as 'light' | 'dark' | 'system');
originalTheme = localConfig.theme as string;
} }
function handleSave() { function handleSave() {
@ -341,7 +339,7 @@
} }
updateMultipleConfig(processedConfig); updateMultipleConfig(processedConfig);
onOpenChange?.(false); onSave?.();
} }
function scrollToCenter(element: HTMLElement) { function scrollToCenter(element: HTMLElement) {
@ -377,14 +375,11 @@
canScrollRight = scrollLeft < scrollWidth - clientWidth - 1; // -1 for rounding canScrollRight = scrollLeft < scrollWidth - clientWidth - 1; // -1 for rounding
} }
$effect(() => { export function reset() {
if (open) {
localConfig = { ...config() }; localConfig = { ...config() };
originalTheme = config().theme as string;
setTimeout(updateScrollButtons, 100); setTimeout(updateScrollButtons, 100);
} }
});
$effect(() => { $effect(() => {
if (scrollContainer) { if (scrollContainer) {
@ -393,18 +388,10 @@
}); });
</script> </script>
<Dialog.Root {open} onOpenChange={handleClose}> <div class="flex h-full flex-col overflow-hidden md:flex-row">
<Dialog.Content
class="z-999999 flex h-[100dvh] max-h-[100dvh] min-h-[100dvh] flex-col gap-0 rounded-none p-0
md:h-[64vh] md:max-h-[64vh] md:min-h-0 md:rounded-lg"
style="max-width: 48rem;"
>
<div class="flex flex-1 flex-col overflow-hidden md:flex-row">
<!-- Desktop Sidebar --> <!-- Desktop Sidebar -->
<div class="hidden w-64 border-r border-border/30 p-6 md:block"> <div class="hidden w-64 border-r border-border/30 p-6 md:block">
<nav class="space-y-1 py-2"> <nav class="space-y-1 py-2">
<Dialog.Title class="mb-6 flex items-center gap-2">Settings</Dialog.Title>
{#each settingSections as section (section.title)} {#each settingSections as section (section.title)}
<button <button
class="flex w-full cursor-pointer items-center gap-3 rounded-lg px-3 py-2 text-left text-sm transition-colors hover:bg-accent {activeSection === class="flex w-full cursor-pointer items-center gap-3 rounded-lg px-3 py-2 text-left text-sm transition-colors hover:bg-accent {activeSection ===
@ -424,8 +411,6 @@
<!-- Mobile Header with Horizontal Scrollable Menu --> <!-- Mobile Header with Horizontal Scrollable Menu -->
<div class="flex flex-col md:hidden"> <div class="flex flex-col md:hidden">
<div class="border-b border-border/30 py-4"> <div class="border-b border-border/30 py-4">
<Dialog.Title class="mb-6 flex items-center gap-2 px-4">Settings</Dialog.Title>
<!-- Horizontal Scrollable Category Menu with Navigation --> <!-- Horizontal Scrollable Category Menu with Navigation -->
<div class="relative flex items-center" style="scroll-padding: 1rem;"> <div class="relative flex items-center" style="scroll-padding: 1rem;">
<button <button
@ -485,7 +470,7 @@
</div> </div>
{#if currentSection.title === 'Import/Export'} {#if currentSection.title === 'Import/Export'}
<ImportExportTab /> <ChatSettingsImportExportTab />
{:else} {:else}
<div class="space-y-6"> <div class="space-y-6">
<ChatSettingsFields <ChatSettingsFields
@ -499,14 +484,10 @@
</div> </div>
<div class="mt-8 border-t pt-6"> <div class="mt-8 border-t pt-6">
<p class="text-xs text-muted-foreground"> <p class="text-xs text-muted-foreground">Settings are saved in browser's localStorage</p>
Settings are saved in browser's localStorage
</p>
</div> </div>
</div> </div>
</ScrollArea> </ScrollArea>
</div> </div>
<ChatSettingsFooter onReset={handleReset} onSave={handleSave} /> <ChatSettingsFooter onReset={handleReset} onSave={handleSave} />
</Dialog.Content>
</Dialog.Root>

View File

@ -1,5 +1,5 @@
<script lang="ts"> <script lang="ts">
import { RotateCcw } from '@lucide/svelte'; import { RotateCcw, FlaskConical } from '@lucide/svelte';
import { Checkbox } from '$lib/components/ui/checkbox'; import { Checkbox } from '$lib/components/ui/checkbox';
import { Input } from '$lib/components/ui/input'; import { Input } from '$lib/components/ui/input';
import Label from '$lib/components/ui/label/label.svelte'; import Label from '$lib/components/ui/label/label.svelte';
@ -9,7 +9,7 @@
import { supportsVision } from '$lib/stores/server.svelte'; import { supportsVision } from '$lib/stores/server.svelte';
import { getParameterInfo, resetParameterToServerDefault } from '$lib/stores/settings.svelte'; import { getParameterInfo, resetParameterToServerDefault } from '$lib/stores/settings.svelte';
import { ParameterSyncService } from '$lib/services/parameter-sync'; import { ParameterSyncService } from '$lib/services/parameter-sync';
import ParameterSourceIndicator from './ParameterSourceIndicator.svelte'; import { ChatSettingsParameterSourceIndicator } from '$lib/components/app';
import type { Component } from 'svelte'; import type { Component } from 'svelte';
interface Props { interface Props {
@ -55,11 +55,15 @@
})()} })()}
<div class="flex items-center gap-2"> <div class="flex items-center gap-2">
<Label for={field.key} class="text-sm font-medium"> <Label for={field.key} class="flex items-center gap-1.5 text-sm font-medium">
{field.label} {field.label}
{#if field.isExperimental}
<FlaskConical class="h-3.5 w-3.5 text-muted-foreground" />
{/if}
</Label> </Label>
{#if isCustomRealTime} {#if isCustomRealTime}
<ParameterSourceIndicator /> <ChatSettingsParameterSourceIndicator />
{/if} {/if}
</div> </div>
@ -97,8 +101,12 @@
</p> </p>
{/if} {/if}
{:else if field.type === 'textarea'} {:else if field.type === 'textarea'}
<Label for={field.key} class="block text-sm font-medium"> <Label for={field.key} class="block flex items-center gap-1.5 text-sm font-medium">
{field.label} {field.label}
{#if field.isExperimental}
<FlaskConical class="h-3.5 w-3.5 text-muted-foreground" />
{/if}
</Label> </Label>
<Textarea <Textarea
@ -129,11 +137,15 @@
})()} })()}
<div class="flex items-center gap-2"> <div class="flex items-center gap-2">
<Label for={field.key} class="text-sm font-medium"> <Label for={field.key} class="flex items-center gap-1.5 text-sm font-medium">
{field.label} {field.label}
{#if field.isExperimental}
<FlaskConical class="h-3.5 w-3.5 text-muted-foreground" />
{/if}
</Label> </Label>
{#if isCustomRealTime} {#if isCustomRealTime}
<ParameterSourceIndicator /> <ChatSettingsParameterSourceIndicator />
{/if} {/if}
</div> </div>
@ -214,9 +226,13 @@
for={field.key} for={field.key}
class="cursor-pointer text-sm leading-none font-medium {isDisabled class="cursor-pointer text-sm leading-none font-medium {isDisabled
? 'text-muted-foreground' ? 'text-muted-foreground'
: ''}" : ''} flex items-center gap-1.5"
> >
{field.label} {field.label}
{#if field.isExperimental}
<FlaskConical class="h-3.5 w-3.5 text-muted-foreground" />
{/if}
</label> </label>
{#if field.help || SETTING_CONFIG_INFO[field.key]} {#if field.help || SETTING_CONFIG_INFO[field.key]}

View File

@ -1,7 +1,7 @@
<script lang="ts"> <script lang="ts">
import { Download, Upload } from '@lucide/svelte'; import { Download, Upload } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button'; import { Button } from '$lib/components/ui/button';
import ConversationSelectionDialog from './ConversationSelectionDialog.svelte'; import { DialogConversationSelection } from '$lib/components/app';
import { DatabaseStore } from '$lib/stores/database'; import { DatabaseStore } from '$lib/stores/database';
import type { ExportedConversations } from '$lib/types/database'; import type { ExportedConversations } from '$lib/types/database';
import { createMessageCountMap } from '$lib/utils/conversation-utils'; import { createMessageCountMap } from '$lib/utils/conversation-utils';
@ -236,7 +236,7 @@
</div> </div>
</div> </div>
<ConversationSelectionDialog <DialogConversationSelection
conversations={availableConversations} conversations={availableConversations}
{messageCountMap} {messageCountMap}
mode="export" mode="export"
@ -245,7 +245,7 @@
onConfirm={handleExportConfirm} onConfirm={handleExportConfirm}
/> />
<ConversationSelectionDialog <DialogConversationSelection
conversations={availableConversations} conversations={availableConversations}
{messageCountMap} {messageCountMap}
mode="import" mode="import"

View File

@ -1,249 +0,0 @@
<script lang="ts">
import { Search, X } from '@lucide/svelte';
import * as Dialog from '$lib/components/ui/dialog';
import { Button } from '$lib/components/ui/button';
import { Input } from '$lib/components/ui/input';
import { Checkbox } from '$lib/components/ui/checkbox';
import { ScrollArea } from '$lib/components/ui/scroll-area';
import { SvelteSet } from 'svelte/reactivity';
interface Props {
conversations: DatabaseConversation[];
messageCountMap?: Map<string, number>;
mode: 'export' | 'import';
onCancel: () => void;
onConfirm: (selectedConversations: DatabaseConversation[]) => void;
open?: boolean;
}
let {
conversations,
messageCountMap = new Map(),
mode,
onCancel,
onConfirm,
open = $bindable(false)
}: Props = $props();
let searchQuery = $state('');
let selectedIds = $state.raw<SvelteSet<string>>(new SvelteSet(conversations.map((c) => c.id)));
let lastClickedId = $state<string | null>(null);
let filteredConversations = $derived(
conversations.filter((conv) => {
const name = conv.name || 'Untitled conversation';
return name.toLowerCase().includes(searchQuery.toLowerCase());
})
);
let allSelected = $derived(
filteredConversations.length > 0 &&
filteredConversations.every((conv) => selectedIds.has(conv.id))
);
let someSelected = $derived(
filteredConversations.some((conv) => selectedIds.has(conv.id)) && !allSelected
);
function toggleConversation(id: string, shiftKey: boolean = false) {
const newSet = new SvelteSet(selectedIds);
if (shiftKey && lastClickedId !== null) {
const lastIndex = filteredConversations.findIndex((c) => c.id === lastClickedId);
const currentIndex = filteredConversations.findIndex((c) => c.id === id);
if (lastIndex !== -1 && currentIndex !== -1) {
const start = Math.min(lastIndex, currentIndex);
const end = Math.max(lastIndex, currentIndex);
const shouldSelect = !newSet.has(id);
for (let i = start; i <= end; i++) {
if (shouldSelect) {
newSet.add(filteredConversations[i].id);
} else {
newSet.delete(filteredConversations[i].id);
}
}
selectedIds = newSet;
return;
}
}
if (newSet.has(id)) {
newSet.delete(id);
} else {
newSet.add(id);
}
selectedIds = newSet;
lastClickedId = id;
}
function toggleAll() {
if (allSelected) {
const newSet = new SvelteSet(selectedIds);
filteredConversations.forEach((conv) => newSet.delete(conv.id));
selectedIds = newSet;
} else {
const newSet = new SvelteSet(selectedIds);
filteredConversations.forEach((conv) => newSet.add(conv.id));
selectedIds = newSet;
}
}
function handleConfirm() {
const selected = conversations.filter((conv) => selectedIds.has(conv.id));
onConfirm(selected);
}
function handleCancel() {
selectedIds = new SvelteSet(conversations.map((c) => c.id));
searchQuery = '';
lastClickedId = null;
onCancel();
}
let previousOpen = $state(false);
$effect(() => {
if (open && !previousOpen) {
selectedIds = new SvelteSet(conversations.map((c) => c.id));
searchQuery = '';
lastClickedId = null;
} else if (!open && previousOpen) {
onCancel();
}
previousOpen = open;
});
</script>
<Dialog.Root bind:open>
<Dialog.Portal>
<Dialog.Overlay class="z-[1000000]" />
<Dialog.Content class="z-[1000001] max-w-2xl">
<Dialog.Header>
<Dialog.Title>
Select Conversations to {mode === 'export' ? 'Export' : 'Import'}
</Dialog.Title>
<Dialog.Description>
{#if mode === 'export'}
Choose which conversations you want to export. Selected conversations will be downloaded
as a JSON file.
{:else}
Choose which conversations you want to import. Selected conversations will be merged
with your existing conversations.
{/if}
</Dialog.Description>
</Dialog.Header>
<div class="space-y-4">
<div class="relative">
<Search class="absolute top-1/2 left-3 h-4 w-4 -translate-y-1/2 text-muted-foreground" />
<Input bind:value={searchQuery} placeholder="Search conversations..." class="pr-9 pl-9" />
{#if searchQuery}
<button
class="absolute top-1/2 right-3 -translate-y-1/2 text-muted-foreground hover:text-foreground"
onclick={() => (searchQuery = '')}
type="button"
>
<X class="h-4 w-4" />
</button>
{/if}
</div>
<div class="flex items-center justify-between text-sm text-muted-foreground">
<span>
{selectedIds.size} of {conversations.length} selected
{#if searchQuery}
({filteredConversations.length} shown)
{/if}
</span>
</div>
<div class="overflow-hidden rounded-md border">
<ScrollArea class="h-[400px]">
<table class="w-full">
<thead class="sticky top-0 z-10 bg-muted">
<tr class="border-b">
<th class="w-12 p-3 text-left">
<Checkbox
checked={allSelected}
indeterminate={someSelected}
onCheckedChange={toggleAll}
/>
</th>
<th class="p-3 text-left text-sm font-medium">Conversation Name</th>
<th class="w-32 p-3 text-left text-sm font-medium">Messages</th>
</tr>
</thead>
<tbody>
{#if filteredConversations.length === 0}
<tr>
<td colspan="3" class="p-8 text-center text-sm text-muted-foreground">
{#if searchQuery}
No conversations found matching "{searchQuery}"
{:else}
No conversations available
{/if}
</td>
</tr>
{:else}
{#each filteredConversations as conv (conv.id)}
<tr
class="cursor-pointer border-b transition-colors hover:bg-muted/50"
onclick={(e) => toggleConversation(conv.id, e.shiftKey)}
>
<td class="p-3">
<Checkbox
checked={selectedIds.has(conv.id)}
onclick={(e) => {
e.preventDefault();
e.stopPropagation();
toggleConversation(conv.id, e.shiftKey);
}}
/>
</td>
<td class="p-3 text-sm">
<div
class="max-w-[17rem] truncate"
title={conv.name || 'Untitled conversation'}
>
{conv.name || 'Untitled conversation'}
</div>
</td>
<td class="p-3 text-sm text-muted-foreground">
{messageCountMap.get(conv.id) ?? 0}
</td>
</tr>
{/each}
{/if}
</tbody>
</table>
</ScrollArea>
</div>
</div>
<Dialog.Footer>
<Button variant="outline" onclick={handleCancel}>Cancel</Button>
<Button onclick={handleConfirm} disabled={selectedIds.size === 0}>
{mode === 'export' ? 'Export' : 'Import'} ({selectedIds.size})
</Button>
</Dialog.Footer>
</Dialog.Content>
</Dialog.Portal>
</Dialog.Root>

View File

@ -2,7 +2,7 @@
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { page } from '$app/state'; import { page } from '$app/state';
import { Trash2 } from '@lucide/svelte'; import { Trash2 } from '@lucide/svelte';
import { ChatSidebarConversationItem, ConfirmationDialog } from '$lib/components/app'; import { ChatSidebarConversationItem, DialogConfirmation } from '$lib/components/app';
import ScrollArea from '$lib/components/ui/scroll-area/scroll-area.svelte'; import ScrollArea from '$lib/components/ui/scroll-area/scroll-area.svelte';
import * as Sidebar from '$lib/components/ui/sidebar'; import * as Sidebar from '$lib/components/ui/sidebar';
import * as AlertDialog from '$lib/components/ui/alert-dialog'; import * as AlertDialog from '$lib/components/ui/alert-dialog';
@ -158,7 +158,7 @@
<div class="bottom-0 z-10 bg-sidebar bg-sidebar/50 px-4 py-4 backdrop-blur-lg md:sticky"></div> <div class="bottom-0 z-10 bg-sidebar bg-sidebar/50 px-4 py-4 backdrop-blur-lg md:sticky"></div>
</ScrollArea> </ScrollArea>
<ConfirmationDialog <DialogConfirmation
bind:open={showDeleteDialog} bind:open={showDeleteDialog}
title="Delete Conversation" title="Delete Conversation"
description={selectedConversation description={selectedConversation

View File

@ -0,0 +1,78 @@
<script lang="ts">
import * as Dialog from '$lib/components/ui/dialog';
import { ChatAttachmentPreview } from '$lib/components/app';
import { formatFileSize } from '$lib/utils/file-preview';
interface Props {
open: boolean;
// Either an uploaded file or a stored attachment
uploadedFile?: ChatUploadedFile;
attachment?: DatabaseMessageExtra;
// For uploaded files
preview?: string;
name?: string;
type?: string;
size?: number;
textContent?: string;
}
let {
open = $bindable(),
uploadedFile,
attachment,
preview,
name,
type,
size,
textContent
}: Props = $props();
let chatAttachmentPreviewRef: ChatAttachmentPreview | undefined = $state();
let displayName = $derived(uploadedFile?.name || attachment?.name || name || 'Unknown File');
let displayType = $derived(
uploadedFile?.type ||
(attachment?.type === 'imageFile'
? 'image'
: attachment?.type === 'textFile'
? 'text'
: attachment?.type === 'audioFile'
? attachment.mimeType || 'audio'
: attachment?.type === 'pdfFile'
? 'application/pdf'
: type || 'unknown')
);
let displaySize = $derived(uploadedFile?.size || size);
$effect(() => {
if (open && chatAttachmentPreviewRef) {
chatAttachmentPreviewRef.reset();
}
});
</script>
<Dialog.Root bind:open>
<Dialog.Content class="grid max-h-[90vh] max-w-5xl overflow-hidden sm:w-auto sm:max-w-6xl">
<Dialog.Header>
<Dialog.Title>{displayName}</Dialog.Title>
<Dialog.Description>
{displayType}
{#if displaySize}
{formatFileSize(displaySize)}
{/if}
</Dialog.Description>
</Dialog.Header>
<ChatAttachmentPreview
bind:this={chatAttachmentPreviewRef}
{uploadedFile}
{attachment}
{preview}
{name}
{type}
{textContent}
/>
</Dialog.Content>
</Dialog.Root>

View File

@ -0,0 +1,51 @@
<script lang="ts">
import * as Dialog from '$lib/components/ui/dialog';
import { ChatAttachmentsViewAll } from '$lib/components/app';
interface Props {
open?: boolean;
uploadedFiles?: ChatUploadedFile[];
attachments?: DatabaseMessageExtra[];
readonly?: boolean;
onFileRemove?: (fileId: string) => void;
imageHeight?: string;
imageWidth?: string;
imageClass?: string;
}
let {
open = $bindable(false),
uploadedFiles = [],
attachments = [],
readonly = false,
onFileRemove,
imageHeight = 'h-24',
imageWidth = 'w-auto',
imageClass = ''
}: Props = $props();
let totalCount = $derived(uploadedFiles.length + attachments.length);
</script>
<Dialog.Root bind:open>
<Dialog.Portal>
<Dialog.Overlay />
<Dialog.Content class="flex !max-h-[90vh] !max-w-6xl flex-col">
<Dialog.Header>
<Dialog.Title>All Attachments ({totalCount})</Dialog.Title>
<Dialog.Description>View and manage all attached files</Dialog.Description>
</Dialog.Header>
<ChatAttachmentsViewAll
{uploadedFiles}
{attachments}
{readonly}
{onFileRemove}
{imageHeight}
{imageWidth}
{imageClass}
/>
</Dialog.Content>
</Dialog.Portal>
</Dialog.Root>

View File

@ -0,0 +1,37 @@
<script lang="ts">
import * as Dialog from '$lib/components/ui/dialog';
import { ChatSettings } from '$lib/components/app';
interface Props {
onOpenChange?: (open: boolean) => void;
open?: boolean;
}
let { onOpenChange, open = false }: Props = $props();
let chatSettingsRef: ChatSettings | undefined = $state();
function handleClose() {
onOpenChange?.(false);
}
function handleSave() {
onOpenChange?.(false);
}
$effect(() => {
if (open && chatSettingsRef) {
chatSettingsRef.reset();
}
});
</script>
<Dialog.Root {open} onOpenChange={handleClose}>
<Dialog.Content
class="z-999999 flex h-[100dvh] max-h-[100dvh] min-h-[100dvh] flex-col gap-0 rounded-none p-0
md:h-[64vh] md:max-h-[64vh] md:min-h-0 md:rounded-lg"
style="max-width: 48rem;"
>
<ChatSettings bind:this={chatSettingsRef} onSave={handleSave} />
</Dialog.Content>
</Dialog.Root>

View File

@ -0,0 +1,68 @@
<script lang="ts">
import * as Dialog from '$lib/components/ui/dialog';
import { ConversationSelection } from '$lib/components/app';
interface Props {
conversations: DatabaseConversation[];
messageCountMap?: Map<string, number>;
mode: 'export' | 'import';
onCancel: () => void;
onConfirm: (selectedConversations: DatabaseConversation[]) => void;
open?: boolean;
}
let {
conversations,
messageCountMap = new Map(),
mode,
onCancel,
onConfirm,
open = $bindable(false)
}: Props = $props();
let conversationSelectionRef: ConversationSelection | undefined = $state();
let previousOpen = $state(false);
$effect(() => {
if (open && !previousOpen && conversationSelectionRef) {
conversationSelectionRef.reset();
} else if (!open && previousOpen) {
onCancel();
}
previousOpen = open;
});
</script>
<Dialog.Root bind:open>
<Dialog.Portal>
<Dialog.Overlay class="z-[1000000]" />
<Dialog.Content class="z-[1000001] max-w-2xl">
<Dialog.Header>
<Dialog.Title>
Select Conversations to {mode === 'export' ? 'Export' : 'Import'}
</Dialog.Title>
<Dialog.Description>
{#if mode === 'export'}
Choose which conversations you want to export. Selected conversations will be downloaded
as a JSON file.
{:else}
Choose which conversations you want to import. Selected conversations will be merged
with your existing conversations.
{/if}
</Dialog.Description>
</Dialog.Header>
<ConversationSelection
bind:this={conversationSelectionRef}
{conversations}
{messageCountMap}
{mode}
{onCancel}
{onConfirm}
/>
</Dialog.Content>
</Dialog.Portal>
</Dialog.Root>

View File

@ -1,56 +1,63 @@
// Chat
export { default as ChatAttachmentPreview } from './chat/ChatAttachments/ChatAttachmentPreview.svelte';
export { default as ChatAttachmentThumbnailFile } from './chat/ChatAttachments/ChatAttachmentThumbnailFile.svelte';
export { default as ChatAttachmentThumbnailImage } from './chat/ChatAttachments/ChatAttachmentThumbnailImage.svelte';
export { default as ChatAttachmentsList } from './chat/ChatAttachments/ChatAttachmentsList.svelte'; export { default as ChatAttachmentsList } from './chat/ChatAttachments/ChatAttachmentsList.svelte';
export { default as ChatAttachmentFilePreview } from './chat/ChatAttachments/ChatAttachmentFilePreview.svelte'; export { default as ChatAttachmentsViewAll } from './chat/ChatAttachments/ChatAttachmentsViewAll.svelte';
export { default as ChatAttachmentImagePreview } from './chat/ChatAttachments/ChatAttachmentImagePreview.svelte';
export { default as ChatAttachmentPreviewDialog } from './chat/ChatAttachments/ChatAttachmentPreviewDialog.svelte';
export { default as ChatAttachmentsViewAllDialog } from './chat/ChatAttachments/ChatAttachmentsViewAllDialog.svelte';
export { default as ChatForm } from './chat/ChatForm/ChatForm.svelte'; export { default as ChatForm } from './chat/ChatForm/ChatForm.svelte';
export { default as ChatFormTextarea } from './chat/ChatForm/ChatFormTextarea.svelte'; export { default as ChatFormActionFileAttachments } from './chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte';
export { default as ChatFormActions } from './chat/ChatForm/ChatFormActions.svelte'; export { default as ChatFormActionRecord } from './chat/ChatForm/ChatFormActions/ChatFormActionRecord.svelte';
export { default as ChatFormActionFileAttachments } from './chat/ChatForm/ChatFormActionFileAttachments.svelte'; export { default as ChatFormActions } from './chat/ChatForm/ChatFormActions/ChatFormActions.svelte';
export { default as ChatFormActionRecord } from './chat/ChatForm/ChatFormActionRecord.svelte';
export { default as ChatFormModelSelector } from './chat/ChatForm/ChatFormModelSelector.svelte';
export { default as ChatFormHelperText } from './chat/ChatForm/ChatFormHelperText.svelte';
export { default as ChatFormFileInputInvisible } from './chat/ChatForm/ChatFormFileInputInvisible.svelte'; export { default as ChatFormFileInputInvisible } from './chat/ChatForm/ChatFormFileInputInvisible.svelte';
export { default as ChatFormHelperText } from './chat/ChatForm/ChatFormHelperText.svelte';
export { default as ChatFormModelSelector } from './chat/ChatForm/ChatFormModelSelector.svelte';
export { default as ChatFormTextarea } from './chat/ChatForm/ChatFormTextarea.svelte';
export { default as ChatMessage } from './chat/ChatMessages/ChatMessage.svelte'; export { default as ChatMessage } from './chat/ChatMessages/ChatMessage.svelte';
export { default as ChatMessages } from './chat/ChatMessages/ChatMessages.svelte'; export { default as ChatMessages } from './chat/ChatMessages/ChatMessages.svelte';
export { default as ChatMessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte';
export { default as ChatMessageThinkingBlock } from './chat/ChatMessages/ChatMessageThinkingBlock.svelte'; export { default as ChatMessageThinkingBlock } from './chat/ChatMessages/ChatMessageThinkingBlock.svelte';
export { default as MessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte';
export { default as ChatProcessingInfo } from './chat/ChatProcessingInfo.svelte';
export { default as ChatScreenHeader } from './chat/ChatScreen/ChatScreenHeader.svelte';
export { default as ChatScreenWarning } from './chat/ChatScreen/ChatScreenWarning.svelte';
export { default as ChatScreen } from './chat/ChatScreen/ChatScreen.svelte'; export { default as ChatScreen } from './chat/ChatScreen/ChatScreen.svelte';
export { default as ChatScreenHeader } from './chat/ChatScreen/ChatScreenHeader.svelte';
export { default as ChatScreenProcessingInfo } from './chat/ChatScreen/ChatScreenProcessingInfo.svelte';
export { default as ChatScreenWarning } from './chat/ChatScreen/ChatScreenWarning.svelte';
export { default as ChatSettingsDialog } from './chat/ChatSettings/ChatSettingsDialog.svelte'; export { default as ChatSettings } from './chat/ChatSettings/ChatSettings.svelte';
export { default as ChatSettingsFooter } from './chat/ChatSettings/ChatSettingsFooter.svelte'; export { default as ChatSettingsFooter } from './chat/ChatSettings/ChatSettingsFooter.svelte';
export { default as ChatSettingsFields } from './chat/ChatSettings/ChatSettingsFields.svelte'; export { default as ChatSettingsFields } from './chat/ChatSettings/ChatSettingsFields.svelte';
export { default as ImportExportTab } from './chat/ChatSettings/ImportExportTab.svelte'; export { default as ChatSettingsImportExportTab } from './chat/ChatSettings/ChatSettingsImportExportTab.svelte';
export { default as ConversationSelectionDialog } from './chat/ChatSettings/ConversationSelectionDialog.svelte'; export { default as ChatSettingsParameterSourceIndicator } from './chat/ChatSettings/ChatSettingsParameterSourceIndicator.svelte';
export { default as ParameterSourceIndicator } from './chat/ChatSettings/ParameterSourceIndicator.svelte';
export { default as ChatSidebar } from './chat/ChatSidebar/ChatSidebar.svelte'; export { default as ChatSidebar } from './chat/ChatSidebar/ChatSidebar.svelte';
export { default as ChatSidebarConversationItem } from './chat/ChatSidebar/ChatSidebarConversationItem.svelte'; export { default as ChatSidebarConversationItem } from './chat/ChatSidebar/ChatSidebarConversationItem.svelte';
export { default as ChatSidebarSearch } from './chat/ChatSidebar/ChatSidebarSearch.svelte'; export { default as ChatSidebarSearch } from './chat/ChatSidebar/ChatSidebarSearch.svelte';
export { default as ChatErrorDialog } from './dialogs/ChatErrorDialog.svelte';
export { default as EmptyFileAlertDialog } from './dialogs/EmptyFileAlertDialog.svelte';
export { default as ConversationTitleUpdateDialog } from './dialogs/ConversationTitleUpdateDialog.svelte'; // Dialogs
export { default as DialogChatAttachmentPreview } from './dialogs/DialogChatAttachmentPreview.svelte';
export { default as DialogChatAttachmentsViewAll } from './dialogs/DialogChatAttachmentsViewAll.svelte';
export { default as DialogChatError } from './dialogs/DialogChatError.svelte';
export { default as DialogChatSettings } from './dialogs/DialogChatSettings.svelte';
export { default as DialogConfirmation } from './dialogs/DialogConfirmation.svelte';
export { default as DialogConversationSelection } from './dialogs/DialogConversationSelection.svelte';
export { default as DialogConversationTitleUpdate } from './dialogs/DialogConversationTitleUpdate.svelte';
export { default as DialogEmptyFileAlert } from './dialogs/DialogEmptyFileAlert.svelte';
// Miscellanous
export { default as ActionButton } from './misc/ActionButton.svelte';
export { default as ActionDropdown } from './misc/ActionDropdown.svelte';
export { default as ConversationSelection } from './misc/ConversationSelection.svelte';
export { default as KeyboardShortcutInfo } from './misc/KeyboardShortcutInfo.svelte'; export { default as KeyboardShortcutInfo } from './misc/KeyboardShortcutInfo.svelte';
export { default as MarkdownContent } from './misc/MarkdownContent.svelte'; export { default as MarkdownContent } from './misc/MarkdownContent.svelte';
export { default as RemoveButton } from './misc/RemoveButton.svelte'; export { default as RemoveButton } from './misc/RemoveButton.svelte';
// Server
export { default as ServerStatus } from './server/ServerStatus.svelte'; export { default as ServerStatus } from './server/ServerStatus.svelte';
export { default as ServerErrorSplash } from './server/ServerErrorSplash.svelte'; export { default as ServerErrorSplash } from './server/ServerErrorSplash.svelte';
export { default as ServerLoadingSplash } from './server/ServerLoadingSplash.svelte'; export { default as ServerLoadingSplash } from './server/ServerLoadingSplash.svelte';
export { default as ServerInfo } from './server/ServerInfo.svelte'; export { default as ServerInfo } from './server/ServerInfo.svelte';
// Shared components
export { default as ActionButton } from './misc/ActionButton.svelte';
export { default as ActionDropdown } from './misc/ActionDropdown.svelte';
export { default as ConfirmationDialog } from './dialogs/ConfirmationDialog.svelte';

View File

@ -0,0 +1,205 @@
<script lang="ts">
import { Search, X } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import { Input } from '$lib/components/ui/input';
import { Checkbox } from '$lib/components/ui/checkbox';
import { ScrollArea } from '$lib/components/ui/scroll-area';
import { SvelteSet } from 'svelte/reactivity';
interface Props {
conversations: DatabaseConversation[];
messageCountMap?: Map<string, number>;
mode: 'export' | 'import';
onCancel: () => void;
onConfirm: (selectedConversations: DatabaseConversation[]) => void;
}
let { conversations, messageCountMap = new Map(), mode, onCancel, onConfirm }: Props = $props();
let searchQuery = $state('');
let selectedIds = $state.raw<SvelteSet<string>>(new SvelteSet(conversations.map((c) => c.id)));
let lastClickedId = $state<string | null>(null);
let filteredConversations = $derived(
conversations.filter((conv) => {
const name = conv.name || 'Untitled conversation';
return name.toLowerCase().includes(searchQuery.toLowerCase());
})
);
let allSelected = $derived(
filteredConversations.length > 0 &&
filteredConversations.every((conv) => selectedIds.has(conv.id))
);
let someSelected = $derived(
filteredConversations.some((conv) => selectedIds.has(conv.id)) && !allSelected
);
function toggleConversation(id: string, shiftKey: boolean = false) {
const newSet = new SvelteSet(selectedIds);
if (shiftKey && lastClickedId !== null) {
const lastIndex = filteredConversations.findIndex((c) => c.id === lastClickedId);
const currentIndex = filteredConversations.findIndex((c) => c.id === id);
if (lastIndex !== -1 && currentIndex !== -1) {
const start = Math.min(lastIndex, currentIndex);
const end = Math.max(lastIndex, currentIndex);
const shouldSelect = !newSet.has(id);
for (let i = start; i <= end; i++) {
if (shouldSelect) {
newSet.add(filteredConversations[i].id);
} else {
newSet.delete(filteredConversations[i].id);
}
}
selectedIds = newSet;
return;
}
}
if (newSet.has(id)) {
newSet.delete(id);
} else {
newSet.add(id);
}
selectedIds = newSet;
lastClickedId = id;
}
function toggleAll() {
if (allSelected) {
const newSet = new SvelteSet(selectedIds);
filteredConversations.forEach((conv) => newSet.delete(conv.id));
selectedIds = newSet;
} else {
const newSet = new SvelteSet(selectedIds);
filteredConversations.forEach((conv) => newSet.add(conv.id));
selectedIds = newSet;
}
}
function handleConfirm() {
const selected = conversations.filter((conv) => selectedIds.has(conv.id));
onConfirm(selected);
}
function handleCancel() {
selectedIds = new SvelteSet(conversations.map((c) => c.id));
searchQuery = '';
lastClickedId = null;
onCancel();
}
export function reset() {
selectedIds = new SvelteSet(conversations.map((c) => c.id));
searchQuery = '';
lastClickedId = null;
}
</script>
<div class="space-y-4">
<div class="relative">
<Search class="absolute top-1/2 left-3 h-4 w-4 -translate-y-1/2 text-muted-foreground" />
<Input bind:value={searchQuery} placeholder="Search conversations..." class="pr-9 pl-9" />
{#if searchQuery}
<button
class="absolute top-1/2 right-3 -translate-y-1/2 text-muted-foreground hover:text-foreground"
onclick={() => (searchQuery = '')}
type="button"
>
<X class="h-4 w-4" />
</button>
{/if}
</div>
<div class="flex items-center justify-between text-sm text-muted-foreground">
<span>
{selectedIds.size} of {conversations.length} selected
{#if searchQuery}
({filteredConversations.length} shown)
{/if}
</span>
</div>
<div class="overflow-hidden rounded-md border">
<ScrollArea class="h-[400px]">
<table class="w-full">
<thead class="sticky top-0 z-10 bg-muted">
<tr class="border-b">
<th class="w-12 p-3 text-left">
<Checkbox
checked={allSelected}
indeterminate={someSelected}
onCheckedChange={toggleAll}
/>
</th>
<th class="p-3 text-left text-sm font-medium">Conversation Name</th>
<th class="w-32 p-3 text-left text-sm font-medium">Messages</th>
</tr>
</thead>
<tbody>
{#if filteredConversations.length === 0}
<tr>
<td colspan="3" class="p-8 text-center text-sm text-muted-foreground">
{#if searchQuery}
No conversations found matching "{searchQuery}"
{:else}
No conversations available
{/if}
</td>
</tr>
{:else}
{#each filteredConversations as conv (conv.id)}
<tr
class="cursor-pointer border-b transition-colors hover:bg-muted/50"
onclick={(e) => toggleConversation(conv.id, e.shiftKey)}
>
<td class="p-3">
<Checkbox
checked={selectedIds.has(conv.id)}
onclick={(e) => {
e.preventDefault();
e.stopPropagation();
toggleConversation(conv.id, e.shiftKey);
}}
/>
</td>
<td class="p-3 text-sm">
<div class="max-w-[17rem] truncate" title={conv.name || 'Untitled conversation'}>
{conv.name || 'Untitled conversation'}
</div>
</td>
<td class="p-3 text-sm text-muted-foreground">
{messageCountMap.get(conv.id) ?? 0}
</td>
</tr>
{/each}
{/if}
</tbody>
</table>
</ScrollArea>
</div>
<div class="flex justify-end gap-2">
<Button variant="outline" onclick={handleCancel}>Cancel</Button>
<Button onclick={handleConfirm} disabled={selectedIds.size === 0}>
{mode === 'export' ? 'Export' : 'Import'} ({selectedIds.size})
</Button>
</div>
</div>

View File

@ -38,7 +38,8 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
max_tokens: -1, max_tokens: -1,
custom: '', // custom json-stringified object custom: '', // custom json-stringified object
// experimental features // experimental features
pyInterpreterEnabled: false pyInterpreterEnabled: false,
enableContinueGeneration: false
}; };
export const SETTING_CONFIG_INFO: Record<string, string> = { export const SETTING_CONFIG_INFO: Record<string, string> = {
@ -96,5 +97,7 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
modelSelectorEnabled: modelSelectorEnabled:
'Enable the model selector in the chat input to choose the inference model. Sends the associated model field in API requests.', 'Enable the model selector in the chat input to choose the inference model. Sends the associated model field in API requests.',
pyInterpreterEnabled: pyInterpreterEnabled:
'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.' 'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.',
enableContinueGeneration:
'Enable "Continue" button for assistant messages. Currently works only with non-reasoning models.'
}; };

View File

@ -312,7 +312,6 @@ export class ChatService {
let aggregatedContent = ''; let aggregatedContent = '';
let fullReasoningContent = ''; let fullReasoningContent = '';
let aggregatedToolCalls: ApiChatCompletionToolCall[] = []; let aggregatedToolCalls: ApiChatCompletionToolCall[] = [];
let hasReceivedData = false;
let lastTimings: ChatMessageTimings | undefined; let lastTimings: ChatMessageTimings | undefined;
let streamFinished = false; let streamFinished = false;
let modelEmitted = false; let modelEmitted = false;
@ -352,8 +351,6 @@ export class ChatService {
return; return;
} }
hasReceivedData = true;
if (!abortSignal?.aborted) { if (!abortSignal?.aborted) {
onToolCallChunk?.(serializedToolCalls); onToolCallChunk?.(serializedToolCalls);
} }
@ -415,7 +412,6 @@ export class ChatService {
if (content) { if (content) {
finalizeOpenToolCallBatch(); finalizeOpenToolCallBatch();
hasReceivedData = true;
aggregatedContent += content; aggregatedContent += content;
if (!abortSignal?.aborted) { if (!abortSignal?.aborted) {
onChunk?.(content); onChunk?.(content);
@ -424,7 +420,6 @@ export class ChatService {
if (reasoningContent) { if (reasoningContent) {
finalizeOpenToolCallBatch(); finalizeOpenToolCallBatch();
hasReceivedData = true;
fullReasoningContent += reasoningContent; fullReasoningContent += reasoningContent;
if (!abortSignal?.aborted) { if (!abortSignal?.aborted) {
onReasoningChunk?.(reasoningContent); onReasoningChunk?.(reasoningContent);
@ -446,15 +441,6 @@ export class ChatService {
if (streamFinished) { if (streamFinished) {
finalizeOpenToolCallBatch(); finalizeOpenToolCallBatch();
if (
!hasReceivedData &&
aggregatedContent.length === 0 &&
aggregatedToolCalls.length === 0
) {
const noResponseError = new Error('No response received from server. Please try again.');
throw noResponseError;
}
const finalToolCalls = const finalToolCalls =
aggregatedToolCalls.length > 0 ? JSON.stringify(aggregatedToolCalls) : undefined; aggregatedToolCalls.length > 0 ? JSON.stringify(aggregatedToolCalls) : undefined;

View File

@ -1486,6 +1486,10 @@ class ChatStore {
timestamp: Date.now() timestamp: Date.now()
}); });
// Ensure currNode points to the edited message to maintain correct path
await DatabaseStore.updateCurrentNode(this.activeConversation.id, messageToEdit.id);
this.activeConversation.currNode = messageToEdit.id;
this.updateMessageAtIndex(messageIndex, { this.updateMessageAtIndex(messageIndex, {
content: newContent, content: newContent,
timestamp: Date.now() timestamp: Date.now()
@ -1499,6 +1503,69 @@ class ChatStore {
} }
} }
/**
* Edits a user message and preserves all responses below
* Updates the message content in-place without deleting or regenerating responses
*
* **Use Case**: When you want to fix a typo or rephrase a question without losing the assistant's response
*
* **Important Behavior:**
* - Does NOT create a branch (unlike editMessageWithBranching)
* - Does NOT regenerate assistant responses
* - Only updates the user message content in the database
* - Preserves the entire conversation tree below the edited message
* - Updates conversation title if this is the first user message
*
* @param messageId - The ID of the user message to edit
* @param newContent - The new content for the message
*/
async editUserMessagePreserveResponses(messageId: string, newContent: string): Promise<void> {
if (!this.activeConversation) return;
try {
const messageIndex = this.findMessageIndex(messageId);
if (messageIndex === -1) {
console.error('Message not found for editing');
return;
}
const messageToEdit = this.activeMessages[messageIndex];
if (messageToEdit.role !== 'user') {
console.error('Only user messages can be edited with this method');
return;
}
// Simply update the message content in-place
await DatabaseStore.updateMessage(messageId, {
content: newContent,
timestamp: Date.now()
});
this.updateMessageAtIndex(messageIndex, {
content: newContent,
timestamp: Date.now()
});
// Check if first user message for title update
const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id);
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
const isFirstUserMessage =
rootMessage && messageToEdit.parent === rootMessage.id && messageToEdit.role === 'user';
if (isFirstUserMessage && newContent.trim()) {
await this.updateConversationTitleWithConfirmation(
this.activeConversation.id,
newContent.trim(),
this.titleUpdateConfirmationCallback
);
}
this.updateConversationTimestamp();
} catch (error) {
console.error('Failed to edit user message:', error);
}
}
/** /**
* Edits a message by creating a new branch with the edited content * Edits a message by creating a new branch with the edited content
* @param messageId - The ID of the message to edit * @param messageId - The ID of the message to edit
@ -1696,6 +1763,200 @@ class ChatStore {
} }
} }
/**
* Continues generation for an existing assistant message
* @param messageId - The ID of the assistant message to continue
*/
async continueAssistantMessage(messageId: string): Promise<void> {
if (!this.activeConversation || this.isLoading) return;
try {
const messageIndex = this.findMessageIndex(messageId);
if (messageIndex === -1) {
console.error('Message not found for continuation');
return;
}
const messageToContinue = this.activeMessages[messageIndex];
if (messageToContinue.role !== 'assistant') {
console.error('Only assistant messages can be continued');
return;
}
// Race condition protection: Check if this specific conversation is already loading
// This prevents multiple rapid clicks on "Continue" from creating concurrent operations
if (this.isConversationLoading(this.activeConversation.id)) {
console.warn('Continuation already in progress for this conversation');
return;
}
this.errorDialogState = null;
this.setConversationLoading(this.activeConversation.id, true);
this.clearConversationStreaming(this.activeConversation.id);
// IMPORTANT: Fetch the latest content from the database to ensure we have
// the most up-to-date content, especially after a stopped generation
// This prevents issues where the in-memory state might be stale
const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id);
const dbMessage = allMessages.find((m) => m.id === messageId);
if (!dbMessage) {
console.error('Message not found in database for continuation');
this.setConversationLoading(this.activeConversation.id, false);
return;
}
// Use content from database as the source of truth
const originalContent = dbMessage.content;
const originalThinking = dbMessage.thinking || '';
// Get conversation context up to (but not including) the message to continue
const conversationContext = this.activeMessages.slice(0, messageIndex);
const contextWithContinue = [
...conversationContext.map((msg) => {
if ('id' in msg && 'convId' in msg && 'timestamp' in msg) {
return msg as DatabaseMessage & { extra?: DatabaseMessageExtra[] };
}
return msg as ApiChatMessageData;
}),
{
role: 'assistant' as const,
content: originalContent
}
];
let appendedContent = '';
let appendedThinking = '';
let hasReceivedContent = false;
await chatService.sendMessage(
contextWithContinue,
{
...this.getApiOptions(),
onChunk: (chunk: string) => {
hasReceivedContent = true;
appendedContent += chunk;
// Preserve originalContent exactly as-is, including any trailing whitespace
// The concatenation naturally preserves any whitespace at the end of originalContent
const fullContent = originalContent + appendedContent;
this.setConversationStreaming(
messageToContinue.convId,
fullContent,
messageToContinue.id
);
this.updateMessageAtIndex(messageIndex, {
content: fullContent
});
},
onReasoningChunk: (reasoningChunk: string) => {
hasReceivedContent = true;
appendedThinking += reasoningChunk;
const fullThinking = originalThinking + appendedThinking;
this.updateMessageAtIndex(messageIndex, {
thinking: fullThinking
});
},
onComplete: async (
finalContent?: string,
reasoningContent?: string,
timings?: ChatMessageTimings
) => {
const fullContent = originalContent + (finalContent || appendedContent);
const fullThinking = originalThinking + (reasoningContent || appendedThinking);
const updateData: {
content: string;
thinking: string;
timestamp: number;
timings?: ChatMessageTimings;
} = {
content: fullContent,
thinking: fullThinking,
timestamp: Date.now(),
timings: timings
};
await DatabaseStore.updateMessage(messageToContinue.id, updateData);
this.updateMessageAtIndex(messageIndex, updateData);
this.updateConversationTimestamp();
this.setConversationLoading(messageToContinue.convId, false);
this.clearConversationStreaming(messageToContinue.convId);
slotsService.clearConversationState(messageToContinue.convId);
},
onError: async (error: Error) => {
if (this.isAbortError(error)) {
// User cancelled - save partial continuation if any content was received
if (hasReceivedContent && appendedContent) {
const partialContent = originalContent + appendedContent;
const partialThinking = originalThinking + appendedThinking;
await DatabaseStore.updateMessage(messageToContinue.id, {
content: partialContent,
thinking: partialThinking,
timestamp: Date.now()
});
this.updateMessageAtIndex(messageIndex, {
content: partialContent,
thinking: partialThinking,
timestamp: Date.now()
});
}
this.setConversationLoading(messageToContinue.convId, false);
this.clearConversationStreaming(messageToContinue.convId);
slotsService.clearConversationState(messageToContinue.convId);
return;
}
// Non-abort error - rollback to original content
console.error('Continue generation error:', error);
// Rollback: Restore original content in UI
this.updateMessageAtIndex(messageIndex, {
content: originalContent,
thinking: originalThinking
});
// Ensure database has original content (in case of partial writes)
await DatabaseStore.updateMessage(messageToContinue.id, {
content: originalContent,
thinking: originalThinking
});
this.setConversationLoading(messageToContinue.convId, false);
this.clearConversationStreaming(messageToContinue.convId);
slotsService.clearConversationState(messageToContinue.convId);
const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server';
this.showErrorDialog(dialogType, error.message);
}
},
messageToContinue.convId
);
} catch (error) {
if (this.isAbortError(error)) return;
console.error('Failed to continue message:', error);
if (this.activeConversation) {
this.setConversationLoading(this.activeConversation.id, false);
}
}
}
/** /**
* Public methods for accessing per-conversation states * Public methods for accessing per-conversation states
*/ */
@ -1743,8 +2004,11 @@ export const refreshActiveMessages = chatStore.refreshActiveMessages.bind(chatSt
export const navigateToSibling = chatStore.navigateToSibling.bind(chatStore); export const navigateToSibling = chatStore.navigateToSibling.bind(chatStore);
export const editAssistantMessage = chatStore.editAssistantMessage.bind(chatStore); export const editAssistantMessage = chatStore.editAssistantMessage.bind(chatStore);
export const editMessageWithBranching = chatStore.editMessageWithBranching.bind(chatStore); export const editMessageWithBranching = chatStore.editMessageWithBranching.bind(chatStore);
export const editUserMessagePreserveResponses =
chatStore.editUserMessagePreserveResponses.bind(chatStore);
export const regenerateMessageWithBranching = export const regenerateMessageWithBranching =
chatStore.regenerateMessageWithBranching.bind(chatStore); chatStore.regenerateMessageWithBranching.bind(chatStore);
export const continueAssistantMessage = chatStore.continueAssistantMessage.bind(chatStore);
export const deleteMessage = chatStore.deleteMessage.bind(chatStore); export const deleteMessage = chatStore.deleteMessage.bind(chatStore);
export const getDeletionInfo = chatStore.getDeletionInfo.bind(chatStore); export const getDeletionInfo = chatStore.getDeletionInfo.bind(chatStore);
export const updateConversationName = chatStore.updateConversationName.bind(chatStore); export const updateConversationName = chatStore.updateConversationName.bind(chatStore);

View File

@ -7,6 +7,7 @@ export interface SettingsFieldConfig {
key: string; key: string;
label: string; label: string;
type: 'input' | 'textarea' | 'checkbox' | 'select'; type: 'input' | 'textarea' | 'checkbox' | 'select';
isExperimental?: boolean;
help?: string; help?: string;
options?: Array<{ value: string; label: string; icon?: typeof import('@lucide/svelte').Icon }>; options?: Array<{ value: string; label: string; icon?: typeof import('@lucide/svelte').Icon }>;
} }

View File

@ -1,7 +1,7 @@
<script lang="ts"> <script lang="ts">
import '../app.css'; import '../app.css';
import { page } from '$app/state'; import { page } from '$app/state';
import { ChatSidebar, ConversationTitleUpdateDialog } from '$lib/components/app'; import { ChatSidebar, DialogConversationTitleUpdate } from '$lib/components/app';
import { import {
activeMessages, activeMessages,
isLoading, isLoading,
@ -150,7 +150,7 @@
<Toaster richColors /> <Toaster richColors />
<ConversationTitleUpdateDialog <DialogConversationTitleUpdate
bind:open={titleUpdateDialogOpen} bind:open={titleUpdateDialogOpen}
currentTitle={titleUpdateCurrentTitle} currentTitle={titleUpdateCurrentTitle}
newTitle={titleUpdateNewTitle} newTitle={titleUpdateNewTitle}

View File

@ -0,0 +1,19 @@
<script module>
import { defineMeta } from '@storybook/addon-svelte-csf';
import { ChatSettings } from '$lib/components/app';
import { fn } from 'storybook/test';
const { Story } = defineMeta({
title: 'Components/ChatSettings',
component: ChatSettings,
parameters: {
layout: 'fullscreen'
},
args: {
onClose: fn(),
onSave: fn()
}
});
</script>
<Story name="Default" />

View File

@ -1,26 +0,0 @@
<script module>
import { defineMeta } from '@storybook/addon-svelte-csf';
import { ChatSettingsDialog } from '$lib/components/app';
import { fn } from 'storybook/test';
const { Story } = defineMeta({
title: 'Components/ChatSettingsDialog',
component: ChatSettingsDialog,
parameters: {
layout: 'fullscreen'
},
argTypes: {
open: {
control: 'boolean',
description: 'Whether the dialog is open'
}
},
args: {
onOpenChange: fn()
}
});
</script>
<Story name="Open" args={{ open: true }} />
<Story name="Closed" args={{ open: false }} />