diff --git a/common/common.cpp b/common/common.cpp index 4dc95dcba2..f3cc55247e 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -26,7 +26,6 @@ #include #include #include -#include #include #include @@ -60,6 +59,14 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif +common_time_meas::common_time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {} + +common_time_meas::~common_time_meas() { + if (t_start_us >= 0) { + t_acc += ggml_time_us() - t_start_us; + } +} + // // CPU utils // diff --git a/common/common.h b/common/common.h index f42c083faa..de5b404dd8 100644 --- a/common/common.h +++ b/common/common.h @@ -2,17 +2,15 @@ #pragma once +#include "ggml-opt.h" +#include "llama-cpp.h" + #include #include #include #include #include #include -#include -#include - -#include "ggml-opt.h" -#include "llama-cpp.h" #ifdef _WIN32 #define DIRECTORY_SEPARATOR '\\' @@ -30,6 +28,15 @@ #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" +struct common_time_meas { + common_time_meas(int64_t & t_acc, bool disable = false); + ~common_time_meas(); + + const int64_t t_start_us; + + int64_t & t_acc; +}; + struct common_adapter_lora_info { std::string path; float scale; diff --git a/common/sampling.cpp b/common/sampling.cpp index c69d525b5b..7a6b7be1e0 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -3,9 +3,10 @@ #include "common.h" #include "log.h" -#include -#include #include +#include +#include +#include // the ring buffer works similarly to std::deque, but with a fixed capacity // TODO: deduplicate with llama-impl.h @@ -112,6 +113,13 @@ struct common_sampler { llama_token_data_array cur_p; + void reset() { + prev.clear(); + + llama_sampler_reset(grmr); + llama_sampler_reset(chain); + } + void set_logits(struct llama_context * ctx, int idx) { const auto * logits = llama_get_logits_ith(ctx, idx); @@ -128,6 +136,12 @@ struct common_sampler { cur_p = { cur.data(), cur.size(), -1, false }; } + + common_time_meas tm() { + return common_time_meas(t_total_us, params.no_perf); + } + + mutable int64_t t_total_us = 0; }; std::string common_params_sampling::print() const { @@ -298,6 +312,8 @@ void common_sampler_free(struct common_sampler * gsmpl) { } void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) { + const auto tm = gsmpl->tm(); + if (accept_grammar) { llama_sampler_accept(gsmpl->grmr, token); } @@ -308,9 +324,7 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo } void common_sampler_reset(struct common_sampler * gsmpl) { - llama_sampler_reset(gsmpl->grmr); - - llama_sampler_reset(gsmpl->chain); + gsmpl->reset(); } struct common_sampler * common_sampler_clone(common_sampler * gsmpl) { @@ -327,16 +341,54 @@ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) { void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) { // TODO: measure grammar performance + const double t_sampling_ms = gsmpl ? 1e-3*gsmpl->t_total_us : 0; + + llama_perf_sampler_data data_smpl; + llama_perf_context_data data_ctx; + + memset(&data_smpl, 0, sizeof(data_smpl)); + memset(&data_ctx, 0, sizeof(data_ctx)); + if (gsmpl) { - llama_perf_sampler_print(gsmpl->chain); + auto & data = data_smpl; + + data = llama_perf_sampler(gsmpl->chain); + + // note: the sampling time includes the samplers time + extra time spent in common/sampling + LOG_INF("%s: sampling time = %10.2f ms\n", __func__, t_sampling_ms); + LOG_INF("%s: samplers time = %10.2f ms / %5d tokens\n", __func__, data.t_sample_ms, data.n_sample); } + if (ctx) { - llama_perf_context_print(ctx); + auto & data = data_ctx; + + data = llama_perf_context(ctx); + + const double t_end_ms = 1e-3 * ggml_time_us(); + + const double t_total_ms = t_end_ms - data.t_start_ms; + const double t_unacc_ms = t_total_ms - (t_sampling_ms + data.t_p_eval_ms + data.t_eval_ms); + const double t_unacc_pc = 100.0 * t_unacc_ms / t_total_ms; + + LOG_INF("%s: load time = %10.2f ms\n", __func__, data.t_load_ms); + LOG_INF("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval); + LOG_INF("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval); + LOG_INF("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval)); + LOG_INF("%s: unaccounted time = %10.2f ms / %5.1f %% (total - sampling - prompt eval - eval) / (total)\n", __func__, t_unacc_ms, t_unacc_pc); + LOG_INF("%s: graphs reused = %10d\n", __func__, data.n_reused); + llama_memory_breakdown_print(ctx); } } llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { + llama_synchronize(ctx); + + // start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations + const auto tm = gsmpl->tm(); + gsmpl->set_logits(ctx, idx); auto & grmr = gsmpl->grmr; @@ -428,6 +480,8 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { // helpers llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) { + const auto tm = gsmpl->tm(); + auto * res = &gsmpl->cur_p; if (do_sort && !res->sorted) { diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0cc3df0975..8743202ad6 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1673,11 +1673,9 @@ class GPTNeoXModel(TextModel): model_arch = gguf.MODEL_ARCH.GPTNEOX def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] - self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_rope_dimension_count( int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])), @@ -1735,7 +1733,7 @@ class BloomModel(TextModel): self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed)) self.gguf_writer.add_embedding_length(n_embed) self.gguf_writer.add_feed_forward_length(4 * n_embed) - self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count_kv(n_head) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) @@ -1798,10 +1796,9 @@ class MPTModel(TextModel): self.gguf_writer.add_unk_token_id(0) def set_gguf_parameters(self): - block_count = self.hparams["n_layers"] self.gguf_writer.add_context_length(self.hparams["max_seq_len"]) self.gguf_writer.add_embedding_length(self.hparams["d_model"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"]) self.gguf_writer.add_head_count(self.hparams["n_heads"]) if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"): @@ -1834,7 +1831,6 @@ class OrionModel(TextModel): self._set_vocab_sentencepiece() def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] head_count = self.hparams["num_attention_heads"] head_count_kv = self.hparams.get("num_key_value_heads", head_count) @@ -1852,7 +1848,7 @@ class OrionModel(TextModel): self.gguf_writer.add_tensor_data_layout("Meta AI original pth") self.gguf_writer.add_context_length(ctx_length) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_head_count(head_count) self.gguf_writer.add_head_count_kv(head_count_kv) @@ -1869,7 +1865,6 @@ class BaichuanModel(TextModel): self._set_vocab_sentencepiece() def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] head_count = self.hparams["num_attention_heads"] head_count_kv = self.hparams.get("num_key_value_heads", head_count) @@ -1886,7 +1881,7 @@ class BaichuanModel(TextModel): self.gguf_writer.add_tensor_data_layout("Meta AI original pth") self.gguf_writer.add_context_length(ctx_length) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) self.gguf_writer.add_head_count(head_count) @@ -1993,7 +1988,6 @@ class XverseModel(TextModel): special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] head_count = self.hparams["num_attention_heads"] head_count_kv = self.hparams.get("num_key_value_heads", head_count) @@ -2010,7 +2004,7 @@ class XverseModel(TextModel): self.gguf_writer.add_tensor_data_layout("Meta AI original pth") self.gguf_writer.add_context_length(ctx_length) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) self.gguf_writer.add_head_count(head_count) @@ -2053,10 +2047,6 @@ class FalconModel(TextModel): model_arch = gguf.MODEL_ARCH.FALCON def set_gguf_parameters(self): - block_count = self.hparams.get("num_hidden_layers") - if block_count is None: - block_count = self.hparams["n_layer"] # old name - n_head = self.hparams.get("num_attention_heads") if n_head is None: n_head = self.hparams["n_head"] # old name @@ -2069,7 +2059,7 @@ class FalconModel(TextModel): self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count_kv(n_head_kv) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) @@ -2107,12 +2097,10 @@ class StarCoderModel(TextModel): model_arch = gguf.MODEL_ARCH.STARCODER def set_gguf_parameters(self): - block_count = self.hparams["n_layer"] - self.gguf_writer.add_context_length(self.hparams["n_positions"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(self.hparams["n_head"]) self.gguf_writer.add_head_count_kv(1) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) @@ -2142,14 +2130,12 @@ class RefactModel(TextModel): multiple_of = 256 ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - block_count = self.hparams["n_layer"] - # refact uses Alibi. So this is from config.json which might be used by training. self.gguf_writer.add_context_length(self.hparams["n_positions"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) self.gguf_writer.add_feed_forward_length(ff_dim) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(self.hparams["n_head"]) self.gguf_writer.add_head_count_kv(1) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"]) @@ -2196,11 +2182,10 @@ class StableLMModel(TextModel): def set_gguf_parameters(self): hparams = self.hparams - block_count = hparams["num_hidden_layers"] self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"]) self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"]))) @@ -3151,7 +3136,7 @@ class DbrxModel(TextModel): def set_gguf_parameters(self): ffn_config = self.hparams["ffn_config"] attn_config = self.hparams["attn_config"] - self.gguf_writer.add_block_count(self.hparams["n_layers"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_context_length(self.hparams["max_seq_len"]) self.gguf_writer.add_embedding_length(self.hparams["d_model"]) @@ -3353,7 +3338,7 @@ class QwenModel(TextModel): def set_gguf_parameters(self): self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) - self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"]) @@ -4384,7 +4369,7 @@ class GPT2Model(TextModel): model_arch = gguf.MODEL_ARCH.GPT2 def set_gguf_parameters(self): - self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_context_length(self.hparams["n_ctx"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) @@ -4416,8 +4401,6 @@ class Phi2Model(TextModel): model_arch = gguf.MODEL_ARCH.PHI2 def set_gguf_parameters(self): - block_count = self.find_hparam(["num_hidden_layers", "n_layer"]) - rot_pct = self.find_hparam(["partial_rotary_factor"]) n_embd = self.find_hparam(["hidden_size", "n_embd"]) n_head = self.find_hparam(["num_attention_heads", "n_head"]) @@ -4426,7 +4409,7 @@ class Phi2Model(TextModel): self.gguf_writer.add_embedding_length(n_embd) self.gguf_writer.add_feed_forward_length(4 * n_embd) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count_kv(n_head) self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_epsilon", "layer_norm_eps"])) @@ -4544,8 +4527,6 @@ class Phi3MiniModel(TextModel): special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): - block_count = self.find_hparam(["num_hidden_layers", "n_layer"]) - n_embd = self.find_hparam(["hidden_size", "n_embd"]) n_head = self.find_hparam(["num_attention_heads", "n_head"]) n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"]) @@ -4559,7 +4540,7 @@ class Phi3MiniModel(TextModel): self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds) self.gguf_writer.add_embedding_length(n_embd) self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"])) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count_kv(n_head_kv) self.gguf_writer.add_layer_norm_rms_eps(rms_eps) @@ -4679,12 +4660,11 @@ class PlamoModel(TextModel): def set_gguf_parameters(self): hparams = self.hparams - block_count = hparams["num_hidden_layers"] self.gguf_writer.add_context_length(4096) # not in config.json self.gguf_writer.add_embedding_length(hparams["hidden_size"]) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(hparams["num_attention_heads"]) self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) @@ -4807,7 +4787,6 @@ class Plamo2Model(TextModel): def set_gguf_parameters(self): hparams = self.hparams - block_count = hparams["num_hidden_layers"] self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) # Which layers are Mamba layers @@ -4819,10 +4798,10 @@ class Plamo2Model(TextModel): num_attention_heads = [] if mamba_enabled: - for i in range(block_count): - if block_count <= (mamba_step // 2): + for i in range(self.block_count): + if self.block_count <= (mamba_step // 2): # use attention in last layer - is_mamba = (i != block_count - 1) + is_mamba = (i != self.block_count - 1) else: is_mamba = (i % mamba_step) != (mamba_step // 2) if is_mamba: @@ -4840,7 +4819,7 @@ class Plamo2Model(TextModel): self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096)) self.gguf_writer.add_key_length(hparams.get("hidden_size_per_head", 128)) self.gguf_writer.add_value_length(hparams.get("hidden_size_per_head", 128)) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06)) self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000)) @@ -4897,12 +4876,10 @@ class CodeShellModel(TextModel): model_arch = gguf.MODEL_ARCH.CODESHELL def set_gguf_parameters(self): - block_count = self.hparams["n_layer"] - self.gguf_writer.add_context_length(self.hparams["n_positions"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(self.hparams["n_head"]) self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"]) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) @@ -5044,7 +5021,7 @@ class InternLM2Model(TextModel): def set_gguf_parameters(self): self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) - self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"]) @@ -5665,11 +5642,10 @@ class GemmaModel(TextModel): def set_gguf_parameters(self): hparams = self.hparams - block_count = hparams["num_hidden_layers"] self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) self.gguf_writer.add_head_count(hparams["num_attention_heads"]) self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"]) @@ -5705,11 +5681,10 @@ class Gemma2Model(TextModel): def set_gguf_parameters(self): hparams = self.hparams - block_count = hparams["num_hidden_layers"] self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) self.gguf_writer.add_head_count(hparams["num_attention_heads"]) self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"]) @@ -5753,12 +5728,11 @@ class Gemma3Model(TextModel): def set_gguf_parameters(self): hparams = self.hparams - block_count = hparams["num_hidden_layers"] # some default values are not specified in the hparams self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072)) self.gguf_writer.add_embedding_length(hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8)) self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6)) @@ -6034,7 +6008,6 @@ class Rwkv6Model(TextModel): self._set_vocab_rwkv_world() def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] head_size = self.hparams["head_size"] hidden_size = self.hparams["hidden_size"] layer_norm_eps = self.hparams["layer_norm_epsilon"] @@ -6046,7 +6019,7 @@ class Rwkv6Model(TextModel): # RWKV isn't context limited self.gguf_writer.add_context_length(1048576) self.gguf_writer.add_embedding_length(hidden_size) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_layer_norm_eps(layer_norm_eps) self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers) self.gguf_writer.add_wkv_head_size(head_size) @@ -6110,7 +6083,6 @@ class RWKV6Qwen2Model(Rwkv6Model): self._set_vocab_gpt2() def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] num_attention_heads = self.hparams["num_attention_heads"] num_key_value_heads = self.hparams["num_key_value_heads"] hidden_size = self.hparams["hidden_size"] @@ -6123,7 +6095,7 @@ class RWKV6Qwen2Model(Rwkv6Model): # RWKV isn't context limited self.gguf_writer.add_context_length(1048576) self.gguf_writer.add_embedding_length(hidden_size) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_wkv_head_size(head_size) self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim) self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim) @@ -6164,7 +6136,6 @@ class Rwkv7Model(TextModel): return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32 def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] try: head_size = self.hparams["head_size"] layer_norm_eps = self.hparams["layer_norm_epsilon"] @@ -6189,7 +6160,7 @@ class Rwkv7Model(TextModel): # RWKV isn't context limited self.gguf_writer.add_context_length(1048576) self.gguf_writer.add_embedding_length(hidden_size) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_layer_norm_eps(layer_norm_eps) self.gguf_writer.add_wkv_head_size(head_size) self.gguf_writer.add_decay_lora_rank(lora_rank_decay) @@ -6283,7 +6254,6 @@ class ARwkv7Model(Rwkv7Model): self._set_vocab_gpt2() def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] hidden_size = self.hparams["hidden_size"] head_size = self.hparams["head_size"] rms_norm_eps = self.hparams["rms_norm_eps"] @@ -6300,7 +6270,7 @@ class ARwkv7Model(Rwkv7Model): # RWKV isn't context limited self.gguf_writer.add_context_length(1048576) self.gguf_writer.add_embedding_length(hidden_size) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) self.gguf_writer.add_wkv_head_size(head_size) self.gguf_writer.add_decay_lora_rank(lora_rank_decay) @@ -7524,7 +7494,7 @@ class T5Model(TextModel): self.gguf_writer.add_context_length(n_ctx) self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"]) - self.gguf_writer.add_block_count(self.hparams["num_layers"]) + self.gguf_writer.add_block_count(self.block_count) if (dec_n_layer := self.hparams.get("num_decoder_layers")) is not None: self.gguf_writer.add_decoder_block_count(dec_n_layer) self.gguf_writer.add_head_count(self.hparams["num_heads"]) @@ -7663,7 +7633,7 @@ class T5EncoderModel(TextModel): self.gguf_writer.add_context_length(n_ctx) self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"]) - self.gguf_writer.add_block_count(self.hparams["num_layers"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(self.hparams["num_heads"]) self.gguf_writer.add_key_length(self.hparams["d_kv"]) self.gguf_writer.add_value_length(self.hparams["d_kv"]) @@ -7726,7 +7696,7 @@ class JaisModel(TextModel): self._set_vocab_gpt2() def set_gguf_parameters(self): - self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_context_length(self.hparams["n_positions"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) self.gguf_writer.add_feed_forward_length(self.hparams["n_inner"]) @@ -8068,7 +8038,7 @@ class ChatGLMModel(TextModel): self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed)) self.gguf_writer.add_embedding_length(n_embed) self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", self.hparams.get("intermediate_size", 4 * n_embed))) - self.gguf_writer.add_block_count(self.hparams.get("num_layers", self.hparams["num_hidden_layers"])) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count_kv(n_head_kv) self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon",1e-5)) @@ -8150,7 +8120,6 @@ class ExaoneModel(TextModel): num_kv_heads = hparams.get("num_key_value_heads", num_heads) layer_norm_eps = hparams["layer_norm_epsilon"] intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim - num_layers = hparams["num_layers"] # ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0 # attention_dropout_rate = hparams["attention_dropout"] # ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0 @@ -8161,7 +8130,7 @@ class ExaoneModel(TextModel): self.gguf_writer.add_context_length(max_position_embeddings) self.gguf_writer.add_layer_norm_rms_eps(layer_norm_eps) self.gguf_writer.add_feed_forward_length(intermediate_size) - self.gguf_writer.add_block_count(num_layers) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_file_type(self.ftype) if (rope_theta := self.hparams.get("rope_theta")) is not None: diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index befe8ab9cc..57c6cd0df1 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -277,10 +277,15 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() -def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]: +def load_hparams_from_hf(hf_model_id: str) -> tuple[dict[str, Any], Path | None]: + from huggingface_hub import try_to_load_from_cache + # normally, adapter does not come with base model config, we need to load it from AutoConfig config = AutoConfig.from_pretrained(hf_model_id) - return config.to_dict() + cache_dir = try_to_load_from_cache(hf_model_id, "config.json") + cache_dir = Path(cache_dir).parent if isinstance(cache_dir, str) else None + + return config.to_dict(), cache_dir if __name__ == '__main__': @@ -325,13 +330,13 @@ if __name__ == '__main__': # load base model if base_model_id is not None: logger.info(f"Loading base model from Hugging Face: {base_model_id}") - hparams = load_hparams_from_hf(base_model_id) + hparams, dir_base_model = load_hparams_from_hf(base_model_id) elif dir_base_model is None: if "base_model_name_or_path" in lparams: model_id = lparams["base_model_name_or_path"] logger.info(f"Loading base model from Hugging Face: {model_id}") try: - hparams = load_hparams_from_hf(model_id) + hparams, dir_base_model = load_hparams_from_hf(model_id) except OSError as e: logger.error(f"Failed to load base model config: {e}") logger.error("Please try downloading the base model and add its path to --base") @@ -480,6 +485,7 @@ if __name__ == '__main__': dir_lora_model=dir_lora, lora_alpha=alpha, hparams=hparams, + remote_hf_model_id=base_model_id, ) logger.info("Exporting model...") diff --git a/docs/ops.md b/docs/ops.md index 4ada4384fc..62a921e8f7 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -17,12 +17,12 @@ Legend: | ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ | | ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | -| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | +| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | | ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | -| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | +| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | -| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | +| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | | CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | | CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | | CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ | @@ -43,9 +43,9 @@ Legend: | ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ | | EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ | | EXPM1 | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | -| FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | | FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | -| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | +| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | | GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | | GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | | GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | @@ -87,7 +87,7 @@ Legend: | ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | | ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | -| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | +| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | | RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | @@ -99,7 +99,7 @@ Legend: | SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | | SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | | SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | +| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | | SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | | SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | @@ -107,7 +107,7 @@ Legend: | SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | | SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | -| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ | +| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ | | SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | | SUM | ❌ | ✅ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | | SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | @@ -116,6 +116,6 @@ Legend: | TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ | | TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | +| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | | UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | | XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | diff --git a/docs/ops/Vulkan.csv b/docs/ops/Vulkan.csv index 290bdd1215..8073930e94 100644 --- a/docs/ops/Vulkan.csv +++ b/docs/ops/Vulkan.csv @@ -5,8 +5,8 @@ "Vulkan0","SGN","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" "Vulkan0","NEG","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","NEG","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" -"Vulkan0","STEP","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","STEP","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" +"Vulkan0","STEP","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","STEP","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","TANH","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","TANH","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","ELU","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" @@ -29,18 +29,18 @@ "Vulkan0","EXP","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","EXPM1","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" "Vulkan0","EXPM1","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","SOFTPLUS","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","SOFTPLUS","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" +"Vulkan0","SOFTPLUS","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","SOFTPLUS","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","GELU_ERF","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" -"Vulkan0","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" +"Vulkan0","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","ABS","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan" "Vulkan0","ABS","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","Vulkan" "Vulkan0","SGN","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan" @@ -89,8 +89,8 @@ "Vulkan0","SGN","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" "Vulkan0","NEG","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","NEG","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" -"Vulkan0","STEP","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","STEP","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" +"Vulkan0","STEP","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","STEP","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","TANH","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","TANH","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","ELU","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" @@ -113,18 +113,18 @@ "Vulkan0","EXP","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","EXPM1","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" "Vulkan0","EXPM1","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","SOFTPLUS","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","SOFTPLUS","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" +"Vulkan0","SOFTPLUS","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","SOFTPLUS","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","GELU_ERF","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" -"Vulkan0","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" +"Vulkan0","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","ABS","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan" "Vulkan0","ABS","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","Vulkan" "Vulkan0","SGN","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan" @@ -5654,7 +5654,7 @@ "Vulkan0","SUB","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan" "Vulkan0","MUL","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan" "Vulkan0","DIV","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan" -"Vulkan0","ADD1","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan" +"Vulkan0","ADD1","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" "Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=0.000000,inplace=0","support","1","yes","Vulkan" "Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=0","support","1","yes","Vulkan" "Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=1","support","1","yes","Vulkan" @@ -8632,10 +8632,10 @@ "Vulkan0","COS","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" "Vulkan0","CLAMP","type=f16,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","0","no","Vulkan" "Vulkan0","LEAKY_RELU","type=f16,ne_a=[10,5,4,3],negative_slope=0.100000","support","0","no","Vulkan" -"Vulkan0","FLOOR","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" +"Vulkan0","FLOOR","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan" "Vulkan0","SQR","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","SQRT","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","LOG","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" @@ -8643,10 +8643,10 @@ "Vulkan0","COS","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","CLAMP","type=f16,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","0","no","Vulkan" "Vulkan0","LEAKY_RELU","type=f16,ne_a=[7,1,5,3],negative_slope=0.100000","support","0","no","Vulkan" -"Vulkan0","FLOOR","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" +"Vulkan0","FLOOR","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","SQR","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" "Vulkan0","SQRT","type=f32,ne=[10,3,3,2]","support","1","yes","Vulkan" "Vulkan0","LOG","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" @@ -8654,10 +8654,10 @@ "Vulkan0","COS","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan" "Vulkan0","CLAMP","type=f32,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan" "Vulkan0","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","1","yes","Vulkan" -"Vulkan0","FLOOR","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan" +"Vulkan0","FLOOR","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan" "Vulkan0","SQR","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","SQRT","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","LOG","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" @@ -8665,10 +8665,10 @@ "Vulkan0","COS","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","CLAMP","type=f32,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan" "Vulkan0","LEAKY_RELU","type=f32,ne_a=[7,1,5,3],negative_slope=0.100000","support","1","yes","Vulkan" -"Vulkan0","FLOOR","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan" +"Vulkan0","FLOOR","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,1,1],n_past=5","support","1","yes","Vulkan" "Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,3,1],n_past=5","support","1","yes","Vulkan" "Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,3,2],n_past=5","support","1","yes","Vulkan" @@ -9478,7 +9478,7 @@ "Vulkan0","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","0","no","Vulkan" "Vulkan0","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","0","no","Vulkan" "Vulkan0","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","1","yes","Vulkan" -"Vulkan0","ARANGE","type=f32,start=0.000000,stop=10.000000,step=1.000000","support","0","no","Vulkan" +"Vulkan0","ARANGE","type=f32,start=0.000000,stop=10.000000,step=1.000000","support","1","yes","Vulkan" "Vulkan0","TIMESTEP_EMBEDDING","type=f32,ne_a=[2,1,1,1],dim=320,max_period=10000","support","1","yes","Vulkan" "Vulkan0","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","1","yes","Vulkan" "Vulkan0","CUMSUM","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan" @@ -9487,9 +9487,9 @@ "Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","0","no","Vulkan" "Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","0","no","Vulkan" "Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","0","no","Vulkan" -"Vulkan0","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","0","no","Vulkan" -"Vulkan0","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","0","no","Vulkan" -"Vulkan0","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","0","no","Vulkan" +"Vulkan0","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","1","yes","Vulkan" +"Vulkan0","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","1","yes","Vulkan" +"Vulkan0","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","1","yes","Vulkan" "Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","0","no","Vulkan" "Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","0","no","Vulkan" "Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","0","no","Vulkan" diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index cefa39a57c..80c693ce61 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -4,10 +4,10 @@ #include "llama.h" #include "ggml.h" +#include #include #include #include -#include /** * This the arbitrary data which will be passed to each callback. @@ -37,23 +37,23 @@ static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) { return u.f; } -static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) { +static float ggml_get_float_value(const uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) { size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; float v; if (type == GGML_TYPE_F16) { - v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]); + v = ggml_fp16_to_fp32(*(const ggml_fp16_t *) &data[i]); } else if (type == GGML_TYPE_F32) { - v = *(float *) &data[i]; + v = *(const float *) &data[i]; } else if (type == GGML_TYPE_I64) { - v = (float) *(int64_t *) &data[i]; + v = (float) *(const int64_t *) &data[i]; } else if (type == GGML_TYPE_I32) { - v = (float) *(int32_t *) &data[i]; + v = (float) *(const int32_t *) &data[i]; } else if (type == GGML_TYPE_I16) { - v = (float) *(int16_t *) &data[i]; + v = (float) *(const int16_t *) &data[i]; } else if (type == GGML_TYPE_I8) { - v = (float) *(int8_t *) &data[i]; + v = (float) *(const int8_t *) &data[i]; } else if (type == GGML_TYPE_BF16) { - v = ggml_compute_bf16_to_fp32(*(ggml_bf16_t *) &data[i]); + v = ggml_compute_bf16_to_fp32(*(const ggml_bf16_t *) &data[i]); } else { GGML_ABORT("fatal error"); } diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index b883556edf..d0cab0bcb9 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -392,9 +392,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}") if (EXTRACTED_NUMBER GREATER_EQUAL 10) - list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64) + list(APPEND ARCH_FLAGS -mcpu=power10) elseif (EXTRACTED_NUMBER EQUAL 9) - list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64) + list(APPEND ARCH_FLAGS -mcpu=power9) elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le") list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native) else() diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp index 1d5b44f9fe..55a00f008a 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -39,7 +39,7 @@ #include "kernels.h" -#define NELEMS(x) sizeof(x) / sizeof(*x) +#define NELEMS(x) (sizeof(x) / sizeof(*x)) template static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) { @@ -635,6 +635,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { }, #endif #endif + { /* Sentinel */ } }; static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = { @@ -803,6 +804,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = { /* .op_type = */ GGML_TYPE_F32, }, #endif + { /* Sentinel */ } }; ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) { @@ -810,7 +812,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) { #if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) - for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { + for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) { if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu && gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type && gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type && @@ -820,7 +822,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c } } if (!kernel) { - for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) { + for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) { if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu && gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type && gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type && @@ -830,6 +832,10 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c } } } +#else + GGML_UNUSED(gemm_gemv_kernels); + GGML_UNUSED(gemm_gemv_kernels_q8); + GGML_UNUSED(cpu_features); #endif } @@ -840,12 +846,14 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) ggml_kleidiai_kernels * kernels = nullptr; #if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) - for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { + for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) { if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) { kernels = &gemm_gemv_kernels[i]; break; } } +#else + GGML_UNUSED(features); #endif return kernels; @@ -855,12 +863,14 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features) ggml_kleidiai_kernels * kernels = nullptr; #if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) - for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) { + for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) { if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) { kernels = &gemm_gemv_kernels_q8[i]; break; } } +#else + GGML_UNUSED(features); #endif return kernels; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index b6209588db..41e89d83c2 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9696,13 +9696,12 @@ static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params for (int64_t i00 = 0; i00 < n; ++i00) { float sum = 0.0f; for (int64_t t = 0; t < i00; ++t) { - sum += A_batch[i00 * n + t] * X_batch[i01 * n + t]; + sum += A_batch[i00 * n + t] * X_batch[t * k + i01]; } const float diag = A_batch[i00 * n + i00]; GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix"); - - X_batch[i01 * n + i00] = (B_batch[i00 * k + i01] - sum) / diag; + X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag; } } } diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index 74c74d1a28..101a9c086b 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -160,18 +160,18 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { #define GGML_F32xt svfloat32_t #define GGML_F32xt_ZERO svdup_n_f32(0.0f) #define GGML_F32xt_SET1(x) svdup_n_f32(x) -#define GGML_F32xt_LOAD_IMPL(pg, a, ...) svld1_f32(pg, a) -#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__) -#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b) -#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_LOAD_IMPL(pg, a) svld1_f32(pg, a) +#define GGML_F32xt_LOAD(a) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, a) +#define GGML_F32xt_STORE_IMPL(pg, a, b) svst1_f32(pg, a, b) +#define GGML_F32xt_STORE(a, b) GGML_F32xt_STORE_IMPL(DEFAULT_PG, a, b) #define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a) -#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_FMA(a, b, c) GGML_F32xt_FMA_IMPL(DEFAULT_PG, a, b, c) #define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b) -#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_ADD(a, b) GGML_F32xt_ADD_IMPL(DEFAULT_PG, a, b) #define GGML_F32xt_MUL_IMPL(pg, a, b) svmul_f32_m(pg, a, b) -#define GGML_F32xt_MUL(...) GGML_F32xt_MUL_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_MUL(a, b) GGML_F32xt_MUL_IMPL(DEFAULT_PG, a, b) #define GGML_F32xt_REDUCE_ONE_IMPL(pg, a) svaddv(pg, a) -#define GGML_F32xt_REDUCE_ONE(...) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_REDUCE_ONE(a) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, a) #define GGML_F32xt_REDUCE_IMPL(pg, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \ { \ sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum2); \ @@ -183,7 +183,8 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum5); \ (res) = (ggml_float) GGML_F32xt_REDUCE_ONE(sum1); \ } -#define GGML_F32xt_REDUCE(...) GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_REDUCE(res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \ + GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) #define GGML_F32_VEC GGML_F32xt #define GGML_F32_VEC_ZERO GGML_F32xt_ZERO @@ -206,11 +207,11 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { #define GGML_F32Cxt_STORE(dst_ptr, src_vec) svst1_f16(DEFAULT_PG16, (__fp16 *)(dst_ptr), (src_vec)) #define GGML_F32Cxt_FMA_IMPL(pg, a, b, c) svmad_f16_x(pg, b, c, a) -#define GGML_F32Cxt_FMA(...) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, __VA_ARGS__) +#define GGML_F32Cxt_FMA(a, b, c) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, a, b, c) #define GGML_F32Cxt_ADD_IMPL(pg, a, b) svadd_f16_x(pg, a, b) -#define GGML_F32Cxt_ADD(...) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, __VA_ARGS__) +#define GGML_F32Cxt_ADD(a, b) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, a, b) #define GGML_F32Cxt_MUL_IMPL(pg, a, b) svmul_f16_x(pg, a, b) -#define GGML_F32Cxt_MUL(...) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, __VA_ARGS__) +#define GGML_F32Cxt_MUL(a, b) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, a, b) #define GGML_F32Cxt_REDUCE GGML_F16xt_REDUCE_MIXED #define GGML_F16x_VEC GGML_F32Cxt @@ -224,7 +225,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { #define GGML_F16x_VEC_REDUCE GGML_F32Cxt_REDUCE #define GGML_F16xt_REDUCE_ONE_IMPL(pg, a) svaddv_f16(pg, a) -#define GGML_F16xt_REDUCE_ONE(...) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, __VA_ARGS__) +#define GGML_F16xt_REDUCE_ONE(a) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, a) #define GGML_F16xt_REDUCE_MIXED_IMPL(pg16, res, sum1, sum2, sum3, sum4) \ { \ @@ -234,7 +235,8 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { __fp16 sum_f16 = svaddv_f16(pg16, sum1); \ (res) = (ggml_float) sum_f16; \ } -#define GGML_F16xt_REDUCE_MIXED(...) GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, __VA_ARGS__) +#define GGML_F16xt_REDUCE_MIXED(res, sum1, sum2, sum3, sum4) \ + GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, res, sum1, sum2, sum3, sum4) // F16 NEON diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index ac59f1fe85..e0e9540433 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -698,60 +698,61 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { } inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { -#if defined(GGML_SIMD) - #if defined(__ARM_FEATURE_SVE) - const int sve_register_length = svcntb() * 8; - const int ggml_f16_epr = sve_register_length / 16; - const int ggml_f16_step = 2 * ggml_f16_epr; +#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE) + const int sve_register_length = svcntb() * 8; + const int ggml_f16_epr = sve_register_length / 16; + const int ggml_f16_step = 2 * ggml_f16_epr; - GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v); - const int np = (n & ~(ggml_f16_step - 1)); - svfloat16_t ay1, ay2; + GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v); + const int np = (n & ~(ggml_f16_step - 1)); + svfloat16_t ay1, ay2; - for (int i = 0; i < np; i += ggml_f16_step) { - ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0); - ay1 = GGML_F16x_VEC_MUL(ay1, vx); - GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0); + for (int i = 0; i < np; i += ggml_f16_step) { + ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0); + ay1 = GGML_F16x_VEC_MUL(ay1, vx); + GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0); - ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1); - ay2 = GGML_F16x_VEC_MUL(ay2, vx); - GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1); + ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1); + ay2 = GGML_F16x_VEC_MUL(ay2, vx); + GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1); + } + // leftovers + // maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only + if (np < n) { + svbool_t pg = svwhilelt_b16(np, n); + svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np)); + svfloat16_t out = svmul_f16_m(pg, hy, vx); + svst1_f16(pg, (__fp16 *)(y + np), out); + } +#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh) + for (int i = 0, vl; i < n; i += vl) { + vl = __riscv_vsetvl_e16m2(n - i); + vfloat16m2_t vy = __riscv_vle16_v_f16m2((_Float16 *)&y[i], vl); + vfloat32m4_t vy32 = __riscv_vfwcvt_f_f_v_f32m4(vy, vl); + vy32 = __riscv_vfmul_vf_f32m4(vy32, v, vl); + vy = __riscv_vfncvt_f_f_w_f16m2(vy32, vl); + __riscv_vse16_v_f16m2((_Float16 *)&y[i], vy, vl); + } +#elif defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_MUL(ay[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); } - // leftovers - // maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only - if (np < n) { - svbool_t pg = svwhilelt_b16(np, n); - svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np)); - svfloat16_t out = svmul_f16_m(pg, hy, vx); - svst1_f16(pg, (__fp16 *)(y + np), out); - } - #elif defined(__riscv_v_intrinsic) - // todo: RVV impl - // scalar - for (int i = 0; i < n; ++i) { - y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); - } - #else - const int np = (n & ~(GGML_F16_STEP - 1)); + } - GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); - - GGML_F16_VEC ay[GGML_F16_ARR]; - - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_MUL(ay[j], vx); - - GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); - } - } - - // leftovers - for (int i = np; i < n; ++i) { - y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); - } - #endif + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); + } #else // scalar for (int i = 0; i < n; ++i) { diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 50612237c8..c1afde9627 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -384,7 +384,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg char * src1_ddc = (char *) src1->data; const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1); - const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && src0->ne[3] == 1; + const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && + src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0); if (src0->type == src1->type && contiguous_srcs) { GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 7d792e60cf..8fe0899bb5 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3001,6 +3001,10 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope, const ggml_tensor * view, const ggml_tensor * set_rows) { + + if (rope->op != GGML_OP_ROPE || view->op != GGML_OP_VIEW || set_rows->op != GGML_OP_SET_ROWS) { + return false; + } // ne3 not tested if (rope->src[0]->ne[3] != 1) { return false; @@ -3744,10 +3748,110 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t return ctx->description.c_str(); } +#if defined(__linux__) +// Helper function to get available memory from /proc/meminfo for UMA systems +static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_kb, long * free_swap_kb) { + FILE * meminfo_file = nullptr; + // 2KB buffer for reading /proc/meminfo since it does not report size info, should be enough + const size_t BUFFER_SIZE = 2048; + auto file_buffer = std::make_unique(BUFFER_SIZE); + size_t bytes_read = 0; + long huge_tlb_total_pages = -1; + long huge_tlb_free_pages = -1; + long huge_tlb_page_size = -1; + + if (available_memory_kb == nullptr || free_swap_kb == nullptr) { + return false; + } + + meminfo_file = fopen("/proc/meminfo", "r"); + if (meminfo_file == nullptr) { + GGML_LOG_ERROR("%s: failed to open /proc/meminfo\n", __func__); + return false; + } + + // Read file into buffer + bytes_read = fread(file_buffer.get(), 1, BUFFER_SIZE - 1, meminfo_file); + fclose(meminfo_file); + + if (bytes_read == 0) { + GGML_LOG_ERROR("%s: failed to read from /proc/meminfo\n", __func__); + return false; + } + file_buffer[bytes_read] = '\0'; + + *available_memory_kb = -1; + *free_swap_kb = -1; + + // Parse the file buffer line by line + char * line = file_buffer.get(); + char * line_next; + while (line < file_buffer.get() + bytes_read) { + // Find the end of the current line + line_next = strchr(line, '\n'); + if (line_next != nullptr) { + *line_next = '\0'; + line_next++; + } else { + line_next = file_buffer.get() + bytes_read; + } + + long value; + if (sscanf(line, "MemAvailable: %ld kB", &value) == 1) { + *available_memory_kb = value; + } else if (sscanf(line, "SwapFree: %ld kB", &value) == 1) { + *free_swap_kb = value; + } else if (sscanf(line, "HugePages_Total: %ld", &value) == 1) { + huge_tlb_total_pages = value; + } else if (sscanf(line, "HugePages_Free: %ld", &value) == 1) { + huge_tlb_free_pages = value; + } else if (sscanf(line, "Hugepagesize: %ld kB", &value) == 1) { + huge_tlb_page_size = value; + } + + line = line_next; + } + + if (huge_tlb_total_pages != 0 && huge_tlb_total_pages != -1) { + *available_memory_kb = huge_tlb_free_pages * huge_tlb_page_size; + + // Hugetlbfs pages are not swappable. + *free_swap_kb = 0; + } + + GGML_LOG_DEBUG("%s: final available_memory_kb: %ld\n", __func__, *available_memory_kb); + return true; +} +#endif // defined(__linux__) + static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ggml_cuda_set_device(ctx->device); CUDA_CHECK(cudaMemGetInfo(free, total)); + +// ref: https://github.com/ggml-org/llama.cpp/pull/17368 +#if defined(__linux__) + // Check if this is a UMA (Unified Memory Architecture) system + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, ctx->device)); + + // Check if UMA is explicitly enabled via environment variable + bool uma_env = getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr; + bool is_uma = prop.unifiedAddressing > 0 || uma_env; + + if (is_uma) { + // For UMA systems (like DGX Spark), use system memory info + long available_memory_kb = 0; + long free_swap_kb = 0; + + if (ggml_backend_cuda_get_available_uma_memory(&available_memory_kb, &free_swap_kb) && available_memory_kb > 0) { + *free = (size_t)available_memory_kb * 1024; + } else { + GGML_LOG_ERROR("%s: /proc/meminfo reading failed, using cudaMemGetInfo\n", __func__); + } + } +#endif // defined(__linux__) + } static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) { diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 366c54ebec..e46970dad4 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -11,6 +11,7 @@ #include #include #include +#include static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) { if (!t) { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 11262c1989..f83dfdaef6 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -406,8 +406,8 @@ enum shader_reduction_mode { SHADER_REDUCTION_MODE_COUNT, }; +// argsort pipelines for up to 1<<10 invocations per workgroup static constexpr uint32_t num_argsort_pipelines = 11; -static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1); static constexpr uint32_t num_topk_moe_pipelines = 10; static constexpr std::initializer_list topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, @@ -526,6 +526,7 @@ struct vk_device_struct { bool multi_add; bool shader_int64; bool buffer_device_address; + bool vulkan_memory_model; bool add_rms_fusion; uint32_t partials_binding_alignment; @@ -539,6 +540,9 @@ struct vk_device_struct { uint32_t subgroup_max_size; bool subgroup_require_full_support; + // floor(log2(maxComputeWorkGroupInvocations)) + uint32_t max_workgroup_size_log2 {}; + bool coopmat_support; bool coopmat_acc_f32_support {}; bool coopmat_acc_f16_support {}; @@ -638,6 +642,7 @@ struct vk_device_struct { vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_cpy_transpose_16, pipeline_cpy_transpose_32; vk_pipeline pipeline_set_rows_i32[GGML_TYPE_COUNT]; vk_pipeline pipeline_set_rows_i64[GGML_TYPE_COUNT]; vk_pipeline pipeline_norm_f32; @@ -664,6 +669,20 @@ struct vk_device_struct { vk_pipeline pipeline_hardsigmoid[2]; vk_pipeline pipeline_hardswish[2]; vk_pipeline pipeline_abs[2]; + vk_pipeline pipeline_softplus[2]; + vk_pipeline pipeline_step[2]; + vk_pipeline pipeline_round[2]; + vk_pipeline pipeline_ceil[2]; + vk_pipeline pipeline_floor[2]; + vk_pipeline pipeline_trunc[2]; + + vk_pipeline pipeline_add1_f16_f16; + vk_pipeline pipeline_add1_f16_f32; + vk_pipeline pipeline_add1_f32_f32; + + vk_pipeline pipeline_arange_f32; + + vk_pipeline pipeline_fill_f32; vk_pipeline pipeline_geglu[2]; vk_pipeline pipeline_reglu[2]; @@ -683,6 +702,7 @@ struct vk_device_struct { vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; vk_pipeline pipeline_argsort_f32[num_argsort_pipelines]; + vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines]; vk_pipeline pipeline_sum_rows_f32; vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_count_equal_i32; @@ -1173,8 +1193,14 @@ struct vk_op_soft_max_push_constants { struct vk_op_argsort_push_constants { uint32_t ncols; + uint32_t ncols_padded; + uint32_t ncols_padded_log2; uint32_t nrows; - int32_t order; + uint32_t order; + uint32_t outer_start; + uint32_t outer_end; + uint32_t inner_start; + uint32_t inner_end; }; struct vk_op_im2col_push_constants { @@ -2901,15 +2927,15 @@ static void ggml_vk_load_shaders(vk_device& device) { if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } \ } \ } \ @@ -3697,6 +3723,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_32, "cpy_transpose_32", cpy_transpose_32_len, cpy_transpose_32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_16, "cpy_transpose_16", cpy_transpose_16_len, cpy_transpose_16_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); + if (device->float_controls_rte_fp16) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); @@ -3826,6 +3855,12 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_UNARY(hardsigmoid) CREATE_UNARY(hardswish) CREATE_UNARY(abs) + CREATE_UNARY(softplus) + CREATE_UNARY(step) + CREATE_UNARY(round) + CREATE_UNARY(ceil) + CREATE_UNARY(floor) + CREATE_UNARY(trunc) #undef CREATE_UNARY #define CREATE_UNARY_RTE(name) \ @@ -3839,6 +3874,14 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_UNARY_RTE(exp) #undef CREATE_UNARY_RTE + ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f16, "add1_f16_f16", add1_f16_f16_len, add1_f16_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + #define CREATE_GLU(name) \ if (device->float_controls_rte_fp16) { \ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ @@ -3891,7 +3934,15 @@ static void ggml_vk_load_shaders(vk_device& device) { } for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { - ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<max_workgroup_size_log2); + if (i <= device->max_workgroup_size_log2 && + 2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) { + const uint32_t NCOLS_PADDED_LOG2 = i; + ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true); + } + const uint32_t WG_UNROLL_FACTOR = BLOCK_SIZE > 1 ? 2 : 1; + BLOCK_SIZE /= WG_UNROLL_FACTOR; + ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true); } ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); @@ -4292,6 +4343,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated; + device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations))); + std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); // Try to find a non-graphics compute queue and transfer-focused queues @@ -4431,6 +4484,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->shader_int64 = device_features2.features.shaderInt64; device->buffer_device_address = vk12_features.bufferDeviceAddress; + device->vulkan_memory_model = vk12_features.vulkanMemoryModel; if (device->subgroup_size_control) { device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize; @@ -6247,6 +6301,17 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const // Choose "contiguous copy" shader if src/dst are contiguous bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst)); + // Use optimized "transpose" shader if src dim1 is the innermost dimension. + bool transpose = dst && src->nb[1] == ggml_type_size(to) && ggml_are_same_shape(dst, src); + + if (transpose && src->type == to) { + if (ggml_type_size(to) == 4) { + return ctx->device->pipeline_cpy_transpose_32; + } else if (ggml_type_size(to) == 2) { + return ctx->device->pipeline_cpy_transpose_16; + } + } + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) { if (contig) { return ctx->device->pipeline_contig_cpy_f32_f32; @@ -8242,6 +8307,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_ABS: return ctx->device->pipeline_abs[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_SOFTPLUS: + return ctx->device->pipeline_softplus[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_STEP: + return ctx->device->pipeline_step[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_ROUND: + return ctx->device->pipeline_round[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_CEIL: + return ctx->device->pipeline_ceil[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_FLOOR: + return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_TRUNC: + return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16]; default: break; } @@ -8344,19 +8421,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; } - case GGML_OP_ARGSORT: - if (ctx->num_additional_fused_ops) { - uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); - GGML_ASSERT(idx < num_topk_moe_pipelines); - topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); - return ctx->device->pipeline_topk_moe[idx][mode]; - } - - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { - uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); - return ctx->device->pipeline_argsort_f32[idx]; - } - return nullptr; case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: @@ -8449,7 +8513,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const case GGML_OP_CONV_TRANSPOSE_2D: if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { - std::array elements; + std::array elements{}; if (op == GGML_OP_CONV_2D) elements = ggml_vk_get_conv_elements(dst); else if (op == GGML_OP_CONV_TRANSPOSE_2D) elements = ggml_vk_get_conv_transpose_2d_elements(dst); vk_conv_shapes shape; @@ -8527,6 +8591,27 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } } return nullptr; + case GGML_OP_ADD1: + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_add1_f16_f16; + } + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_add1_f16_f32; + } + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_add1_f32_f32; + } + return nullptr; + case GGML_OP_ARANGE: + if (dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_arange_f32; + } + return nullptr; + case GGML_OP_FILL: + if (dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_fill_f32; + } + return nullptr; default: return nullptr; } @@ -8748,8 +8833,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); break; case GGML_OP_ARGSORT: - elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 }; - elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + GGML_ASSERT(0); break; case GGML_OP_IM2COL: { @@ -8817,6 +8901,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_SUB: case GGML_OP_DIV: case GGML_OP_MUL: + case GGML_OP_ADD1: + case GGML_OP_ARANGE: + case GGML_OP_FILL: case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SQRT: @@ -8858,6 +8945,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } else { elements = { ne, 1, 1 }; } + + if (pipeline == ctx->device->pipeline_cpy_transpose_32 || + pipeline == ctx->device->pipeline_cpy_transpose_16) { + // 32x32 tiles + elements[0] = (uint32_t)CEIL_DIV(dst->ne[0], 32); + elements[1] = (uint32_t)CEIL_DIV(dst->ne[1], 32); + elements[2] = (uint32_t)(dst->ne[2]*dst->ne[3]); + elements[0] = std::min(elements[0], ctx->device->properties.limits.maxComputeWorkGroupCount[0]); + elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); + } } break; case GGML_OP_ADD_ID: { @@ -9423,6 +9521,63 @@ static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, cons ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst)); } +static void ggml_vk_add1(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ADD1, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }); +} + +static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { + VK_LOG_DEBUG("ggml_vk_arange(dst=" << dst << ", ne=" << ggml_nelements(dst) << ")"); + + vk_op_push_constants pc = { + (uint32_t)ggml_nelements(dst), + 1, + ggml_get_op_params_f32(dst, 0), + ggml_get_op_params_f32(dst, 2), + }; + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_ARANGE); + GGML_ASSERT(pipeline != nullptr); + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false); + + std::array elements = { (uint32_t)ggml_nelements(dst), 1, 1 }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements); +} + +static void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { + VK_LOG_DEBUG("ggml_vk_fill(dst=" << dst << ", ne=" << ggml_nelements(dst) << ")"); + + vk_op_push_constants pc = { + (uint32_t)ggml_nelements(dst), + 1, + ggml_get_op_params_f32(dst, 0), + 0.0f, + }; + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_FILL); + GGML_ASSERT(pipeline != nullptr); + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false); + + std::array elements = { (uint32_t)ggml_nelements(dst), 1, 1 }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements); +} + static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst)); } @@ -9865,16 +10020,89 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons } static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { - int32_t * op_params = (int32_t *)dst->op_params; + const uint32_t * op_params = (const uint32_t *)dst->op_params; uint32_t ncols = src0->ne[0]; uint32_t nrows = ggml_nrows(src0); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, { - ncols, - nrows, - op_params[0], - }); + uint32_t ncols_pad_log2 = (uint32_t)ceilf(log2f(float(ncols))); + uint32_t ncolsp2 = 1 << ncols_pad_log2; + + vk_op_argsort_push_constants pc { ncols, ncolsp2, ncols_pad_log2, nrows, op_params[0], 0, 0, 0, 0, }; + + // Pick the largest workgroup size <= ncolsp2 + uint32_t pipeline_idx = std::min(ncols_pad_log2, num_argsort_pipelines - 1); + + // Use the "small" argsort shader if the whole sort can be done by a single workgroup. + bool use_small = ncols_pad_log2 <= ctx->device->max_workgroup_size_log2 && + ctx->device->pipeline_argsort_f32[pipeline_idx] != nullptr; + + vk_pipeline pipeline = use_small ? ctx->device->pipeline_argsort_f32[pipeline_idx] + : ctx->device->pipeline_argsort_large_f32[pipeline_idx]; + + vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); + vk_subbuffer subbuf1 = dst_buf; + + // Reserve space for ivec2 per element, with rows padded to a power of two + if (!use_small) { + const size_t x_sz = size_t{ncolsp2} * nrows * 2 * sizeof(int); + + if (ctx->prealloc_size_x < x_sz) { + ctx->prealloc_size_x = x_sz; + ggml_vk_preallocate_buffers(ctx, subctx); + } + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + subbuf1 = { ctx->prealloc_x, 0, ctx->prealloc_x->size }; + } + + std::array elements; + + elements[0] = ncolsp2; + elements[1] = std::min((uint32_t)ggml_nrows(src0), ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + elements[2] = 1; + + // First dispatch initializes tmp_idx and does the first N passes where + // there is only communication between threads in the same workgroup. + { + vk_op_argsort_push_constants pc2 = pc; + pc2.outer_start = 0; + pc2.outer_end = std::min(ncols_pad_log2, ctx->device->max_workgroup_size_log2); + pc2.inner_start = 0; + pc2.inner_end = 100; + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements); + } + if (!use_small) { + ggml_vk_sync_buffers(ctx, subctx); + // Loop over outer/inner passes, synchronizing between each pass. + for (uint32_t outer = ctx->device->max_workgroup_size_log2; outer < ncols_pad_log2; ++outer) { + for (uint32_t inner = 0; inner < outer + 1; ++inner) { + vk_op_argsort_push_constants pc2 = pc; + pc2.outer_start = outer; + pc2.outer_end = outer + 1; + pc2.inner_start = inner; + pc2.inner_end = inner + 1; + // When the inner idx is large enough, there's only communication + // within a workgroup. So the remaining inner iterations can all + // run in the same dispatch. + if (outer - inner < pipeline_idx) { + pc2.inner_end = 100; + inner = outer; + pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx]; + } else { + // Smaller workgroup empirically seems to perform better + pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx - 2]; + } + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements); + ggml_vk_sync_buffers(ctx, subctx); + } + } + ctx->prealloc_x_need_sync = true; + } } static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { @@ -11182,6 +11410,12 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_TRUNC: break; default: return false; @@ -11223,6 +11457,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: + case GGML_OP_ADD1: + case GGML_OP_ARANGE: + case GGML_OP_FILL: case GGML_OP_CONCAT: case GGML_OP_UPSCALE: case GGML_OP_SCALE: @@ -11435,6 +11672,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_UPSCALE: ggml_vk_upscale(ctx, compute_ctx, src0, node); + break; + case GGML_OP_ADD1: + ggml_vk_add1(ctx, compute_ctx, src0, src1, node); + + break; + case GGML_OP_ARANGE: + ggml_vk_arange(ctx, compute_ctx, node); + + break; + case GGML_OP_FILL: + ggml_vk_fill(ctx, compute_ctx, node); + break; case GGML_OP_SCALE: ggml_vk_scale(ctx, compute_ctx, src0, node); @@ -11519,6 +11768,12 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_TRUNC: ggml_vk_unary(ctx, compute_ctx, src0, node); break; default: @@ -11721,6 +11976,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: + case GGML_OP_ADD1: + case GGML_OP_ARANGE: + case GGML_OP_FILL: case GGML_OP_ADD_ID: case GGML_OP_CONCAT: case GGML_OP_UPSCALE: @@ -11792,6 +12050,12 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_TRUNC: buf = tensor->buffer; break; default: @@ -13394,6 +13658,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_TRUNC: return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && @@ -13695,10 +13965,25 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_LOG: return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; case GGML_OP_ARGSORT: - return op->ne[0] <= max_argsort_cols; + { + if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { + return false; + } + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + // pipeline_argsort_large_f32 requires vulkan memory model. + if (device->vulkan_memory_model) { + return true; + } else { + return op->ne[0] <= (1 << device->max_workgroup_size_log2); + } + } case GGML_OP_UPSCALE: case GGML_OP_ACC: case GGML_OP_CONCAT: + case GGML_OP_ADD1: + case GGML_OP_ARANGE: + case GGML_OP_FILL: case GGML_OP_SCALE: case GGML_OP_PAD: case GGML_OP_ROLL: @@ -14181,6 +14466,16 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * } else if (tensor->op == GGML_OP_SCALE) { const float * params = (const float *)tensor->op_params; tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]); + } else if (tensor->op == GGML_OP_ADD1) { + tensor_clone = ggml_add1(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_ARANGE) { + const float start = ggml_get_op_params_f32(tensor, 0); + const float stop = ggml_get_op_params_f32(tensor, 1); + const float step = ggml_get_op_params_f32(tensor, 2); + tensor_clone = ggml_arange(ggml_ctx, start, stop, step); + } else if (tensor->op == GGML_OP_FILL) { + const float value = ggml_get_op_params_f32(tensor, 0); + tensor_clone = ggml_fill(ggml_ctx, tensor_clone, value); } else if (tensor->op == GGML_OP_SQR) { tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SQRT) { @@ -14294,6 +14589,24 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_UNARY_OP_ABS: tensor_clone = ggml_abs(ggml_ctx, src_clone[0]); break; + case GGML_UNARY_OP_SOFTPLUS: + tensor_clone = ggml_softplus(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_STEP: + tensor_clone = ggml_step(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_ROUND: + tensor_clone = ggml_round(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_CEIL: + tensor_clone = ggml_ceil(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_FLOOR: + tensor_clone = ggml_floor(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_TRUNC: + tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]); + break; default: std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp new file mode 100644 index 0000000000..db60725d4c --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp @@ -0,0 +1,28 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +#include "types.glsl" +#include "generic_binary_head.glsl" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset()])); + + idx += num_threads; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp b/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp new file mode 100644 index 0000000000..f4936eeada --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp @@ -0,0 +1,20 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + // p.param1 = start, p.param2 = step + float value = p.param1 + p.param2 * float(i); + data_d[i] = D_TYPE(value); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp index c4e68bc023..0fc2b9b725 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp @@ -4,28 +4,27 @@ #include "types.glsl" layout(constant_id = 0) const int BLOCK_SIZE = 1024; -layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10; +layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10; #define ASC 0 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) buffer D {int data_d[];}; +layout (binding = 2) writeonly buffer D {int data_d[];}; layout (push_constant) uniform parameter { uint ncols; + uint ncols_padded; + uint ncols_padded_log2; uint nrows; uint order; + uint outer_start; + uint outer_end; + uint inner_start; + uint inner_end; } p; -shared int dst_row[BLOCK_SIZE]; -shared A_TYPE a_sh[BLOCK_SIZE]; - -void swap(uint idx0, uint idx1) { - int tmp = dst_row[idx0]; - dst_row[idx0] = dst_row[idx1]; - dst_row[idx1] = tmp; -} +shared ivec2 dst_row[BLOCK_SIZE]; void argsort(bool needs_bounds_check, const uint row) { // bitonic sort @@ -34,11 +33,10 @@ void argsort(bool needs_bounds_check, const uint row) { const uint row_offset = row * p.ncols; // initialize indices - dst_row[col] = col; - a_sh[col] = data_a[row_offset + col]; + dst_row[col] = ivec2(col, floatBitsToInt(data_a[row_offset + col])); barrier(); - uint num_outer_loop_iters = BLOCK_SIZE_LOG2; + uint num_outer_loop_iters = NCOLS_PADDED_LOG2; [[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) { uint num_inner_loop_iters = outer_idx + 1; [[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) { @@ -47,14 +45,15 @@ void argsort(bool needs_bounds_check, const uint row) { int idx_0 = (col & k) == 0 ? col : ixj; int idx_1 = (col & k) == 0 ? ixj : col; - int sh_idx_0 = dst_row[idx_0]; - int sh_idx_1 = dst_row[idx_1]; - bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false; - bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false; + ivec2 sh_idx_0 = dst_row[idx_0]; + ivec2 sh_idx_1 = dst_row[idx_1]; + bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false; + bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false; if ((idx_0_oob || - (!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) { - swap(idx_0, idx_1); + (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y))) && (ixj > col)) { + dst_row[idx_0] = sh_idx_1; + dst_row[idx_1] = sh_idx_0; } barrier(); @@ -63,9 +62,9 @@ void argsort(bool needs_bounds_check, const uint row) { if (col < p.ncols) { if (p.order == ASC) { - data_d[row_offset + col] = dst_row[col]; + data_d[row_offset + col] = dst_row[col].x; } else { - data_d[row_offset + p.ncols - col - 1] = dst_row[col]; + data_d[row_offset + p.ncols - col - 1] = dst_row[col].x; } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp new file mode 100644 index 0000000000..920bac6bb8 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp @@ -0,0 +1,114 @@ +#version 450 +#extension GL_EXT_control_flow_attributes : enable +#extension GL_KHR_memory_scope_semantics : enable +#pragma use_vulkan_memory_model + +#include "types.glsl" + +layout(constant_id = 0) const int BLOCK_SIZE = 1024; +layout(constant_id = 1) const int WG_UNROLL_FACTOR = 2; +#define ASC 0 + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) workgroupcoherent buffer B {ivec2 tmp_idx[];}; +layout (binding = 2) workgroupcoherent buffer D {int data_d[];}; + +layout (push_constant) uniform parameter { + uint ncols; + uint ncols_padded; + uint ncols_padded_log2; + uint nrows; + uint order; + uint outer_start; + uint outer_end; + uint inner_start; + uint inner_end; +} p; + +void argsort(bool needs_bounds_check, const uint row) { + // bitonic sort + int col = int(gl_GlobalInvocationID.x); + col = (col % BLOCK_SIZE) + (col / BLOCK_SIZE) * BLOCK_SIZE * WG_UNROLL_FACTOR; + + const uint row_offset = row * p.ncols; + uint idx_offset = row * p.ncols_padded; + + bool need_barrier = false; + + // initialize indices + if (p.outer_start == 0 && p.inner_start == 0) { + [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) { + uint c = u*BLOCK_SIZE + col; + if (c < p.ncols_padded) { + ivec2 v = ivec2(c, floatBitsToInt(data_a[row_offset + c])); + tmp_idx[idx_offset + c] = v; + } + } + need_barrier = true; + } + + [[unroll]] for (uint outer_idx = p.outer_start, k = (2 << outer_idx); outer_idx < p.outer_end; k *= 2, outer_idx++) { + uint inner_end = min(p.inner_end, outer_idx + 1); + for (uint j = k >> (p.inner_start + 1), inner_idx = p.inner_start; inner_idx < inner_end; j /= 2, inner_idx++) { + if (need_barrier) { + controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease); + } + need_barrier = true; + [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) { + int c = u*BLOCK_SIZE + col; + const int ixj = int(c ^ j); + + if (ixj < c) { + continue; + } + + int idx_0 = (c & k) == 0 ? c : ixj; + int idx_1 = (c & k) == 0 ? ixj : c; + + ivec2 sh_idx_0 = tmp_idx[idx_offset + idx_0]; + ivec2 sh_idx_1 = tmp_idx[idx_offset + idx_1]; + bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false; + bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false; + + if ((idx_0_oob || + (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y)))) { + tmp_idx[idx_offset + idx_0] = sh_idx_1; + tmp_idx[idx_offset + idx_1] = sh_idx_0; + } + } + } + } + + if (p.outer_end == p.ncols_padded_log2 && + p.inner_end >= p.ncols_padded_log2 + 1) { + controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease); + [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) { + uint c = u*BLOCK_SIZE + col; + if (c < p.ncols) { + if (p.order == ASC) { + data_d[row_offset + c] = tmp_idx[idx_offset + c].x; + } else { + data_d[row_offset + p.ncols - c - 1] = tmp_idx[idx_offset + c].x; + } + } + } + } +} + +void main() { + if (p.ncols == p.ncols_padded) { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + argsort(false, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + } else { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + argsort(true, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp new file mode 100644 index 0000000000..0028d3721d --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(ceil(x)); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp new file mode 100644 index 0000000000..220ccc9111 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp @@ -0,0 +1,67 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +// workgroup does 32x32 tile, but uses 32x8 threads +#define TILE_DIM 32 +layout(local_size_x = 32, local_size_y = 8, local_size_z = 1) in; + +shared uint sh[TILE_DIM][TILE_DIM + 1]; + +void iter(uvec3 wg_id) { + const uint tile_col = wg_id.x; + const uint tile_row = wg_id.y; + + const uint tid_col = gl_LocalInvocationID.x; + const uint tid_row = gl_LocalInvocationID.y; + + const uint i2 = wg_id.z % p.ne12; + const uint i3 = wg_id.z / p.ne12; + const uint i02 = i2; + const uint i03 = i3; + + // The workgroup does TILE_DIM x TILE_DIM, but swaps the LSBs of the + // src coords to make memory accesses contiguous, dst has tid.x in i0, + // src has tid.x in i01 + + [[unroll]] for (uint y = 0; y < 4; ++y) { + const uint i00 = tile_col * TILE_DIM + tid_row + 8 * y; + const uint i01 = tile_row * TILE_DIM + tid_col; + if (i00 < p.ne00 && i01 < p.ne01 && i02 < p.ne02 && i03 < p.ne03) { + const uint src_idx = i00 * p.nb00 + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03; + sh[tid_row + 8 * y][tid_col] = uint(data_a[get_aoffset() + src_idx]); + } + } + + barrier(); + + [[unroll]] for (uint y = 0; y < 4; ++y) { + const uint i0 = tile_col * TILE_DIM + tid_col; + const uint i1 = tile_row * TILE_DIM + tid_row + 8 * y; + if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) { + const uint dst_idx = i0 * p.nb10 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + // load transposed + data_d[get_doffset() + dst_idx] = D_TYPE(sh[tid_col][tid_row + 8 * y]); + } + } +} + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +void main() { + uint z = gl_WorkGroupID.z; + uint y = gl_WorkGroupID.y; + bool need_barrier = false; + for (uint z = gl_WorkGroupID.z; z < p.ne12 * p.ne13; z += gl_NumWorkGroups.z) { + for (uint y = gl_WorkGroupID.y; y < CEIL_DIV(p.ne11, TILE_DIM); y += gl_NumWorkGroups.y) { + for (uint x = gl_WorkGroupID.x; x < CEIL_DIV(p.ne10, TILE_DIM); x += gl_NumWorkGroups.x) { + if (need_barrier) { + barrier(); + } + need_barrier = true; + iter(uvec3(x, y, z)); + } + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp b/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp new file mode 100644 index 0000000000..a56be76c61 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp @@ -0,0 +1,19 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + // p.param1 = fill value + data_d[i] = D_TYPE(p.param1); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp b/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp new file mode 100644 index 0000000000..20017eb184 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(floor(x)); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/round.comp b/ggml/src/ggml-vulkan/vulkan-shaders/round.comp new file mode 100644 index 0000000000..e6155dcbf3 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/round.comp @@ -0,0 +1,29 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + float result; + // Round halfway cases away from zero as roundf does. + if (x >= 0.0) { + result = floor(x + 0.5); + } else { + result = ceil(x - 0.5); + } + data_d[i] = D_TYPE(result); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp b/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp new file mode 100644 index 0000000000..323e3cdea4 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp @@ -0,0 +1,23 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + const float result = (x > 20.0f) ? x : log(1.0f + exp(x)); + data_d[i] = D_TYPE(result); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/step.comp b/ggml/src/ggml-vulkan/vulkan-shaders/step.comp new file mode 100644 index 0000000000..654a2124e0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/step.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(x >= 0.0f ? 1.0f : 0.0f); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp b/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp new file mode 100644 index 0000000000..cf1b76d3bb --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(trunc(x)); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 9c207f1e46..bc992068f8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -734,6 +734,9 @@ void process_shaders() { string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); + string_to_spv("cpy_transpose_16", "copy_transpose.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}}); + string_to_spv("cpy_transpose_32", "copy_transpose.comp", {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}}); + for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); @@ -843,6 +846,25 @@ void process_shaders() { string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("add1_f16_f16", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("add1_f16_f32", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("add1_f32_f32", "add1.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("arange_f32", "arange.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("fill_f32", "fill.comp", {{"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("step_f16", "step.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("step_f32", "step.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("round_f16", "round.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("round_f32", "round.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("ceil_f16", "ceil.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("ceil_f32", "ceil.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("floor_f16", "floor.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("floor_f32", "floor.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("trunc_f16", "trunc.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("trunc_f32", "trunc.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + for (auto rte : {false, true}) { std::string suffix = rte ? "_rte" : ""; string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); @@ -889,6 +911,7 @@ void process_shaders() { string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); + string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}}); string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 46173585f2..c9056b59c7 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -7b6abb2b92fcef35cb01c6ce6ada9bd85306522d +781baf2a14d9e0aaee542b2e1bb918bfc4132199 diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index bed706bb24..bdd337e952 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -6,8 +6,10 @@ #include #include +#include #include +#define MAX_REPETITION_THRESHOLD 2000 // // helpers // @@ -345,7 +347,9 @@ const char * llama_grammar_parser::parse_sequence( size_t last_sym_start = rule.size(); const char * pos = src; - auto handle_repetitions = [&](int min_times, int max_times) { + // use UINT64_MAX as the empty value because we aligned to the proper unsigned long type so -1 can't be used + // (though it's technically the same as -1 now) + auto handle_repetitions = [&](unsigned long min_times, unsigned long max_times) { if (last_sym_start == rule.size()) { throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); @@ -373,20 +377,20 @@ const char * llama_grammar_parser::parse_sequence( rule.resize(last_sym_start); } else { // Repeat the previous elements (min_times - 1) times - for (int i = 1; i < min_times; i++) { + for (unsigned long i = 1; i < min_times; i++) { rule.insert(rule.end(), prev_rule.begin(), prev_rule.end()); } } uint32_t last_rec_rule_id = 0; - auto n_opt = max_times < 0 ? 1 : max_times - min_times; + auto n_opt = max_times == UINT64_MAX ? 1 : max_times - min_times; llama_grammar_rule rec_rule(prev_rule); - for (int i = 0; i < n_opt; i++) { + for (unsigned long i = 0; i < n_opt; i++) { rec_rule.resize(prev_rule.size()); uint32_t rec_rule_id = generate_symbol_id( rule_name); - if (i > 0 || max_times < 0) { - rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); + if (i > 0 || max_times == UINT64_MAX) { + rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times == UINT64_MAX ? rec_rule_id : last_rec_rule_id}); } rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); rec_rule.push_back({LLAMA_GRETYPE_END, 0}); @@ -478,10 +482,10 @@ const char * llama_grammar_parser::parse_sequence( throw std::runtime_error(std::string("expecting an int at ") + pos); } const char * int_end = parse_int(pos); - int min_times = std::stoul(std::string(pos, int_end - pos)); + unsigned long min_times = std::stoul(std::string(pos, int_end - pos)); pos = parse_space(int_end, is_nested); - int max_times = -1; + unsigned long max_times = UINT64_MAX; if (*pos == '}') { max_times = min_times; @@ -502,6 +506,9 @@ const char * llama_grammar_parser::parse_sequence( } else { throw std::runtime_error(std::string("expecting ',' at ") + pos); } + if (min_times > MAX_REPETITION_THRESHOLD || (max_times != UINT64_MAX && max_times > MAX_REPETITION_THRESHOLD)) { + throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions")); + } handle_repetitions(min_times, max_times); } else { break; diff --git a/src/llama-impl.cpp b/src/llama-impl.cpp index 6ec709dd32..c7a1880aad 100644 --- a/src/llama-impl.cpp +++ b/src/llama-impl.cpp @@ -20,10 +20,10 @@ static llama_logger_state g_logger_state; time_meas::time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {} time_meas::~time_meas() { - if (t_start_us >= 0) { - t_acc += ggml_time_us() - t_start_us; - } + if (t_start_us >= 0) { + t_acc += ggml_time_us() - t_start_us; } +} void llama_log_set(ggml_log_callback log_callback, void * user_data) { ggml_log_set(log_callback, user_data); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index adb3f8810e..3f4a729bc3 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -472,9 +472,6 @@ static void llama_sampler_chain_reset(struct llama_sampler * smpl) { for (auto * smpl : chain->samplers) { llama_sampler_reset(smpl); } - - chain->t_sample_us = 0; - chain->n_sample = 0; } static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) { @@ -2670,8 +2667,7 @@ struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * c void llama_perf_sampler_print(const struct llama_sampler * chain) { const auto data = llama_perf_sampler(chain); - LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample); + LLAMA_LOG_INFO("%s: samplers time = %10.2f ms / %5d runs\n", __func__, data.t_sample_ms, data.n_sample); } void llama_perf_sampler_reset(struct llama_sampler * chain) { @@ -2681,5 +2677,6 @@ void llama_perf_sampler_reset(struct llama_sampler * chain) { auto * ctx = (struct llama_sampler_chain *) chain->ctx; - ctx->t_sample_us = ctx->n_sample = 0; + ctx->t_sample_us = 0; + ctx->n_sample = 0; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 267bead8c4..2bb4b12224 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2776,24 +2776,34 @@ struct test_cpy : public test_case { struct test_cont : public test_case { const ggml_type type; const std::array ne; + bool use_view_slice; std::string vars() override { - return VARS_TO_STR2(type, ne); + return VARS_TO_STR3(type, ne, use_view_slice); } test_cont(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 10, 10, 1}) - : type(type), ne(ne) {} + std::array ne = {10, 10, 10, 1}, + bool use_view_slice = false) + : type(type), ne(ne), use_view_slice(use_view_slice) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_param(src); ggml_set_name(src, "src"); - src = ggml_transpose(ctx, src); - ggml_set_name(src, "src_transposed"); - ggml_tensor * out = ggml_cont(ctx, src); + ggml_tensor * dst; + if (use_view_slice) { + dst = ggml_view_4d(ctx, src, src->ne[0], 1, src->ne[2], src->ne[3], + src->nb[1], src->nb[2], src->nb[3], src->nb[0] * (src->ne[1] - 1)); + ggml_set_name(dst, "src_view_slice"); + } else { + dst = ggml_transpose(ctx, src); + ggml_set_name(dst, "src_transposed"); + } + + ggml_tensor * out = ggml_cont(ctx, dst); ggml_set_name(out, "out"); return out; @@ -6945,16 +6955,17 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0})); - test_cases.emplace_back(new test_cont()); - test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1})); - test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 3 ,5})); - test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 3, 5 ,7})); - test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 1, 1 ,1})); - test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 1, 3 ,5})); - test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 3, 5 ,7})); - test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 1, 1 ,1})); - test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 1, 3 ,5})); - test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 3, 5 ,7})); + for (ggml_type type_dst : { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16 }) { + for (bool use_view_slice : { true, false }) { + for (std::array ne : std::initializer_list>{ {2, 1, 1, 1}, {2, 1, 3, 5}, + {2, 3, 5, 7}, {1, 4, 4, 1}, {1, 8, 17, 1}, {10, 10, 10, 1} }) { + if (use_view_slice && (type_dst == GGML_TYPE_F16 || type_dst == GGML_TYPE_BF16)) { + continue; // TODO: add after WebGPU is fixed + } + test_cases.emplace_back(new test_cont(type_dst, ne, use_view_slice)); + } + } + } auto add_test_bin_bcast = [&](ggml_type type, std::array ne, std::array nr) { for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) { @@ -7015,6 +7026,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16)); test_cases.emplace_back(new test_add1()); + test_cases.emplace_back(new test_add1(GGML_TYPE_F32, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_scale()); test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f)); test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f, true)); // inplace test @@ -7354,9 +7366,13 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_clamp (type, {7, 1, 5, 3})); test_cases.emplace_back(new test_leaky_relu(type, {7, 1, 5, 3})); test_cases.emplace_back(new test_floor (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_floor (type, { 1024, 1024, 1, 1 })); test_cases.emplace_back(new test_ceil (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_ceil (type, { 1024, 1024, 1, 1 })); test_cases.emplace_back(new test_round (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_round (type, { 1024, 1024, 1, 1 })); test_cases.emplace_back(new test_trunc (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_trunc (type, { 1024, 1024, 1, 1 })); } test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5)); @@ -7501,13 +7517,15 @@ static std::vector> make_test_cases_eval() { } for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) { - test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order)); + for (uint32_t i = 4; i <= 1024*1024; i *= 2) { + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i-1, 1, 1, 1})); + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i, 1, 1, 1})); + } test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1023, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 2, 1, 3}, order)); - test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // many backends only handle up to 1024 test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2047, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order)); @@ -7556,6 +7574,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1})); test_cases.emplace_back(new test_roll()); test_cases.emplace_back(new test_arange()); + test_cases.emplace_back(new test_arange(GGML_TYPE_F32, 0.0f, 1048576.0f, 1.0f)); test_cases.emplace_back(new test_timestep_embedding()); test_cases.emplace_back(new test_leaky_relu()); @@ -7583,6 +7602,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_fill(0.0f)); test_cases.emplace_back(new test_fill(2.0f, GGML_TYPE_F32, { 303, 207, 11, 3 })); test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F32, { 800, 600, 4, 4 })); + test_cases.emplace_back(new test_fill(3.5f, GGML_TYPE_F32, { 2048, 512, 2, 2 })); test_cases.emplace_back(new test_solve_tri()); test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 11, 11, 1, 1 }, { 5, 11, 1, 1 })); diff --git a/tools/main/main.cpp b/tools/main/main.cpp index 33e8862335..78b42267b5 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -147,11 +147,15 @@ int main(int argc, char ** argv) { return 1; } - auto * mem = llama_get_memory(ctx); - + llama_memory_t mem = llama_get_memory(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); + + // note: the time for chat template initialization is not negligible: auto chat_templates = common_chat_templates_init(model, params.chat_template); + // start measuring performance timings from here + llama_perf_context_reset(ctx); + LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads); auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 48e341dbd1..097c9440be 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/webui/.gitignore b/tools/server/webui/.gitignore index cc54bb717f..051d884b08 100644 --- a/tools/server/webui/.gitignore +++ b/tools/server/webui/.gitignore @@ -25,3 +25,4 @@ vite.config.ts.timestamp-* *storybook.log storybook-static +*.code-workspace \ No newline at end of file diff --git a/tools/server/webui/package-lock.json b/tools/server/webui/package-lock.json index a11b87ad50..4af5e86ab9 100644 --- a/tools/server/webui/package-lock.json +++ b/tools/server/webui/package-lock.json @@ -2109,9 +2109,9 @@ } }, "node_modules/@sveltejs/kit": { - "version": "2.48.4", - "resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.48.4.tgz", - "integrity": "sha512-TGFX1pZUt9qqY20Cv5NyYvy0iLWHf2jXi8s+eCGsig7jQMdwZWKUFMR6TbvFNhfDSUpc1sH/Y5EHv20g3HHA3g==", + "version": "2.48.5", + "resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.48.5.tgz", + "integrity": "sha512-/rnwfSWS3qwUSzvHynUTORF9xSJi7PCR9yXkxUOnRrNqyKmCmh3FPHH+E9BbgqxXfTevGXBqgnlh9kMb+9T5XA==", "dev": true, "license": "MIT", "dependencies": { @@ -5087,9 +5087,9 @@ "license": "MIT" }, "node_modules/js-yaml": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", - "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz", + "integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==", "dev": true, "license": "MIT", "dependencies": { diff --git a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreview.svelte b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreview.svelte new file mode 100644 index 0000000000..212b1fe890 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreview.svelte @@ -0,0 +1,273 @@ + + +
+
+ {#if isPdf} +
+ + + +
+ {/if} +
+ +
+ {#if isImage && displayPreview} +
+ {displayName} +
+ {:else if isPdf && pdfViewMode === 'pages'} + {#if pdfImagesLoading} +
+
+
+ +

Converting PDF to images...

+
+
+ {:else if pdfImagesError} +
+
+ + +

Failed to load PDF images

+ +

{pdfImagesError}

+ + +
+
+ {:else if pdfImages.length > 0} +
+ {#each pdfImages as image, index (image)} +
+

Page {index + 1}

+ + PDF Page {index + 1} +
+ {/each} +
+ {:else} +
+
+ + +

No PDF pages available

+
+
+ {/if} + {:else if (isText || (isPdf && pdfViewMode === 'text')) && displayTextContent} +
+ {displayTextContent} +
+ {:else if isAudio} +
+
+ + + {#if attachment?.type === 'audioFile'} + + {:else if uploadedFile?.preview} + + {:else} +

Audio preview not available

+ {/if} + +

+ {displayName} +

+
+
+ {:else} +
+
+ {#if IconComponent} + + {/if} + +

Preview not available for this file type

+
+
+ {/if} +
+
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreviewDialog.svelte b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreviewDialog.svelte deleted file mode 100644 index 8a3389b657..0000000000 --- a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreviewDialog.svelte +++ /dev/null @@ -1,314 +0,0 @@ - - - - - -
-
- {#if IconComponent} - - {/if} - -
- {displayName} - -
- {displayType} - - {#if displaySize} - - - {formatFileSize(displaySize)} - {/if} -
-
-
- - {#if isPdf} -
- - - -
- {/if} -
-
- -
- {#if isImage && displayPreview} -
- {displayName} -
- {:else if isPdf && pdfViewMode === 'pages'} - {#if pdfImagesLoading} -
-
-
- -

Converting PDF to images...

-
-
- {:else if pdfImagesError} -
-
- - -

Failed to load PDF images

- -

{pdfImagesError}

- - -
-
- {:else if pdfImages.length > 0} -
- {#each pdfImages as image, index (image)} -
-

Page {index + 1}

- - PDF Page {index + 1} -
- {/each} -
- {:else} -
-
- - -

No PDF pages available

-
-
- {/if} - {:else if (isText || (isPdf && pdfViewMode === 'text')) && displayTextContent} -
- {displayTextContent} -
- {:else if isAudio} -
-
- - - {#if attachment?.type === 'audioFile'} - - {:else if uploadedFile?.preview} - - {:else} -

Audio preview not available

- {/if} - -

- {displayName} -

-
-
- {:else} -
-
- {#if IconComponent} - - {/if} - -

Preview not available for this file type

-
-
- {/if} -
-
-
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentFilePreview.svelte b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentThumbnailFile.svelte similarity index 100% rename from tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentFilePreview.svelte rename to tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentThumbnailFile.svelte diff --git a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentImagePreview.svelte b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentThumbnailImage.svelte similarity index 100% rename from tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentImagePreview.svelte rename to tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentThumbnailImage.svelte diff --git a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsList.svelte b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsList.svelte index a2aea0232a..050c793316 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsList.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsList.svelte @@ -1,11 +1,10 @@ - - - - - - - All Attachments ({displayItems.length}) - - View and manage all attached files - - - -
- {#if fileItems.length > 0} -
-

Files ({fileItems.length})

-
- {#each fileItems as item (item.id)} - openPreview(item, event)} - /> - {/each} -
-
- {/if} - - {#if imageItems.length > 0} -
-

Images ({imageItems.length})

-
- {#each imageItems as item (item.id)} - {#if item.preview} - openPreview(item, event)} - /> - {/if} - {/each} -
-
- {/if} +
+
+ {#if fileItems.length > 0} +
+

Files ({fileItems.length})

+
+ {#each fileItems as item (item.id)} + openPreview(item, event)} + /> + {/each} +
- - - + {/if} + + {#if imageItems.length > 0} +
+

Images ({imageItems.length})

+
+ {#each imageItems as item (item.id)} + {#if item.preview} + openPreview(item, event)} + /> + {/if} + {/each} +
+
+ {/if} +
+
{#if previewItem} - import { Square, ArrowUp } from '@lucide/svelte'; import { Button } from '$lib/components/ui/button'; - import ChatFormActionFileAttachments from './ChatFormActionFileAttachments.svelte'; - import ChatFormActionRecord from './ChatFormActionRecord.svelte'; - import ChatFormModelSelector from './ChatFormModelSelector.svelte'; + import { + ChatFormActionFileAttachments, + ChatFormActionRecord, + ChatFormModelSelector + } from '$lib/components/app'; import { config } from '$lib/stores/settings.svelte'; import type { FileTypeCategory } from '$lib/enums/files'; diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte index e47a5a7dba..ae0dc2ed9f 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte @@ -10,6 +10,7 @@ class?: string; message: DatabaseMessage; onCopy?: (message: DatabaseMessage) => void; + onContinueAssistantMessage?: (message: DatabaseMessage) => void; onDelete?: (message: DatabaseMessage) => void; onEditWithBranching?: (message: DatabaseMessage, newContent: string) => void; onEditWithReplacement?: ( @@ -17,6 +18,7 @@ newContent: string, shouldBranch: boolean ) => void; + onEditUserMessagePreserveResponses?: (message: DatabaseMessage, newContent: string) => void; onNavigateToSibling?: (siblingId: string) => void; onRegenerateWithBranching?: (message: DatabaseMessage) => void; siblingInfo?: ChatMessageSiblingInfo | null; @@ -26,9 +28,11 @@ class: className = '', message, onCopy, + onContinueAssistantMessage, onDelete, onEditWithBranching, onEditWithReplacement, + onEditUserMessagePreserveResponses, onNavigateToSibling, onRegenerateWithBranching, siblingInfo = null @@ -133,17 +137,33 @@ onRegenerateWithBranching?.(message); } + function handleContinue() { + onContinueAssistantMessage?.(message); + } + function handleSaveEdit() { if (message.role === 'user') { + // For user messages, trim to avoid accidental whitespace onEditWithBranching?.(message, editedContent.trim()); } else { - onEditWithReplacement?.(message, editedContent.trim(), shouldBranchAfterEdit); + // For assistant messages, preserve exact content including trailing whitespace + // This is important for the Continue feature to work properly + onEditWithReplacement?.(message, editedContent, shouldBranchAfterEdit); } isEditing = false; shouldBranchAfterEdit = false; } + function handleSaveEditOnly() { + if (message.role === 'user') { + // For user messages, trim to avoid accidental whitespace + onEditUserMessagePreserveResponses?.(message, editedContent.trim()); + } + + isEditing = false; + } + function handleShowDeleteDialogChange(show: boolean) { showDeleteDialog = show; } @@ -166,6 +186,7 @@ onEditedContentChange={handleEditedContentChange} {onNavigateToSibling} onSaveEdit={handleSaveEdit} + onSaveEditOnly={handleSaveEditOnly} onShowDeleteDialogChange={handleShowDeleteDialogChange} {showDeleteDialog} {siblingInfo} @@ -181,6 +202,7 @@ messageContent={message.content} onCancelEdit={handleCancelEdit} onConfirmDelete={handleConfirmDelete} + onContinue={handleContinue} onCopy={handleCopy} onDelete={handleDelete} onEdit={handleEdit} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageActions.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageActions.svelte index c16a3105cb..ff335c328c 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageActions.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageActions.svelte @@ -1,7 +1,10 @@ - - -
- - + +
+
+ +
+ - -
-
- Settings - - -
- - -
-
- {#each settingSections as section (section.title)} - - {/each} -
-
- - +
+
+ {#each settingSections as section (section.title)} + + {/each}
+ +
- - -
-
- - - {#if currentSection.title === 'Import/Export'} - - {:else} -
- -
- {/if} -
- -
-

- Settings are saved in browser's localStorage -

-
-
-
+
- - - + +
+
+ + + {#if currentSection.title === 'Import/Export'} + + {:else} +
+ +
+ {/if} +
+ +
+

Settings are saved in browser's localStorage

+
+
+
+
+ + diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte index d17f7e4229..8834e3e3e1 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte @@ -1,5 +1,5 @@ - - - - - - - - - Select Conversations to {mode === 'export' ? 'Export' : 'Import'} - - - - {#if mode === 'export'} - Choose which conversations you want to export. Selected conversations will be downloaded - as a JSON file. - {:else} - Choose which conversations you want to import. Selected conversations will be merged - with your existing conversations. - {/if} - - - -
-
- - - - - {#if searchQuery} - - {/if} -
- -
- - {selectedIds.size} of {conversations.length} selected - {#if searchQuery} - ({filteredConversations.length} shown) - {/if} - -
- -
- - - - - - - - - - - - - {#if filteredConversations.length === 0} - - - - {:else} - {#each filteredConversations as conv (conv.id)} - toggleConversation(conv.id, e.shiftKey)} - > - - - - - - - {/each} - {/if} - -
- - Conversation NameMessages
- {#if searchQuery} - No conversations found matching "{searchQuery}" - {:else} - No conversations available - {/if} -
- { - e.preventDefault(); - e.stopPropagation(); - toggleConversation(conv.id, e.shiftKey); - }} - /> - -
- {conv.name || 'Untitled conversation'} -
-
- {messageCountMap.get(conv.id) ?? 0} -
-
-
-
- - - - - - -
-
-
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebar.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebar.svelte index 5976e5dd03..34f3da53ea 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebar.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebar.svelte @@ -2,7 +2,7 @@ import { goto } from '$app/navigation'; import { page } from '$app/state'; import { Trash2 } from '@lucide/svelte'; - import { ChatSidebarConversationItem, ConfirmationDialog } from '$lib/components/app'; + import { ChatSidebarConversationItem, DialogConfirmation } from '$lib/components/app'; import ScrollArea from '$lib/components/ui/scroll-area/scroll-area.svelte'; import * as Sidebar from '$lib/components/ui/sidebar'; import * as AlertDialog from '$lib/components/ui/alert-dialog'; @@ -158,7 +158,7 @@
- + import * as Dialog from '$lib/components/ui/dialog'; + import { ChatAttachmentPreview } from '$lib/components/app'; + import { formatFileSize } from '$lib/utils/file-preview'; + + interface Props { + open: boolean; + // Either an uploaded file or a stored attachment + uploadedFile?: ChatUploadedFile; + attachment?: DatabaseMessageExtra; + // For uploaded files + preview?: string; + name?: string; + type?: string; + size?: number; + textContent?: string; + } + + let { + open = $bindable(), + uploadedFile, + attachment, + preview, + name, + type, + size, + textContent + }: Props = $props(); + + let chatAttachmentPreviewRef: ChatAttachmentPreview | undefined = $state(); + + let displayName = $derived(uploadedFile?.name || attachment?.name || name || 'Unknown File'); + + let displayType = $derived( + uploadedFile?.type || + (attachment?.type === 'imageFile' + ? 'image' + : attachment?.type === 'textFile' + ? 'text' + : attachment?.type === 'audioFile' + ? attachment.mimeType || 'audio' + : attachment?.type === 'pdfFile' + ? 'application/pdf' + : type || 'unknown') + ); + + let displaySize = $derived(uploadedFile?.size || size); + + $effect(() => { + if (open && chatAttachmentPreviewRef) { + chatAttachmentPreviewRef.reset(); + } + }); + + + + + + {displayName} + + {displayType} + {#if displaySize} + • {formatFileSize(displaySize)} + {/if} + + + + + + diff --git a/tools/server/webui/src/lib/components/app/dialogs/DialogChatAttachmentsViewAll.svelte b/tools/server/webui/src/lib/components/app/dialogs/DialogChatAttachmentsViewAll.svelte new file mode 100644 index 0000000000..8f6ca76d42 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/dialogs/DialogChatAttachmentsViewAll.svelte @@ -0,0 +1,51 @@ + + + + + + + + + All Attachments ({totalCount}) + View and manage all attached files + + + + + + diff --git a/tools/server/webui/src/lib/components/app/dialogs/ChatErrorDialog.svelte b/tools/server/webui/src/lib/components/app/dialogs/DialogChatError.svelte similarity index 100% rename from tools/server/webui/src/lib/components/app/dialogs/ChatErrorDialog.svelte rename to tools/server/webui/src/lib/components/app/dialogs/DialogChatError.svelte diff --git a/tools/server/webui/src/lib/components/app/dialogs/DialogChatSettings.svelte b/tools/server/webui/src/lib/components/app/dialogs/DialogChatSettings.svelte new file mode 100644 index 0000000000..e9aaa1000b --- /dev/null +++ b/tools/server/webui/src/lib/components/app/dialogs/DialogChatSettings.svelte @@ -0,0 +1,37 @@ + + + + + + + diff --git a/tools/server/webui/src/lib/components/app/dialogs/ConfirmationDialog.svelte b/tools/server/webui/src/lib/components/app/dialogs/DialogConfirmation.svelte similarity index 100% rename from tools/server/webui/src/lib/components/app/dialogs/ConfirmationDialog.svelte rename to tools/server/webui/src/lib/components/app/dialogs/DialogConfirmation.svelte diff --git a/tools/server/webui/src/lib/components/app/dialogs/DialogConversationSelection.svelte b/tools/server/webui/src/lib/components/app/dialogs/DialogConversationSelection.svelte new file mode 100644 index 0000000000..1f8ea64bed --- /dev/null +++ b/tools/server/webui/src/lib/components/app/dialogs/DialogConversationSelection.svelte @@ -0,0 +1,68 @@ + + + + + + + + + + Select Conversations to {mode === 'export' ? 'Export' : 'Import'} + + + {#if mode === 'export'} + Choose which conversations you want to export. Selected conversations will be downloaded + as a JSON file. + {:else} + Choose which conversations you want to import. Selected conversations will be merged + with your existing conversations. + {/if} + + + + + + + diff --git a/tools/server/webui/src/lib/components/app/dialogs/ConversationTitleUpdateDialog.svelte b/tools/server/webui/src/lib/components/app/dialogs/DialogConversationTitleUpdate.svelte similarity index 100% rename from tools/server/webui/src/lib/components/app/dialogs/ConversationTitleUpdateDialog.svelte rename to tools/server/webui/src/lib/components/app/dialogs/DialogConversationTitleUpdate.svelte diff --git a/tools/server/webui/src/lib/components/app/dialogs/EmptyFileAlertDialog.svelte b/tools/server/webui/src/lib/components/app/dialogs/DialogEmptyFileAlert.svelte similarity index 100% rename from tools/server/webui/src/lib/components/app/dialogs/EmptyFileAlertDialog.svelte rename to tools/server/webui/src/lib/components/app/dialogs/DialogEmptyFileAlert.svelte diff --git a/tools/server/webui/src/lib/components/app/index.ts b/tools/server/webui/src/lib/components/app/index.ts index a695f99747..54bd8d5aa3 100644 --- a/tools/server/webui/src/lib/components/app/index.ts +++ b/tools/server/webui/src/lib/components/app/index.ts @@ -1,56 +1,63 @@ +// Chat + +export { default as ChatAttachmentPreview } from './chat/ChatAttachments/ChatAttachmentPreview.svelte'; +export { default as ChatAttachmentThumbnailFile } from './chat/ChatAttachments/ChatAttachmentThumbnailFile.svelte'; +export { default as ChatAttachmentThumbnailImage } from './chat/ChatAttachments/ChatAttachmentThumbnailImage.svelte'; export { default as ChatAttachmentsList } from './chat/ChatAttachments/ChatAttachmentsList.svelte'; -export { default as ChatAttachmentFilePreview } from './chat/ChatAttachments/ChatAttachmentFilePreview.svelte'; -export { default as ChatAttachmentImagePreview } from './chat/ChatAttachments/ChatAttachmentImagePreview.svelte'; -export { default as ChatAttachmentPreviewDialog } from './chat/ChatAttachments/ChatAttachmentPreviewDialog.svelte'; -export { default as ChatAttachmentsViewAllDialog } from './chat/ChatAttachments/ChatAttachmentsViewAllDialog.svelte'; +export { default as ChatAttachmentsViewAll } from './chat/ChatAttachments/ChatAttachmentsViewAll.svelte'; export { default as ChatForm } from './chat/ChatForm/ChatForm.svelte'; -export { default as ChatFormTextarea } from './chat/ChatForm/ChatFormTextarea.svelte'; -export { default as ChatFormActions } from './chat/ChatForm/ChatFormActions.svelte'; -export { default as ChatFormActionFileAttachments } from './chat/ChatForm/ChatFormActionFileAttachments.svelte'; -export { default as ChatFormActionRecord } from './chat/ChatForm/ChatFormActionRecord.svelte'; -export { default as ChatFormModelSelector } from './chat/ChatForm/ChatFormModelSelector.svelte'; -export { default as ChatFormHelperText } from './chat/ChatForm/ChatFormHelperText.svelte'; +export { default as ChatFormActionFileAttachments } from './chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte'; +export { default as ChatFormActionRecord } from './chat/ChatForm/ChatFormActions/ChatFormActionRecord.svelte'; +export { default as ChatFormActions } from './chat/ChatForm/ChatFormActions/ChatFormActions.svelte'; export { default as ChatFormFileInputInvisible } from './chat/ChatForm/ChatFormFileInputInvisible.svelte'; +export { default as ChatFormHelperText } from './chat/ChatForm/ChatFormHelperText.svelte'; +export { default as ChatFormModelSelector } from './chat/ChatForm/ChatFormModelSelector.svelte'; +export { default as ChatFormTextarea } from './chat/ChatForm/ChatFormTextarea.svelte'; export { default as ChatMessage } from './chat/ChatMessages/ChatMessage.svelte'; export { default as ChatMessages } from './chat/ChatMessages/ChatMessages.svelte'; +export { default as ChatMessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte'; export { default as ChatMessageThinkingBlock } from './chat/ChatMessages/ChatMessageThinkingBlock.svelte'; -export { default as MessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte'; -export { default as ChatProcessingInfo } from './chat/ChatProcessingInfo.svelte'; - -export { default as ChatScreenHeader } from './chat/ChatScreen/ChatScreenHeader.svelte'; -export { default as ChatScreenWarning } from './chat/ChatScreen/ChatScreenWarning.svelte'; export { default as ChatScreen } from './chat/ChatScreen/ChatScreen.svelte'; +export { default as ChatScreenHeader } from './chat/ChatScreen/ChatScreenHeader.svelte'; +export { default as ChatScreenProcessingInfo } from './chat/ChatScreen/ChatScreenProcessingInfo.svelte'; +export { default as ChatScreenWarning } from './chat/ChatScreen/ChatScreenWarning.svelte'; -export { default as ChatSettingsDialog } from './chat/ChatSettings/ChatSettingsDialog.svelte'; +export { default as ChatSettings } from './chat/ChatSettings/ChatSettings.svelte'; export { default as ChatSettingsFooter } from './chat/ChatSettings/ChatSettingsFooter.svelte'; export { default as ChatSettingsFields } from './chat/ChatSettings/ChatSettingsFields.svelte'; -export { default as ImportExportTab } from './chat/ChatSettings/ImportExportTab.svelte'; -export { default as ConversationSelectionDialog } from './chat/ChatSettings/ConversationSelectionDialog.svelte'; -export { default as ParameterSourceIndicator } from './chat/ChatSettings/ParameterSourceIndicator.svelte'; +export { default as ChatSettingsImportExportTab } from './chat/ChatSettings/ChatSettingsImportExportTab.svelte'; +export { default as ChatSettingsParameterSourceIndicator } from './chat/ChatSettings/ChatSettingsParameterSourceIndicator.svelte'; export { default as ChatSidebar } from './chat/ChatSidebar/ChatSidebar.svelte'; export { default as ChatSidebarConversationItem } from './chat/ChatSidebar/ChatSidebarConversationItem.svelte'; export { default as ChatSidebarSearch } from './chat/ChatSidebar/ChatSidebarSearch.svelte'; -export { default as ChatErrorDialog } from './dialogs/ChatErrorDialog.svelte'; -export { default as EmptyFileAlertDialog } from './dialogs/EmptyFileAlertDialog.svelte'; -export { default as ConversationTitleUpdateDialog } from './dialogs/ConversationTitleUpdateDialog.svelte'; +// Dialogs +export { default as DialogChatAttachmentPreview } from './dialogs/DialogChatAttachmentPreview.svelte'; +export { default as DialogChatAttachmentsViewAll } from './dialogs/DialogChatAttachmentsViewAll.svelte'; +export { default as DialogChatError } from './dialogs/DialogChatError.svelte'; +export { default as DialogChatSettings } from './dialogs/DialogChatSettings.svelte'; +export { default as DialogConfirmation } from './dialogs/DialogConfirmation.svelte'; +export { default as DialogConversationSelection } from './dialogs/DialogConversationSelection.svelte'; +export { default as DialogConversationTitleUpdate } from './dialogs/DialogConversationTitleUpdate.svelte'; +export { default as DialogEmptyFileAlert } from './dialogs/DialogEmptyFileAlert.svelte'; + +// Miscellanous + +export { default as ActionButton } from './misc/ActionButton.svelte'; +export { default as ActionDropdown } from './misc/ActionDropdown.svelte'; +export { default as ConversationSelection } from './misc/ConversationSelection.svelte'; export { default as KeyboardShortcutInfo } from './misc/KeyboardShortcutInfo.svelte'; - export { default as MarkdownContent } from './misc/MarkdownContent.svelte'; - export { default as RemoveButton } from './misc/RemoveButton.svelte'; +// Server + export { default as ServerStatus } from './server/ServerStatus.svelte'; export { default as ServerErrorSplash } from './server/ServerErrorSplash.svelte'; export { default as ServerLoadingSplash } from './server/ServerLoadingSplash.svelte'; export { default as ServerInfo } from './server/ServerInfo.svelte'; - -// Shared components -export { default as ActionButton } from './misc/ActionButton.svelte'; -export { default as ActionDropdown } from './misc/ActionDropdown.svelte'; -export { default as ConfirmationDialog } from './dialogs/ConfirmationDialog.svelte'; diff --git a/tools/server/webui/src/lib/components/app/misc/ConversationSelection.svelte b/tools/server/webui/src/lib/components/app/misc/ConversationSelection.svelte new file mode 100644 index 0000000000..e2095e0876 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/misc/ConversationSelection.svelte @@ -0,0 +1,205 @@ + + +
+
+ + + + + {#if searchQuery} + + {/if} +
+ +
+ + {selectedIds.size} of {conversations.length} selected + {#if searchQuery} + ({filteredConversations.length} shown) + {/if} + +
+ +
+ + + + + + + + + + + + + {#if filteredConversations.length === 0} + + + + {:else} + {#each filteredConversations as conv (conv.id)} + toggleConversation(conv.id, e.shiftKey)} + > + + + + + + + {/each} + {/if} + +
+ + Conversation NameMessages
+ {#if searchQuery} + No conversations found matching "{searchQuery}" + {:else} + No conversations available + {/if} +
+ { + e.preventDefault(); + e.stopPropagation(); + toggleConversation(conv.id, e.shiftKey); + }} + /> + +
+ {conv.name || 'Untitled conversation'} +
+
+ {messageCountMap.get(conv.id) ?? 0} +
+
+
+ +
+ + + +
+
diff --git a/tools/server/webui/src/lib/constants/settings-config.ts b/tools/server/webui/src/lib/constants/settings-config.ts index 7547832d95..c25ea23f37 100644 --- a/tools/server/webui/src/lib/constants/settings-config.ts +++ b/tools/server/webui/src/lib/constants/settings-config.ts @@ -38,7 +38,8 @@ export const SETTING_CONFIG_DEFAULT: Record = max_tokens: -1, custom: '', // custom json-stringified object // experimental features - pyInterpreterEnabled: false + pyInterpreterEnabled: false, + enableContinueGeneration: false }; export const SETTING_CONFIG_INFO: Record = { @@ -96,5 +97,7 @@ export const SETTING_CONFIG_INFO: Record = { modelSelectorEnabled: 'Enable the model selector in the chat input to choose the inference model. Sends the associated model field in API requests.', pyInterpreterEnabled: - 'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.' + 'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.', + enableContinueGeneration: + 'Enable "Continue" button for assistant messages. Currently works only with non-reasoning models.' }; diff --git a/tools/server/webui/src/lib/services/chat.ts b/tools/server/webui/src/lib/services/chat.ts index 1908d83909..aa83910b27 100644 --- a/tools/server/webui/src/lib/services/chat.ts +++ b/tools/server/webui/src/lib/services/chat.ts @@ -312,7 +312,6 @@ export class ChatService { let aggregatedContent = ''; let fullReasoningContent = ''; let aggregatedToolCalls: ApiChatCompletionToolCall[] = []; - let hasReceivedData = false; let lastTimings: ChatMessageTimings | undefined; let streamFinished = false; let modelEmitted = false; @@ -352,8 +351,6 @@ export class ChatService { return; } - hasReceivedData = true; - if (!abortSignal?.aborted) { onToolCallChunk?.(serializedToolCalls); } @@ -415,7 +412,6 @@ export class ChatService { if (content) { finalizeOpenToolCallBatch(); - hasReceivedData = true; aggregatedContent += content; if (!abortSignal?.aborted) { onChunk?.(content); @@ -424,7 +420,6 @@ export class ChatService { if (reasoningContent) { finalizeOpenToolCallBatch(); - hasReceivedData = true; fullReasoningContent += reasoningContent; if (!abortSignal?.aborted) { onReasoningChunk?.(reasoningContent); @@ -446,15 +441,6 @@ export class ChatService { if (streamFinished) { finalizeOpenToolCallBatch(); - if ( - !hasReceivedData && - aggregatedContent.length === 0 && - aggregatedToolCalls.length === 0 - ) { - const noResponseError = new Error('No response received from server. Please try again.'); - throw noResponseError; - } - const finalToolCalls = aggregatedToolCalls.length > 0 ? JSON.stringify(aggregatedToolCalls) : undefined; diff --git a/tools/server/webui/src/lib/stores/chat.svelte.ts b/tools/server/webui/src/lib/stores/chat.svelte.ts index 5b5a9d74a5..c70b9580cb 100644 --- a/tools/server/webui/src/lib/stores/chat.svelte.ts +++ b/tools/server/webui/src/lib/stores/chat.svelte.ts @@ -1486,6 +1486,10 @@ class ChatStore { timestamp: Date.now() }); + // Ensure currNode points to the edited message to maintain correct path + await DatabaseStore.updateCurrentNode(this.activeConversation.id, messageToEdit.id); + this.activeConversation.currNode = messageToEdit.id; + this.updateMessageAtIndex(messageIndex, { content: newContent, timestamp: Date.now() @@ -1499,6 +1503,69 @@ class ChatStore { } } + /** + * Edits a user message and preserves all responses below + * Updates the message content in-place without deleting or regenerating responses + * + * **Use Case**: When you want to fix a typo or rephrase a question without losing the assistant's response + * + * **Important Behavior:** + * - Does NOT create a branch (unlike editMessageWithBranching) + * - Does NOT regenerate assistant responses + * - Only updates the user message content in the database + * - Preserves the entire conversation tree below the edited message + * - Updates conversation title if this is the first user message + * + * @param messageId - The ID of the user message to edit + * @param newContent - The new content for the message + */ + async editUserMessagePreserveResponses(messageId: string, newContent: string): Promise { + if (!this.activeConversation) return; + + try { + const messageIndex = this.findMessageIndex(messageId); + if (messageIndex === -1) { + console.error('Message not found for editing'); + return; + } + + const messageToEdit = this.activeMessages[messageIndex]; + if (messageToEdit.role !== 'user') { + console.error('Only user messages can be edited with this method'); + return; + } + + // Simply update the message content in-place + await DatabaseStore.updateMessage(messageId, { + content: newContent, + timestamp: Date.now() + }); + + this.updateMessageAtIndex(messageIndex, { + content: newContent, + timestamp: Date.now() + }); + + // Check if first user message for title update + const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id); + const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); + const isFirstUserMessage = + rootMessage && messageToEdit.parent === rootMessage.id && messageToEdit.role === 'user'; + + if (isFirstUserMessage && newContent.trim()) { + await this.updateConversationTitleWithConfirmation( + this.activeConversation.id, + newContent.trim(), + this.titleUpdateConfirmationCallback + ); + } + + this.updateConversationTimestamp(); + } catch (error) { + console.error('Failed to edit user message:', error); + } + } + /** * Edits a message by creating a new branch with the edited content * @param messageId - The ID of the message to edit @@ -1696,6 +1763,200 @@ class ChatStore { } } + /** + * Continues generation for an existing assistant message + * @param messageId - The ID of the assistant message to continue + */ + async continueAssistantMessage(messageId: string): Promise { + if (!this.activeConversation || this.isLoading) return; + + try { + const messageIndex = this.findMessageIndex(messageId); + if (messageIndex === -1) { + console.error('Message not found for continuation'); + return; + } + + const messageToContinue = this.activeMessages[messageIndex]; + if (messageToContinue.role !== 'assistant') { + console.error('Only assistant messages can be continued'); + return; + } + + // Race condition protection: Check if this specific conversation is already loading + // This prevents multiple rapid clicks on "Continue" from creating concurrent operations + if (this.isConversationLoading(this.activeConversation.id)) { + console.warn('Continuation already in progress for this conversation'); + return; + } + + this.errorDialogState = null; + this.setConversationLoading(this.activeConversation.id, true); + this.clearConversationStreaming(this.activeConversation.id); + + // IMPORTANT: Fetch the latest content from the database to ensure we have + // the most up-to-date content, especially after a stopped generation + // This prevents issues where the in-memory state might be stale + const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id); + const dbMessage = allMessages.find((m) => m.id === messageId); + + if (!dbMessage) { + console.error('Message not found in database for continuation'); + this.setConversationLoading(this.activeConversation.id, false); + + return; + } + + // Use content from database as the source of truth + const originalContent = dbMessage.content; + const originalThinking = dbMessage.thinking || ''; + + // Get conversation context up to (but not including) the message to continue + const conversationContext = this.activeMessages.slice(0, messageIndex); + + const contextWithContinue = [ + ...conversationContext.map((msg) => { + if ('id' in msg && 'convId' in msg && 'timestamp' in msg) { + return msg as DatabaseMessage & { extra?: DatabaseMessageExtra[] }; + } + return msg as ApiChatMessageData; + }), + { + role: 'assistant' as const, + content: originalContent + } + ]; + + let appendedContent = ''; + let appendedThinking = ''; + let hasReceivedContent = false; + + await chatService.sendMessage( + contextWithContinue, + { + ...this.getApiOptions(), + + onChunk: (chunk: string) => { + hasReceivedContent = true; + appendedContent += chunk; + // Preserve originalContent exactly as-is, including any trailing whitespace + // The concatenation naturally preserves any whitespace at the end of originalContent + const fullContent = originalContent + appendedContent; + + this.setConversationStreaming( + messageToContinue.convId, + fullContent, + messageToContinue.id + ); + + this.updateMessageAtIndex(messageIndex, { + content: fullContent + }); + }, + + onReasoningChunk: (reasoningChunk: string) => { + hasReceivedContent = true; + appendedThinking += reasoningChunk; + + const fullThinking = originalThinking + appendedThinking; + + this.updateMessageAtIndex(messageIndex, { + thinking: fullThinking + }); + }, + + onComplete: async ( + finalContent?: string, + reasoningContent?: string, + timings?: ChatMessageTimings + ) => { + const fullContent = originalContent + (finalContent || appendedContent); + const fullThinking = originalThinking + (reasoningContent || appendedThinking); + + const updateData: { + content: string; + thinking: string; + timestamp: number; + timings?: ChatMessageTimings; + } = { + content: fullContent, + thinking: fullThinking, + timestamp: Date.now(), + timings: timings + }; + + await DatabaseStore.updateMessage(messageToContinue.id, updateData); + + this.updateMessageAtIndex(messageIndex, updateData); + + this.updateConversationTimestamp(); + + this.setConversationLoading(messageToContinue.convId, false); + this.clearConversationStreaming(messageToContinue.convId); + slotsService.clearConversationState(messageToContinue.convId); + }, + + onError: async (error: Error) => { + if (this.isAbortError(error)) { + // User cancelled - save partial continuation if any content was received + if (hasReceivedContent && appendedContent) { + const partialContent = originalContent + appendedContent; + const partialThinking = originalThinking + appendedThinking; + + await DatabaseStore.updateMessage(messageToContinue.id, { + content: partialContent, + thinking: partialThinking, + timestamp: Date.now() + }); + + this.updateMessageAtIndex(messageIndex, { + content: partialContent, + thinking: partialThinking, + timestamp: Date.now() + }); + } + + this.setConversationLoading(messageToContinue.convId, false); + this.clearConversationStreaming(messageToContinue.convId); + slotsService.clearConversationState(messageToContinue.convId); + + return; + } + + // Non-abort error - rollback to original content + console.error('Continue generation error:', error); + + // Rollback: Restore original content in UI + this.updateMessageAtIndex(messageIndex, { + content: originalContent, + thinking: originalThinking + }); + + // Ensure database has original content (in case of partial writes) + await DatabaseStore.updateMessage(messageToContinue.id, { + content: originalContent, + thinking: originalThinking + }); + + this.setConversationLoading(messageToContinue.convId, false); + this.clearConversationStreaming(messageToContinue.convId); + slotsService.clearConversationState(messageToContinue.convId); + + const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server'; + this.showErrorDialog(dialogType, error.message); + } + }, + messageToContinue.convId + ); + } catch (error) { + if (this.isAbortError(error)) return; + console.error('Failed to continue message:', error); + if (this.activeConversation) { + this.setConversationLoading(this.activeConversation.id, false); + } + } + } + /** * Public methods for accessing per-conversation states */ @@ -1743,8 +2004,11 @@ export const refreshActiveMessages = chatStore.refreshActiveMessages.bind(chatSt export const navigateToSibling = chatStore.navigateToSibling.bind(chatStore); export const editAssistantMessage = chatStore.editAssistantMessage.bind(chatStore); export const editMessageWithBranching = chatStore.editMessageWithBranching.bind(chatStore); +export const editUserMessagePreserveResponses = + chatStore.editUserMessagePreserveResponses.bind(chatStore); export const regenerateMessageWithBranching = chatStore.regenerateMessageWithBranching.bind(chatStore); +export const continueAssistantMessage = chatStore.continueAssistantMessage.bind(chatStore); export const deleteMessage = chatStore.deleteMessage.bind(chatStore); export const getDeletionInfo = chatStore.getDeletionInfo.bind(chatStore); export const updateConversationName = chatStore.updateConversationName.bind(chatStore); diff --git a/tools/server/webui/src/lib/types/settings.d.ts b/tools/server/webui/src/lib/types/settings.d.ts index b85b0597d0..b47842b66e 100644 --- a/tools/server/webui/src/lib/types/settings.d.ts +++ b/tools/server/webui/src/lib/types/settings.d.ts @@ -7,6 +7,7 @@ export interface SettingsFieldConfig { key: string; label: string; type: 'input' | 'textarea' | 'checkbox' | 'select'; + isExperimental?: boolean; help?: string; options?: Array<{ value: string; label: string; icon?: typeof import('@lucide/svelte').Icon }>; } diff --git a/tools/server/webui/src/routes/+layout.svelte b/tools/server/webui/src/routes/+layout.svelte index b08bd59c15..dfe094c079 100644 --- a/tools/server/webui/src/routes/+layout.svelte +++ b/tools/server/webui/src/routes/+layout.svelte @@ -1,7 +1,7 @@ + + diff --git a/tools/server/webui/src/stories/ChatSettingsDialog.stories.svelte b/tools/server/webui/src/stories/ChatSettingsDialog.stories.svelte deleted file mode 100644 index 1e53f70708..0000000000 --- a/tools/server/webui/src/stories/ChatSettingsDialog.stories.svelte +++ /dev/null @@ -1,26 +0,0 @@ - - - - -