diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index d740dac065..23e23ca8c7 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -1395,6 +1395,14 @@ static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { builder.consume_reasoning_with_xml_tool_calls(form, "", ""); } +static void common_chat_parse_solar_open(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("<|think|>", "<|end|><|begin|>assistant<|content|>"); + + // TODO: Tool calling + + builder.add_content(builder.consume_rest()); +} + static void common_chat_parse_content_only(common_chat_msg_parser & builder) { builder.try_parse_reasoning("", ""); builder.add_content(builder.consume_rest()); @@ -1479,6 +1487,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) { case COMMON_CHAT_FORMAT_XIAOMI_MIMO: common_chat_parse_xiaomi_mimo(builder); break; + case COMMON_CHAT_FORMAT_SOLAR_OPEN: + common_chat_parse_solar_open(builder); + break; default: throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); } diff --git a/common/chat.cpp b/common/chat.cpp index be44c8abb0..b98ab21ce1 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -380,8 +380,8 @@ std::vector common_chat_tools_parse_oaicompat(const json & too const auto & function = tool.at("function"); result.push_back({ /* .name = */ function.at("name"), - /* .description = */ function.at("description"), - /* .parameters = */ function.at("parameters").dump(), + /* .description = */ function.value("description", ""), + /* .parameters = */ function.value("parameters", json::object()).dump(), }); } } @@ -669,6 +669,7 @@ const char * common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder"; case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5"; case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo"; + case COMMON_CHAT_FORMAT_SOLAR_OPEN: return "Solar Open"; case COMMON_CHAT_FORMAT_PEG_SIMPLE: return "peg-simple"; case COMMON_CHAT_FORMAT_PEG_NATIVE: return "peg-native"; case COMMON_CHAT_FORMAT_PEG_CONSTRUCTED: return "peg-constructed"; @@ -2517,6 +2518,27 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp return data; } +static common_chat_params common_chat_params_init_solar_open(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + + // TODO: Reasoning effort + json additional_context = {}; + + data.prompt = apply(tmpl, inputs, std::nullopt, std::nullopt, additional_context); + data.format = COMMON_CHAT_FORMAT_SOLAR_OPEN; + + data.preserved_tokens = { + "<|think|>", + "<|content|>", + "<|begin|>", + "<|end|>", + }; + + // TODO: Tool calling + + return data; +} + static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; data.prompt = apply(tmpl, inputs); @@ -2780,6 +2802,13 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_magistral(tmpl, params); } + // Solar Open + if (src.find("<|tool_response:begin|>") != std::string::npos && + src.find("<|tool_response:name|>") != std::string::npos && + src.find("<|tool_response:result|>") != std::string::npos) { + return common_chat_params_init_solar_open(tmpl, params); + } + // Plain handler (no tools) if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { return common_chat_params_init_without_tools(tmpl, params); diff --git a/common/chat.h b/common/chat.h index 6085510a40..8bd4a325ff 100644 --- a/common/chat.h +++ b/common/chat.h @@ -124,6 +124,7 @@ enum common_chat_format { COMMON_CHAT_FORMAT_QWEN3_CODER_XML, COMMON_CHAT_FORMAT_APRIEL_1_5, COMMON_CHAT_FORMAT_XIAOMI_MIMO, + COMMON_CHAT_FORMAT_SOLAR_OPEN, // These are intended to be parsed by the PEG parser COMMON_CHAT_FORMAT_PEG_SIMPLE, diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index edc0ed539d..3340a0a7dc 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1062,6 +1062,9 @@ class TextModel(ModelBase): if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273": # ref: https://huggingface.co/alvarobartt/grok-2-tokenizer res = "grok-2" + if chkhsh == "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df": + # ref: https://huggingface.co/aari1995/German_Semantic_V3 + res = "jina-v2-de" if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5": # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B res = "llama-bpe" @@ -1230,6 +1233,12 @@ class TextModel(ModelBase): if chkhsh == "4a2e2abae11ca2b86d570fc5b44be4d5eb5e72cc8f22dd136a94b37da83ab665": # ref: https://huggingface.co/KORMo-Team/KORMo-tokenizer res = "kormo" + if chkhsh == "9d70134b369a70e5735009b6de918f7581b5211f7c074d1f89f753aea8248af1": + # ref: https://huggingface.co/tencent/Youtu-LLM-2B + res = "youtu" + if chkhsh == "16389f0a1f51ee53e562ffd51c371dc508639ab0e4261502071836e50e223e91": + # ref: https://huggingface.co/upstage/Solar-Open-100B + res = "solar-open" if res is None: logger.warning("\n") @@ -2486,6 +2495,7 @@ class StableLMModel(TextModel): "VLlama3ForCausalLM", "LlavaForConditionalGeneration", "VoxtralForConditionalGeneration", + "IQuestCoderForCausalLM", "LlamaModel") class LlamaModel(TextModel): model_arch = gguf.MODEL_ARCH.LLAMA @@ -5284,13 +5294,14 @@ class BertModel(TextModel): self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1)) # convert to phantom space vocab - def phantom(tok): - if tok.startswith("[") and tok.endswith("]"): + def phantom(tok, toktype): + if toktype == gguf.TokenType.CONTROL: return tok if tok.startswith("##"): return tok[2:] return "\u2581" + tok - tokens = list(map(phantom, tokens)) + assert len(tokens) == len(toktypes) + tokens = list(map(phantom, tokens, toktypes)) # add vocab to gguf self.gguf_writer.add_tokenizer_model("bert") @@ -6404,6 +6415,17 @@ class ARwkv7Model(Rwkv7Model): self.gguf_writer.add_head_count(0) +@ModelBase.register("MaincoderForCausalLM") +class MaincoderModel(TextModel): + model_arch = gguf.MODEL_ARCH.MAINCODER + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + if (head_dim := self.hparams.get("head_dim")) is not None: + self.gguf_writer.add_rope_dimension_count(head_dim) + + @ModelBase.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM") class MambaModel(TextModel): model_arch = gguf.MODEL_ARCH.MAMBA @@ -7181,6 +7203,7 @@ class DeepseekModel(TextModel): "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", "KimiVLForConditionalGeneration", + "YoutuForCausalLM", ) class DeepseekV2Model(TextModel): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 @@ -7247,7 +7270,15 @@ class DeepseekV2Model(TextModel): super().set_gguf_parameters() hparams = self.hparams - self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) + # first_k_dense_replace: number of leading layers using dense FFN instead of MoE + # For non-MoE models (like Youtu), set to n_layer to use dense FFN for all layers + # For MoE models (like DeepSeek-V2), this is the number of leading non-MoE layers + has_moe = hparams.get("n_routed_experts") is not None + first_k_dense_replace = hparams.get("first_k_dense_replace") + if first_k_dense_replace is None: + # Default: if no MoE, all layers are dense; if MoE, none are dense + first_k_dense_replace = hparams["num_hidden_layers"] if not has_moe else 0 + self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace) self.gguf_writer.add_vocab_size(hparams["vocab_size"]) if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) @@ -7259,11 +7290,24 @@ class DeepseekV2Model(TextModel): self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) - self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) - self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) - self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) - self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) - self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + # MoE parameters (required by C++ code for DEEPSEEK2 arch) + # For non-MoE models like Youtu, use intermediate_size as expert_feed_forward_length + moe_intermediate_size = self.find_hparam(["moe_intermediate_size", "intermediate_size"], optional=False) + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) + + if (n_routed_experts := hparams.get("n_routed_experts")) is not None: + self.gguf_writer.add_expert_count(n_routed_experts) + + # expert_shared_count is required by C++ code, default to 0 for non-MoE models + n_shared_experts = hparams.get("n_shared_experts", 0) + self.gguf_writer.add_expert_shared_count(n_shared_experts) + + # When not set, C++ code will use scale_w = false to skip the no-op scaling + if (routed_scaling_factor := hparams.get("routed_scaling_factor")) is not None: + self.gguf_writer.add_expert_weights_scale(routed_scaling_factor) + + if (norm_topk_prob := hparams.get("norm_topk_prob")) is not None and norm_topk_prob: + self.gguf_writer.add_expert_weights_norm(norm_topk_prob) self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) @@ -7279,10 +7323,17 @@ class DeepseekV2Model(TextModel): # skip vision tensors and remove "language_model." for Kimi-VL if "vision_tower" in name or "multi_modal_projector" in name: return [] - + if name.startswith("siglip2.") or name.startswith("merger."): + return [] if name.startswith("language_model."): name = name.replace("language_model.", "") + # skip lm_head.weight if tie_word_embeddings is True + if self.hparams.get("tie_word_embeddings", False): + if name == "lm_head.weight" or name == "model.lm_head.weight": + logger.info("Skipping tied output layer 'lm_head.weight' (will use token_embd.weight)") + return [] + # rename e_score_correction_bias tensors if name.endswith("e_score_correction_bias"): name = name.replace("e_score_correction_bias", "e_score_correction.bias") @@ -10617,6 +10668,79 @@ class JanusProVisionModel(MmprojModel): return [] +@ModelBase.register("YOUTUVLForConditionalGeneration", "YOUTUVLForCausalLM") +class YOUTUVLVisionModel(MmprojModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_vision is not None + self.hparams_vision["image_size"] = self.hparams_vision.get("image_size", 560) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.YOUTUVL) + self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-6)) + + # Handle activation function + hidden_act = str(self.hparams.get("hidden_act", "gelu_pytorch_tanh")).lower() + if hidden_act in ("gelu", "gelu_pytorch_tanh", "gelu_fast", "gelu_new", "gelu_accurate"): + self.gguf_writer.add_vision_use_gelu(True) + elif hidden_act == "silu": + self.gguf_writer.add_vision_use_silu(True) + else: + raise ValueError(f"Unsupported activation function for YOUTUVL: {hidden_act}") + + self.gguf_writer.add_vision_spatial_merge_size(self.hparams.get("spatial_merge_size", 2)) + + window_size = self.hparams.get("window_size") + if window_size is not None: + self.gguf_writer.add_vision_window_size(window_size) + # fullatt_block_indexes contains explicit layer indices that use full attention + # e.g., [2, 5, 8, 11] means layers 2, 5, 8, 11 use full attention + # All other layers use window attention + fullatt_block_indexes = self.hparams.get("fullatt_block_indexes") + assert fullatt_block_indexes is not None, "fullatt_block_indexes is required for youtuvl" + # Store the explicit layer indices for YoutuVL (irregular pattern approach) + self.gguf_writer.add_vision_wa_layer_indexes(layers=fullatt_block_indexes) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + # Skip language model tensors + skip_prefixes = ('lm_head.', 'model.layers.', 'model.embed_tokens.', 'model.norm.') + if name.startswith(skip_prefixes): + return [] + + # Try to map the tensor using TensorNameMap (handles vision encoder and projector) + try: + new_name = self.map_tensor_name(name) + return [(new_name, data_torch)] + except ValueError: + # If mapping fails, log warning and skip + logger.warning(f"Cannot map tensor: {name}") + return [] + + +@ModelBase.register("SolarOpenForCausalLM") +class SolarOpenModel(Glm4MoeModel): + model_arch = gguf.MODEL_ARCH.GLM4_MOE + + def set_vocab(self): + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + tokens, toktypes, tokpre = self.get_vocab_base() + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token("unk", tokenizer.get_added_vocab()[""]) + special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|startoftext|>"]) + special_vocab.add_to_gguf(self.gguf_writer) + + ###### CONVERSION LOGIC ###### diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 4378378309..74c67e6a9c 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -145,6 +145,8 @@ models = [ {"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", }, {"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", }, {"name": "kormo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/KORMo-Team/KORMo-tokenizer", }, + {"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", }, + {"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", }, ] # some models are known to be broken upstream, so we will skip them as exceptions @@ -165,6 +167,8 @@ pre_computed_hashes = [ {"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"}, {"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"}, {"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"}, + # jina-v2-de variants + {"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"}, ] diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 4ed5f35774..a9d1778641 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -358,7 +358,7 @@ extern "C" { typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data); // Compare the output of two backends - GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node); + GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes); // Tensor initialization GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 8547ecc849..1b59924b8c 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -2053,7 +2053,7 @@ void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) { ggml_free(copy.ctx_unallocated); } -bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node) { +bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes) { struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph); if (copy.buffer == NULL) { return false; @@ -2064,22 +2064,22 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t assert(g1->n_nodes == g2->n_nodes); - if (test_node != nullptr) { - // Compute the whole graph and only test the output for a specific tensor + if (num_test_nodes != 0) { + GGML_ASSERT(test_nodes); + // Compute the whole graph and only test the output for specific tensors ggml_backend_graph_compute(backend1, g1); ggml_backend_graph_compute(backend2, g2); - int test_node_idx = -1; + bool verified = false; for (int i = 0; i < g1->n_nodes; i++) { - struct ggml_tensor * t1 = g1->nodes[i]; - if (t1 == test_node) { - test_node_idx = i; - break; + for (size_t j = 0; j < num_test_nodes; ++j) { + if (g1->nodes[i] == test_nodes[j]) { + callback(i, g1->nodes[i], g2->nodes[i], user_data); + verified = true; + } } } - GGML_ASSERT(test_node_idx != -1); - - callback(test_node_idx, g1->nodes[test_node_idx], g2->nodes[test_node_idx], user_data); + GGML_ASSERT(verified); } else { for (int i = 0; i < g1->n_nodes; i++) { struct ggml_tensor * t1 = g1->nodes[i]; diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index c4ceb4fc57..ee84303ef0 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -12,11 +12,11 @@ const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows template -static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { - const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; +static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, + const int64_t nb12, const int64_t nb13) { + const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { return; @@ -40,10 +40,10 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne, } template -static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { +static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, + const int64_t nb12, const int64_t nb13) { const T* src = reinterpret_cast(cx); T* dst = reinterpret_cast(cdst); @@ -117,60 +117,60 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) { } template -static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { - const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; +static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, + const int64_t nb12, const int64_t nb13) { + const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk; if (i >= ne) { return; } - const int i03 = i/(ne00 * ne01 * ne02); - const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); - const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; - const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; - const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; + const int64_t i03 = i/(ne00 * ne01 * ne02); + const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); + const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; + const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; + const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; - const int i13 = i/(ne10 * ne11 * ne12); - const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); - const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; - const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; - const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + const int64_t i13 = i/(ne10 * ne11 * ne12); + const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13; cpy_blck(cx + x_offset, cdst + dst_offset); } template -static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { - const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; +static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, + const int64_t nb12, const int64_t nb13) { + const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk; if (i >= ne) { return; } - const int i03 = i/(ne00 * ne01 * ne02); - const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); - const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; - const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; - const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; + const int64_t i03 = i/(ne00 * ne01 * ne02); + const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); + const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; + const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; + const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; - const int i13 = i/(ne10 * ne11 * ne12); - const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); - const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; - const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; - const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + const int64_t i13 = i/(ne10 * ne11 * ne12); + const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13; cpy_blck(cx + x_offset, cdst + dst_offset); } template static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) { - const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; + const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { return; @@ -188,19 +188,20 @@ static void ggml_cpy_scalar_contiguous_cuda( cudaStream_t stream) { const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_scalar_contiguous<<>> (cx, cdst, ne); } template static void ggml_cpy_scalar_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { if (transposed) { GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed - int ne00n, ne01n, ne02n; + int64_t ne00n, ne01n, ne02n; if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here ne00n = ne00; ne01n = ne01; @@ -211,143 +212,159 @@ static void ggml_cpy_scalar_cuda( ne02n = 1; } - dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, - (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, - (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM); + int64_t grid_x = (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D; + int64_t grid_y = (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D; + int64_t grid_z = (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM; + GGML_ASSERT(grid_x < UINT_MAX); + GGML_ASSERT(grid_y < USHRT_MAX); + GGML_ASSERT(grid_z < USHRT_MAX); + dim3 dimGrid(grid_x, grid_y, grid_z); dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1); cpy_scalar_transpose<<>> (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } else { - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_scalar><<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } } static void ggml_cpy_f32_q8_0_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK8_0 == 0); - const int num_blocks = ne / QK8_0; + const int64_t num_blocks = ne / QK8_0; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q8_0_f32_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { - const int num_blocks = ne; + const int64_t num_blocks = ne; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_q_f32<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q4_0_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK4_0 == 0); - const int num_blocks = ne / QK4_0; + const int64_t num_blocks = ne / QK4_0; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q4_0_f32_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, - const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, - const int nb10, const int nb11, const int nb12, const int nb13, + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, + const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { - const int num_blocks = ne; + const int64_t num_blocks = ne; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_q_f32, QK4_0><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q4_1_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK4_1 == 0); - const int num_blocks = ne / QK4_1; + const int64_t num_blocks = ne / QK4_1; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q4_1_f32_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, - const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, - const int nb10, const int nb11, const int nb12, const int nb13, + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, + const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { - const int num_blocks = ne; + const int64_t num_blocks = ne; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_q_f32, QK4_1><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q5_0_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK5_0 == 0); - const int num_blocks = ne / QK5_0; + const int64_t num_blocks = ne / QK5_0; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q5_0_f32_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, - const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, - const int nb10, const int nb11, const int nb12, const int nb13, + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, + const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { - const int num_blocks = ne; + const int64_t num_blocks = ne; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_q_f32, QK5_0><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q5_1_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK5_1 == 0); - const int num_blocks = ne / QK5_1; + const int64_t num_blocks = ne / QK5_1; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q5_1_f32_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, - const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, - const int nb10, const int nb11, const int nb12, const int nb13, + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, + const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { - const int num_blocks = ne; + const int64_t num_blocks = ne; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_q_f32, QK5_1><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_iq4_nl_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK4_NL == 0); - const int num_blocks = ne / QK4_NL; + const int64_t num_blocks = ne / QK4_NL; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } @@ -356,9 +373,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); - GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); - GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); - const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 55e1c20c96..84eccea3f7 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -201,16 +201,6 @@ static ggml_cuda_device_info ggml_cuda_init() { GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES); int64_t total_vram = 0; -#ifdef GGML_CUDA_FORCE_MMQ - GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__); -#else - GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__); -#endif // GGML_CUDA_FORCE_MMQ -#ifdef GGML_CUDA_FORCE_CUBLAS - GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__); -#else - GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__); -#endif // GGML_CUDA_FORCE_CUBLAS GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count); std::vector> turing_devices_without_mma; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index acf2aa9184..a50b12b6f3 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2181,7 +2181,11 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) { const bool has_mask = op->src[3] != nullptr; - if (ggml_metal_op_flash_attn_ext_use_vec(op)) { + // note: the non-vec kernel requires more extra memory, so always reserve for it + GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG); + + //if (ggml_metal_op_flash_attn_ext_use_vec(op)) { + if (false) { // note: always reserve the padding space to avoid graph reallocations //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0; const bool has_kvpad = true; diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 164b39d01e..d7c8ad8c16 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -1517,10 +1517,12 @@ bool rpc_server::graph_compute(const std::vector & input) { struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false); graph->n_nodes = n_nodes; std::unordered_map tensor_ptrs; + tensor_ptrs.reserve(n_tensors); for (uint32_t i = 0; i < n_tensors; i++) { - tensor_ptrs[tensors[i].id] = &tensors[i]; + tensor_ptrs.emplace(tensors[i].id, &tensors[i]); } std::unordered_map tensor_map; + tensor_map.reserve(n_nodes); for (uint32_t i = 0; i < n_nodes; i++) { int64_t id; memcpy(&id, &nodes[i], sizeof(id)); diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 493ee9c9a4..16254457bb 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -434,8 +434,15 @@ static constexpr std::initializer_list topk_moe_early_softmax_norm{ GGM GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, GGML_OP_RESHAPE }; + +static constexpr std::initializer_list topk_moe_sigmoid_norm_bias{ GGML_OP_UNARY, GGML_OP_RESHAPE, GGML_OP_ADD, + GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, + GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, + GGML_OP_DIV, GGML_OP_RESHAPE }; + static constexpr std::initializer_list topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS }; + static constexpr std::initializer_list topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; @@ -464,6 +471,32 @@ static constexpr std::initializer_list> topk_moe_early_softma { 9, 0, 8 }, // reshape->src[0] == div }; +//node #436 ( UNARY): ffn_moe_probs-10 ( 256K) [Vulka ] use=2: ffn_moe_logits-10 ( 256K) [Vulka ] +//node #437 ( RESHAPE): ffn_moe_probs-10 (re ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ] +//node #438 ( ADD): ffn_moe_probs_biased ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ] blk.10.exp_probs_b.b ( 0K) [Vulka ] +//node #439 ( ARGSORT): ffn_moe_argsort-10 ( 256K) [Vulka ] use=1: ffn_moe_probs_biased ( 256K) [Vulka ] +//node #440 ( VIEW): ffn_moe_topk-10 ( 255K) [Vulka ] use=3: ffn_moe_argsort-10 ( 256K) [Vulka ] +//node #441 ( GET_ROWS): ffn_moe_weights-10 ( 12K) [Vulka ] use=1: ffn_moe_probs-10 (re ( 256K) [Vulka ] ffn_moe_topk-10 ( 255K) [Vulka ] +//node #442 ( RESHAPE): ffn_moe_weights-10 ( ( 12K) [Vulka ] use=2: ffn_moe_weights-10 ( 12K) [Vulka ] +//node #443 ( SUM_ROWS): ffn_moe_weights_sum- ( 2K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ] +//node #444 ( CLAMP): ffn_moe_weights_sum_ ( 2K) [Vulka ] use=1: ffn_moe_weights_sum- ( 2K) [Vulka ] +//node #445 ( DIV): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ] ffn_moe_weights_sum_ ( 2K) [Vulka ] +//node #446 ( RESHAPE): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights_norm ( 12K) [Vulka ] +static constexpr std::initializer_list> topk_moe_sigmoid_norm_bias_edges { + { 1, 0, 0 }, // reshape->src[0] == sigmoid + { 2, 0, 0 }, // add->src[0] == sigmoid + { 3, 0, 2 }, // argsort->src[0] == add + { 4, 0, 3 }, // view->src[0] == argsort + { 5, 0, 1 }, // get_rows->src[0] == reshape + { 5, 1, 4 }, // get_rows->src[1] == view + { 6, 0, 5 }, // reshape->src[0] == get_rows + { 7, 0, 6 }, // sum_rows->src[0] == reshape + { 8, 0, 7 }, // clamp->src[0] == sum_rows + { 9, 0, 6 }, // div->src[0] == reshape + { 9, 1, 8 }, // div->src[1] == clamp + {10, 0, 9 }, // reshape->src[0] == div +}; + // same as early_softmax_norm but ending after the get_rows static constexpr std::initializer_list> topk_moe_early_softmax_edges { { 1, 0, 0 }, // reshape->src[0] == softmax @@ -491,16 +524,10 @@ enum topk_moe_mode { TOPK_MOE_EARLY_SOFTMAX, TOPK_MOE_EARLY_SOFTMAX_NORM, TOPK_MOE_LATE_SOFTMAX, + TOPK_MOE_SIGMOID_NORM_BIAS, TOPK_MOE_COUNT, }; -static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) { - topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM : - num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX : - TOPK_MOE_LATE_SOFTMAX; - return mode; -} - static constexpr std::initializer_list> rope_view_set_rows_edges { { 1, 0, 0 }, // view->src[0] == rope { 2, 0, 1 }, // set_rows->src[0] == view @@ -738,6 +765,9 @@ struct vk_device_struct { vk_pipeline pipeline_topk_f32[num_topk_pipelines]; vk_pipeline pipeline_sum_rows_f32; vk_pipeline pipeline_cumsum_f32; + vk_pipeline pipeline_cumsum_small_f32; + vk_pipeline pipeline_cumsum_multipass1_f32; + vk_pipeline pipeline_cumsum_multipass2_f32; vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_count_equal_i32; std::map pipeline_solve_tri_f32; @@ -766,7 +796,7 @@ struct vk_device_struct { vk_pipeline pipeline_count_experts; // [2] is for whether to take n_experts from spec constant (0) or push constant (1) - vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2]; + vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2]; std::vector all_pipelines; @@ -1181,6 +1211,11 @@ struct vk_op_topk_moe_push_constants { uint32_t n_expert_used; float clamp_min; float clamp_max; + uint32_t gating_func; + uint32_t has_bias; + uint32_t with_norm; + float output_scale; + float output_bias; }; struct vk_op_add_id_push_constants { @@ -1771,6 +1806,8 @@ struct ggml_backend_vk_context { // Bit 'i' means nodes[start_of_fusion + i] writes to memory. // If there's no fusion, bit 0 is still set. int fused_ops_write_mask {}; + topk_moe_mode fused_topk_moe_mode {}; + bool fused_topk_moe_scale {}; // for GGML_VK_PERF_LOGGER std::unique_ptr perf_logger; @@ -2668,7 +2705,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec switch (src0_type) { case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: - lut_size = 2*2048; + lut_size = 2*2048 + 4*2048; break; case GGML_TYPE_IQ2_XXS: lut_size = 8*256; @@ -3593,6 +3630,7 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t rm_kq = 2; uint32_t rm_stdq_int = 1; uint32_t rm_kq_int = 1; + auto const &rm_iq_int = [](uint32_t i) { return i == 0 ? 8u : 4u; }; if (device->vendor_id == VK_VENDOR_ID_AMD) { if (device->architecture == AMD_GCN) { rm_stdq = 2; @@ -3696,6 +3734,10 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_q8_1_f32", arr_dmmv_iq1_s_q8_1_f32_len[reduc], arr_dmmv_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_q8_1_f32", arr_dmmv_iq1_m_q8_1_f32_len[reduc], arr_dmmv_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int); + } #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT } @@ -3742,6 +3784,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_q8_1_f32", arr_dmmv_id_iq1_s_q8_1_f32_len[reduc], arr_dmmv_id_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_q8_1_f32", arr_dmmv_id_iq1_m_q8_1_f32_len[reduc], arr_dmmv_id_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int); } #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT } @@ -3749,6 +3794,7 @@ static void ggml_vk_load_shaders(vk_device& device) { #if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) GGML_UNUSED(rm_stdq_int); GGML_UNUSED(rm_kq_int); + GGML_UNUSED(rm_iq_int); #endif // dequant shaders @@ -4135,7 +4181,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size); + const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4; + ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_cumsum_small_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size, 1 }, 1, true, true, device->subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass1_f32, "cumsum_multipass1_f32", cumsum_multipass1_f32_len, cumsum_multipass1_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass2_f32, "cumsum_multipass2_f32", cumsum_multipass2_f32_len, cumsum_multipass2_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); @@ -4291,9 +4341,7 @@ static void ggml_vk_load_shaders(vk_device& device) { for (uint32_t use_push = 0; use_push < 2; ++use_push) { for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX][use_push], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM][use_push], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX][use_push], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][use_push], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 4, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); } } @@ -5584,6 +5632,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: break; default: return nullptr; @@ -5740,6 +5790,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: break; default: return nullptr; @@ -7005,7 +7057,7 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ // Quantization overhead is not worth it for small k switch (device->vendor_id) { case VK_VENDOR_ID_NVIDIA: - if (src0_type == GGML_TYPE_Q2_K) { + if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) { return true; } @@ -8684,10 +8736,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const if (ctx->num_additional_fused_ops) { uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); GGML_ASSERT(idx < num_topk_moe_pipelines); - topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); // use n_experts from push constant if it's not equal to the power of two spec constant bool use_push = dst->ne[0] != (1u << idx); - return ctx->device->pipeline_topk_moe[idx][mode][use_push]; + return ctx->device->pipeline_topk_moe[idx][use_push]; } if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { @@ -8760,7 +8811,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_CUMSUM: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_cumsum_f32; + if (src0->ne[0] <= 512) { + return ctx->device->pipeline_cumsum_small_f32; + } else { + return ctx->device->pipeline_cumsum_f32; + } } return nullptr; case GGML_OP_SOLVE_TRI: @@ -10346,14 +10401,16 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub } static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) { - topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); + topk_moe_mode mode = ctx->fused_topk_moe_mode; ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0]; - ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] : - (mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] : - cgraph->nodes[node_idx + 5]; - ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3]; + ggml_tensor * bias = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 2]->src[1] : logits; + ggml_tensor * weights = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; + ggml_tensor * ids = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 4] : + (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : + cgraph->nodes[node_idx + 3]; GGML_ASSERT(logits->type == GGML_TYPE_F32); + GGML_ASSERT(bias->type == GGML_TYPE_F32); GGML_ASSERT(weights->type == GGML_TYPE_F32); GGML_ASSERT(ids->type == GGML_TYPE_I32); @@ -10368,6 +10425,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits); + vk_subbuffer bias_buf = ggml_vk_tensor_subbuffer(ctx, bias); vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights); vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids); @@ -10375,18 +10433,45 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, pc.n_rows = n_rows; pc.n_experts_push = n_experts; pc.n_expert_used = n_expert_used; + pc.clamp_min = -std::numeric_limits::infinity(); + pc.clamp_max = std::numeric_limits::infinity(); if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) { ggml_tensor * clamp = cgraph->nodes[node_idx + 7]; + GGML_ASSERT(clamp->op == GGML_OP_CLAMP); pc.clamp_min = ggml_get_op_params_f32(clamp, 0); pc.clamp_max = ggml_get_op_params_f32(clamp, 1); } + if (mode == TOPK_MOE_SIGMOID_NORM_BIAS) { + ggml_tensor * clamp = cgraph->nodes[node_idx + 8]; + GGML_ASSERT(clamp->op == GGML_OP_CLAMP); + pc.clamp_min = ggml_get_op_params_f32(clamp, 0); + pc.clamp_max = ggml_get_op_params_f32(clamp, 1); + } + +#define GATING_FUNC_SOFTMAX 0 +#define GATING_FUNC_SIGMOID 1 +#define GATING_FUNC_SOFTMAX_WEIGHT 2 + + pc.gating_func = mode == TOPK_MOE_SIGMOID_NORM_BIAS ? GATING_FUNC_SIGMOID : + mode == TOPK_MOE_LATE_SOFTMAX ? GATING_FUNC_SOFTMAX_WEIGHT : + GATING_FUNC_SOFTMAX; + pc.has_bias = mode == TOPK_MOE_SIGMOID_NORM_BIAS; + pc.with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM || mode == TOPK_MOE_SIGMOID_NORM_BIAS; + if (ctx->fused_topk_moe_scale) { + GGML_ASSERT(weights->op == GGML_OP_SCALE); + pc.output_scale = ggml_get_op_params_f32(weights, 0); + pc.output_bias = ggml_get_op_params_f32(weights, 1); + } else { + pc.output_scale = 1.0f; + pc.output_bias = 0.0f; + } GGML_ASSERT(n_expert_used <= n_experts); const uint32_t rows_per_block = 4; std::array elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, weights_buf, ids_buf}, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, bias_buf, weights_buf, ids_buf}, pc, elements); } static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) { @@ -10634,8 +10719,50 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons } static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { - vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p); + vk_op_sum_rows_push_constants pc = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); + // Use the single pass shader when the rows are small or there are enough rows to fill the GPU. + // For fewer, larger rows, use the multipass shader to spread each row across SMs. + if (dst->ne[0] <= 4096 || ggml_nrows(dst) >= ctx->device->shader_core_count) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, pc); + return; + } + + // First pass computes partial sums within a block, and stores the last partial + // to the temp buffer. Second pass sums the block partials from the temp buffer + // and adds that to the result of the first pass. + vk_pipeline pipeline1 = ctx->device->pipeline_cumsum_multipass1_f32; + vk_pipeline pipeline2 = ctx->device->pipeline_cumsum_multipass2_f32; + GGML_ASSERT(pipeline1 != nullptr && pipeline2 != nullptr); + + ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1); + + std::array elements; + + elements[0] = dst->ne[0]; + elements[1] = (uint32_t)ggml_nrows(dst); + elements[2] = 1; + + size_t temp_size = sizeof(float) * elements[0] * ggml_nrows(dst); + + if (ctx->prealloc_size_split_k < temp_size) { + ctx->prealloc_size_split_k = temp_size; + ggml_vk_preallocate_buffers(ctx, subctx); + } + + vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src0); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); + vk_subbuffer temp_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0); + + if (ctx->prealloc_split_k_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, {src_buf, dst_buf, temp_buf}, pc, elements); + ggml_vk_sync_buffers(ctx, subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, {src_buf, dst_buf, temp_buf}, pc, elements); + + ctx->prealloc_split_k_need_sync = true; } static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { @@ -12128,6 +12255,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_UNARY: + if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) { + ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx); + break; + } + switch (ggml_get_unary_op(node)) { case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SILU: @@ -12175,7 +12307,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_SOFT_MAX: - if (ctx->num_additional_fused_ops) { + if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) { ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx); } else { ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node); @@ -12195,7 +12327,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_ARGSORT: - if (ctx->num_additional_fused_ops) { + if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) { ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx); } else { ggml_vk_argsort(ctx, compute_ctx, src0, node); @@ -13048,6 +13180,24 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc get_rows = cgraph->nodes[node_idx + 4]; argsort = cgraph->nodes[node_idx + 2]; break; + case TOPK_MOE_SIGMOID_NORM_BIAS: + softmax = cgraph->nodes[node_idx + 0]; // really sigmoid + weights = cgraph->nodes[node_idx + 10]; + get_rows = cgraph->nodes[node_idx + 5]; + argsort = cgraph->nodes[node_idx + 3]; + if (ggml_get_unary_op(softmax) != GGML_UNARY_OP_SIGMOID) { + return false; + } + // bias is expected to be 1D + if (ggml_nrows(cgraph->nodes[node_idx + 2]->src[1]) != 1 || + !ggml_is_contiguous(cgraph->nodes[node_idx + 2]->src[1])) { + return false; + } + // sigmoid fusion seems to generate infinities on moltenvk + if (ctx->device->driver_id == vk::DriverId::eMoltenvk) { + return false; + } + break; case TOPK_MOE_EARLY_SOFTMAX: softmax = cgraph->nodes[node_idx + 0]; weights = cgraph->nodes[node_idx + 4]; @@ -13071,26 +13221,28 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc probs = probs->src[0]; ggml_tensor * selection_probs = argsort->src[0]; - if (probs != selection_probs) { + if (probs != selection_probs && mode != TOPK_MOE_SIGMOID_NORM_BIAS) { return false; } - const float * op_params = (const float *)softmax->op_params; - - float scale = op_params[0]; - float max_bias = op_params[1]; - if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) { return false; } - if (scale != 1.0f || max_bias != 0.0f) { - return false; - } + if (softmax->op == GGML_OP_SOFT_MAX) { + const float * op_params = (const float *)softmax->op_params; - // don't fuse when masks or sinks are present - if (softmax->src[1] || softmax->src[2]) { - return false; + float scale = op_params[0]; + float max_bias = op_params[1]; + + if (scale != 1.0f || max_bias != 0.0f) { + return false; + } + + // don't fuse when masks or sinks are present + if (softmax->src[1] || softmax->src[2]) { + return false; + } } const int n_expert = softmax->ne[0]; @@ -13363,6 +13515,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg total_mul_mat_bytes += bytes; } + ctx->fused_topk_moe_mode = TOPK_MOE_COUNT; + ctx->fused_topk_moe_scale = false; const char *fusion_string {}; if (!ctx->device->disable_fusion) { uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i); @@ -13408,13 +13562,23 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1; // view of argsort writes to memory ctx->fused_ops_write_mask |= 1 << 3; + ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM; fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM"; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) && + ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) { + ctx->num_additional_fused_ops = topk_moe_sigmoid_norm_bias.size() - 1; + // view of argsort writes to memory + ctx->fused_ops_write_mask |= 1 << 4; + ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS; + fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS"; } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) && ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) { ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1; // view of argsort writes to memory ctx->fused_ops_write_mask |= 1 << 3; + ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX; fusion_string = "TOPK_MOE_EARLY_SOFTMAX"; } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) && ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) && @@ -13422,8 +13586,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1; // view of argsort writes to memory ctx->fused_ops_write_mask |= 1 << 1; + ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX; fusion_string = "TOPK_MOE_LATE_SOFTMAX"; } + if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) { + // Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano. + if (ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops - 1, { GGML_OP_DIV, GGML_OP_RESHAPE, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 }) || + ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) { + ctx->fused_topk_moe_scale = true; + ctx->num_additional_fused_ops++; + } + } } ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops; @@ -13602,6 +13775,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * if (keep_pattern(topk_moe_early_softmax_norm)) { continue; } + if (keep_pattern(topk_moe_sigmoid_norm_bias)) { + continue; + } if (keep_pattern(topk_moe_early_softmax)) { continue; } @@ -13628,6 +13804,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * } // Don't pull forward nodes from fusion patterns if (match_pattern(topk_moe_early_softmax_norm, j) || + match_pattern(topk_moe_sigmoid_norm_bias, j) || match_pattern(topk_moe_early_softmax, j) || match_pattern(topk_moe_late_softmax, j)) { continue; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp index a4c8fc354e..75e3c3b0eb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp @@ -14,6 +14,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout (constant_id = 0) const uint BLOCK_SIZE = 128; layout (constant_id = 1) const uint SUBGROUP_SIZE = 32; +layout (constant_id = 2) const uint ELEM_PER_THREAD = 4; #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) @@ -38,32 +39,45 @@ void main() { last_sum = 0; } - uint col = tid; - uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE); + uint col = tid * ELEM_PER_THREAD; + uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE * ELEM_PER_THREAD); for (int i = 0; i < num_iter; ++i) { - FLOAT_TYPE v = 0; - if (col < p.n_cols) { - v = FLOAT_TYPE(data_a[src_idx + col]); + FLOAT_TYPE v[ELEM_PER_THREAD]; + FLOAT_TYPE thread_sum = 0; + [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { + if (col + j < p.n_cols) { + thread_sum += FLOAT_TYPE(data_a[src_idx + col + j]); + } + v[j] = thread_sum; } - v = subgroupInclusiveAdd(v); + thread_sum = subgroupExclusiveAdd(thread_sum); + [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { + v[j] += thread_sum; + } // Store the largest partial sum for each subgroup, then add the partials for all // lower subgroups and the final partial sum from the previous iteration. if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) { - partial[subgroup_id] = v; + partial[subgroup_id] = v[ELEM_PER_THREAD - 1]; } barrier(); - for (int j = 0; j < subgroup_id; ++j) { - v += partial[j]; + for (int s = 0; s < subgroup_id; ++s) { + [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { + v[j] += partial[s]; + } + } + [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { + v[j] += last_sum; } - v += last_sum; barrier(); if (tid == BLOCK_SIZE - 1) { - last_sum = v; + last_sum = v[ELEM_PER_THREAD - 1]; } - if (col < p.n_cols) { - data_d[dst_idx + col] = D_TYPE(v); + [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { + if (col + j < p.n_cols) { + data_d[dst_idx + col + j] = D_TYPE(v[j]); + } } - col += BLOCK_SIZE; + col += BLOCK_SIZE * ELEM_PER_THREAD; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp new file mode 100644 index 0000000000..6d39f927fc --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp @@ -0,0 +1,60 @@ +#version 450 + +#include "types.glsl" +#include "sum_rows.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; +layout (binding = 2) writeonly buffer T {D_TYPE data_t[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 128; +layout (constant_id = 1) const uint SUBGROUP_SIZE = 32; + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.y; + const uint tid = gl_LocalInvocationID.x; + const uint col = gl_GlobalInvocationID.x; + + const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L); + const uint i03_offset = i03 * p.ne01*p.ne02; + const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L); + const uint i01 = row - i03_offset - i02*p.ne01; + + const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03; + const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13; + + uint subgroup_id = tid / SUBGROUP_SIZE; + + FLOAT_TYPE v = 0; + if (col < p.n_cols) { + v = FLOAT_TYPE(data_a[src_idx + col]); + } + v = subgroupInclusiveAdd(v); + + // Store the largest partial sum for each subgroup, then add the partials for all + // lower subgroups and the final partial sum from the previous iteration. + if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) { + partial[subgroup_id] = v; + } + barrier(); + for (int j = 0; j < subgroup_id; ++j) { + v += partial[j]; + } + barrier(); + if (tid == BLOCK_SIZE - 1) { + data_t[gl_WorkGroupID.x + gl_NumWorkGroups.x * row] = v; + } + if (col < p.n_cols) { + data_d[dst_idx + col] = D_TYPE(v); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp new file mode 100644 index 0000000000..e401893466 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp @@ -0,0 +1,66 @@ +#version 450 + +#include "types.glsl" +#include "sum_rows.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) buffer D {D_TYPE data_d[];}; +layout (binding = 2) readonly buffer T {D_TYPE data_t[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 128; +layout (constant_id = 1) const uint SUBGROUP_SIZE = 32; + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +shared FLOAT_TYPE temp[BLOCK_SIZE / SUBGROUP_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.y; + const uint tid = gl_LocalInvocationID.x; + + const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L); + const uint i03_offset = i03 * p.ne01*p.ne02; + const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L); + const uint i01 = row - i03_offset - i02*p.ne01; + + const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03; + const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13; + + const uint col = gl_GlobalInvocationID.x; + + float v = 0; + // prefetch value we're adding to + if (col < p.n_cols) { + v = data_d[dst_idx + col]; + } + + // compute the sum of all previous blocks + uint c = tid; + float sum = 0; + while (c < gl_WorkGroupID.x) { + sum += data_t[c + gl_NumWorkGroups.x * row]; + c += BLOCK_SIZE; + } + + sum = subgroupAdd(sum); + if (gl_SubgroupInvocationID == 0) { + temp[gl_SubgroupID] = sum; + } + barrier(); + sum = 0; + [[unroll]] for (uint s = 0; s < BLOCK_SIZE / SUBGROUP_SIZE; ++s) { + sum += temp[s]; + } + + // Add the sum to what the first pass computed + if (col < p.n_cols) { + data_d[dst_idx + col] = v + sum; + } +} + diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp index 15f005be3e..ff5f43979d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp @@ -14,6 +14,8 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; #define K_PER_ITER 8 #elif defined(DATA_A_QUANT_K) #define K_PER_ITER 16 +#elif defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M) +#define K_PER_ITER 32 #else #error unimplemented #endif @@ -49,6 +51,15 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1]; cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2]; cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3]; +#elif K_PER_ITER == 32 + cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 ]; + cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 1]; + cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 2]; + cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 3]; + cache_b_qs[4] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 4]; + cache_b_qs[5] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 5]; + cache_b_qs[6] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 6]; + cache_b_qs[7] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 7]; #else #error unimplemented #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl index 2389ea0b1e..6ddbed309d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl @@ -377,3 +377,118 @@ FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum)); } #endif + +#if defined(DATA_A_IQ1_S) +void repack8(uint ib, uint iqs, out i32vec4 out0, out i32vec4 out1) { + const uint ib32 = iqs / 32; + + const uint qh = data_a[ib].qh[ib32]; + + const uint qs16_0 = data_a_packed16[ib].qs[(4 * ib32 + 0) / 2]; + const uint qs16_1 = data_a_packed16[ib].qs[(4 * ib32 + 2) / 2]; + + const uint qs0 = qs16_0 & 0xFF; + const uint qs1 = qs16_0 >> 8; + const uint qs2 = qs16_1 & 0xFF; + const uint qs3 = qs16_1 >> 8; + + const uint hi0 = bitfieldExtract(qh, 3 * int(0), 3); + const uint hi1 = bitfieldExtract(qh, 3 * int(1), 3); + const uint hi2 = bitfieldExtract(qh, 3 * int(2), 3); + const uint hi3 = bitfieldExtract(qh, 3 * int(3), 3); + + const int32_t grid0 = int32_t(iq1s_grid_gpu[qs0 | (hi0 << 8)]); + const int32_t grid1 = int32_t(iq1s_grid_gpu[qs1 | (hi1 << 8)]); + const int32_t grid2 = int32_t(iq1s_grid_gpu[qs2 | (hi2 << 8)]); + const int32_t grid3 = int32_t(iq1s_grid_gpu[qs3 | (hi3 << 8)]); + + out0 = i32vec4((grid0 >> 0) & 0x0F0F0F0F, + (grid0 >> 4) & 0x0F0F0F0F, + (grid1 >> 0) & 0x0F0F0F0F, + (grid1 >> 4) & 0x0F0F0F0F); + out1 = i32vec4((grid2 >> 0) & 0x0F0F0F0F, + (grid2 >> 4) & 0x0F0F0F0F, + (grid3 >> 0) & 0x0F0F0F0F, + (grid3 >> 4) & 0x0F0F0F0F); +} + +vec2 get_dm(uint ib, uint iqs) { + const uint ib32 = iqs / 32; + + const uint qh = data_a[ib].qh[ib32]; + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + + const float d = float(data_a[ib].d); + const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); + + // the -1 cancels out the bias in iq1s_grid_gpu + return FLOAT_TYPE_VEC2(dl, dl * (delta - 1)); +} + +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t q_sum = 0; + + const uint ib_k = ib_a / 8; + const uint iqs_k = (ib_a % 8) * 32 + iqs * 32; + + i32vec4 qs_a0; + i32vec4 qs_a1; + repack8(ib_k, iqs_k, qs_a0, qs_a1); + + const vec2 dm = get_dm(ib_k, iqs_k); + + q_sum += dotPacked4x8EXT(qs_a0.x, cache_b_qs[0]); + q_sum += dotPacked4x8EXT(qs_a0.y, cache_b_qs[1]); + q_sum += dotPacked4x8EXT(qs_a0.z, cache_b_qs[2]); + q_sum += dotPacked4x8EXT(qs_a0.w, cache_b_qs[3]); + q_sum += dotPacked4x8EXT(qs_a1.x, cache_b_qs[4]); + q_sum += dotPacked4x8EXT(qs_a1.y, cache_b_qs[5]); + q_sum += dotPacked4x8EXT(qs_a1.z, cache_b_qs[6]); + q_sum += dotPacked4x8EXT(qs_a1.w, cache_b_qs[7]); + + return FLOAT_TYPE(float(cache_b_ds.x) * float(dm.x) * float(q_sum) + float(dm.y) * float(cache_b_ds.y)); +} +#endif + +#if defined(DATA_A_IQ1_M) +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + const uint ib_k = ib_a / 8; + const uint iqs_k = (ib_a % 8) * 32 + iqs * 32; + + const uint ib32 = iqs_k / 32; + const uint ib64 = ib32 / 2; + + const uint16_t[4] scales = data_a[ib_k].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + + const uint qs32 = data_a_packed32[ib_k].qs[ib32]; + const uint qh16 = data_a_packed16[ib_k].qh[ib32]; + + float sum = 0; + const uint sc = data_a[ib_k].scales[ib64]; + [[unroll]] for (int l = 0; l < 4; ++l) { + const uint ib16 = 2 * ib32 + l / 2; + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); + const uint qh = qh16 >> (4 * l); + const uint qs = (qs32 >> (8 * l)) & 0xFF; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + + const int32_t grid = int32_t(iq1s_grid_gpu[qs | ((qh & 7) << 8)]); + + int32_t q_sum = 0; + q_sum += dotPacked4x8EXT((grid >> 0) & 0x0F0F0F0F, cache_b_qs[2 * l + 0]); + q_sum += dotPacked4x8EXT((grid >> 4) & 0x0F0F0F0F, cache_b_qs[2 * l + 1]); + + int32_t y_sum = 0; + y_sum += dotPacked4x8EXT(int(0x01010101), cache_b_qs[2 * l + 0]); + y_sum += dotPacked4x8EXT(int(0x01010101), cache_b_qs[2 * l + 1]); + + // the -1 cancels out the bias in iq1s_grid_gpu + sum += dl * (q_sum + y_sum * (delta - 1)); + } + sum *= float(cache_b_ds.x); + + return sum; +} +#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp index b83a2b9d2d..4bf6d2bcb0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp @@ -7,6 +7,10 @@ #include "types.glsl" +#define GATING_FUNC_SOFTMAX 0 +#define GATING_FUNC_SIGMOID 1 +#define GATING_FUNC_SOFTMAX_WEIGHT 2 + layout (push_constant) uniform parameter { uint n_rows; @@ -14,15 +18,18 @@ layout (push_constant) uniform parameter uint n_expert_used; float clamp_min; float clamp_max; + uint gating_func; + uint has_bias; + uint with_norm; + float output_scale; + float output_bias; }; layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; layout(constant_id = 0) const uint WARP_SIZE = 32; layout(constant_id = 1) const uint n_experts_spec = 512; -layout(constant_id = 2) const bool with_norm = true; -layout(constant_id = 3) const bool late_softmax = false; -layout(constant_id = 4) const bool nexperts_use_push = false; +layout(constant_id = 2) const bool nexperts_use_push = false; uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec; @@ -31,8 +38,9 @@ uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec; const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE); layout (binding = 0, std430) readonly buffer Logits {float logits[];}; -layout (binding = 1, std430) writeonly buffer Weights {float weights[];}; -layout (binding = 2, std430) writeonly buffer Ids {uint ids[];}; +layout (binding = 1, std430) readonly buffer BiasProbs {float bias[];}; +layout (binding = 2, std430) writeonly buffer Weights {float weights[];}; +layout (binding = 3, std430) writeonly buffer Ids {uint ids[];}; const float INFINITY = 1.0 / 0.0; @@ -87,20 +95,40 @@ void main() { } const uint logits_offset = n_experts * row; + const uint bias_offset = 0; // 1D const uint weights_offset = n_expert_used * row; const uint ids_offset = n_experts * row; const uint lane = gl_SubgroupInvocationID; - float wt[experts_per_thread]; + float probs[experts_per_thread]; [[unroll]] for (uint i = 0; i < n_experts; i += WARP_SIZE) { const uint expert = i + lane; - wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY; + probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY; } - if (!late_softmax) { - softmax_warp_inplace(wt, n_experts, lane, nexperts_use_push); + if (gating_func == GATING_FUNC_SOFTMAX) { + softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push); + } else if (gating_func == GATING_FUNC_SIGMOID) { + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + probs[i] = 1.f / (1.f + exp(-probs[i])); + } + } + + float selection_probs[experts_per_thread]; + if (has_bias != 0) { + [[unroll]] + for (uint i = 0; i < n_experts; i += WARP_SIZE) { + const uint expert = i + lane; + selection_probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? probs[i / WARP_SIZE] + bias[bias_offset + expert] : -INFINITY; + } + } else { + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + selection_probs[i] = probs[i]; + } } // at this point, each thread holds a portion of softmax, @@ -117,14 +145,16 @@ void main() { } for (int k = 0; k < n_expert_used; k++) { - float max_val = wt[0]; + float max_val = probs[0]; + float max_val_s = selection_probs[0]; uint max_expert = lane; [[unroll]] for (int i = 1; i < experts_per_thread; i++) { const uint expert = lane + i * WARP_SIZE; - if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { - max_val = wt[i]; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i] > max_val_s) { + max_val = probs[i]; + max_val_s = selection_probs[i]; max_expert = expert; } } @@ -132,9 +162,11 @@ void main() { [[unroll]] for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) { const float val = subgroupShuffleXor(max_val, mask); + const float val_s = subgroupShuffleXor(max_val_s, mask); const uint expert = subgroupShuffleXor(max_expert, mask); - if (val > max_val || (val == max_val && expert < max_expert)) { + if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) { max_val = val; + max_val_s = val_s; max_expert = expert; } } @@ -144,16 +176,14 @@ void main() { } if ((max_expert & (WARP_SIZE - 1)) == lane) { - wt[max_expert / WARP_SIZE] = -INFINITY; + selection_probs[max_expert / WARP_SIZE] = -INFINITY; ids[ids_offset + k] = max_expert; - if (with_norm) { - wt_sum += max_val; - } + wt_sum += max_val; } } - if (with_norm) { + if (with_norm != 0) { wt_sum = subgroupAdd(wt_sum); wt_sum = clamp(wt_sum, clamp_min, clamp_max); const float inv_sum = 1.0f / wt_sum; @@ -164,7 +194,7 @@ void main() { } } - if (late_softmax) { + if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) { softmax_warp_inplace(output_weights, n_expert_used, lane, true); } @@ -172,7 +202,7 @@ void main() { for (uint i = 0; i < experts_per_thread; ++i) { uint idx = i * WARP_SIZE + lane; if (idx < n_expert_used) { - weights[weights_offset + idx] = output_weights[i]; + weights[weights_offset + idx] = output_scale * output_weights[i] + output_bias; } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 402a2a8397..bdb2c09259 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -396,6 +396,12 @@ struct block_iq1_s { uint16_t qh[QUANT_K_IQ1_S/32]; }; +struct block_iq1_s_packed16 { + float16_t d; + uint16_t qs[QUANT_K_IQ1_S/8/2]; + uint16_t qh[QUANT_K_IQ1_S/32]; +}; + #define QUANT_K_IQ1_M 256 #define QUANT_R_IQ1_M 1 @@ -405,6 +411,18 @@ struct block_iq1_m { uint16_t scales[QUANT_K_IQ1_M/64]; }; +struct block_iq1_m_packed16 { + uint16_t qs[QUANT_K_IQ1_M/8/2]; + uint16_t qh[QUANT_K_IQ1_M/16/2]; + uint16_t scales[QUANT_K_IQ1_M/64]; +}; + +struct block_iq1_m_packed32 { + uint32_t qs[QUANT_K_IQ1_M/8/4]; + uint32_t qh[QUANT_K_IQ1_M/16/4]; + uint32_t scales[QUANT_K_IQ1_M/64/2]; +}; + struct block_iq1_m_packed64 { uint64_t qs[QUANT_K_IQ1_M/8/8]; uint64_t qh[QUANT_K_IQ1_M/16/8]; @@ -415,12 +433,15 @@ struct block_iq1_m_packed64 { #define QUANT_K QUANT_K_IQ1_S #define QUANT_R QUANT_R_IQ1_S #define A_TYPE block_iq1_s +#define A_TYPE_PACKED16 block_iq1_s_packed16 #endif #if defined(DATA_A_IQ1_M) #define QUANT_K QUANT_K_IQ1_M #define QUANT_R QUANT_R_IQ1_M #define A_TYPE block_iq1_m +#define A_TYPE_PACKED16 block_iq1_m_packed16 +#define A_TYPE_PACKED32 block_iq1_m_packed32 #endif #if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M) @@ -559,7 +580,270 @@ const uint[1024] iq1s_grid_const = { 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 }; +// Same content as iq1s_grid_const except each 2-bit value is expanded to 4-bit +// and has 1 added to it (allows packed values to be extracted with & 0x0F0F0F0F +// and 0xF0F0F0F0). +const uint32_t[2048] iq1s_grid_gpu_const = { + 0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000, + 0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101, + 0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200, + 0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212, + 0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011, + 0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111, + 0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220, + 0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022, + 0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220, + 0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101, + 0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110, + 0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111, + 0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010, + 0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210, + 0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221, + 0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021, + 0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002, + 0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101, + 0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101, + 0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211, + 0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110, + 0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022, + 0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121, + 0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220, + 0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001, + 0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101, + 0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102, + 0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012, + 0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010, + 0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111, + 0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122, + 0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222, + 0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001, + 0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102, + 0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101, + 0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000, + 0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101, + 0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112, + 0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110, + 0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211, + 0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012, + 0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111, + 0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120, + 0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122, + 0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121, + 0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221, + 0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001, + 0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101, + 0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101, + 0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011, + 0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111, + 0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011, + 0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122, + 0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121, + 0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222, + 0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101, + 0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000, + 0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200, + 0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110, + 0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112, + 0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222, + 0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021, + 0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121, + 0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201, + 0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200, + 0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101, + 0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011, + 0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010, + 0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211, + 0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121, + 0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000, + 0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202, + 0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202, + 0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211, + 0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112, + 0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020, + 0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121, + 0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222, + 0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102, + 0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100, + 0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110, + 0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011, + 0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111, + 0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110, + 0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121, + 0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222, + 0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201, + 0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102, + 0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201, + 0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012, + 0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010, + 0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010, + 0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110, + 0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011, + 0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212, + 0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021, + 0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021, + 0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021, + 0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101, + 0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101, + 0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100, + 0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010, + 0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111, + 0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010, + 0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111, + 0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120, + 0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120, + 0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101, + 0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001, + 0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201, + 0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210, + 0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211, + 0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111, + 0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112, + 0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211, + 0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010, + 0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021, + 0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122, + 0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221, + 0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102, + 0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100, + 0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101, + 0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101, + 0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101, + 0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012, + 0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110, + 0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112, + 0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210, + 0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210, + 0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210, + 0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010, + 0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110, + 0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122, + 0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020, + 0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021, + 0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022, + 0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120, + 0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222, + 0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221, + 0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001, + 0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102, + 0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201, + 0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012, + 0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111, + 0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012, + 0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110, + 0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110, + 0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121, + 0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221, + 0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220, + 0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222, + 0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000, + 0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201, + 0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012, + 0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011, + 0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212, + 0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221, + 0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121, + 0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202, + 0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202, + 0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002, + 0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101, + 0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210, + 0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112, + 0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011, + 0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011, + 0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210, + 0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020, + 0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220, + 0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222, + 0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222, + 0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001, + 0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010, + 0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111, + 0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010, + 0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110, + 0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221, + 0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122, + 0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202, + 0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100, + 0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101, + 0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112, + 0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111, + 0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211, + 0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222, + 0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221, + 0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022, + 0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101, + 0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211, + 0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111, + 0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111, + 0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010, + 0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121, + 0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222, + 0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000, + 0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202, + 0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000, + 0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202, + 0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110, + 0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110, + 0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222, + 0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120, + 0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022, + 0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101, + 0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202, + 0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110, + 0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110, + 0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111, + 0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111, + 0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120, + 0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121, + 0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001, + 0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202, + 0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001, + 0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200, + 0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011, + 0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212, + 0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012, + 0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110, + 0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012, + 0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111, + 0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020, + 0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121, + 0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222, + 0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102, + 0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102, + 0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101, + 0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212, + 0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210, + 0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111, + 0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212, + 0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221, + 0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121, + 0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002, + 0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000, + 0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202, + 0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112, + 0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111, + 0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020, + 0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221, + 0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022, + 0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100, + 0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201, + 0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112, + 0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211, + 0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012, + 0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121, + 0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020, + 0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120, + 0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200, + 0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200, + 0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110, + 0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011, + 0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222, + 0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020, + 0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222, +}; + shared uint16_t iq1s_grid[2048]; +shared uint32_t iq1s_grid_gpu[2048]; #define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) @@ -573,6 +857,12 @@ void init_iq_shmem(uvec3 wgsize) iq1s_grid[2*idx+1] = g.y; } } + [[unroll]] for (uint i = 0; i < iq1s_grid_gpu_const.length(); i += wgsize.x) { + uint idx = i + gl_LocalInvocationIndex.x; + if (iq1s_grid_gpu_const.length() % wgsize.x == 0 || idx < iq1s_grid_gpu_const.length()) { + iq1s_grid_gpu[idx] = iq1s_grid_gpu_const[idx]; + } + } barrier(); } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 4a83378374..5b61ff9ca2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -685,7 +685,7 @@ void process_shaders() { // mul mat vec with integer dot product #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname)) { + if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname) || tname == "iq1_s" || tname == "iq1_m") { string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); @@ -944,6 +944,8 @@ void process_shaders() { string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("cumsum_multipass1_f32", "cumsum_multipass1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("cumsum_multipass2_f32", "cumsum_multipass2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("count_experts", "count_experts.comp", merge_maps(base_dict, {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}})); @@ -1123,7 +1125,7 @@ void write_output_files() { for (const std::string& btype : btypes) { for (const auto& tname : type_names) { - if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname)) { + if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname) && tname != "iq1_s" && tname != "iq1_m") { continue; } hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n"; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c2a0f41c1b..c8feca5679 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -294,7 +294,9 @@ class Keys: USE_GELU = "clip.use_gelu" USE_SILU = "clip.use_silu" N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl + WA_LAYER_INDEXES = "clip.vision.wa_layer_indexes" # used by youtuvl IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers" + WINDOW_SIZE = "clip.vision.window_size" class Attention: HEAD_COUNT = "clip.vision.attention.head_count" @@ -452,6 +454,7 @@ class MODEL_ARCH(IntEnum): MISTRAL3 = auto() MIMO2 = auto() LLAMA_EMBED = auto() + MAINCODER = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -850,6 +853,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.MISTRAL3: "mistral3", MODEL_ARCH.MIMO2: "mimo2", MODEL_ARCH.LLAMA_EMBED: "llama-embed", + MODEL_ARCH.MAINCODER: "maincoder", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -3257,6 +3261,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.MAINCODER: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], # TODO } @@ -3494,6 +3514,7 @@ class VisionProjectorType: LFM2A = "lfm2a" # audio MUSIC_FLAMINGO = "musicflamingo" # audio GLM4V = "glm4v" + YOUTUVL = "youtuvl" # Items here are (block size, type size) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 6a4a504f8d..612a978e4c 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1129,11 +1129,40 @@ class GGUFWriter: self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value) def add_vision_n_wa_pattern(self, value: int) -> None: + """Add window attention pattern interval for vision models. + + This defines the pattern interval for window attention vs full attention layers. + For example, if n_wa_pattern=4, then layers 3, 7, 11, ... use full attention, + while other layers use window attention. + + Used by models like Qwen2.5-VL where full attention layers follow a regular pattern. + """ self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value) + def add_vision_wa_layer_indexes(self, layers: Sequence[int]) -> None: + """Add explicit layer indexes that use full attention in vision models. + + This specifies the exact layer indices (0-based) that should use full attention + instead of window attention. All other layers will use window attention. + + Args: + layers: List of layer indices that use full attention (e.g., [3, 7, 11, 15]) + + Used by models like YoutuVL where full attention layers are explicitly specified + rather than following a regular pattern. + + Difference from add_vision_n_wa_pattern: + - n_wa_pattern: Defines a regular interval pattern (every Nth layer uses full attention) + - wa_layer_indexes: Explicitly lists which layers use full attention (irregular pattern) + """ + self.add_array(Keys.ClipVision.WA_LAYER_INDEXES, layers) + def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None: self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers) + def add_vision_window_size(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.WINDOW_SIZE, value) + # audio models def add_audio_projection_dim(self, value: int) -> None: diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 115df6c7c3..64dd4ddca5 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1221,6 +1221,7 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ: ( "multi_modal_projector.linear_{bid}", "visual.merger.mlp.{bid}", # qwen2vl + "merger.mlp.{bid}", ), MODEL_TENSOR.V_MMPROJ_FC: ( @@ -1258,6 +1259,7 @@ class TensorNameMap: "visual.patch_embed.proj", # qwen2vl "vision_tower.patch_embed.proj", # kimi-vl "model.vision.patch_embedding.proj", # cogvlm + "siglip2.vision_model.embeddings.patch_embedding", ), MODEL_TENSOR.V_ENC_EMBD_NORM: ( @@ -1291,6 +1293,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral "visual.blocks.{bid}.attn.q", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated + "siglip2.vision_model.encoder.layers.{bid}.self_attn.q_proj", # youtuvl ), MODEL_TENSOR.V_ENC_ATTN_Q_NORM: ( @@ -1308,6 +1311,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral "visual.blocks.{bid}.attn.k", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated + "siglip2.vision_model.encoder.layers.{bid}.self_attn.k_proj", ), MODEL_TENSOR.V_ENC_ATTN_K_NORM: ( @@ -1325,6 +1329,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral "visual.blocks.{bid}.attn.v", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated + "siglip2.vision_model.encoder.layers.{bid}.self_attn.v_proj", ), MODEL_TENSOR.V_ENC_INPUT_NORM: ( @@ -1339,6 +1344,7 @@ class TensorNameMap: "visual.blocks.{bid}.norm1", # qwen2vl "vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1) "model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm + "siglip2.vision_model.encoder.layers.{bid}.layer_norm1", ), MODEL_TENSOR.V_ENC_ATTN_O: ( @@ -1354,6 +1360,7 @@ class TensorNameMap: "visual.blocks.{bid}.attn.proj", # qwen2vl "vision_tower.encoder.blocks.{bid}.wo", # kimi-vl "model.vision.transformer.layers.{bid}.attention.dense", # cogvlm + "siglip2.vision_model.encoder.layers.{bid}.self_attn.out_proj", # youtuvl ), MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( @@ -1368,6 +1375,7 @@ class TensorNameMap: "visual.blocks.{bid}.norm2", # qwen2vl "vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1) "model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm + "siglip2.vision_model.encoder.layers.{bid}.layer_norm2", ), MODEL_TENSOR.V_ENC_FFN_UP: ( @@ -1383,6 +1391,7 @@ class TensorNameMap: "visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm + "siglip2.vision_model.encoder.layers.{bid}.mlp.fc1", ), MODEL_TENSOR.V_ENC_FFN_GATE: ( @@ -1404,6 +1413,7 @@ class TensorNameMap: "visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm + "siglip2.vision_model.encoder.layers.{bid}.mlp.fc2", ), MODEL_TENSOR.V_LAYER_SCALE_1: ( @@ -1430,6 +1440,7 @@ class TensorNameMap: "visual.merger.ln_q", # qwen2vl "vision_tower.encoder.final_layernorm", # kimi-vl "visual.post_layernorm", # glm4v + "siglip2.vision_model.post_layernorm", ), MODEL_TENSOR.V_MM_POST_NORM: ( @@ -1446,6 +1457,7 @@ class TensorNameMap: "multi_modal_projector.pre_norm", "pre_mm_projector_norm", "model.vision.linear_proj.norm1", # cogvlm + "merger.ln_q", ), MODEL_TENSOR.V_MM_SOFT_EMB_NORM: ( diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 762ea65c71..b0932794d4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -87,6 +87,7 @@ add_library(llama models/llada.cpp models/llama-iswa.cpp models/llama.cpp + models/maincoder.cpp models/mamba.cpp models/mimo2-iswa.cpp models/minicpm3.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 94a6807eac..93fed1a9a3 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -118,6 +118,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_MISTRAL3, "mistral3" }, { LLM_ARCH_MIMO2, "mimo2" }, { LLM_ARCH_LLAMA_EMBED, "llama-embed" }, + { LLM_ARCH_MAINCODER, "maincoder" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -2234,6 +2235,23 @@ static std::set llm_get_tensor_names(llm_arch arch) { return { LLM_TENSOR_TOKEN_EMBD, }; + case LLM_ARCH_MAINCODER: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; default: GGML_ABORT("unknown architecture for tensor mapping"); } diff --git a/src/llama-arch.h b/src/llama-arch.h index 714ead4025..57e470a9f3 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -122,6 +122,7 @@ enum llm_arch { LLM_ARCH_MISTRAL3, LLM_ARCH_MIMO2, LLM_ARCH_LLAMA_EMBED, + LLM_ARCH_MAINCODER, LLM_ARCH_UNKNOWN, }; diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index fc6a6223cf..b54ebbd155 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -74,6 +74,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS }, { "grok-2", LLM_CHAT_TEMPLATE_GROK_2 }, { "pangu-embedded", LLM_CHAT_TEMPLATE_PANGU_EMBED }, + { "solar-open", LLM_CHAT_TEMPLATE_SOLAR_OPEN }, }; llm_chat_template llm_chat_template_from_str(const std::string & name) { @@ -216,6 +217,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_GROK_2; } else if (tmpl_contains(LU8("[unused9]系统:[unused10]"))) { return LLM_CHAT_TEMPLATE_PANGU_EMBED; + } else if (tmpl_contains("<|begin|>") && tmpl_contains("<|end|>") && tmpl_contains("<|content|>")) { + return LLM_CHAT_TEMPLATE_SOLAR_OPEN; } return LLM_CHAT_TEMPLATE_UNKNOWN; } @@ -845,6 +848,14 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "[unused9]助手:"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_SOLAR_OPEN) { + for (auto message : chat) { + std::string role(message->role); + ss << "<|begin|>" << role << "<|content|>" << message->content << "<|end|>"; + } + if (add_ass) { + ss << "<|begin|>assistant"; + } } else { // template not supported return -1; diff --git a/src/llama-chat.h b/src/llama-chat.h index 684efb4d67..e1f795249c 100644 --- a/src/llama-chat.h +++ b/src/llama-chat.h @@ -54,6 +54,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_SEED_OSS, LLM_CHAT_TEMPLATE_GROK_2, LLM_CHAT_TEMPLATE_PANGU_EMBED, + LLM_CHAT_TEMPLATE_SOLAR_OPEN, LLM_CHAT_TEMPLATE_UNKNOWN, }; diff --git a/src/llama-mmap.cpp b/src/llama-mmap.cpp index 23b648a2e3..232005e140 100644 --- a/src/llama-mmap.cpp +++ b/src/llama-mmap.cpp @@ -240,9 +240,10 @@ struct llama_file::impl { throw std::runtime_error("unexpectedly reached end of file"); } } else { - bool successful = false; - while (!successful) { - off_t ret = read(fd, ptr, len); + size_t bytes_read = 0; + while (bytes_read < len) { + const size_t to_read = len - bytes_read; + ssize_t ret = ::read(fd, reinterpret_cast(ptr) + bytes_read, to_read); if (ret == -1) { if (errno == EINTR) { @@ -251,10 +252,16 @@ struct llama_file::impl { throw std::runtime_error(format("read error: %s", strerror(errno))); } if (ret == 0) { + // EOF: allow if this read was only pulling alignment padding past file end + off_t pos = lseek(fd, 0, SEEK_CUR); + if (pos != -1 && (size_t) pos == size) { + std::memset(reinterpret_cast(ptr) + bytes_read, 0, len - bytes_read); + return; + } throw std::runtime_error("unexpectedly reached end of file"); } - successful = true; + bytes_read += (size_t) ret; } } } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 5e664c8c57..6e6ca48507 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -126,6 +126,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; case LLM_TYPE_80B_A3B: return "80B.A3B"; case LLM_TYPE_100B_A6B: return "100B.A6B"; + case LLM_TYPE_102B_A12B: return "102B.A12B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; case LLM_TYPE_230B_A10B: return "230B.A10B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; @@ -1109,6 +1110,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MAINCODER: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_QWEN3VL: { ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); @@ -1682,7 +1691,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { @@ -1778,6 +1787,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { switch (hparams.n_layer) { case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) + case 48: type = LLM_TYPE_102B_A12B; break; // Solar Open case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) default: type = LLM_TYPE_UNKNOWN; } @@ -3320,7 +3330,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, layer.ffn_gate ? n_ff : n_ff * 2}, 0); + + const auto tn_ffn_up_weight = tn(LLM_TENSOR_FFN_UP, "weight", i); + ggml_tensor * t_ffn_up = ml.get_tensor_meta(tn_ffn_up_weight.str().c_str()); + const int64_t n_ffn_up = t_ffn_up ? t_ffn_up->ne[1] : n_ff; + + GGML_ASSERT(n_ffn_up == n_ff || n_ffn_up == n_ff * 2); + layer.ffn_up = create_tensor(tn_ffn_up_weight, {n_embd, n_ffn_up}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ffn_up}, TENSOR_NOT_REQUIRED); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); @@ -4776,7 +4793,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -4839,7 +4860,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -5206,9 +5231,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, flags); layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, flags); layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, flags); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, flags); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, flags); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, flags); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, TENSOR_NOT_REQUIRED | flags); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, TENSOR_NOT_REQUIRED | flags); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, TENSOR_NOT_REQUIRED | flags); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags); @@ -6761,6 +6786,37 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); } } break; + case LLM_ARCH_MAINCODER: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -7406,6 +7462,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique>(*this, params); } break; + case LLM_ARCH_MAINCODER: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_DECI: { llm = std::make_unique(*this, params); @@ -7440,7 +7500,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_MODERN_BERT: { - llm = std::make_unique>(*this, params); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_NEO_BERT: { @@ -8014,6 +8074,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ERNIE4_5_MOE: case LLM_ARCH_MISTRAL3: case LLM_ARCH_LLAMA_EMBED: + case LLM_ARCH_MAINCODER: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 diff --git a/src/llama-model.h b/src/llama-model.h index f4f44a92b6..79200a0d97 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -119,6 +119,7 @@ enum llm_type { LLM_TYPE_31B_A3_5B, LLM_TYPE_80B_A3B, // Qwen3 Next LLM_TYPE_100B_A6B, + LLM_TYPE_102B_A12B, // Solar-Open LLM_TYPE_106B_A12B, // GLM-4.5-Air LLM_TYPE_230B_A10B, // Minimax M2 LLM_TYPE_235B_A22B, diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index cd4092ca07..a20c6525e4 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -314,6 +314,12 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_YOUTU: + regex_exprs = { + "[가-힣ㄱ-ㆎ]+|[!…“”‘’—:;,、-〿︰-﹏]+|[ㄅ-ㄯ]+|[一-龥぀-ゟ゠-ヿ]+", + "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER: regex_exprs = { "[\r\n]", @@ -355,6 +361,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { case LLAMA_VOCAB_PRE_TYPE_STABLELM2: case LLAMA_VOCAB_PRE_TYPE_QWEN2: case LLAMA_VOCAB_PRE_TYPE_HUNYUAN: + case LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN: regex_exprs = { // original regex from tokenizer.json // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" @@ -1860,6 +1867,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "deepseek-v3") { pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM; clean_spaces = false; + } else if ( + tokenizer_pre == "youtu") { + pre_type = LLAMA_VOCAB_PRE_TYPE_YOUTU; + clean_spaces = false; + ignore_merges = true; } else if ( tokenizer_pre == "falcon") { pre_type = LLAMA_VOCAB_PRE_TYPE_FALCON; @@ -2015,6 +2027,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "minimax-m2") { pre_type = LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2; clean_spaces = false; + } else if ( + tokenizer_pre == "solar-open") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN; + clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -2187,6 +2203,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // for now, we apply this workaround to find the tokens based on their text for (const auto & t : token_to_id) { + auto & attr = id_to_token[t.second].attr; + // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc. if (special_eot_id == LLAMA_TOKEN_NULL) { if (false @@ -2202,10 +2220,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "" // smoldocling ) { special_eot_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2216,10 +2234,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|eom_id|>" ) { special_eom_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2236,10 +2254,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|code_prefix|>" // GLM-4.5 ) { special_fim_pre_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2256,10 +2274,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|code_suffix|>" // GLM-4.5 ) { special_fim_suf_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2276,10 +2294,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|code_middle|>" // GLM-4.5 ) { special_fim_mid_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2293,10 +2311,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "" ) { special_fim_pad_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2311,10 +2329,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "" // Granite ) { special_fim_rep_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2325,15 +2343,41 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|file_sep|>" // Qwen ) { special_fim_sep_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } } + // auto-detect unused tokens: e.g. control tokens with the word "unused" + // ideally, these tokens should be marked as unused during conversion + { + uint32_t n_unused = 0; + + for (const auto & t : token_to_id) { + auto & attr = id_to_token[t.second].attr; + + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + continue; + } + + if ((attr & LLAMA_TOKEN_ATTR_UNUSED) == 0) { + if (strstr(t.first.c_str(), "unused") != NULL) { + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_UNUSED); + } + } + + if (attr & LLAMA_TOKEN_ATTR_UNUSED) { + n_unused++; + } + } + + LLAMA_LOG_INFO("%s: %u unused tokens\n", __func__, n_unused); + } + // maintain a list of tokens that cause end-of-generation // this is currently determined based on the token text, which is obviously not ideal // ref: https://github.com/ggerganov/llama.cpp/issues/9606 @@ -2352,12 +2396,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } for (const auto & t : token_to_id) { + auto & attr = id_to_token[t.second].attr; + if (false || t.first == "<|eot_id|>" || t.first == "<|im_end|>" || t.first == "<|end|>" || t.first == "<|return|>" // o200k_harmony || t.first == "<|call|>" // o200k_harmony + || t.first == "<|flush|>" // solar-open + || t.first == "<|calls|>" // solar-open || t.first == "" || t.first == "<|endoftext|>" || t.first == "<|eom_id|>" @@ -2367,24 +2415,28 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "" // smoldocling ) { special_eog_ids.insert(t.second); - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } else { - // token is control, but not marked as EOG -> print a debug log - if (id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && special_eog_ids.count(t.second) == 0) { - LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n", - __func__, t.second, t.first.c_str()); + if (attr & LLAMA_TOKEN_ATTR_CONTROL && !(attr & LLAMA_TOKEN_ATTR_UNUSED)) { + // token is control, but not marked as EOG -> print a debug log + if (special_eog_ids.count(t.second) == 0) { + LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n", + __func__, t.second, t.first.c_str()); + } } } } // @ngxson : quick hack for gpt-oss, always render these tokens for (const auto & t : token_to_id) { + auto & attr = id_to_token[t.second].attr; + if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>" || t.first == "<|constrain|>") { - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_USER_DEFINED; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_USER_DEFINED); } } @@ -2404,34 +2456,42 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__); } - // TODO: workaround for o200k_harmony tokenizer: the "<|end|>" token should not be EOG - // we don't have a good way to detect this, so for now, if we have "<|return|>" and "<|call|>" tokens, + // TODO: workaround for o200k_harmony and solar-open tokenizer: the "<|end|>" token should not be EOG + // we don't have a good way to detect this, so for now, if we have "<|return|>" and "<|call|>" tokens ("<|calls|>" and "<|flush|>" for solar-open), // we remove the "<|end|>" token from the EOG list { bool has_return = false; bool has_call = false; bool has_end = false; + bool has_flush = false; llama_token end_id = LLAMA_TOKEN_NULL; LLAMA_LOG_INFO("%s: printing all EOG tokens:\n", __func__); for (auto tid : special_eog_ids) { - LLAMA_LOG_INFO("%s: - %d ('%s')\n", __func__, tid, id_to_token[tid].text.c_str()); + auto & text = id_to_token[tid].text; - if (id_to_token[tid].text == "<|return|>") { + LLAMA_LOG_INFO("%s: - %d ('%s')\n", __func__, tid, text.c_str()); + + if (text == "<|return|>") { has_return = true; - } else if (id_to_token[tid].text == "<|call|>") { + } else if (text == "<|call|>" || text == "<|calls|>") { has_call = true; - } else if (id_to_token[tid].text == "<|end|>") { + } else if (text == "<|flush|>") { + has_flush = true; + } else if (text == "<|end|>") { has_end = true; end_id = tid; } } - if (has_return && has_call && has_end) { + if ((has_return && has_call && has_end) || (has_call && has_flush && has_end)) { special_eog_ids.erase(end_id); - id_to_token[end_id].attr = LLAMA_TOKEN_ATTR_USER_DEFINED; - LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>' tokens, removing '<|end|>' token from EOG list\n", __func__); + + auto & attr = id_to_token[end_id].attr; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_USER_DEFINED); + + LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>', or '<|calls|>' and '<|flush|>' tokens, removing '<|end|>' token from EOG list\n", __func__); } } } diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 55f8f3923c..2b240a5491 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -51,6 +51,8 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41, LLAMA_VOCAB_PRE_TYPE_AFMOE = 42, + LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43, + LLAMA_VOCAB_PRE_TYPE_YOUTU = 44, }; struct LLM_KV; diff --git a/src/models/bert.cpp b/src/models/bert.cpp index 3274fa3b99..bca0e254fc 100644 --- a/src/models/bert.cpp +++ b/src/models/bert.cpp @@ -142,11 +142,13 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params LLM_FFN_GELU, LLM_FFN_SEQ, il); cb(cur, "ffn_out", il); } else if (model.arch == LLM_ARCH_JINA_BERT_V2) { + const bool up_contains_gate = !model.layers[il].ffn_gate && model.layers[il].ffn_up->ne[1] != hparams.n_ff(); + auto type_op = up_contains_gate ? LLM_FFN_GEGLU : LLM_FFN_GELU; cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, - model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_GEGLU, LLM_FFN_PAR, il); + type_op, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); } else { cur = build_ffn(cur, diff --git a/src/models/cogvlm.cpp b/src/models/cogvlm.cpp index edf0d1424c..0ceae3aaeb 100644 --- a/src/models/cogvlm.cpp +++ b/src/models/cogvlm.cpp @@ -3,12 +3,14 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; - float kq_scale = 1.0f / sqrtf(float(n_embd_head)); + const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); - ggml_tensor *inpL, *cur; + ggml_tensor * inpL; + ggml_tensor * cur; + inpL = build_inp_embd(model.tok_embd); ggml_tensor * inp_pos = build_inp_pos(); @@ -44,7 +46,7 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_pa } ggml_tensor * inpSA = inpL; - cur = build_norm(inpSA, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cur = build_norm(inpSA, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); // build self attention { diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index 49382874ba..ca63a62ad1 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -215,7 +215,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, hparams.expert_weights_norm, - true, hparams.expert_weights_scale, + hparams.expert_weights_scale, hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, il); cb(moe_out, "ffn_moe_out", il); diff --git a/src/models/gemma-embedding.cpp b/src/models/gemma-embedding.cpp index 90a98f7abf..944c198bf9 100644 --- a/src/models/gemma-embedding.cpp +++ b/src/models/gemma-embedding.cpp @@ -1,7 +1,5 @@ #include "models.h" - - llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k; @@ -12,10 +10,8 @@ llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, inpL = build_inp_embd(model.tok_embd); // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) - if (ubatch.token) { - inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); - cb(inpL, "inp_scaled", -1); - } + inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); + cb(inpL, "inp_scaled", -1); // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); diff --git a/src/models/gemma3.cpp b/src/models/gemma3.cpp index ae60ef4790..dec3fc4b8b 100644 --- a/src/models/gemma3.cpp +++ b/src/models/gemma3.cpp @@ -10,10 +10,9 @@ llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_gr inpL = build_inp_embd(model.tok_embd); // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) - if (ubatch.token) { - inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); - cb(inpL, "inp_scaled", -1); - } + inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); + cb(inpL, "inp_scaled", -1); + // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); diff --git a/src/models/gemma3n-iswa.cpp b/src/models/gemma3n-iswa.cpp index a0bdd6a15a..9c7b3ba0bb 100644 --- a/src/models/gemma3n-iswa.cpp +++ b/src/models/gemma3n-iswa.cpp @@ -1,7 +1,5 @@ #include "models.h" - - llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model), @@ -15,10 +13,9 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const inpL = build_inp_embd(model.tok_embd); // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) - if (ubatch.token) { - inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); - cb(inpL, "inp_scaled", -1); - } + inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); + cb(inpL, "inp_scaled", -1); + // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); @@ -248,7 +245,7 @@ ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) { // equivalent to get_per_layer_inputs() in python code // output shape: [n_embd_altup, n_layer, n_tokens] ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { - auto inp = std::make_unique(); + auto inp = std::make_unique(); ggml_tensor * inp_per_layer; if (ubatch.token) { inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); diff --git a/src/models/maincoder.cpp b/src/models/maincoder.cpp new file mode 100644 index 0000000000..da57308167 --- /dev/null +++ b/src/models/maincoder.cpp @@ -0,0 +1,117 @@ +#include "models.h" + +llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/models.h b/src/models/models.h index e2cd4e484f..72b2b760c6 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -312,6 +312,10 @@ struct llm_build_llama_iswa : public llm_graph_context { llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_maincoder : public llm_graph_context { + llm_build_maincoder(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_mamba : public llm_graph_context_mamba { llm_build_mamba(const llama_model & model, const llm_graph_params & params); }; @@ -332,7 +336,6 @@ struct llm_build_mistral3 : public llm_graph_context { llm_build_mistral3(const llama_model & model, const llm_graph_params & params); }; -template struct llm_build_modern_bert : public llm_graph_context { llm_build_modern_bert(const llama_model & model, const llm_graph_params & params); }; diff --git a/src/models/modern-bert.cpp b/src/models/modern-bert.cpp index c7809bdedf..6df418ecda 100644 --- a/src/models/modern-bert.cpp +++ b/src/models/modern-bert.cpp @@ -1,7 +1,6 @@ #include "models.h" -template -llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -24,13 +23,7 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, co auto * inp_attn = build_attn_inp_no_cache(); for (int il = 0; il < n_layer; ++il) { - float freq_base_l = 0.0f; - - if constexpr (iswa) { - freq_base_l = model.get_rope_freq_base(cparams, il); - } else { - freq_base_l = freq_base; - } + float freq_base_l = model.get_rope_freq_base(cparams, il); cur = inpL; @@ -120,7 +113,3 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, co res->t_embd = cur; ggml_build_forward_expand(gf, cur); } - -// Explicit template instantiations -template struct llm_build_modern_bert; -template struct llm_build_modern_bert; diff --git a/src/unicode.cpp b/src/unicode.cpp index bb44edfadd..b47dcbe619 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -964,6 +964,11 @@ std::vector unicode_regex_split(const std::string & text, const std { "\\p{P}", unicode_cpt_flags::PUNCTUATION }, { "\\p{M}", unicode_cpt_flags::ACCENT_MARK }, { "\\p{S}", unicode_cpt_flags::SYMBOL }, + { "\\p{Lu}", unicode_cpt_flags::LETTER }, // Uppercase letter + { "\\p{Ll}", unicode_cpt_flags::LETTER }, // Lowercase letter + { "\\p{Lt}", unicode_cpt_flags::LETTER }, // Titlecase letter + { "\\p{Lm}", unicode_cpt_flags::LETTER }, // Modifier letter + { "\\p{Lo}", unicode_cpt_flags::LETTER }, // Other letter }; static const std::map k_ucat_cpt = { @@ -1074,22 +1079,26 @@ std::vector unicode_regex_split(const std::string & text, const std continue; } - if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() && + // Match \p{...} Unicode properties of varying lengths + if (regex_expr[i + 0] == '\\' && i + 3 < regex_expr.size() && regex_expr[i + 1] == 'p' && - regex_expr[i + 2] == '{' && - regex_expr[i + 4] == '}') { - const std::string pat = regex_expr.substr(i, 5); - if (k_ucat_enum.find(pat) != k_ucat_enum.end()) { - if (!inside) { - regex_expr_collapsed += '['; + regex_expr[i + 2] == '{') { + // Find the closing brace + size_t closing_brace = regex_expr.find('}', i + 3); + if (closing_brace != std::string::npos && closing_brace <= i + 10) { // reasonable limit + const std::string pat = regex_expr.substr(i, closing_brace - i + 1); + if (k_ucat_enum.find(pat) != k_ucat_enum.end()) { + if (!inside) { + regex_expr_collapsed += '['; + } + regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat)); + regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat)); + if (!inside) { + regex_expr_collapsed += ']'; + } + i = closing_brace; + continue; } - regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat)); - regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat)); - if (!inside) { - regex_expr_collapsed += ']'; - } - i += 4; - continue; } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0b981b1788..6dedd8de58 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1158,6 +1158,7 @@ struct test_case { } virtual bool run_whole_graph() { return false; } + virtual std::vector fusion_test_nodes() { return {}; } ggml_cgraph * gf = nullptr; ggml_cgraph * gb = nullptr; @@ -1391,7 +1392,13 @@ struct test_case { GGML_UNUSED(index); }; - const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud, run_whole_graph() ? out : nullptr); + std::vector fused_nodes_to_verify = fusion_test_nodes(); + if (fused_nodes_to_verify.size() == 0 && run_whole_graph()) { + fused_nodes_to_verify.push_back(out); + } + const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud, + run_whole_graph() ? fused_nodes_to_verify.data() : nullptr, + fused_nodes_to_verify.size()); ggml_backend_buffer_free(buf); @@ -5180,6 +5187,8 @@ struct test_topk_moe : public test_case { const bool bias_probs; const MoeGatingFunc gating_func; const float scale_w; + ggml_tensor * weights {}; + ggml_tensor * selected_experts {}; test_topk_moe(std::array ne = { 10, 5, 1, 1 }, int n_expert_used = 1, @@ -5217,16 +5226,16 @@ struct test_topk_moe : public test_case { ggml_tensor * selection_probs = probs; if (bias_probs) { - ggml_tensor * exp_probs_b = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data()); + ggml_tensor * exp_probs_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]); ggml_set_name(exp_probs_b, "exp_probs_b"); selection_probs = ggml_add(ctx, probs, exp_probs_b); ggml_set_name(selection_probs, "selection_probs"); } - ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens] + selected_experts = ggml_argsort_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens] ggml_set_name(selected_experts, "selected_experts"); - ggml_tensor * weights = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] + weights = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] ggml_set_name(weights, "weights"); if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) { @@ -5252,6 +5261,21 @@ struct test_topk_moe : public test_case { ggml_set_name(weights, "weights"); return weights; } + // Verify two outputs + std::vector fusion_test_nodes() override { return { selected_experts, weights }; } + + // allow output in arbitrary order + double err(const float * a, const float * b, size_t n) override { + std::vector a2(n); + std::vector b2(n); + for (size_t i = 0; i < n; ++i) { + a2[i] = a[i]; + b2[i] = b[i]; + } + std::sort(a2.begin(), a2.end()); + std::sort(b2.begin(), b2.end()); + return nmse(a2.data(), b2.data(), n); + } }; struct test_mul_mat_vec_fusion : public test_case { diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index a78627604e..a07c81fba6 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -724,6 +724,30 @@ static void test_tools_oaicompat_json_conversion() { "]" ), common_chat_tools_to_json_oaicompat({special_function_tool}).dump(2)); + + { + auto tools_no_params = common_chat_tools_parse_oaicompat(json::parse( + R"([{"type": "function", "function": {"name": "test_func", "description": "A test"}}])")); + assert_equals((size_t) 1, tools_no_params.size()); + assert_equals(std::string("test_func"), tools_no_params[0].name); + assert_equals(std::string("A test"), tools_no_params[0].description); + assert_equals(std::string("{}"), tools_no_params[0].parameters); + } + { + auto tools_no_desc = common_chat_tools_parse_oaicompat(json::parse( + R"([{"type": "function", "function": {"name": "test_func", "parameters": {"type": "object"}}}])")); + assert_equals((size_t) 1, tools_no_desc.size()); + assert_equals(std::string("test_func"), tools_no_desc[0].name); + assert_equals(std::string(""), tools_no_desc[0].description); + } + { + auto tools_minimal = common_chat_tools_parse_oaicompat(json::parse( + R"([{"type": "function", "function": {"name": "test_func"}}])")); + assert_equals((size_t) 1, tools_minimal.size()); + assert_equals(std::string("test_func"), tools_minimal[0].name); + assert_equals(std::string(""), tools_minimal[0].description); + assert_equals(std::string("{}"), tools_minimal[0].parameters); + } } static void test_template_output_parsers() { diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 317d5f19fd..4b9022cb58 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -27,6 +27,7 @@ add_library(mtmd models/qwen3vl.cpp models/siglip.cpp models/whisper-enc.cpp + models/youtuvl.cpp ) set_target_properties(mtmd PROPERTIES diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 1ed0741883..df7e479765 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -45,13 +45,14 @@ #define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size" #define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers" -#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" -#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" -#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" -#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern" -#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size" -#define KEY_MINICPMV_VERSION "clip.minicpmv_version" -#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num" +#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" +#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" +#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" +#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern" +#define KEY_WIN_ATTN_LAYER_INDEXES "clip.vision.wa_layer_indexes" +#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size" +#define KEY_MINICPMV_VERSION "clip.minicpmv_version" +#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num" // audio-specific #define KEY_AUDIO_PROJ_TYPE "clip.audio.projector_type" // for models with mixed modalities @@ -188,6 +189,7 @@ enum projector_type { PROJECTOR_TYPE_JANUS_PRO, PROJECTOR_TYPE_LFM2A, PROJECTOR_TYPE_GLM4V, + PROJECTOR_TYPE_YOUTUVL, PROJECTOR_TYPE_UNKNOWN, }; @@ -218,6 +220,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, { PROJECTOR_TYPE_LFM2A, "lfm2a"}, { PROJECTOR_TYPE_GLM4V, "glm4v"}, + { PROJECTOR_TYPE_YOUTUVL, "youtuvl"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index 1e5aa87b98..702e10151a 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -61,6 +61,7 @@ struct clip_hparams { std::unordered_set vision_feature_layer; int32_t attn_window_size = 0; int32_t n_wa_pattern = 0; + std::unordered_set wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL) // audio int32_t n_mel_bins = 0; // whisper preprocessor diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index fb08dd258c..9f551e8f3c 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -846,6 +846,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_YOUTUVL: + { + builder = std::make_unique(ctx, img); + } break; default: GGML_ABORT("missing cgraph builder"); } @@ -1159,6 +1163,20 @@ struct clip_model_loader { LOG_WRN("%s: more info: https://github.com/ggml-org/llama.cpp/issues/16842\n\n", __func__); } } break; + case PROJECTOR_TYPE_YOUTUVL: + { + hparams.n_merge = 2; + get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false); + get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size, true); + std::vector wa_layer_indexes_vec; + get_arr_int(KEY_WIN_ATTN_LAYER_INDEXES, wa_layer_indexes_vec, true); + for (auto & layer : wa_layer_indexes_vec) { + hparams.wa_layer_indexes.insert(layer); + } + // support max_height * max_width = 8000 * 8000. 8000/16/2 = 250 image tokens + hparams.set_limit_image_tokens(1, 62500); + hparams.set_warmup_n_tokens(16*16); // avoid OOM on warmup + } break; case PROJECTOR_TYPE_GLM4V: { hparams.rope_theta = 10000.0f; @@ -1227,7 +1245,14 @@ struct clip_model_loader { LOG_INF("%s: has_llava_proj: %d\n", __func__, hparams.has_llava_projector); LOG_INF("%s: minicpmv_version: %d\n", __func__, hparams.minicpmv_version); LOG_INF("%s: n_merge: %d\n", __func__, hparams.n_merge); - LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern); + LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern); + if (!hparams.wa_layer_indexes.empty()) { + LOG_INF("%s: wa_layer_indexes: ", __func__); + for (auto & layer : hparams.wa_layer_indexes) { + LOG_INF("%d ", layer); + } + LOG_INF("\n"); + } if (hparams.image_min_pixels > 0) { LOG_INF("%s: image_min_pixels: %d%s\n", __func__, hparams.image_min_pixels, hparams.custom_image_min_tokens > 0 ? " (custom value)" : ""); } @@ -1495,6 +1520,14 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; + case PROJECTOR_TYPE_YOUTUVL: + { + model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM); // merger.ln_q (RMS norm) + model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); // merger.mlp.0 + model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); // merger.mlp.2 + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); + } break; case PROJECTOR_TYPE_GLM4V: { model.projection = get_tensor(TN_MM_PROJECTOR); @@ -2697,6 +2730,57 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str // res_imgs->data[0] = *res; res_imgs->entries.push_back(std::move(img_f32)); } break; + case PROJECTOR_TYPE_YOUTUVL: + { + const int patch_size = params.patch_size; // typically 16 + const int merge_size = params.n_merge; // typically 2 + const int align_size = patch_size * merge_size; // 32 + + const int max_num_patches = params.image_max_pixels > 0 ? + params.image_max_pixels / (patch_size * patch_size) : 256; + + // Linear search for optimal scale to fit within max_num_patches + float scale = 1.0f; + int target_height = original_size.height; + int target_width = original_size.width; + + auto get_scaled_image_size = [align_size](float scale, int size) -> int { + float scaled_size = size * scale; + // Round up to nearest multiple of align_size + int aligned = static_cast(std::ceil(scaled_size / align_size)) * align_size; + // Ensure at least one patch + return std::max(align_size, aligned); + }; + + // Linear search with 0.02 step size + while (scale > 0.0f) { + target_height = get_scaled_image_size(scale, original_size.height); + target_width = get_scaled_image_size(scale, original_size.width); + + int num_patches_h = target_height / patch_size; + int num_patches_w = target_width / patch_size; + int num_patches = num_patches_h * num_patches_w; + + if (num_patches > max_num_patches) { + scale -= 0.02f; + } else { + break; + } + } + + clip_image_size new_size = {target_width, target_height}; + + // Resize the image + clip_image_u8 resized; + img_tool::resize(*img, resized, new_size, img_tool::RESIZE_ALGO_BILINEAR, false); + + // Normalize to float32 + clip_image_f32_ptr img_f32(clip_image_f32_init()); + normalize_image_u8_to_f32(resized, *img_f32, params.image_mean, params.image_std); + + // Add to results + res_imgs->entries.push_back(std::move(img_f32)); + } break; case PROJECTOR_TYPE_IDEFICS3: { @@ -2929,6 +3013,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_GLM4V: + case PROJECTOR_TYPE_YOUTUVL: return (img->nx / params.patch_size) / 2; default: break; @@ -2944,6 +3029,7 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_GLM4V: + case PROJECTOR_TYPE_YOUTUVL: return (img->ny / params.patch_size) / 2; default: break; @@ -3004,6 +3090,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_GLM4V: + case PROJECTOR_TYPE_YOUTUVL: { // dynamic size (2 conv, so double patch size) int x_patch = img->nx / (params.patch_size * 2); @@ -3131,7 +3218,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima const int pos_w = image_size_width / patch_size; const int pos_h = image_size_height / patch_size; - const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl auto get_inp_tensor = [&gf](const char * name) { ggml_tensor * inp = ggml_graph_get_tensor(gf, name); @@ -3280,9 +3366,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima set_input_i32("positions", positions); } break; case PROJECTOR_TYPE_QWEN25VL: + case PROJECTOR_TYPE_YOUTUVL: { // pw * ph = number of tokens output by ViT after apply patch merger // ipw * ipw = number of vision token been processed inside ViT + const bool use_window_attn = ctx->model.proj_type == PROJECTOR_TYPE_QWEN25VL ? hparams.n_wa_pattern > 0 : !hparams.wa_layer_indexes.empty(); const int merge_ratio = 2; const int pw = image_size_width / patch_size / merge_ratio; const int ph = image_size_height / patch_size / merge_ratio; @@ -3293,7 +3381,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima std::vector inv_idx(ph * pw); if (use_window_attn) { - const int attn_window_size = 112; + const int attn_window_size = hparams.attn_window_size > 0 ? hparams.attn_window_size : 112; const int grid_window = attn_window_size / patch_size / merge_ratio; int dst = 0; // [num_vision_tokens, num_vision_tokens] attention mask tensor @@ -3531,6 +3619,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_JANUS_PRO: + case PROJECTOR_TYPE_YOUTUVL: return ctx->model.mm_1_b->ne[0]; case PROJECTOR_TYPE_QWEN3VL: // main path + deepstack paths diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index e08c33f353..74e94f60ec 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -27,6 +27,11 @@ struct clip_graph_qwen3vl : clip_graph { ggml_cgraph * build() override; }; +struct clip_graph_youtuvl : clip_graph { + clip_graph_youtuvl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + struct clip_graph_minicpmv : clip_graph { clip_graph_minicpmv(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; diff --git a/tools/mtmd/models/youtuvl.cpp b/tools/mtmd/models/youtuvl.cpp new file mode 100644 index 0000000000..ffbf2be554 --- /dev/null +++ b/tools/mtmd/models/youtuvl.cpp @@ -0,0 +1,179 @@ +#include "models.h" + +ggml_cgraph * clip_graph_youtuvl::build() { + GGML_ASSERT(model.class_embedding == nullptr); + const int batch_size = 1; + const bool use_window_attn = !hparams.wa_layer_indexes.empty(); + const int n_pos = n_patches; + const int num_position_ids = n_pos * 4; + const int m = 2; + const int Wp = n_patches_x; + const int Hp = n_patches_y; + const int Hm = Hp / m; + const int Wm = Wp / m; + norm_type norm_t = NORM_TYPE_NORMAL; + + int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; + + ggml_tensor * inp = build_inp_raw(); + + // change conv3d to linear + // reshape and permute to get patches, permute from (patch_size, m, Wm, patch_size, m, Hm, C) to (C, patch_size, patch_size, m, m, Wm, Hm) + { + inp = ggml_reshape_4d( + ctx0, inp, + Wm * m * patch_size, m * patch_size, Hm, 3); + inp = ggml_permute(ctx0, inp, 1, 2, 3, 0); + inp = ggml_cont_4d( + ctx0, inp, + m * patch_size * 3, Wm, m * patch_size, Hm); + + inp = ggml_permute(ctx0, inp, 0, 2, 1, 3); + inp = ggml_cont_4d( + ctx0, inp, + m * patch_size * 3, patch_size, m, Hm * Wm); + + inp = ggml_permute(ctx0, inp, 1, 0, 2, 3); + inp = ggml_cont_4d( + ctx0, inp, + patch_size, 3, patch_size, Hm * Wm * m * m); + + inp = ggml_permute(ctx0, inp, 2, 0, 1, 3); + inp = ggml_cont_3d( + ctx0, inp, + 3*patch_size* patch_size, Hm * Wm * m * m, 1); + } + inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp); + + if (model.patch_bias) { + inp = ggml_add(ctx0, inp, model.patch_bias); + } + + inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches); + + ggml_tensor * inpL = inp; + ggml_tensor * window_mask = nullptr; + ggml_tensor * window_idx = nullptr; + ggml_tensor * inv_window_idx = nullptr; + + ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + // pre-layernorm + if (model.pre_ln_w) { + inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1); + } + if (use_window_attn) { + inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4); + ggml_set_name(inv_window_idx, "inv_window_idx"); + ggml_set_input(inv_window_idx); + // mask for window attention + window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos); + ggml_set_name(window_mask, "window_mask"); + ggml_set_input(window_mask); + + // if flash attn is used, we need to pad the mask and cast to f16 + if (flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { + window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16); + } + + // inpL shape: [n_embd, n_patches_x * n_patches_y, batch_size] + GGML_ASSERT(batch_size == 1); + inpL = ggml_reshape_2d(ctx0, inpL, n_embd * 4, n_patches_x * n_patches_y * batch_size / 4); + inpL = ggml_get_rows(ctx0, inpL, inv_window_idx); + inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_patches_x * n_patches_y, batch_size); + } + + // loop over layers + for (int il = 0; il < n_layer; il++) { + const auto & layer = model.layers[il]; + const bool full_attn = use_window_attn ? hparams.wa_layer_indexes.count(il) > 0 : true; + + ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states + + // layernorm1 + cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il); + // self-attention + { + ggml_tensor * Qcur = ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b); + ggml_tensor * Kcur = ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b); + ggml_tensor * Vcur = ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b); + + Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches); + Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches); + Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_patches); + + Qcur = ggml_rope_multi( + ctx0, Qcur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + Kcur = ggml_rope_multi( + ctx0, Kcur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + + ggml_tensor * attn_mask = full_attn ? nullptr : window_mask; + + cur = build_attn(layer.o_w, layer.o_b, + Qcur, Kcur, Vcur, attn_mask, kq_scale, il); + } + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, inpL); + + inpL = cur; // inpL = residual, cur = hidden_states + + // layernorm2 + cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il); + + // ffn + cur = build_ffn(cur, + layer.ff_up_w, layer.ff_up_b, + nullptr, nullptr, + layer.ff_down_w, layer.ff_down_b, + hparams.ffn_op, il); + + // residual 2 + cur = ggml_add(ctx0, inpL, cur); + + inpL = cur; + } + + ggml_tensor * embeddings = inpL; + if (use_window_attn) { + const int spatial_merge_unit = 4; + window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / spatial_merge_unit); + ggml_set_name(window_idx, "window_idx"); + ggml_set_input(window_idx); + GGML_ASSERT(batch_size == 1); + embeddings = ggml_reshape_2d(ctx0, embeddings, n_embd * spatial_merge_unit, n_patches / spatial_merge_unit); + embeddings = ggml_get_rows(ctx0, embeddings, window_idx); + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd, n_patches, batch_size); + cb(embeddings, "window_order_restored", -1); + } + + // post-layernorm (part of Siglip2VisionTransformer, applied after encoder) + if (model.post_ln_w) { + embeddings = build_norm(embeddings, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer); + } + + // Now apply merger (VLPatchMerger): + // 1. Apply RMS norm (ln_q in VLPatchMerger) + embeddings = build_norm(embeddings, model.mm_input_norm_w, nullptr, NORM_TYPE_RMS, 1e-6, -1); + cb(embeddings, "merger_normed", -1); + + // 2. First reshape for spatial merge (merge 2x2 patches) + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size); + cb(embeddings, "merger_reshaped", -1); + + embeddings = build_ffn(embeddings, + model.mm_0_w, model.mm_0_b, + nullptr, nullptr, + model.mm_1_w, model.mm_1_b, + FFN_GELU, + -1); + ggml_build_forward_expand(gf, embeddings); + + return gf; +} diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index b0b5ab42ab..fca55b76f8 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -283,7 +283,7 @@ struct mtmd_context { // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md img_end = "[IMG_END]"; - } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL) { + } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL || proj == PROJECTOR_TYPE_YOUTUVL) { // <|vision_start|> ... (image embeddings) ... <|vision_end|> img_beg = "<|vision_start|>"; img_end = "<|vision_end|>"; diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index d1c10eed91..b3983b2b17 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/webui/src/lib/utils/clipboard.ts b/tools/server/webui/src/lib/utils/clipboard.ts index 91e8ea75ae..940e64c8ff 100644 --- a/tools/server/webui/src/lib/utils/clipboard.ts +++ b/tools/server/webui/src/lib/utils/clipboard.ts @@ -65,10 +65,7 @@ export async function copyCodeToClipboard( successMessage = 'Code copied to clipboard', errorMessage = 'Failed to copy code' ): Promise { - const doc = new DOMParser().parseFromString(rawCode, 'text/html'); - const decodedCode = doc.body.textContent ?? rawCode; - - return copyToClipboard(decodedCode, successMessage, errorMessage); + return copyToClipboard(rawCode, successMessage, errorMessage); } /**