diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 875eb766f3..e4f05258db 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -16,7 +16,7 @@ The project differentiates between 3 levels of contributors: - If you modified a `ggml` operator or added a new one, add the corresponding test cases to `test-backend-ops` - Create separate PRs for each feature or fix. Avoid combining unrelated changes in a single PR - Consider allowing write access to your branch for faster reviews, as reviewers can push commits directly -- If your PR becomes stale, don't hesitate to ping the maintainers in the comments +- If your PR becomes stale, rebase it on top of latest `master` to get maintainers attention - Maintainers will rely on your insights and approval when making a final decision to approve and merge a PR - Consider adding yourself to [CODEOWNERS](CODEOWNERS) to indicate your availability for reviewing related PRs - Using AI to generate PRs is permitted. However, you must (1) explicitly disclose how AI was used and (2) conduct a thorough manual review before publishing the PR. Note that trivial tab autocompletions do not require disclosure. diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 4590b23921..c641989c19 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1524,6 +1524,79 @@ class TextModel(ModelBase): special_vocab._set_special_token("bos", 151643) special_vocab.add_to_gguf(self.gguf_writer) + def _set_vocab_mistral(self): + if not _mistral_common_installed: + raise ImportError(_mistral_import_error_msg) + + vocab = MistralVocab(self.dir_model) + logger.info( + f"Converting tokenizer {vocab.tokenizer_type} of size {vocab.vocab_size}." + ) + + self.gguf_writer.add_tokenizer_model(vocab.gguf_tokenizer_model) + + tokens = [] + scores = [] + toktypes = [] + + for text, score, toktype in vocab.all_tokens(): + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + assert len(tokens) == vocab.vocab_size, ( + f"token count ({len(tokens)}) != vocab size ({vocab.vocab_size})" + ) + + if vocab.tokenizer_type == MistralTokenizerType.tekken: + self.gguf_writer.add_tokenizer_pre("tekken") + self.gguf_writer.add_token_merges( + vocab.extract_vocab_merges_from_model() + ) + + logger.info( + f"Setting bos, eos, unk and pad token IDs to {vocab.bos_id}, {vocab.eos_id}, {vocab.unk_id}, {vocab.pad_id}." + ) + + self.gguf_writer.add_bos_token_id(vocab.bos_id) + self.gguf_writer.add_eos_token_id(vocab.eos_id) + self.gguf_writer.add_unk_token_id(vocab.unk_id) + self.gguf_writer.add_pad_token_id(vocab.pad_id) + + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_vocab_size(vocab.vocab_size) + + self.gguf_writer.add_add_bos_token(True) + self.gguf_writer.add_add_eos_token(False) + + local_template_file_path = self.dir_model / "chat_template.jinja" + + if self.is_mistral_format and local_template_file_path.is_file(): + # Ministral-3 and other new Mistral models come with chat templates. + # ref: https://huggingface.co/mistralai/Ministral-3-14B-Instruct-2512/tree/main + logger.info("Using an existing Mistral local chat template.") + + with open(local_template_file_path, "r", encoding="utf-8") as f: + template = f.read() + elif not self.is_mistral_format or not self.disable_mistral_community_chat_template: + template_dir = Path(__file__).parent / "models/templates/" + + # Log only for Mistral format that the official tokenization and detokenization is via `mistral-common`. + if self.is_mistral_format: + logger.info( + "Using a Mistral community chat template. These templates can be subject to errors in early days or weeks after a release. " + "Mistral recommends to use `mistral-common` to perform tokenization and detokenization." + ) + template = MistralModel.get_community_chat_template(vocab, template_dir, self.is_mistral_format) + else: + logger.info("Not using a Mistral local or community chat template. Ensure to perform the tokenization and detokenization via `mistral-common`.") + template = None + + if template is not None: + self.gguf_writer.add_chat_template(template) + class MmprojModel(ModelBase): model_type = ModelType.MMPROJ @@ -2294,79 +2367,6 @@ class LlamaModel(TextModel): if self.hf_arch == "VLlama3ForCausalLM": self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32) - def _set_vocab_mistral(self): - if not _mistral_common_installed: - raise ImportError(_mistral_import_error_msg) - - vocab = MistralVocab(self.dir_model) - logger.info( - f"Converting tokenizer {vocab.tokenizer_type} of size {vocab.vocab_size}." - ) - - self.gguf_writer.add_tokenizer_model(vocab.gguf_tokenizer_model) - - tokens = [] - scores = [] - toktypes = [] - - for text, score, toktype in vocab.all_tokens(): - tokens.append(text) - scores.append(score) - toktypes.append(toktype) - - assert len(tokens) == vocab.vocab_size, ( - f"token count ({len(tokens)}) != vocab size ({vocab.vocab_size})" - ) - - if vocab.tokenizer_type == MistralTokenizerType.tekken: - self.gguf_writer.add_tokenizer_pre("tekken") - self.gguf_writer.add_token_merges( - vocab.extract_vocab_merges_from_model() - ) - - logger.info( - f"Setting bos, eos, unk and pad token IDs to {vocab.bos_id}, {vocab.eos_id}, {vocab.unk_id}, {vocab.pad_id}." - ) - - self.gguf_writer.add_bos_token_id(vocab.bos_id) - self.gguf_writer.add_eos_token_id(vocab.eos_id) - self.gguf_writer.add_unk_token_id(vocab.unk_id) - self.gguf_writer.add_pad_token_id(vocab.pad_id) - - self.gguf_writer.add_token_list(tokens) - self.gguf_writer.add_token_scores(scores) - self.gguf_writer.add_token_types(toktypes) - self.gguf_writer.add_vocab_size(vocab.vocab_size) - - self.gguf_writer.add_add_bos_token(True) - self.gguf_writer.add_add_eos_token(False) - - local_template_file_path = self.dir_model / "chat_template.jinja" - - if self.is_mistral_format and local_template_file_path.is_file(): - # Ministral-3 and other new Mistral models come with chat templates. - # ref: https://huggingface.co/mistralai/Ministral-3-14B-Instruct-2512/tree/main - logger.info("Using an existing Mistral local chat template.") - - with open(local_template_file_path, "r", encoding="utf-8") as f: - template = f.read() - elif not self.is_mistral_format or not self.disable_mistral_community_chat_template: - template_dir = Path(__file__).parent / "models/templates/" - - # Log only for Mistral format that the official tokenization and detokenization is via `mistral-common`. - if self.is_mistral_format: - logger.info( - "Using a Mistral community chat template. These templates can be subject to errors in early days or weeks after a release. " - "Mistral recommends to use `mistral-common` to perform tokenization and detokenization." - ) - template = MistralModel.get_community_chat_template(vocab, template_dir, self.is_mistral_format) - else: - logger.info("Not using a Mistral local or community chat template. Ensure to perform the tokenization and detokenization via `mistral-common`.") - template = None - - if template is not None: - self.gguf_writer.add_chat_template(template) - def set_vocab(self): if self.is_mistral_format: return self._set_vocab_mistral() @@ -9924,17 +9924,109 @@ class MistralModel(LlamaModel): def set_gguf_parameters(self): super().set_gguf_parameters() - if "yarn" in self.hparams: - yarn_params = self.hparams["yarn"] - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) - self.gguf_writer.add_rope_scaling_factor(yarn_params["factor"]) - self.gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"]) - self.gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"]) - self.gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim - self.gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"]) + MistralModel.set_mistral_config(self.gguf_writer, self.hparams) - if "llama_4_scaling" in self.hparams: - self.gguf_writer.add_attn_temperature_scale(self.hparams["llama_4_scaling"]["beta"]) + @staticmethod + def set_mistral_config(gguf_writer: gguf.GGUFWriter, hparams: dict): + if "yarn" in hparams: + yarn_params = hparams["yarn"] + gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + gguf_writer.add_rope_scaling_factor(yarn_params["factor"]) + gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"]) + gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"]) + gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim + gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"]) + + if "llama_4_scaling" in hparams: + gguf_writer.add_attn_temperature_scale(hparams["llama_4_scaling"]["beta"]) + + +class MistralMoeModel(DeepseekV2Model): + model_arch = gguf.MODEL_ARCH.DEEPSEEK2 + model_name = "Mistral" + hf_arch = "" + is_mistral_format = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + logger.info("Using MistralMoeModel") + # remap hparams from Mistral MoE format to DeepseekV2 format + # we do this way to be able to reuse DeepseekV2Model set_gguf_parameters logic + # ref: https://github.com/vllm-project/vllm/blob/b294e28db2c5dee61bc25157664edcada8b90b31/vllm/transformers_utils/configs/mistral.py + config = self.hparams + # Mistral key -> HF key + config_mapping = { + "dim": "hidden_size", + "norm_eps": "rms_norm_eps", + "n_kv_heads": "num_key_value_heads", + "n_layers": "num_hidden_layers", + "n_heads": "num_attention_heads", + "hidden_dim": "intermediate_size", + } + # HF key -> (Mistral key, default value) + top_level_mapping_with_default = { + "model_type": ("model_type", "transformer"), + "hidden_act": ("activation", "silu"), + "tie_word_embeddings": ("tied_embeddings", False), + "max_seq_len": ("max_seq_len", config.get("max_position_embeddings", 128_000)), + "max_position_embeddings": ("max_position_embeddings", 128_000), + } + # mapping top-level keys + for key, new_key in config_mapping.items(): + if key in config: + config[new_key] = config[key] + for new_key, (key, default_value) in top_level_mapping_with_default.items(): + config[new_key] = config.get(key, default_value) + # mapping MoE-specific keys + moe_config_map = { + "route_every_n": "moe_layer_freq", + "first_k_dense_replace": "first_k_dense_replace", + "num_experts_per_tok": "num_experts_per_tok", + "num_experts": "n_routed_experts", + "expert_hidden_dim": "moe_intermediate_size", + "routed_scale": "routed_scaling_factor", + "num_shared_experts": "n_shared_experts", + "num_expert_groups": "n_group", + "num_expert_groups_per_tok": "topk_group", + } + moe = config["moe"] + for key, new_key in moe_config_map.items(): + if key in moe: + config[new_key] = moe[key] + # provide missing values + config["topk_method"] = None + config["norm_topk_prob"] = True + config["scoring_func"] = "softmax" + + def set_vocab(self): + self._set_vocab_mistral() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + MistralModel.set_mistral_config(self.gguf_writer, self.hparams) + yarn_params = self.hparams["yarn"] + self.gguf_writer.add_attn_temperature_length(yarn_params["original_max_position_embeddings"]) + self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1) # mscale_all_dim * 0.1 + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): + if name.startswith("vision_") or name.startswith("patch_merger.") or "mm_projector" in name: + return [] + + # rename certain tensors so that we can reuse DeepseekV2Model modify_tensors logic + if name.endswith(".qscale_act"): + name = name.replace(".qscale_act", ".input_scale") + if name.endswith(".qscale_weight"): + name = name.replace(".qscale_weight", ".weight_scale") + if ".wkv_b." in name: + name = name.replace(".wkv_b.", ".kv_b_proj.") + if ".experts." in name: + name = name.replace(".experts.", ".mlp.experts.") + name = name.replace(".w1.", ".gate_proj.") + name = name.replace(".w2.", ".down_proj.") + name = name.replace(".w3.", ".up_proj.") + name = "model." + name + + return super().modify_tensors(data_torch, name, bid) class PixtralModel(LlavaVisionModel): @@ -10490,6 +10582,8 @@ def main() -> None: elif args.mmproj: assert hparams.get("vision_encoder") is not None, "This model does not support multimodal" model_class = PixtralModel + elif "moe" in hparams: + model_class = MistralMoeModel else: model_class = MistralModel diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 1065ec6bb0..4cd3071f76 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -241,6 +241,12 @@ int main(int argc, char ** argv) { llama_batch_free(batch); + // this one is managed by common_init_result + //llama_free(ctx); + + llama_free(ctx2); + llama_free(ctx3); + if (result0 != result2) { fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__); return 1; diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index b0e10f5768..6bc762c069 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2196,6 +2196,15 @@ extern "C" { int p2, int p3); + // pad each dimension with values on the other side of the torus (looping around) + GGML_API struct ggml_tensor * ggml_pad_circular( + struct ggml_context * ctx, + struct ggml_tensor * a, + int p0, + int p1, + int p2, + int p3); + GGML_API struct ggml_tensor * ggml_pad_ext( struct ggml_context * ctx, struct ggml_tensor * a, @@ -2209,6 +2218,19 @@ extern "C" { int rp3 ); + // pad each dimension with values on the other side of the torus (looping around) + GGML_API struct ggml_tensor * ggml_pad_ext_circular( + struct ggml_context * ctx, + struct ggml_tensor * a, + int lp0, + int rp0, + int lp1, + int rp1, + int lp2, + int rp2, + int lp3, + int rp3); + // pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c] GGML_API struct ggml_tensor * ggml_pad_reflect_1d( struct ggml_context * ctx, diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index e96b5c403d..88e6dc45c7 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -534,8 +534,12 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, fs::path best_path; for (const auto & search_path : search_paths) { - if (!fs::exists(search_path)) { - GGML_LOG_DEBUG("%s: search path %s does not exist\n", __func__, path_str(search_path).c_str()); + if (std::error_code ec; !fs::exists(search_path, ec)) { + if (ec) { + GGML_LOG_DEBUG("%s: posix_stat(%s) failure, error-message: %s\n", __func__, path_str(search_path).c_str(), ec.message().c_str()); + } else { + GGML_LOG_DEBUG("%s: search path %s does not exist\n", __func__, path_str(search_path).c_str()); + } continue; } fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied); @@ -575,8 +579,12 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, for (const auto & search_path : search_paths) { fs::path filename = backend_filename_prefix().native() + name_path.native() + backend_filename_extension().native(); fs::path path = search_path / filename; - if (fs::exists(path)) { + if (std::error_code ec; fs::exists(path, ec)) { return get_reg().load_backend(path, silent); + } else { + if (ec) { + GGML_LOG_DEBUG("%s: posix_stat(%s) failure, error-message: %s\n", __func__, path_str(path).c_str(), ec.message().c_str()); + } } } return nullptr; diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 544c1e2a50..23dc0433c1 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2551,6 +2551,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_OP_ACC: case GGML_OP_GROUP_NORM: case GGML_OP_PAD: + // TODO: add circular padding support for cann, see https://github.com/ggml-org/llama.cpp/pull/16985 + return ggml_get_op_params_i32(op, 8) == 0; case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index ac16b3681b..3032783971 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -6554,8 +6554,13 @@ static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params ggml_compute_forward_mul_mat(params, &dst); } +static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) { + return (coord + size) % size; // adding size avoids negative number weirdness +} + // ggml_compute_forward_conv_2d + static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params, const ggml_tensor * kernel, // [KW, KH, IC, OC] const ggml_tensor * src, // [W, H, C, N] @@ -7591,6 +7596,7 @@ void ggml_compute_forward_upscale( // ggml_compute_forward_pad +template static void ggml_compute_forward_pad_f32( const ggml_compute_params * params, ggml_tensor * dst) { @@ -7615,23 +7621,40 @@ static void ggml_compute_forward_pad_f32( const int32_t lp3 = ggml_get_op_params_i32(dst, 6); const int32_t rp3 = ggml_get_op_params_i32(dst, 7); - // TODO: optimize for (int64_t i2 = 0; i2 < ne2; ++i2) { for (int64_t i1 = ith; i1 < ne1; i1 += nth) { for (int64_t i0 = 0; i0 < ne0; ++i0) { for (int64_t i3 = 0; i3 < ne3; ++i3) { - const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; - if ((i0 >= lp0 && i0 < ne0 - rp0) \ - && (i1 >= lp1 && i1 < ne1 - rp1) \ - && (i2 >= lp2 && i2 < ne2 - rp2) \ - && (i3 >= lp3 && i3 < ne3 - rp3)) { - const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00; + // circular means wrap around on a torus, so x and y loop around + if constexpr (circular_t) { + const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; + const int64_t src_i0 = ggml_wrap_around(i0 - lp0, ne00); + const int64_t src_i1 = ggml_wrap_around(i1 - lp1, ne01); + const int64_t src_i2 = ggml_wrap_around(i2 - lp2, ne02); + const int64_t src_i3 = ggml_wrap_around(i3 - lp3, ne03); + + const int64_t src_idx = + src_i3*nb03 + + src_i2*nb02 + + src_i1*nb01 + + src_i0*nb00; + const float * src_ptr = (const float *)((char *) src0->data + src_idx); dst_ptr[dst_idx] = *src_ptr; } else { - dst_ptr[dst_idx] = 0; + const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; + if ((i0 >= lp0 && i0 < ne0 - rp0) \ + && (i1 >= lp1 && i1 < ne1 - rp1) \ + && (i2 >= lp2 && i2 < ne2 - rp2) \ + && (i3 >= lp3 && i3 < ne3 - rp3)) { + const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00; + const float * src_ptr = (const float *)((char *) src0->data + src_idx); + dst_ptr[dst_idx] = *src_ptr; + } else { + dst_ptr[dst_idx] = 0; + } } } } @@ -7639,16 +7662,20 @@ static void ggml_compute_forward_pad_f32( } } + void ggml_compute_forward_pad( const ggml_compute_params * params, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - + const bool circular = (bool) ggml_get_op_params_i32(dst, 8); switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_pad_f32(params, dst); + if (circular) { + ggml_compute_forward_pad_f32(params, dst); + } else { + ggml_compute_forward_pad_f32(params, dst); + } } break; default: { diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index be2ad1c6b6..7cf33f0ddf 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -160,9 +160,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const case GGML_TYPE_F32: return ampere_mma_available(cc); case GGML_TYPE_F16: - return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc); + return volta_mma_available(cc) || turing_mma_available(cc) || (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc)); case GGML_TYPE_BF16: - return ampere_mma_available(cc) || amd_wmma_available(cc); + return ampere_mma_available(cc) || (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc)); default: return false; } diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu index 29aef33c1a..660c192e48 100644 --- a/ggml/src/ggml-cuda/pad.cu +++ b/ggml/src/ggml-cuda/pad.cu @@ -1,9 +1,17 @@ #include "pad.cuh" +#include + +__device__ __forceinline__ int64_t wrap_around(int64_t coord, int64_t size) { + // + size ensures negatives are handled properly + return (coord + size) % size; +} + static __global__ void pad_f32(const float * src, float * dst, const int lp0, const int rp0, const int lp1, const int rp1, const int lp2, const int rp2, const int lp3, const int rp3, - const int ne0, const int ne1, const int ne2, const int ne3) { + const int ne0, const int ne1, const int ne2, const int ne3, + const bool circular) { // blockIdx.z: i3*ne2+i2 // blockIdx.y: i1 // blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE @@ -12,61 +20,84 @@ static __global__ void pad_f32(const float * src, float * dst, int i1 = blockIdx.y; int i2 = blockIdx.z % ne2; int i3 = blockIdx.z / ne2; + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { return; } - // operation - const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; - if ((i0 >= lp0 && i0 < ne0 - rp0) && - (i1 >= lp1 && i1 < ne1 - rp1) && - (i2 >= lp2 && i2 < ne2 - rp2) && - (i3 >= lp3 && i3 < ne3 - rp3)) { - const int64_t i00 = i0 - lp0; - const int64_t i01 = i1 - lp1; - const int64_t i02 = i2 - lp2; - const int64_t i03 = i3 - lp3; - const int64_t ne02 = ne2 - lp2 - rp2; - const int64_t ne01 = ne1 - lp1 - rp1; - const int64_t ne00 = ne0 - lp0 - rp0; + const int64_t dst_idx = i3 * (ne0 * ne1 * ne2) + i2 * (ne0 * ne1) + i1 * ne0 + i0; - const int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00; + if (!circular) { + if ((i0 >= lp0 && i0 < ne0 - rp0) && (i1 >= lp1 && i1 < ne1 - rp1) && (i2 >= lp2 && i2 < ne2 - rp2) && + (i3 >= lp3 && i3 < ne3 - rp3)) { + const int64_t i00 = i0 - lp0; + const int64_t i01 = i1 - lp1; + const int64_t i02 = i2 - lp2; + const int64_t i03 = i3 - lp3; + const int64_t ne02 = ne2 - lp2 - rp2; + const int64_t ne01 = ne1 - lp1 - rp1; + const int64_t ne00 = ne0 - lp0 - rp0; + + const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00; + + dst[dst_idx] = src[src_idx]; + } else { + dst[dst_idx] = 0.0f; + } + } + // circular means on a torus, so x and y wrap around + else { + const int64_t ne00 = ne0 - lp0 - rp0; + const int64_t ne01 = ne1 - lp1 - rp1; + const int64_t ne02 = ne2 - lp2 - rp2; + const int64_t ne03 = ne3 - lp3 - rp3; + + const int64_t i00 = wrap_around(i0 - lp0, ne00); + const int64_t i01 = wrap_around(i1 - lp1, ne01); + const int64_t i02 = wrap_around(i2 - lp2, ne02); + const int64_t i03 = wrap_around(i3 - lp3, ne03); + + const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00; dst[dst_idx] = src[src_idx]; - } else { - dst[dst_idx] = 0.0f; } } + static void pad_f32_cuda(const float * src, float * dst, const int lp0, const int rp0, const int lp1, const int rp1, const int lp2, const int rp2, const int lp3, const int rp3, - const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) { - int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE; - dim3 gridDim(num_blocks, ne1, ne2*ne3); - pad_f32<<>>(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3); + const int ne0, const int ne1, const int ne2, const int ne3, + const bool circular, cudaStream_t stream) { + int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE; + dim3 gridDim(num_blocks, ne1, ne2 * ne3); + pad_f32<<>>(src, dst, + lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, + ne0, ne1, ne2, ne3, circular); } void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(src0)); - const int32_t lp0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t rp0 = ((const int32_t*)(dst->op_params))[1]; - const int32_t lp1 = ((const int32_t*)(dst->op_params))[2]; - const int32_t rp1 = ((const int32_t*)(dst->op_params))[3]; - const int32_t lp2 = ((const int32_t*)(dst->op_params))[4]; - const int32_t rp2 = ((const int32_t*)(dst->op_params))[5]; - const int32_t lp3 = ((const int32_t*)(dst->op_params))[6]; - const int32_t rp3 = ((const int32_t*)(dst->op_params))[7]; + const int32_t lp0 = ((const int32_t *) (dst->op_params))[0]; + const int32_t rp0 = ((const int32_t *) (dst->op_params))[1]; + const int32_t lp1 = ((const int32_t *) (dst->op_params))[2]; + const int32_t rp1 = ((const int32_t *) (dst->op_params))[3]; + const int32_t lp2 = ((const int32_t *) (dst->op_params))[4]; + const int32_t rp2 = ((const int32_t *) (dst->op_params))[5]; + const int32_t lp3 = ((const int32_t *) (dst->op_params))[6]; + const int32_t rp3 = ((const int32_t *) (dst->op_params))[7]; + const int32_t circular = ((const int32_t *) (dst->op_params))[8]; pad_f32_cuda(src0_d, dst_d, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, - dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (bool) circular, stream); } diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index c53ec506cd..7b7d1c1233 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -574,22 +574,26 @@ ggml_metal_rsets_t ggml_metal_rsets_init(void) { // the requests stop after a certain amount of time (keep_alive_s) of inactivity dispatch_queue_t d_queue = dispatch_get_global_queue(QOS_CLASS_DEFAULT, 0); dispatch_group_async(res->d_group, d_queue, ^{ - while (!atomic_load_explicit(&res->d_stop, memory_order_relaxed)) { - if (atomic_load_explicit(&res->d_loop, memory_order_relaxed) > 0) { - [res->lock lock]; +#if defined(GGML_METAL_HAS_RESIDENCY_SETS) + if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) { + while (!atomic_load_explicit(&res->d_stop, memory_order_relaxed)) { + if (atomic_load_explicit(&res->d_loop, memory_order_relaxed) > 0) { + [res->lock lock]; - for (int i = 0; i < (int) res->data.count; ++i) { - [res->data[i] requestResidency]; + for (int i = 0; i < (int) res->data.count; ++i) { + [res->data[i] requestResidency]; + } + + atomic_fetch_sub_explicit(&res->d_loop, 1, memory_order_relaxed); + + [res->lock unlock]; } - atomic_fetch_sub_explicit(&res->d_loop, 1, memory_order_relaxed); - - [res->lock unlock]; + // half a second + usleep(500 * 1000); } - - // half a second - usleep(500 * 1000); - } + } +#endif }); return res; @@ -600,6 +604,7 @@ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets) { return; } + // note: if you hit this assert, most likely you haven't deallocated all Metal resources before exiting GGML_ASSERT([rsets->data count] == 0); atomic_store_explicit(&rsets->d_stop, true, memory_order_relaxed); @@ -787,9 +792,6 @@ ggml_metal_device_t ggml_metal_device_init(void) { dev->rsets = nil; } - - // -------------------------------------------------- - // print MTL GPU family: GGML_LOG_INFO("%s: GPU name: %s\n", __func__, dev->props.name); @@ -1035,6 +1037,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_POOL_2D: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_PAD: + // TODO: add circular padding support for metal, see https://github.com/ggml-org/llama.cpp/pull/16985 + if (ggml_get_op_params_i32(op, 8) != 0) { + return false; + } + return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) && (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); case GGML_OP_PAD_REFLECT_1D: diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 277a30d30e..0d37587f60 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -3083,6 +3083,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_REPEAT: return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded case GGML_OP_PAD: + // TODO: add circular padding support for opencl, see https://github.com/ggml-org/llama.cpp/pull/16985 + if (ggml_get_op_params_i32(op, 8) != 0) { + return false; + } return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_UPSCALE: { ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & 0xFF); diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index cf656b6a08..18a45d2d96 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -1257,7 +1257,8 @@ bool rpc_server::get_cached_file(uint64_t hash, std::vector & data) { char hash_str[17]; snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash); fs::path cache_file = fs::path(cache_dir) / hash_str; - if (!fs::exists(cache_file)) { + std::error_code ec; + if (!fs::exists(cache_file, ec)) { return false; } std::ifstream ifs(cache_file, std::ios::binary); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index a264ade0b7..7449a91609 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4613,6 +4613,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ACC: return true; case GGML_OP_PAD: + // TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985 + if (ggml_get_op_params_i32(op, 8) != 0) { + return false; + } return ggml_is_contiguous(op->src[0]); case GGML_OP_LEAKY_RELU: case GGML_OP_TIMESTEP_EMBEDDING: diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 4a6ec1be77..2787f63d0e 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1050,6 +1050,7 @@ struct vk_op_pad_push_constants { uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; uint32_t misalign_offsets; + uint32_t circular; uint32_t lp0; uint32_t rp0; uint32_t lp1; uint32_t rp1; @@ -1092,6 +1093,7 @@ static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor p.rp2 = dst->op_params[5]; p.lp3 = dst->op_params[6]; p.rp3 = dst->op_params[7]; + p.circular = dst->op_params[8]; return p; // fastdiv values and offsets are initialized later in ggml_vk_op } @@ -3537,7 +3539,7 @@ static void ggml_vk_load_shaders(vk_device& device) { SHADER_REDUCTION_MODE_SHMEM; for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) { - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[reduc], arr_dmmv_f16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); @@ -3561,7 +3563,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); @@ -3607,7 +3609,7 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT } - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {wg_size_subgroup, 1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); @@ -4033,10 +4035,16 @@ static void ggml_vk_load_shaders(vk_device& device) { for (auto &s : device->pipeline_solve_tri_f32) { const vk_solve_tri_pipeline_state &state = s.first; + + // Max number of rows to load at a time, limited by shared memory + const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((state.N + state.K) * sizeof(float)); + // Need at least K invocations, and prefer a minimum of 128 to spread out loading shared memory + const uint32_t block_size = std::max(128u, 1u << (uint32_t)ceilf(log2f(float(state.K)))); + ggml_vk_create_pipeline( device, s.second, "solve_tri_f32", solve_tri_f32_len, solve_tri_f32_data, "main", 3, - sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K }, 1, true); + sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K, batch_N, block_size }, 1, true); } #define IM2COL(bda) \ @@ -4174,9 +4182,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); } for (auto &c : compiles) { @@ -5062,7 +5070,7 @@ static void ggml_vk_print_gpu_info(size_t idx) { } } -static bool ggml_vk_instance_validation_ext_available(); +static bool ggml_vk_instance_layer_settings_available(); static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions); static bool ggml_vk_instance_debug_utils_ext_available(const std::vector & instance_extensions); static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev); @@ -5091,19 +5099,19 @@ static void ggml_vk_instance_init() { vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version }; const std::vector instance_extensions = vk::enumerateInstanceExtensionProperties(); - const bool validation_ext = ggml_vk_instance_validation_ext_available(); + const bool layer_settings = ggml_vk_instance_layer_settings_available(); #ifdef __APPLE__ const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions); #endif const bool debug_utils_ext = ggml_vk_instance_debug_utils_ext_available(instance_extensions) && getenv("GGML_VK_DEBUG_MARKERS") != nullptr; std::vector layers; - if (validation_ext) { + if (layer_settings) { layers.push_back("VK_LAYER_KHRONOS_validation"); } std::vector extensions; - if (validation_ext) { - extensions.push_back("VK_EXT_validation_features"); + if (layer_settings) { + extensions.push_back("VK_EXT_layer_settings"); } #ifdef __APPLE__ if (portability_enumeration_ext) { @@ -5113,26 +5121,24 @@ static void ggml_vk_instance_init() { if (debug_utils_ext) { extensions.push_back("VK_EXT_debug_utils"); } - vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions); + VkBool32 enable_best_practice = layer_settings; + std::vector settings = { + { + "VK_LAYER_KHRONOS_validation", + "validate_best_practices", + vk::LayerSettingTypeEXT::eBool32, + 1, + &enable_best_practice + }, + }; + vk::LayerSettingsCreateInfoEXT layer_setting_info(settings); + vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions, &layer_setting_info); #ifdef __APPLE__ if (portability_enumeration_ext) { instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR; } #endif - std::vector features_enable; - vk::ValidationFeaturesEXT validation_features; - - if (validation_ext) { - features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices }; - validation_features = { - features_enable, - {}, - }; - validation_features.setPNext(nullptr); - instance_create_info.setPNext(&validation_features); - GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n"); - } vk_instance.instance = vk::createInstance(instance_create_info); vk_instance_initialized = true; @@ -14027,10 +14033,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm const uint32_t N = op->src[0]->ne[0]; const uint32_t K = op->src[1]->ne[0]; // K dimension limited to workgroup size - if (K > 128) { + if (K > 1u << device->max_workgroup_size_log2) { return false; } - if (N * N * sizeof(float) + N * K * sizeof(float) > device->properties.limits.maxComputeSharedMemorySize) { + const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((N + K) * sizeof(float)); + + if (batch_N == 0) { return false; } return true; @@ -14227,21 +14235,21 @@ ggml_backend_reg_t ggml_backend_vk_reg() { } // Extension availability -static bool ggml_vk_instance_validation_ext_available() { +static bool ggml_vk_instance_layer_settings_available() { #ifdef GGML_VULKAN_VALIDATE // Check if validation layer provides the extension const std::string layer_name = "VK_LAYER_KHRONOS_validation"; for (const auto& layer : vk::enumerateInstanceLayerProperties()) { if (layer_name == layer.layerName.data()) { for (const auto& ext : vk::enumerateInstanceExtensionProperties(layer_name)) { - if (strcmp("VK_EXT_validation_features", ext.extensionName.data()) == 0) { + if (strcmp("VK_EXT_layer_settings", ext.extensionName.data()) == 0) { return true; } } } } - std::cerr << "ggml_vulkan: WARNING: Validation layer or layer extension VK_EXT_validation_features not found." << std::endl; + std::cerr << "ggml_vulkan: WARNING: Validation layer or layer extension VK_EXT_layer_settings not found." << std::endl; #endif return false; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp b/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp index f3c8176872..5abd2f6fc6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp @@ -8,6 +8,7 @@ layout (push_constant) uniform parameter uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; uint misalign_offsets; + uint circular; uint lp0; uint rp0; uint lp1; uint rp1; @@ -18,6 +19,10 @@ layout (push_constant) uniform parameter uint get_aoffset() { return p.misalign_offsets >> 16; } uint get_doffset() { return p.misalign_offsets & 0xFFFF; } +uint wrap_around(int coord, uint size) { + return (uint(coord + int(size))) % size; // add size to avoid issues with negative +} + layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; @@ -40,10 +45,20 @@ void main() { const uint src0_idx = (i3 - p.lp3)*p.nb03 + (i2 - p.lp2)*p.nb02 + (i1 - p.lp1)*p.nb01 + (i0 - p.lp0)*p.nb00; const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10; - const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 && - i1 >= p.lp1 && i1 < p.ne11 - p.rp1 && - i2 >= p.lp2 && i2 < p.ne12 - p.rp2 && - i3 >= p.lp3 && i3 < p.ne13 - p.rp3; + if (p.circular != 0u) { + const uint ci0 = wrap_around(int(i0) - int(p.lp0), p.ne00); + const uint ci1 = wrap_around(int(i1) - int(p.lp1), p.ne01); + const uint ci2 = wrap_around(int(i2) - int(p.lp2), p.ne02); + const uint ci3 = wrap_around(int(i3) - int(p.lp3), p.ne03); + const uint circular_src_idx = ci3*p.nb03 + ci2*p.nb02 + ci1*p.nb01 + ci0*p.nb00; + data_d[get_doffset() + dst_idx] = D_TYPE(data_a[get_aoffset() + circular_src_idx]); + } else { + const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 && + i1 >= p.lp1 && i1 < p.ne11 - p.rp1 && + i2 >= p.lp2 && i2 < p.ne12 - p.rp2 && + i3 >= p.lp3 && i3 < p.ne13 - p.rp3; + data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f); + } + - data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp b/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp index 253a9e7efe..3b65145032 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp @@ -5,8 +5,9 @@ layout (constant_id = 1) const uint N = 64; layout (constant_id = 2) const uint K = 32; +layout (constant_id = 3) const uint BATCH_N = 32; -layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 4, local_size_y = 1, local_size_z = 1) in; uint a_base, b_base, x_base; @@ -22,8 +23,8 @@ void store_x(uint r, uint c, FLOAT_TYPE v) { data_d[x_base + r * p.nb21 + c * p.nb20] = D_TYPE(v); } -shared FLOAT_TYPE shA[N * N]; -shared FLOAT_TYPE shB[N * K]; +shared FLOAT_TYPE shA[BATCH_N * N]; +shared FLOAT_TYPE shB[BATCH_N * K]; void main() { const uint batch = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; @@ -39,34 +40,42 @@ void main() { b_base = get_boffset() + i2 * p.nb12 + i3 * p.nb13; x_base = get_doffset() + i2 * p.nb22 + i3 * p.nb23; - // Load the A matrix into shA - [[unroll]] for (uint i = 0; i < N * N; i += gl_WorkGroupSize.x) { - uint idx = i + tid; - if (((N * N) % gl_WorkGroupSize.x == 0) || idx < N * N) { - shA[idx] = get_a(idx / N, idx % N); - } - } - // Load the B matrix into shB - [[unroll]] for (uint i = 0; i < N * K; i += gl_WorkGroupSize.x) { - uint idx = i + tid; - if (((N * K) % gl_WorkGroupSize.x == 0) || idx < N * K) { - shB[idx] = get_b(idx / K, idx % K); - } - } - barrier(); - FLOAT_TYPE X[N]; - // Each thread solves one column - if (tid < K) { - [[unroll]] for (int r = 0; r < N; ++r) { - FLOAT_TYPE b = shB[r * K + tid]; - // Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r] - [[unroll]] for (int c = 0; c < r; ++c) { - b -= shA[r * N + c] * X[c]; + + // Loop over batches of rows + [[unroll]] for (uint row_base = 0; row_base < N; row_base += BATCH_N) { + const uint cur_N = min(BATCH_N, N - row_base); + + // Load the A matrix batch into shA + [[unroll]] for (uint i = 0; i < cur_N * N; i += gl_WorkGroupSize.x) { + uint idx = i + tid; + if (((cur_N * N) % gl_WorkGroupSize.x == 0) || idx < cur_N * N) { + shA[idx] = get_a(row_base + idx / N, idx % N); } - FLOAT_TYPE x = b / shA[r * N + r]; - X[r] = x; - store_x(r, tid, x); } + // Load the B matrix batch into shB + [[unroll]] for (uint i = 0; i < cur_N * K; i += gl_WorkGroupSize.x) { + uint idx = i + tid; + if (((cur_N * K) % gl_WorkGroupSize.x == 0) || idx < cur_N * K) { + shB[idx] = get_b(row_base + idx / K, idx % K); + } + } + barrier(); + + // Each thread solves one column + if (tid < K) { + [[unroll]] for (uint row_offset = 0; row_offset < cur_N; ++row_offset) { + uint r = row_base + row_offset; + FLOAT_TYPE b = shB[row_offset * K + tid]; + // Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r] + [[unroll]] for (int c = 0; c < r; ++c) { + b -= shA[row_offset * N + c] * X[c]; + } + FLOAT_TYPE x = b / shA[row_offset * N + r]; + X[r] = x; + store_x(r, tid, x); + } + } + barrier(); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp index bc1c278bf4..5cd0785d20 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp @@ -75,7 +75,7 @@ void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit } void main() { - const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y; + const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID; if (row >= n_rows) { return; } @@ -83,17 +83,18 @@ void main() { const uint logits_offset = n_experts * row; const uint weights_offset = n_expert_used * row; const uint ids_offset = n_experts * row; + const uint lane = gl_SubgroupInvocationID; float wt[experts_per_thread]; [[unroll]] for (uint i = 0; i < n_experts; i += WARP_SIZE) { - const uint expert = i + gl_LocalInvocationID.x; + const uint expert = i + lane; wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY; } if (!late_softmax) { - softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false); + softmax_warp_inplace(wt, n_experts, lane, false); } // at this point, each thread holds a portion of softmax, @@ -111,11 +112,11 @@ void main() { for (int k = 0; k < n_expert_used; k++) { float max_val = wt[0]; - uint max_expert = gl_LocalInvocationID.x; + uint max_expert = lane; [[unroll]] for (int i = 1; i < experts_per_thread; i++) { - const uint expert = gl_LocalInvocationID.x + i * WARP_SIZE; + const uint expert = lane + i * WARP_SIZE; if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { max_val = wt[i]; max_expert = expert; @@ -132,11 +133,11 @@ void main() { } } - if ((k & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) { + if ((k & (WARP_SIZE - 1)) == lane) { output_weights[k / WARP_SIZE] = max_val; } - if ((max_expert & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) { + if ((max_expert & (WARP_SIZE - 1)) == lane) { wt[max_expert / WARP_SIZE] = -INFINITY; ids[ids_offset + k] = max_expert; @@ -158,12 +159,12 @@ void main() { } if (late_softmax) { - softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true); + softmax_warp_inplace(output_weights, n_expert_used, lane, true); } [[unroll]] for (uint i = 0; i < experts_per_thread; ++i) { - uint idx = i * WARP_SIZE + gl_LocalInvocationID.x; + uint idx = i * WARP_SIZE + lane; if (idx < n_expert_used) { weights[weights_offset + idx] = output_weights[i]; } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 17cf4d84bb..9f5cdc1398 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4947,6 +4947,18 @@ struct ggml_tensor * ggml_pad( return ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3); } +// ggml_pad_circular + +struct ggml_tensor * ggml_pad_circular( + struct ggml_context * ctx, + struct ggml_tensor * a, + int p0, + int p1, + int p2, + int p3) { + return ggml_pad_ext_circular(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3); +} + struct ggml_tensor * ggml_pad_ext( struct ggml_context * ctx, struct ggml_tensor * a, @@ -4973,6 +4985,7 @@ struct ggml_tensor * ggml_pad_ext( ggml_set_op_params_i32(result, 5, rp2); ggml_set_op_params_i32(result, 6, lp3); ggml_set_op_params_i32(result, 7, rp3); + ggml_set_op_params_i32(result, 8, 0); // not circular by default result->op = GGML_OP_PAD; @@ -4981,6 +4994,25 @@ struct ggml_tensor * ggml_pad_ext( return result; } +// ggml_pad_ext_circular + +struct ggml_tensor * ggml_pad_ext_circular( + struct ggml_context * ctx, + struct ggml_tensor * a, + int lp0, + int rp0, + int lp1, + int rp1, + int lp2, + int rp2, + int lp3, + int rp3 + ) { + struct ggml_tensor * result = ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); + ggml_set_op_params_i32(result, 8, 1); // circular + return result; +} + // ggml_pad_reflect_1d struct ggml_tensor * ggml_pad_reflect_1d( diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index a7b0973979..d9c87da194 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -376,6 +376,7 @@ class TensorNameMap: "model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker "model.layers.{bid}.feed_forward.gate", # lfm2moe "model.layers.{bid}.mlp.router.gate", # afmoe + "layers.{bid}.gate", # mistral-large ), MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( @@ -450,6 +451,7 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4 "model.layers.{bid}.feed_forward.down_proj", "model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan + "layers.{bid}.shared_experts.w3", # mistral-large ), MODEL_TENSOR.FFN_UP_CHEXP: ( @@ -496,6 +498,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2 "model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4 "model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan + "layers.{bid}.shared_experts.w1", # mistral-large ), MODEL_TENSOR.FFN_GATE_CHEXP: ( @@ -557,6 +560,7 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4 "model.layers.{bid}.shared_mlp.output_linear", # granitemoe "model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan + "layers.{bid}.shared_experts.w2", # mistral-large ), MODEL_TENSOR.FFN_DOWN_CHEXP: ( @@ -924,14 +928,17 @@ class TensorNameMap: MODEL_TENSOR.ATTN_Q_A: ( "model.layers.{bid}.self_attn.q_a_proj", # deepseek2 + "layers.{bid}.attention.wq_a", # mistral-large ), MODEL_TENSOR.ATTN_Q_B: ( "model.layers.{bid}.self_attn.q_b_proj", # deepseek2 + "layers.{bid}.attention.wq_b", # mistral-large ), MODEL_TENSOR.ATTN_KV_A_MQA: ( "model.layers.{bid}.self_attn.kv_a_proj_with_mqa", # deepseek2 + "layers.{bid}.attention.wkv_a_with_mqa", # mistral-large ), MODEL_TENSOR.ATTN_KV_B: ( @@ -940,18 +947,22 @@ class TensorNameMap: MODEL_TENSOR.ATTN_K_B: ( "model.layers.{bid}.self_attn.k_b_proj", # deepseek2 + "layers.{bid}.attention.k_b_proj", # mistral-large ), MODEL_TENSOR.ATTN_V_B: ( "model.layers.{bid}.self_attn.v_b_proj", # deepseek2 + "layers.{bid}.attention.v_b_proj", # mistral-large ), MODEL_TENSOR.ATTN_Q_A_NORM: ( "model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2 + "layers.{bid}.attention.q_a_norm", # mistral-large ), MODEL_TENSOR.ATTN_KV_A_NORM: ( "model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2 + "layers.{bid}.attention.kv_a_norm", # mistral-large ), MODEL_TENSOR.ATTN_SUB_NORM: ( diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 764833749e..351dcb7baa 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -666,7 +666,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: std::map mapped; int blk_id = 0; - int pruned_attention_w = 0; // make a list of weights std::vector tensors; @@ -674,11 +673,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: for (const auto & it : ml.weights_map) { const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id)); if (remapped_name.empty()) { - if (it.first.find("attn_v.weight") != std::string::npos || - it.first.find("attn_qkv.weight") != std::string::npos || - it.first.find("attn_kv_b.weight") != std::string::npos) { - pruned_attention_w++; - } LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str()); continue; } @@ -703,7 +697,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: }); } - bool is_clip_model = false; for (const auto * it : tensors) { const struct ggml_tensor * tensor = it->tensor; @@ -717,30 +710,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) { qs.has_output = true; } - - is_clip_model |= name.rfind("mm.", 0) == 0; // check the "mm." prefix } qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer; - // sanity checks for models that have attention layers - if (qs.n_attention_wv != 0 && !is_clip_model) - { - int32_t n_layer_all = model.hparams.n_layer; - if (llama_model_has_encoder(&model)) { - // now n_layer_all is the number of attention layers in the encoder - // for each decoder block, there are 2 attention layers - n_layer_all += 2 * model.hparams.dec_n_layer; - } - - // note: for linear-attention models (such as Qwen3 Next) this is the number of linear layers - const int32_t n_layer_recr = std::count(model.hparams.recurrent_layer_arr.begin(), model.hparams.recurrent_layer_arr.end(), true); - - LLAMA_LOG_INFO("%s: n_layer_all = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_all, n_layer_recr, pruned_attention_w); - - GGML_ASSERT((qs.n_attention_wv == n_layer_all - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected"); - } - size_t total_size_org = 0; size_t total_size_new = 0; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 9d7b0152af..2e94a53da2 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5604,21 +5604,24 @@ struct test_pad : public test_case { const std::array ne_a; const int pad_0; const int pad_1; + const bool circular; std::string vars() override { - return VARS_TO_STR4(type, ne_a, pad_0, pad_1); + return VARS_TO_STR5(type, ne_a, pad_0, pad_1, circular); } test_pad(ggml_type type = GGML_TYPE_F32, std::array ne_a = {512, 512, 1, 1}, - int pad_0 = 1, int pad_1 = 1) - : type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1) {} + int pad_0 = 1, int pad_1 = 1, bool circular = false) + : type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1), circular(circular) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); ggml_set_name(a, "a"); - ggml_tensor * out = ggml_pad(ctx, a, pad_0, pad_1, 0, 0); + ggml_tensor * out = circular + ? ggml_pad_circular(ctx, a, pad_0, pad_1, 0, 0) + : ggml_pad(ctx, a, pad_0, pad_1, 0, 0); ggml_set_name(out, "out"); return out; @@ -5638,17 +5641,19 @@ struct test_pad_ext : public test_case { const int lp3; const int rp3; const bool v; + const bool circular; std::string vars() override { - return VARS_TO_STR11(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, v); + return VARS_TO_STR12(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, v, circular); } test_pad_ext(ggml_type type = GGML_TYPE_F32, std::array ne_a = {512, 512, 3, 1}, int lp0 = 1, int rp0 = 1, int lp1 = 1, int rp1 = 1, int lp2 = 1, int rp2 = 1, int lp3 = 1, int rp3 = 1, - bool v = false) - : type(type), ne_a(ne_a), lp0(lp0), rp0(rp0), lp1(lp1), rp1(rp1), lp2(lp2), rp2(rp2), lp3(lp3), rp3(rp3), v(v) {} + bool v = false, bool circular = false) + : type(type), ne_a(ne_a), lp0(lp0), rp0(rp0), lp1(lp1), rp1(rp1), lp2(lp2), rp2(rp2), lp3(lp3), rp3(rp3), + v(v), circular(circular) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); @@ -5659,7 +5664,9 @@ struct test_pad_ext : public test_case { ggml_set_name(a, "view of a"); } - ggml_tensor * out = ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); + ggml_tensor * out = circular + ? ggml_pad_ext_circular(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3) + : ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); ggml_set_name(out, "out"); return out; @@ -6204,6 +6211,15 @@ struct test_solve_tri : public test_case { std::string vars() override { return VARS_TO_STR3(type, ne_lhs, ne_rhs); } + uint64_t op_flops(ggml_tensor * t) override { + GGML_UNUSED(t); + int64_t n = ne_lhs[0]; + int64_t k = ne_rhs[0]; + int64_t batch = ne_lhs[2] * ne_lhs[3]; + // n * (n + 1) / 2 non-zero elements of lhs, 2 flops each, for each col of rhs + return n * (n + 1) * k * batch; + } + test_solve_tri(ggml_type type = GGML_TYPE_F32, std::array ne_lhs = { 10, 10, 4, 3 }, std::array ne_rhs = { 3, 10, 4, 3 } @@ -7773,6 +7789,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {9, 9, 1280, 1})); test_cases.emplace_back(new test_acc()); test_cases.emplace_back(new test_pad()); + test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {33, 17, 2, 1}, 4, 3, true)); // circular test_cases.emplace_back(new test_pad_ext()); test_cases.emplace_back(new test_pad_reflect_1d()); test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1})); @@ -7816,10 +7833,14 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 42, 42, 5, 2 }, { 10, 42, 5, 2 })); test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 10, 64, 2, 2 })); test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 100, 100, 4, 4 }, { 41, 100, 4, 4 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 4 }, { 31, 128, 4, 4 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 300, 64, 4, 4 })); for (bool v : {false, true}) { - test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v)); - test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v)); + for (bool circular : {false, true}) { + test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v, circular)); + test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v, circular)); + } } for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) { @@ -8016,6 +8037,10 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 })); test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 })); + // qwen3next with CHUNK_SIZE 64 + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 8, 32 }, { 64, 64, 8, 32 })); + // qwen3next with CHUNK_SIZE 128 + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 32 }, { 128, 128, 4, 32 })); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 })); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 })); diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 4527cb3356..2db04e9522 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte index 2315611361..96ed56a775 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte @@ -3,6 +3,7 @@ import { copyToClipboard, isIMEComposing } from '$lib/utils'; import ChatMessageAssistant from './ChatMessageAssistant.svelte'; import ChatMessageUser from './ChatMessageUser.svelte'; + import ChatMessageSystem from './ChatMessageSystem.svelte'; interface Props { class?: string; @@ -140,8 +141,7 @@ } function handleSaveEdit() { - if (message.role === 'user') { - // For user messages, trim to avoid accidental whitespace + if (message.role === 'user' || message.role === 'system') { onEditWithBranching?.(message, editedContent.trim()); } else { // For assistant messages, preserve exact content including trailing whitespace @@ -167,7 +167,28 @@ } -{#if message.role === 'user'} +{#if message.role === 'system'} + +{:else if message.role === 'user'} + import { Check, X } from '@lucide/svelte'; + import { Card } from '$lib/components/ui/card'; + import { Button } from '$lib/components/ui/button'; + import { MarkdownContent } from '$lib/components/app'; + import { INPUT_CLASSES } from '$lib/constants/input-classes'; + import { config } from '$lib/stores/settings.svelte'; + import ChatMessageActions from './ChatMessageActions.svelte'; + + interface Props { + class?: string; + message: DatabaseMessage; + isEditing: boolean; + editedContent: string; + siblingInfo?: ChatMessageSiblingInfo | null; + showDeleteDialog: boolean; + deletionInfo: { + totalCount: number; + userMessages: number; + assistantMessages: number; + messageTypes: string[]; + } | null; + onCancelEdit: () => void; + onSaveEdit: () => void; + onEditKeydown: (event: KeyboardEvent) => void; + onEditedContentChange: (content: string) => void; + onCopy: () => void; + onEdit: () => void; + onDelete: () => void; + onConfirmDelete: () => void; + onNavigateToSibling?: (siblingId: string) => void; + onShowDeleteDialogChange: (show: boolean) => void; + textareaElement?: HTMLTextAreaElement; + } + + let { + class: className = '', + message, + isEditing, + editedContent, + siblingInfo = null, + showDeleteDialog, + deletionInfo, + onCancelEdit, + onSaveEdit, + onEditKeydown, + onEditedContentChange, + onCopy, + onEdit, + onDelete, + onConfirmDelete, + onNavigateToSibling, + onShowDeleteDialogChange, + textareaElement = $bindable() + }: Props = $props(); + + let isMultiline = $state(false); + let messageElement: HTMLElement | undefined = $state(); + let isExpanded = $state(false); + let contentHeight = $state(0); + const MAX_HEIGHT = 200; // pixels + const currentConfig = config(); + + let showExpandButton = $derived(contentHeight > MAX_HEIGHT); + + $effect(() => { + if (!messageElement || !message.content.trim()) return; + + if (message.content.includes('\n')) { + isMultiline = true; + } + + const resizeObserver = new ResizeObserver((entries) => { + for (const entry of entries) { + const element = entry.target as HTMLElement; + const estimatedSingleLineHeight = 24; + + isMultiline = element.offsetHeight > estimatedSingleLineHeight * 1.5; + contentHeight = element.scrollHeight; + } + }); + + resizeObserver.observe(messageElement); + + return () => { + resizeObserver.disconnect(); + }; + }); + + function toggleExpand() { + isExpanded = !isExpanded; + } + + +
+ {#if isEditing} +
+ + +
+ + + +
+
+ {:else} + {#if message.content.trim()} +
+ +
+ {/if} +
+ + {#if isExpanded && showExpandButton} +
+ +
+ {/if} + + + + {/if} + + {#if message.timestamp} +
+ +
+ {/if} + {/if} + diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageUser.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageUser.svelte index 8556cbef5b..3d2b8dd35b 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageUser.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageUser.svelte @@ -145,7 +145,7 @@ {#if message.content.trim()} {#if currentConfig.renderUserContentAsMarkdown} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessages.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessages.svelte index f307f829bc..2e5f57cb61 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessages.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessages.svelte @@ -2,6 +2,7 @@ import { ChatMessage } from '$lib/components/app'; import { chatStore } from '$lib/stores/chat.svelte'; import { conversationsStore, activeConversation } from '$lib/stores/conversations.svelte'; + import { config } from '$lib/stores/settings.svelte'; import { getMessageSiblings } from '$lib/utils'; interface Props { @@ -13,6 +14,7 @@ let { class: className, messages = [], onUserAction }: Props = $props(); let allConversationMessages = $state([]); + const currentConfig = config(); function refreshAllMessages() { const conversation = activeConversation(); @@ -40,7 +42,12 @@ return []; } - return messages.map((message) => { + // Filter out system messages if showSystemMessage is false + const filteredMessages = currentConfig.showSystemMessage + ? messages + : messages.filter((msg) => msg.type !== 'system'); + + return filteredMessages.map((message) => { const siblingInfo = getMessageSiblings(allConversationMessages, message.id); return { diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte index 67df20439c..45640e42a0 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte @@ -36,12 +36,6 @@ title: 'General', icon: Settings, fields: [ - { key: 'apiKey', label: 'API Key', type: 'input' }, - { - key: 'systemMessage', - label: 'System Message (will be disabled if left empty)', - type: 'textarea' - }, { key: 'theme', label: 'Theme', @@ -52,6 +46,12 @@ { value: 'dark', label: 'Dark', icon: Moon } ] }, + { key: 'apiKey', label: 'API Key', type: 'input' }, + { + key: 'systemMessage', + label: 'System Message', + type: 'textarea' + }, { key: 'pasteLongTextToFileLen', label: 'Paste long text to file length', diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte index ef46e2d176..a6f51f47d6 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte @@ -95,7 +95,7 @@ {#if field.help || SETTING_CONFIG_INFO[field.key]}

- {field.help || SETTING_CONFIG_INFO[field.key]} + {@html field.help || SETTING_CONFIG_INFO[field.key]}

{/if} {:else if field.type === 'textarea'} @@ -112,13 +112,28 @@ value={String(localConfig[field.key] ?? '')} onchange={(e) => onConfigChange(field.key, e.currentTarget.value)} placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] ?? 'none'}`} - class="min-h-[100px] w-full md:max-w-2xl" + class="min-h-[10rem] w-full md:max-w-2xl" /> + {#if field.help || SETTING_CONFIG_INFO[field.key]}

{field.help || SETTING_CONFIG_INFO[field.key]}

{/if} + + {#if field.key === 'systemMessage'} +
+ onConfigChange('showSystemMessage', Boolean(checked))} + /> + + +
+ {/if} {:else if field.type === 'select'} {@const selectedOption = field.options?.find( (opt: { value: string; label: string; icon?: Component }) => diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebar.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebar.svelte index a113acbcb7..1d313e284e 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebar.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebar.svelte @@ -8,6 +8,7 @@ import * as AlertDialog from '$lib/components/ui/alert-dialog'; import Input from '$lib/components/ui/input/input.svelte'; import { conversationsStore, conversations } from '$lib/stores/conversations.svelte'; + import { chatStore } from '$lib/stores/chat.svelte'; import ChatSidebarActions from './ChatSidebarActions.svelte'; const sidebar = Sidebar.useSidebar(); @@ -98,6 +99,10 @@ await goto(`#/chat/${id}`); } + + function handleStopGeneration(id: string) { + chatStore.stopGenerationForChat(id); + } @@ -132,6 +137,7 @@ onSelect={selectConversation} onEdit={handleEditConversation} onDelete={handleDeleteConversation} + onStop={handleStopGeneration} /> {/each} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebarConversationItem.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebarConversationItem.svelte index 82685f83a5..bf2fa4f9e9 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebarConversationItem.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebarConversationItem.svelte @@ -1,6 +1,7 @@