Compare commits

..

29 Commits

Author SHA1 Message Date
Georgi Gerganov a554a1ecc7
context : fix reserve token padding to n_seqs (#18536) 2026-01-03 15:45:34 +02:00
Johannes Gäßler 0f2e42ca1d
CUDA: only allocate FA tmp buffer if needed (#18564) 2026-01-03 13:55:53 +01:00
pl752 9dba9f5352
(Bugfix, ggml-cuda) Pool alloc count fix + small size computation type adjustment (#18559)
* CUDA: Fixed obj byte size instead of obj count being passed to pool alloc (fattn-common, dst_tmp_meta)

* CUDA: Explicitly casted some of the int alloc counts before multiplication in argsort

---------

Co-authored-by: pl752 <maximpl752@gmail.com>
2026-01-03 11:13:40 +01:00
Shouyu bcfc8c3cec
ggml-hexagon: optimize activation function (#18393)
* refactor: refactor silu

* refactor: optimize swiglu

* refactor: remove unncessary if in swiglu

* refactor: refactor swiglu_oai

* chore: fix formatting issue
2026-01-02 21:24:24 -08:00
Jeff Bolz 18ddaea2ae
vulkan: Optimize GGML_OP_CUMSUM (#18417)
* vulkan: Optimize GGML_OP_CUMSUM

There are two paths: The preexisting one that does a whole row per workgroup
in a single shader, and one that splits each row into multiple blocks and does
two passes. The first pass computes partials within a block, the second adds
the block partials to compute the final result. The multipass shader is used
when there are a small number of large rows.

In the whole-row shader, handle multiple elements per invocation.

* use 2 ELEM_PER_THREAD for AMD/Intel

* address feedback
2026-01-02 15:32:30 -06:00
Jeff Bolz 706e3f93a6
vulkan: Implement mmvq for iq1_s/iq1_m (#18450) 2026-01-02 20:19:04 +01:00
Prabod 5755e52d15
model : Maincoder-1B support (#18534)
* Add Maincoder model support

* Removed SPM model vocabulary setting and MOE related GGUF parameters
Removed trailing spaces from maincoder.cpp

* removed set_vocab

* added new line

* Fix formatting

* Add a new line for PEP8
2026-01-02 20:11:59 +01:00
Georgi Gerganov f38de16341
metal : adjust extra size for FA buffer to avoid reallocations (#18545) 2026-01-02 19:02:18 +02:00
Georgi Gerganov af1e8e1a6c
graph : reduce topology branching (#18548) 2026-01-02 19:01:56 +02:00
Georgi Gerganov d84a6a98be
vocab : reduce debug logs about non-EOG control tokens (#18541)
* vocab : reduce debug logs about non-EOG control tokens

* cont : add comment
2026-01-02 16:17:33 +02:00
Chris Rohlf c6f0e832da
rpc : use unordered_map::reserve and emplace (#18513) 2026-01-02 12:09:36 +02:00
MeeMin e86f3c2221
cuda : fix copy of large tensors (ggml_nbytes <= INT_MAX assertion) (#18433)
* ggml-cuda: fixed assertion in ggml_cuda_cpy (#18140)

* ggml-cuda: changes in data types to int64_t

* ggml-cuda: added asserts for CUDA block numbers

* ggml-cuda: changed the condition for y and z dimension
2026-01-02 00:24:20 +01:00
Sigbjørn Skjæret 169ee68ffb
model : remove modern-bert iswa template (#18529)
* remove modern-bert iswa template

* forgotten
2026-01-02 00:06:42 +01:00
tt ced765be44
model: support youtu-vl model (#18479)
* Support Youtu-VL Model

* merge code

* fix bug

* revert qwen2 code & support rsplit in minja.hpp

* update warm info

* fix annotation

* u

* revert minja.hpp

* fix

* Do not write routed_scaling_factor to gguf when routed_scaling_factor is None

* fix expert_weights_scale

* LGTM after whitespace fixes

* fix

* fix

* fix

* layers to layer_index

* enum fix

---------

Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-01-01 19:25:54 +01:00
Piotr Wilkin (ilintar) 3ccccc83f7
Add conversion support for IQuestCoderForCausalLM (#18524) 2026-01-01 18:45:55 +01:00
o7si d0a6a31470
model : add support for JinaBertModel with non-gated ffn (#18475)
* WIP: Initial commit for fixing JinaBert original FF type support

* convert: add jina-v2-de tokenizer variant for German_Semantic_V3

* convert: fix token collision in BERT phantom vocab conversion

* convert: add feed_forward_type metadata

* model: add feed_forward_type metadata for jina-bert-v2

* model: jina-bert-v2 support standard GELU FFN variant

* model: remove ffn_type, detect FFN variant from tensor dimensions

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/models/bert.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/models/bert.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* revert collision fix to be handled in separate PR

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-01-01 18:38:51 +01:00
o7si 2b2afade9f
convert : fix encoding of WPM vocab for BERT models (#18500)
* convert: avoid token collision when stripping ## prefix

* convert: use token types for BERT special tokens check

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-01-01 18:27:07 +01:00
HelloKS f4f5019254
model: add Solar Open model (#18511)
* model: add Solar-Open model

* vocab: add solar-open to end eog blacklist

* model: add proper llm type

* chat: basic template for solar open

* typo: fix comment about vocab

* convert: sugested changes

* convert: suggested changes

* chat: change reasoning end tag for solar-open

* llama-chat: add solar-open template
2026-01-01 18:01:43 +01:00
Anri Lombard d5574c919c
webui: fix code copy stripping XML/HTML tags (#18518)
* webui: fix code copy stripping XML/HTML tags

* webui: update static build
2026-01-01 13:44:11 +01:00
Aman Gupta 26831bded9
ggml-cuda: remove unneccesary prints on ggml_cuda_init (#18502) 2026-01-01 19:18:43 +08:00
Jeff Bolz be47fb9285
vulkan: extend topk_moe to handle sigmoid w/exp_probs_b for nemotron (#18295)
* vulkan: extend topk_moe to handle sigmoid w/exp_probs_b for nemotron

Also handle GGML_OP_SCALE at the end (nemotron, deepseek2).

Fewer pipeline variants and spec constants, just use push constants.

In test_topk_moe, change exp_probs_b to be 1D, matching real networks.

Update test-backend-ops and ggml-backend to allow verifying multiple outputs
in a fusion test (topk_moe has two outputs). Previously only the final node
was verified.

* change test_topk_moe to allow results in arbitrary order

* disable sigmoid fusion for moltenvk
2026-01-01 08:58:27 +01:00
triplenom 9e10bd2eaf
llama: handle short reads in direct I/O path (#18504) 2026-01-01 10:24:43 +08:00
Anri Lombard 4cd162a123
chat: make tool description and parameters optional per OpenAI spec (#18478)
* chat: make tool description and parameters optional per OpenAI spec

Per the OpenAI API specification, both 'description' and 'parameters'
fields in tool function definitions are optional. Previously, the parser
would throw an exception if these fields were missing.

Attempts to fix #17667

* refactor: use value() for cleaner optional field access
2025-12-31 17:21:37 -06:00
Georgi Gerganov 13814eb370 sync : ggml 2025-12-31 18:54:43 +02:00
Georgi Gerganov 54f67b9b66 ggml : bump version to 0.9.5 (ggml/1410) 2025-12-31 18:54:43 +02:00
Anri Lombard 33ded988ba
quantize: prevent input/output file collision (#18451)
Check if input and output files are the same before quantizing to prevent
file corruption when mmap reads from a file being written to.

Fixes #12753
2025-12-31 23:29:03 +08:00
Sigbjørn Skjæret 0db8109849
convert : lint fix (#18507) 2025-12-31 14:28:21 +01:00
Henry147147 9b8329de7a
mtmd : Adding support for Nvidia Music Flamingo Model (#18470)
* Inital commit, debugging q5_k_s quant

* Made hf_to_gguf extend whisper to reduce code duplication

* addressed convert_hf_to_gguf pull request issue

---------

Co-authored-by: Henry D <henrydorsey147@gmail.com>
2025-12-31 12:13:23 +01:00
gatbontonpc 9a6369bb60
metal : add count_equal op (#18314)
* add count equal for metal

* remove trailing whitespace

* updated doc ops table

* changed shmem to i32

* added multi tg and templating

* removed BLAS support from Metal docs

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* add memset to set dst to 0

* metal : cleanup

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-12-31 10:39:48 +02:00
70 changed files with 2840 additions and 802 deletions

View File

@ -1395,6 +1395,14 @@ static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
builder.consume_reasoning_with_xml_tool_calls(form, "<seed:think>", "</seed:think>"); builder.consume_reasoning_with_xml_tool_calls(form, "<seed:think>", "</seed:think>");
} }
static void common_chat_parse_solar_open(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<|think|>", "<|end|><|begin|>assistant<|content|>");
// TODO: Tool calling
builder.add_content(builder.consume_rest());
}
static void common_chat_parse_content_only(common_chat_msg_parser & builder) { static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>"); builder.try_parse_reasoning("<think>", "</think>");
builder.add_content(builder.consume_rest()); builder.add_content(builder.consume_rest());
@ -1479,6 +1487,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_XIAOMI_MIMO: case COMMON_CHAT_FORMAT_XIAOMI_MIMO:
common_chat_parse_xiaomi_mimo(builder); common_chat_parse_xiaomi_mimo(builder);
break; break;
case COMMON_CHAT_FORMAT_SOLAR_OPEN:
common_chat_parse_solar_open(builder);
break;
default: default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
} }

View File

@ -380,8 +380,8 @@ std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & too
const auto & function = tool.at("function"); const auto & function = tool.at("function");
result.push_back({ result.push_back({
/* .name = */ function.at("name"), /* .name = */ function.at("name"),
/* .description = */ function.at("description"), /* .description = */ function.value("description", ""),
/* .parameters = */ function.at("parameters").dump(), /* .parameters = */ function.value("parameters", json::object()).dump(),
}); });
} }
} }
@ -669,6 +669,7 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder"; case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder";
case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5"; case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5";
case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo"; case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo";
case COMMON_CHAT_FORMAT_SOLAR_OPEN: return "Solar Open";
case COMMON_CHAT_FORMAT_PEG_SIMPLE: return "peg-simple"; case COMMON_CHAT_FORMAT_PEG_SIMPLE: return "peg-simple";
case COMMON_CHAT_FORMAT_PEG_NATIVE: return "peg-native"; case COMMON_CHAT_FORMAT_PEG_NATIVE: return "peg-native";
case COMMON_CHAT_FORMAT_PEG_CONSTRUCTED: return "peg-constructed"; case COMMON_CHAT_FORMAT_PEG_CONSTRUCTED: return "peg-constructed";
@ -2517,6 +2518,27 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp
return data; return data;
} }
static common_chat_params common_chat_params_init_solar_open(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
// TODO: Reasoning effort
json additional_context = {};
data.prompt = apply(tmpl, inputs, std::nullopt, std::nullopt, additional_context);
data.format = COMMON_CHAT_FORMAT_SOLAR_OPEN;
data.preserved_tokens = {
"<|think|>",
"<|content|>",
"<|begin|>",
"<|end|>",
};
// TODO: Tool calling
return data;
}
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data; common_chat_params data;
data.prompt = apply(tmpl, inputs); data.prompt = apply(tmpl, inputs);
@ -2780,6 +2802,13 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_magistral(tmpl, params); return common_chat_params_init_magistral(tmpl, params);
} }
// Solar Open
if (src.find("<|tool_response:begin|>") != std::string::npos &&
src.find("<|tool_response:name|>") != std::string::npos &&
src.find("<|tool_response:result|>") != std::string::npos) {
return common_chat_params_init_solar_open(tmpl, params);
}
// Plain handler (no tools) // Plain handler (no tools)
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return common_chat_params_init_without_tools(tmpl, params); return common_chat_params_init_without_tools(tmpl, params);

View File

@ -124,6 +124,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_QWEN3_CODER_XML, COMMON_CHAT_FORMAT_QWEN3_CODER_XML,
COMMON_CHAT_FORMAT_APRIEL_1_5, COMMON_CHAT_FORMAT_APRIEL_1_5,
COMMON_CHAT_FORMAT_XIAOMI_MIMO, COMMON_CHAT_FORMAT_XIAOMI_MIMO,
COMMON_CHAT_FORMAT_SOLAR_OPEN,
// These are intended to be parsed by the PEG parser // These are intended to be parsed by the PEG parser
COMMON_CHAT_FORMAT_PEG_SIMPLE, COMMON_CHAT_FORMAT_PEG_SIMPLE,

View File

@ -1062,6 +1062,9 @@ class TextModel(ModelBase):
if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273": if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273":
# ref: https://huggingface.co/alvarobartt/grok-2-tokenizer # ref: https://huggingface.co/alvarobartt/grok-2-tokenizer
res = "grok-2" res = "grok-2"
if chkhsh == "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df":
# ref: https://huggingface.co/aari1995/German_Semantic_V3
res = "jina-v2-de"
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5": if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
res = "llama-bpe" res = "llama-bpe"
@ -1230,6 +1233,12 @@ class TextModel(ModelBase):
if chkhsh == "4a2e2abae11ca2b86d570fc5b44be4d5eb5e72cc8f22dd136a94b37da83ab665": if chkhsh == "4a2e2abae11ca2b86d570fc5b44be4d5eb5e72cc8f22dd136a94b37da83ab665":
# ref: https://huggingface.co/KORMo-Team/KORMo-tokenizer # ref: https://huggingface.co/KORMo-Team/KORMo-tokenizer
res = "kormo" res = "kormo"
if chkhsh == "9d70134b369a70e5735009b6de918f7581b5211f7c074d1f89f753aea8248af1":
# ref: https://huggingface.co/tencent/Youtu-LLM-2B
res = "youtu"
if chkhsh == "16389f0a1f51ee53e562ffd51c371dc508639ab0e4261502071836e50e223e91":
# ref: https://huggingface.co/upstage/Solar-Open-100B
res = "solar-open"
if res is None: if res is None:
logger.warning("\n") logger.warning("\n")
@ -2486,6 +2495,7 @@ class StableLMModel(TextModel):
"VLlama3ForCausalLM", "VLlama3ForCausalLM",
"LlavaForConditionalGeneration", "LlavaForConditionalGeneration",
"VoxtralForConditionalGeneration", "VoxtralForConditionalGeneration",
"IQuestCoderForCausalLM",
"LlamaModel") "LlamaModel")
class LlamaModel(TextModel): class LlamaModel(TextModel):
model_arch = gguf.MODEL_ARCH.LLAMA model_arch = gguf.MODEL_ARCH.LLAMA
@ -3503,7 +3513,7 @@ class QwenModel(TextModel):
self._set_vocab_qwen() self._set_vocab_qwen()
@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration", "KORMoForCausalLM") @ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration", "KORMoForCausalLM", "AudioFlamingo3ForConditionalGeneration")
class Qwen2Model(TextModel): class Qwen2Model(TextModel):
model_arch = gguf.MODEL_ARCH.QWEN2 model_arch = gguf.MODEL_ARCH.QWEN2
@ -5284,13 +5294,14 @@ class BertModel(TextModel):
self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1)) self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1))
# convert to phantom space vocab # convert to phantom space vocab
def phantom(tok): def phantom(tok, toktype):
if tok.startswith("[") and tok.endswith("]"): if toktype == gguf.TokenType.CONTROL:
return tok return tok
if tok.startswith("##"): if tok.startswith("##"):
return tok[2:] return tok[2:]
return "\u2581" + tok return "\u2581" + tok
tokens = list(map(phantom, tokens)) assert len(tokens) == len(toktypes)
tokens = list(map(phantom, tokens, toktypes))
# add vocab to gguf # add vocab to gguf
self.gguf_writer.add_tokenizer_model("bert") self.gguf_writer.add_tokenizer_model("bert")
@ -6404,6 +6415,17 @@ class ARwkv7Model(Rwkv7Model):
self.gguf_writer.add_head_count(0) self.gguf_writer.add_head_count(0)
@ModelBase.register("MaincoderForCausalLM")
class MaincoderModel(TextModel):
model_arch = gguf.MODEL_ARCH.MAINCODER
def set_gguf_parameters(self):
super().set_gguf_parameters()
if (head_dim := self.hparams.get("head_dim")) is not None:
self.gguf_writer.add_rope_dimension_count(head_dim)
@ModelBase.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM") @ModelBase.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
class MambaModel(TextModel): class MambaModel(TextModel):
model_arch = gguf.MODEL_ARCH.MAMBA model_arch = gguf.MODEL_ARCH.MAMBA
@ -7181,6 +7203,7 @@ class DeepseekModel(TextModel):
"DeepseekV2ForCausalLM", "DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM", "DeepseekV3ForCausalLM",
"KimiVLForConditionalGeneration", "KimiVLForConditionalGeneration",
"YoutuForCausalLM",
) )
class DeepseekV2Model(TextModel): class DeepseekV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2 model_arch = gguf.MODEL_ARCH.DEEPSEEK2
@ -7247,7 +7270,15 @@ class DeepseekV2Model(TextModel):
super().set_gguf_parameters() super().set_gguf_parameters()
hparams = self.hparams hparams = self.hparams
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) # first_k_dense_replace: number of leading layers using dense FFN instead of MoE
# For non-MoE models (like Youtu), set to n_layer to use dense FFN for all layers
# For MoE models (like DeepSeek-V2), this is the number of leading non-MoE layers
has_moe = hparams.get("n_routed_experts") is not None
first_k_dense_replace = hparams.get("first_k_dense_replace")
if first_k_dense_replace is None:
# Default: if no MoE, all layers are dense; if MoE, none are dense
first_k_dense_replace = hparams["num_hidden_layers"] if not has_moe else 0
self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace)
self.gguf_writer.add_vocab_size(hparams["vocab_size"]) self.gguf_writer.add_vocab_size(hparams["vocab_size"])
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
@ -7259,11 +7290,24 @@ class DeepseekV2Model(TextModel):
self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) # MoE parameters (required by C++ code for DEEPSEEK2 arch)
self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) # For non-MoE models like Youtu, use intermediate_size as expert_feed_forward_length
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) moe_intermediate_size = self.find_hparam(["moe_intermediate_size", "intermediate_size"], optional=False)
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
if (n_routed_experts := hparams.get("n_routed_experts")) is not None:
self.gguf_writer.add_expert_count(n_routed_experts)
# expert_shared_count is required by C++ code, default to 0 for non-MoE models
n_shared_experts = hparams.get("n_shared_experts", 0)
self.gguf_writer.add_expert_shared_count(n_shared_experts)
# When not set, C++ code will use scale_w = false to skip the no-op scaling
if (routed_scaling_factor := hparams.get("routed_scaling_factor")) is not None:
self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)
if (norm_topk_prob := hparams.get("norm_topk_prob")) is not None and norm_topk_prob:
self.gguf_writer.add_expert_weights_norm(norm_topk_prob)
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
@ -7279,10 +7323,17 @@ class DeepseekV2Model(TextModel):
# skip vision tensors and remove "language_model." for Kimi-VL # skip vision tensors and remove "language_model." for Kimi-VL
if "vision_tower" in name or "multi_modal_projector" in name: if "vision_tower" in name or "multi_modal_projector" in name:
return [] return []
if name.startswith("siglip2.") or name.startswith("merger."):
return []
if name.startswith("language_model."): if name.startswith("language_model."):
name = name.replace("language_model.", "") name = name.replace("language_model.", "")
# skip lm_head.weight if tie_word_embeddings is True
if self.hparams.get("tie_word_embeddings", False):
if name == "lm_head.weight" or name == "model.lm_head.weight":
logger.info("Skipping tied output layer 'lm_head.weight' (will use token_embd.weight)")
return []
# rename e_score_correction_bias tensors # rename e_score_correction_bias tensors
if name.endswith("e_score_correction_bias"): if name.endswith("e_score_correction_bias"):
name = name.replace("e_score_correction_bias", "e_score_correction.bias") name = name.replace("e_score_correction_bias", "e_score_correction.bias")
@ -9292,6 +9343,19 @@ class VoxtralWhisperEncoderModel(WhisperEncoderModel):
self.gguf_writer.add_audio_stack_factor(4) # == intermediate_size // hidden_size self.gguf_writer.add_audio_stack_factor(4) # == intermediate_size // hidden_size
@ModelBase.register("AudioFlamingo3ForConditionalGeneration")
class AudioFlamingo3WhisperEncoderModel(WhisperEncoderModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.MUSIC_FLAMINGO)
def tensor_force_quant(self, name, new_name, bid, n_dims):
if ".conv" in name and ".weight" in name:
# Was trained in BF16, being safe, avoiding quantizing to FP16
return gguf.GGMLQuantizationType.F32
return super().tensor_force_quant(name, new_name, bid, n_dims)
@ModelBase.register("FalconH1ForCausalLM") @ModelBase.register("FalconH1ForCausalLM")
class FalconH1Model(Mamba2Model): class FalconH1Model(Mamba2Model):
model_arch = gguf.MODEL_ARCH.FALCON_H1 model_arch = gguf.MODEL_ARCH.FALCON_H1
@ -10604,6 +10668,79 @@ class JanusProVisionModel(MmprojModel):
return [] return []
@ModelBase.register("YOUTUVLForConditionalGeneration", "YOUTUVLForCausalLM")
class YOUTUVLVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
self.hparams_vision["image_size"] = self.hparams_vision.get("image_size", 560)
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.YOUTUVL)
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-6))
# Handle activation function
hidden_act = str(self.hparams.get("hidden_act", "gelu_pytorch_tanh")).lower()
if hidden_act in ("gelu", "gelu_pytorch_tanh", "gelu_fast", "gelu_new", "gelu_accurate"):
self.gguf_writer.add_vision_use_gelu(True)
elif hidden_act == "silu":
self.gguf_writer.add_vision_use_silu(True)
else:
raise ValueError(f"Unsupported activation function for YOUTUVL: {hidden_act}")
self.gguf_writer.add_vision_spatial_merge_size(self.hparams.get("spatial_merge_size", 2))
window_size = self.hparams.get("window_size")
if window_size is not None:
self.gguf_writer.add_vision_window_size(window_size)
# fullatt_block_indexes contains explicit layer indices that use full attention
# e.g., [2, 5, 8, 11] means layers 2, 5, 8, 11 use full attention
# All other layers use window attention
fullatt_block_indexes = self.hparams.get("fullatt_block_indexes")
assert fullatt_block_indexes is not None, "fullatt_block_indexes is required for youtuvl"
# Store the explicit layer indices for YoutuVL (irregular pattern approach)
self.gguf_writer.add_vision_wa_layer_indexes(layers=fullatt_block_indexes)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
# Skip language model tensors
skip_prefixes = ('lm_head.', 'model.layers.', 'model.embed_tokens.', 'model.norm.')
if name.startswith(skip_prefixes):
return []
# Try to map the tensor using TensorNameMap (handles vision encoder and projector)
try:
new_name = self.map_tensor_name(name)
return [(new_name, data_torch)]
except ValueError:
# If mapping fails, log warning and skip
logger.warning(f"Cannot map tensor: {name}")
return []
@ModelBase.register("SolarOpenForCausalLM")
class SolarOpenModel(Glm4MoeModel):
model_arch = gguf.MODEL_ARCH.GLM4_MOE
def set_vocab(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<unk>"])
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|startoftext|>"])
special_vocab.add_to_gguf(self.gguf_writer)
###### CONVERSION LOGIC ###### ###### CONVERSION LOGIC ######

View File

@ -145,6 +145,8 @@ models = [
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", }, {"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", }, {"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },
{"name": "kormo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/KORMo-Team/KORMo-tokenizer", }, {"name": "kormo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/KORMo-Team/KORMo-tokenizer", },
{"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", },
{"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", },
] ]
# some models are known to be broken upstream, so we will skip them as exceptions # some models are known to be broken upstream, so we will skip them as exceptions
@ -165,6 +167,8 @@ pre_computed_hashes = [
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"}, {"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"},
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"}, {"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"},
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"}, {"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
# jina-v2-de variants
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"},
] ]

View File

@ -32,7 +32,7 @@ Legend:
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | | COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | | CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |

View File

@ -965,6 +965,7 @@
"Metal","IM2COL","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2560],ne_kernel=[3,3,1,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1","support","1","yes","Metal" "Metal","IM2COL","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2560],ne_kernel=[3,3,1,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1","support","1","yes","Metal"
"Metal","IM2COL","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2560],ne_kernel=[3,3,2,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1","support","1","yes","Metal" "Metal","IM2COL","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2560],ne_kernel=[3,3,2,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1","support","1","yes","Metal"
"Metal","IM2COL","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[5,5,1,32],ne_kernel=[3,4,1,32],s0=1,s1=1,p0=0,p1=0,d0=1,d1=1,is_2D=1","support","1","yes","Metal" "Metal","IM2COL","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[5,5,1,32],ne_kernel=[3,4,1,32],s0=1,s1=1,p0=0,p1=0,d0=1,d1=1,is_2D=1","support","1","yes","Metal"
"Metal","IM2COL","type_input=f32,type_kernel=f32,dst_type=f32,ne_input=[2,2,1536,729],ne_kernel=[2,2,1536,4096],s0=1,s1=1,p0=0,p1=0,d0=1,d1=1,is_2D=1","support","1","yes","Metal"
"Metal","IM2COL_3D","type_input=f32,type_kernel=f32,dst_type=f32,ne_input=[10,10,10,9],ne_kernel=[3,3,3,1],IC=3,s0=1,s1=1,s2=1,p0=1,p1=1,p2=1,d0=1,d1=1,d2=1,v=0","support","0","no","Metal" "Metal","IM2COL_3D","type_input=f32,type_kernel=f32,dst_type=f32,ne_input=[10,10,10,9],ne_kernel=[3,3,3,1],IC=3,s0=1,s1=1,s2=1,p0=1,p1=1,p2=1,d0=1,d1=1,d2=1,v=0","support","0","no","Metal"
"Metal","IM2COL_3D","type_input=f32,type_kernel=f16,dst_type=f32,ne_input=[10,10,10,9],ne_kernel=[3,3,3,1],IC=3,s0=1,s1=1,s2=1,p0=1,p1=1,p2=1,d0=1,d1=1,d2=1,v=0","support","0","no","Metal" "Metal","IM2COL_3D","type_input=f32,type_kernel=f16,dst_type=f32,ne_input=[10,10,10,9],ne_kernel=[3,3,3,1],IC=3,s0=1,s1=1,s2=1,p0=1,p1=1,p2=1,d0=1,d1=1,d2=1,v=0","support","0","no","Metal"
"Metal","IM2COL_3D","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[10,10,10,9],ne_kernel=[3,3,3,1],IC=3,s0=1,s1=1,s2=1,p0=1,p1=1,p2=1,d0=1,d1=1,d2=1,v=0","support","0","no","Metal" "Metal","IM2COL_3D","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[10,10,10,9],ne_kernel=[3,3,3,1],IC=3,s0=1,s1=1,s2=1,p0=1,p1=1,p2=1,d0=1,d1=1,d2=1,v=0","support","0","no","Metal"
@ -4964,8 +4965,9 @@
"Metal","CONV_TRANSPOSE_1D","ne_input=[2,1,1,1],ne_kernel=[3,1,1,1],s0=1,p0=0,d0=1","support","1","yes","Metal" "Metal","CONV_TRANSPOSE_1D","ne_input=[2,1,1,1],ne_kernel=[3,1,1,1],s0=1,p0=0,d0=1","support","1","yes","Metal"
"Metal","CONV_TRANSPOSE_2D","ne_input=[3,2,3,1],ne_kernel=[2,2,1,3],stride=1","support","1","yes","Metal" "Metal","CONV_TRANSPOSE_2D","ne_input=[3,2,3,1],ne_kernel=[2,2,1,3],stride=1","support","1","yes","Metal"
"Metal","CONV_TRANSPOSE_2D","ne_input=[10,10,9,1],ne_kernel=[3,3,1,9],stride=2","support","1","yes","Metal" "Metal","CONV_TRANSPOSE_2D","ne_input=[10,10,9,1],ne_kernel=[3,3,1,9],stride=2","support","1","yes","Metal"
"Metal","COUNT_EQUAL","type=f32,ne=[4,500,1,1]","support","0","no","Metal" "Metal","CONV_TRANSPOSE_2D","ne_input=[129,63,35,1],ne_kernel=[3,3,48,35],stride=1","support","1","yes","Metal"
"Metal","COUNT_EQUAL","type=f32,ne=[4,5000,1,1]","support","0","no","Metal" "Metal","COUNT_EQUAL","type=f32,ne=[4,500,1,1]","support","1","yes","Metal"
"Metal","COUNT_EQUAL","type=f32,ne=[4,5000,1,1]","support","1","yes","Metal"
"Metal","ARGMAX","type=f32,ne=[32,1,1,1]","support","1","yes","Metal" "Metal","ARGMAX","type=f32,ne=[32,1,1,1]","support","1","yes","Metal"
"Metal","ARGMAX","type=f32,ne=[32,513,1,1]","support","1","yes","Metal" "Metal","ARGMAX","type=f32,ne=[32,513,1,1]","support","1","yes","Metal"
"Metal","ARGMAX","type=f32,ne=[100,10,1,1]","support","1","yes","Metal" "Metal","ARGMAX","type=f32,ne=[100,10,1,1]","support","1","yes","Metal"
@ -5715,15 +5717,15 @@
"Metal","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","Metal" "Metal","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","Metal"
"Metal","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=1","support","1","yes","Metal" "Metal","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=1","support","1","yes","Metal"
"Metal","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","Metal" "Metal","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[3,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[6,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[3,1024,4,1],ne_b=[3,1024,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[3,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[6,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[3,1536,4,1],ne_b=[3,1536,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[3,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[6,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[3,2048,4,1],ne_b=[3,2048,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]","support","1","yes","Metal"
@ -5733,6 +5735,15 @@
"Metal","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[9,1024,1,1],ne_b=[9,1024,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[18,1024,1,1],ne_b=[9,1024,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[9,1024,4,1],ne_b=[9,1024,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[9,1536,1,1],ne_b=[9,1536,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[18,1536,1,1],ne_b=[9,1536,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[9,1536,4,1],ne_b=[9,1536,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[9,2048,1,1],ne_b=[9,2048,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[18,2048,1,1],ne_b=[9,2048,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[9,2048,4,1],ne_b=[9,2048,1,1]","support","1","yes","Metal"
"Metal","SSM_SCAN","type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4","support","1","yes","Metal" "Metal","SSM_SCAN","type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4","support","1","yes","Metal"
"Metal","SSM_SCAN","type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Metal" "Metal","SSM_SCAN","type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Metal"
"Metal","SSM_SCAN","type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Metal" "Metal","SSM_SCAN","type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Metal"
@ -8916,6 +8927,8 @@
"Metal","SOFT_MAX","type=f32,ne=[32,2,32,1],mask=1,sinks=0,m_prec=f16,nr23=[1,1],scale=0.100000,max_bias=0.000000,inplace=0","support","1","yes","Metal" "Metal","SOFT_MAX","type=f32,ne=[32,2,32,1],mask=1,sinks=0,m_prec=f16,nr23=[1,1],scale=0.100000,max_bias=0.000000,inplace=0","support","1","yes","Metal"
"Metal","SOFT_MAX","type=f32,ne=[32,2,32,1],mask=1,sinks=1,m_prec=f32,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","Metal" "Metal","SOFT_MAX","type=f32,ne=[32,2,32,1],mask=1,sinks=1,m_prec=f32,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","Metal"
"Metal","SOFT_MAX","type=f32,ne=[32,2,32,1],mask=1,sinks=1,m_prec=f16,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","Metal" "Metal","SOFT_MAX","type=f32,ne=[32,2,32,1],mask=1,sinks=1,m_prec=f16,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","Metal"
"Metal","SOFT_MAX","type=f32,ne=[200001,2,3,1],mask=1,sinks=1,m_prec=f32,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","Metal"
"Metal","SOFT_MAX","type=f32,ne=[200001,2,3,1],mask=1,sinks=1,m_prec=f16,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","Metal"
"Metal","SOFT_MAX_BACK","type=f32,ne=[16,16,1,1],scale=1.000000,max_bias=0.000000","support","0","no","Metal" "Metal","SOFT_MAX_BACK","type=f32,ne=[16,16,1,1],scale=1.000000,max_bias=0.000000","support","0","no","Metal"
"Metal","SOFT_MAX_BACK","type=f32,ne=[15,15,1,1],scale=1.000000,max_bias=0.000000","support","0","no","Metal" "Metal","SOFT_MAX_BACK","type=f32,ne=[15,15,1,1],scale=1.000000,max_bias=0.000000","support","0","no","Metal"
"Metal","SOFT_MAX_BACK","type=f32,ne=[16,16,2,3],scale=1.000000,max_bias=0.000000","support","0","no","Metal" "Metal","SOFT_MAX_BACK","type=f32,ne=[16,16,2,3],scale=1.000000,max_bias=0.000000","support","0","no","Metal"
@ -9542,311 +9555,311 @@
"Metal","ARGSORT","type=f32,ne=[2048,2,1,3],order=1","support","1","yes","Metal" "Metal","ARGSORT","type=f32,ne=[2048,2,1,3],order=1","support","1","yes","Metal"
"Metal","ARGSORT","type=f32,ne=[2049,2,1,3],order=1","support","1","yes","Metal" "Metal","ARGSORT","type=f32,ne=[2049,2,1,3],order=1","support","1","yes","Metal"
"Metal","ARGSORT","type=f32,ne=[2,8,8192,1],order=1","support","1","yes","Metal" "Metal","ARGSORT","type=f32,ne=[2,8,8192,1],order=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1,1,1,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1,1,1,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[12,1,2,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[12,1,2,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2,1,1,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2,1,1,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[13,1,2,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[13,1,2,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2,1,1,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2,1,1,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[13,1,2,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[13,1,2,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4,1,1,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4,1,1,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[15,1,2,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[15,1,2,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4,1,1,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4,1,1,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[15,1,2,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[15,1,2,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4,1,1,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4,1,1,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[15,1,2,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[15,1,2,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8,1,1,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8,1,1,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[19,1,2,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[19,1,2,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8,1,1,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8,1,1,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[19,1,2,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[19,1,2,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8,1,1,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8,1,1,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[19,1,2,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[19,1,2,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8,1,1,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8,1,1,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[19,1,2,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[19,1,2,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16,1,1,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16,1,1,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[27,1,2,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[27,1,2,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16,1,1,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16,1,1,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[27,1,2,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[27,1,2,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16,1,1,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16,1,1,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[27,1,2,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[27,1,2,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16,1,1,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16,1,1,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[27,1,2,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[27,1,2,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16,1,1,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16,1,1,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[27,1,2,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[27,1,2,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32,1,1,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32,1,1,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[43,1,2,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[43,1,2,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32,1,1,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32,1,1,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[43,1,2,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[43,1,2,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32,1,1,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32,1,1,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[43,1,2,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[43,1,2,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32,1,1,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32,1,1,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[43,1,2,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[43,1,2,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32,1,1,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32,1,1,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[43,1,2,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[43,1,2,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[64,1,1,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[64,1,1,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[75,1,2,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[75,1,2,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[64,1,1,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[64,1,1,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[75,1,2,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[75,1,2,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[64,1,1,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[64,1,1,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[75,1,2,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[75,1,2,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[64,1,1,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[64,1,1,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[75,1,2,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[75,1,2,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[64,1,1,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[64,1,1,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[75,1,2,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[75,1,2,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[128,1,1,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[139,1,2,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[128,1,1,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[139,1,2,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[128,1,1,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[139,1,2,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[128,1,1,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[139,1,2,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[128,1,1,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[139,1,2,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[128,1,1,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[139,1,2,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[256,1,1,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[267,1,2,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[256,1,1,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[267,1,2,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[256,1,1,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[267,1,2,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[256,1,1,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[267,1,2,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[256,1,1,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[267,1,2,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[256,1,1,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[267,1,2,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[512,1,1,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[523,1,2,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[512,1,1,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[523,1,2,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[512,1,1,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[523,1,2,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[512,1,1,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[523,1,2,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[512,1,1,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[523,1,2,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[512,1,1,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[523,1,2,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=500","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[512,1,1,1],k=500,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=500","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[523,1,2,1],k=500,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=500","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=500,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=500","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=500,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=1023","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=1023,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=1023","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=1023,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=500","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=500,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=500","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=500,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=1023","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=1023,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=1023","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=1023,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=500","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=500,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=500","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=500,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=1023","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=1023,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=1023","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=1023,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=500","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=500,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=500","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=500,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=1023","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=1023,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=1023","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=1023,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=500","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=500,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=500","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=500,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=1023","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=1023,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=1023","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=1023,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=9999","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=9999,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=9999","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=9999,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=500","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=500,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=500","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=500,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=1023","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=1023,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=1023","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=1023,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=9999","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=9999,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=9999","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=9999,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=500","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=500,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=500","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=500,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=1023","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=1023,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=1023","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=1023,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=9999","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=9999,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=9999","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=9999,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=500","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=500,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=500","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=500,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=1023","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=1023,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=1023","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=1023,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=9999","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=9999,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=9999","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=9999,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=500","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=500,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=500","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=500,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=1023","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=1023,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=1023","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=1023,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=9999","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=9999,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=9999","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=9999,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=100","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=100,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=500","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=500,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=500","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=500,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=1023","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=1023,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=1023","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=1023,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=9999","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=9999,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=9999","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=9999,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16,10,10,10],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16,10,10,10],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[60,10,10,10],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[60,10,10,10],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=1","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=1,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16,10,10,10],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16,10,10,10],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[60,10,10,10],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[60,10,10,10],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=2","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=2,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16,10,10,10],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16,10,10,10],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[60,10,10,10],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[60,10,10,10],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=3","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=3,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16,10,10,10],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16,10,10,10],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[60,10,10,10],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[60,10,10,10],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=7","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=7,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16,10,10,10],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16,10,10,10],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[60,10,10,10],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[60,10,10,10],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=15,ties=0","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=15","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=15,ties=0","support","1","yes","Metal"
"Metal","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=0","support","1","yes","Metal" "Metal","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=0","support","1","yes","Metal"
"Metal","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=1","support","1","yes","Metal" "Metal","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=1","support","1","yes","Metal"
"Metal","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=nearest,flags=none","support","1","yes","Metal" "Metal","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=nearest,flags=none","support","1","yes","Metal"
@ -9891,8 +9904,9 @@
"Metal","GROUP_NORM","type=f32,ne=[64,64,320,1],num_groups=32,eps=0.000001","support","1","yes","Metal" "Metal","GROUP_NORM","type=f32,ne=[64,64,320,1],num_groups=32,eps=0.000001","support","1","yes","Metal"
"Metal","GROUP_NORM","type=f32,ne=[9,9,1280,1],num_groups=32,eps=0.000001","support","1","yes","Metal" "Metal","GROUP_NORM","type=f32,ne=[9,9,1280,1],num_groups=32,eps=0.000001","support","1","yes","Metal"
"Metal","ACC","type=f32,ne_a=[256,17,1,1],ne_b=[256,16,1,1]","support","1","yes","Metal" "Metal","ACC","type=f32,ne_a=[256,17,1,1],ne_b=[256,16,1,1]","support","1","yes","Metal"
"Metal","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1","support","1","yes","Metal" "Metal","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1,circular=0","support","1","yes","Metal"
"Metal","PAD","type=f32,ne_a=[512,512,3,1],lp0=1,rp0=1,lp1=1,rp1=1,lp2=1,rp2=1,lp3=1,rp3=1,v=0","support","0","no","Metal" "Metal","PAD","type=f32,ne_a=[33,17,2,1],pad_0=4,pad_1=3,circular=1","support","0","no","Metal"
"Metal","PAD","type=f32,ne_a=[512,512,3,1],lp0=1,rp0=1,lp1=1,rp1=1,lp2=1,rp2=1,lp3=1,rp3=1,v=0,circular=0","support","0","no","Metal"
"Metal","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","1","yes","Metal" "Metal","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","1","yes","Metal"
"Metal","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","1","yes","Metal" "Metal","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","1","yes","Metal"
"Metal","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","0","no","Metal" "Metal","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","0","no","Metal"
@ -9923,17 +9937,41 @@
"Metal","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","1","yes","Metal" "Metal","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","1","yes","Metal"
"Metal","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","1","yes","Metal" "Metal","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","1","yes","Metal"
"Metal","FILL","type=f32,ne=[2048,512,2,2],c=3.500000","support","1","yes","Metal" "Metal","FILL","type=f32,ne=[2048,512,2,2],c=3.500000","support","1","yes","Metal"
"Metal","DIAG","type=f32,ne=[10,1,4,3]","support","0","no","Metal"
"Metal","DIAG","type=f32,ne=[79,1,19,13]","support","0","no","Metal"
"Metal","DIAG","type=f32,ne=[256,1,8,16]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","0","no","Metal" "Metal","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","0","no","Metal" "Metal","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","0","no","Metal" "Metal","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[30,30,7,1],ne_rhs=[8,30,7,1]","support","0","no","Metal" "Metal","SOLVE_TRI","type=f32,ne_lhs=[30,30,7,1],ne_rhs=[8,30,7,1]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[42,42,5,2],ne_rhs=[10,42,5,2]","support","0","no","Metal" "Metal","SOLVE_TRI","type=f32,ne_lhs=[42,42,5,2],ne_rhs=[10,42,5,2]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[64,64,2,2],ne_rhs=[10,64,2,2]","support","0","no","Metal" "Metal","SOLVE_TRI","type=f32,ne_lhs=[64,64,2,2],ne_rhs=[10,64,2,2]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[64,64,2,2],ne_rhs=[64,64,2,2]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[79,79,5,3],ne_rhs=[417,79,5,3]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[128,128,4,2],ne_rhs=[32,128,4,2]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[80,80,2,8],ne_rhs=[80,80,2,8]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[80,80,2,8],ne_rhs=[79,80,2,8]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[80,80,2,8],ne_rhs=[81,80,2,8]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[80,80,8,8],ne_rhs=[80,80,8,8]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[80,80,8,8],ne_rhs=[79,80,8,8]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[80,80,8,8],ne_rhs=[81,80,8,8]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[84,84,4,4],ne_rhs=[32,84,4,4]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[95,95,8,8],ne_rhs=[40,95,8,8]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[100,100,4,4],ne_rhs=[41,100,4,4]","support","0","no","Metal" "Metal","SOLVE_TRI","type=f32,ne_lhs=[100,100,4,4],ne_rhs=[41,100,4,4]","support","0","no","Metal"
"Metal","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0","support","1","yes","Metal" "Metal","SOLVE_TRI","type=f32,ne_lhs=[128,128,4,4],ne_rhs=[31,128,4,4]","support","0","no","Metal"
"Metal","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0","support","0","no","Metal" "Metal","SOLVE_TRI","type=f32,ne_lhs=[128,128,4,4],ne_rhs=[32,128,4,4]","support","0","no","Metal"
"Metal","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=1","support","1","yes","Metal" "Metal","SOLVE_TRI","type=f32,ne_lhs=[128,128,3,4],ne_rhs=[32,128,3,4]","support","0","no","Metal"
"Metal","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=1","support","0","no","Metal" "Metal","SOLVE_TRI","type=f32,ne_lhs=[128,128,4,1],ne_rhs=[32,128,4,1]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[64,64,4,4],ne_rhs=[200,64,4,4]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[64,64,4,4],ne_rhs=[384,64,4,4]","support","0","no","Metal"
"Metal","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0,circular=0","support","1","yes","Metal"
"Metal","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0,circular=0","support","0","no","Metal"
"Metal","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0,circular=1","support","0","no","Metal"
"Metal","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0,circular=1","support","0","no","Metal"
"Metal","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=1,circular=0","support","1","yes","Metal"
"Metal","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=1,circular=0","support","0","no","Metal"
"Metal","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=1,circular=1","support","0","no","Metal"
"Metal","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=1,circular=1","support","0","no","Metal"
"Metal","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","1","yes","Metal" "Metal","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","1","yes","Metal"
"Metal","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","1","yes","Metal" "Metal","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","1","yes","Metal"
"Metal","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","1","yes","Metal" "Metal","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","1","yes","Metal"

Can't render this file because it is too large.

View File

@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
### GGML Version ### GGML Version
set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MAJOR 0)
set(GGML_VERSION_MINOR 9) set(GGML_VERSION_MINOR 9)
set(GGML_VERSION_PATCH 4) set(GGML_VERSION_PATCH 5)
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)

View File

@ -358,7 +358,7 @@ extern "C" {
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data); typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
// Compare the output of two backends // Compare the output of two backends
GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node); GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes);
// Tensor initialization // Tensor initialization
GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);

View File

@ -2053,7 +2053,7 @@ void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
ggml_free(copy.ctx_unallocated); ggml_free(copy.ctx_unallocated);
} }
bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node) { bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes) {
struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph); struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
if (copy.buffer == NULL) { if (copy.buffer == NULL) {
return false; return false;
@ -2064,22 +2064,22 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
assert(g1->n_nodes == g2->n_nodes); assert(g1->n_nodes == g2->n_nodes);
if (test_node != nullptr) { if (num_test_nodes != 0) {
// Compute the whole graph and only test the output for a specific tensor GGML_ASSERT(test_nodes);
// Compute the whole graph and only test the output for specific tensors
ggml_backend_graph_compute(backend1, g1); ggml_backend_graph_compute(backend1, g1);
ggml_backend_graph_compute(backend2, g2); ggml_backend_graph_compute(backend2, g2);
int test_node_idx = -1; bool verified = false;
for (int i = 0; i < g1->n_nodes; i++) { for (int i = 0; i < g1->n_nodes; i++) {
struct ggml_tensor * t1 = g1->nodes[i]; for (size_t j = 0; j < num_test_nodes; ++j) {
if (t1 == test_node) { if (g1->nodes[i] == test_nodes[j]) {
test_node_idx = i; callback(i, g1->nodes[i], g2->nodes[i], user_data);
break; verified = true;
}
} }
} }
GGML_ASSERT(test_node_idx != -1); GGML_ASSERT(verified);
callback(test_node_idx, g1->nodes[test_node_idx], g2->nodes[test_node_idx], user_data);
} else { } else {
for (int i = 0; i < g1->n_nodes; i++) { for (int i = 0; i < g1->n_nodes; i++) {
struct ggml_tensor * t1 = g1->nodes[i]; struct ggml_tensor * t1 = g1->nodes[i];

View File

@ -29,8 +29,8 @@ static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
const int nrows, const int nrows,
ggml_sort_order order, ggml_sort_order order,
cudaStream_t stream) { cudaStream_t stream) {
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows); ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ((size_t) ncols) * nrows);
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows); ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ((size_t) ncols) * nrows);
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1); ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
int * temp_indices = temp_indices_alloc.get(); int * temp_indices = temp_indices_alloc.get();

View File

@ -12,11 +12,11 @@ const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available
const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows
template <cpy_kernel_t cpy_1> template <cpy_kernel_t cpy_1>
static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne, static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
const int nb12, const int nb13) { const int64_t nb12, const int64_t nb13) {
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) { if (i >= ne) {
return; return;
@ -40,10 +40,10 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne,
} }
template <typename T> template <typename T>
static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int ne, static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int64_t ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
const int nb12, const int nb13) { const int64_t nb12, const int64_t nb13) {
const T* src = reinterpret_cast<const T*>(cx); const T* src = reinterpret_cast<const T*>(cx);
T* dst = reinterpret_cast<T*>(cdst); T* dst = reinterpret_cast<T*>(cdst);
@ -117,60 +117,60 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
} }
template <cpy_kernel_t cpy_blck, int qk> template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
const int nb12, const int nb13) { const int64_t nb12, const int64_t nb13) {
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk;
if (i >= ne) { if (i >= ne) {
return; return;
} }
const int i03 = i/(ne00 * ne01 * ne02); const int64_t i03 = i/(ne00 * ne01 * ne02);
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
const int i13 = i/(ne10 * ne11 * ne12); const int64_t i13 = i/(ne10 * ne11 * ne12);
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13; const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
cpy_blck(cx + x_offset, cdst + dst_offset); cpy_blck(cx + x_offset, cdst + dst_offset);
} }
template <cpy_kernel_t cpy_blck, int qk> template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne, static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
const int nb12, const int nb13) { const int64_t nb12, const int64_t nb13) {
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk;
if (i >= ne) { if (i >= ne) {
return; return;
} }
const int i03 = i/(ne00 * ne01 * ne02); const int64_t i03 = i/(ne00 * ne01 * ne02);
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
const int i13 = i/(ne10 * ne11 * ne12); const int64_t i13 = i/(ne10 * ne11 * ne12);
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13; const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
cpy_blck(cx + x_offset, cdst + dst_offset); cpy_blck(cx + x_offset, cdst + dst_offset);
} }
template<typename src_t, typename dst_t> template<typename src_t, typename dst_t>
static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) { static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) {
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) { if (i >= ne) {
return; return;
@ -188,19 +188,20 @@ static void ggml_cpy_scalar_contiguous_cuda(
cudaStream_t stream) { cudaStream_t stream) {
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>> cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne); (cx, cdst, ne);
} }
template<typename src_t, typename dst_t, bool transposed = false> template<typename src_t, typename dst_t, bool transposed = false>
static void ggml_cpy_scalar_cuda( static void ggml_cpy_scalar_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int64_t ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
if (transposed) { if (transposed) {
GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
int ne00n, ne01n, ne02n; int64_t ne00n, ne01n, ne02n;
if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here
ne00n = ne00; ne00n = ne00;
ne01n = ne01; ne01n = ne01;
@ -211,143 +212,159 @@ static void ggml_cpy_scalar_cuda(
ne02n = 1; ne02n = 1;
} }
dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, int64_t grid_x = (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
(ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, int64_t grid_y = (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
(ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM); int64_t grid_z = (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM;
GGML_ASSERT(grid_x < UINT_MAX);
GGML_ASSERT(grid_y < USHRT_MAX);
GGML_ASSERT(grid_z < USHRT_MAX);
dim3 dimGrid(grid_x, grid_y, grid_z);
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1); dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>> cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} else { } else {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>> cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
} }
static void ggml_cpy_f32_q8_0_cuda( static void ggml_cpy_f32_q8_0_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int64_t ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK8_0 == 0); GGML_ASSERT(ne % QK8_0 == 0);
const int num_blocks = ne / QK8_0; const int64_t num_blocks = ne / QK8_0;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_q8_0_f32_cuda( static void ggml_cpy_q8_0_f32_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int64_t ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
const int num_blocks = ne; const int64_t num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>> cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_q4_0_cuda( static void ggml_cpy_f32_q4_0_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int64_t ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_0 == 0); GGML_ASSERT(ne % QK4_0 == 0);
const int num_blocks = ne / QK4_0; const int64_t num_blocks = ne / QK4_0;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_q4_0_f32_cuda( static void ggml_cpy_q4_0_f32_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int64_t ne,
const int ne00, const int ne01, const int ne02, const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int nb00, const int nb01, const int nb02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
const int nb10, const int nb11, const int nb12, const int nb13, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
cudaStream_t stream) { cudaStream_t stream) {
const int num_blocks = ne; const int64_t num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>( cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13); ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_q4_1_cuda( static void ggml_cpy_f32_q4_1_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int64_t ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_1 == 0); GGML_ASSERT(ne % QK4_1 == 0);
const int num_blocks = ne / QK4_1; const int64_t num_blocks = ne / QK4_1;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_q4_1_f32_cuda( static void ggml_cpy_q4_1_f32_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int64_t ne,
const int ne00, const int ne01, const int ne02, const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int nb00, const int nb01, const int nb02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
const int nb10, const int nb11, const int nb12, const int nb13, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
cudaStream_t stream) { cudaStream_t stream) {
const int num_blocks = ne; const int64_t num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>( cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13); ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_q5_0_cuda( static void ggml_cpy_f32_q5_0_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int64_t ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK5_0 == 0); GGML_ASSERT(ne % QK5_0 == 0);
const int num_blocks = ne / QK5_0; const int64_t num_blocks = ne / QK5_0;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_q5_0_f32_cuda( static void ggml_cpy_q5_0_f32_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int64_t ne,
const int ne00, const int ne01, const int ne02, const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int nb00, const int nb01, const int nb02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
const int nb10, const int nb11, const int nb12, const int nb13, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
cudaStream_t stream) { cudaStream_t stream) {
const int num_blocks = ne; const int64_t num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>( cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13); ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_q5_1_cuda( static void ggml_cpy_f32_q5_1_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int64_t ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK5_1 == 0); GGML_ASSERT(ne % QK5_1 == 0);
const int num_blocks = ne / QK5_1; const int64_t num_blocks = ne / QK5_1;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_q5_1_f32_cuda( static void ggml_cpy_q5_1_f32_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int64_t ne,
const int ne00, const int ne01, const int ne02, const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int nb00, const int nb01, const int nb02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
const int nb10, const int nb11, const int nb12, const int nb13, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
cudaStream_t stream) { cudaStream_t stream) {
const int num_blocks = ne; const int64_t num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>( cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13); ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_iq4_nl_cuda( static void ggml_cpy_f32_iq4_nl_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int64_t ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_NL == 0); GGML_ASSERT(ne % QK4_NL == 0);
const int num_blocks = ne / QK4_NL; const int64_t num_blocks = ne / QK4_NL;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
@ -356,9 +373,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
const int64_t ne = ggml_nelements(src0); const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1)); GGML_ASSERT(ne == ggml_nelements(src1));
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1]; const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2]; const int64_t ne02 = src0->ne[2];

View File

@ -918,7 +918,9 @@ void launch_fattn(
blocks_num.y = 1; blocks_num.y = 1;
blocks_num.z = 1; blocks_num.z = 1;
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float)); if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
}
} else { } else {
const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size. const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.

View File

@ -201,16 +201,6 @@ static ggml_cuda_device_info ggml_cuda_init() {
GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES); GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
int64_t total_vram = 0; int64_t total_vram = 0;
#ifdef GGML_CUDA_FORCE_MMQ
GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
#else
GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
#endif // GGML_CUDA_FORCE_MMQ
#ifdef GGML_CUDA_FORCE_CUBLAS
GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__);
#else
GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
#endif // GGML_CUDA_FORCE_CUBLAS
GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count); GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
std::vector<std::pair<int, std::string>> turing_devices_without_mma; std::vector<std::pair<int, std::string>> turing_devices_without_mma;

View File

@ -85,13 +85,16 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
struct htp_spad * dst_spad, struct htp_spad * dst_spad,
uint32_t nth, uint32_t nth,
uint32_t ith, uint32_t ith,
uint32_t src0_nrows_per_thread) { uint32_t src0_nrows_per_thread,
dma_queue * dma_queue) {
htp_act_preamble3; htp_act_preamble3;
size_t src0_row_size = nb01; size_t src0_row_size = nb01;
size_t src1_row_size = nb11; size_t src1_row_size = nb11;
size_t dst_row_size = nb1; size_t dst_row_size = nb1;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_start_row = src0_nrows_per_thread * ith;
@ -105,10 +108,129 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
uint64_t t1, t2; uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count(); t1 = HAP_perf_get_qtimer_count();
int is_aligned = 1; const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) { const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
is_aligned = 0; uint8_t * restrict data_dst = (uint8_t *) dst->data;
FARF(HIGH, "swiglu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
const bool src1_valid = src1->ne[0];
const int nc = (src1_valid) ? ne00 : ne00 / 2;
if (!src1_valid) {
const int32_t swapped = op_params[1];
data_src1 = data_src0;
src1_row_size = src0_row_size;
const size_t nc_in_bytes = nc * SIZEOF_FP32;
data_src0 += swapped ? nc_in_bytes : 0;
data_src1 += swapped ? 0 : nc_in_bytes;
}
const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN);
const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN);
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
if (BLOCK == 0) {
FARF(ERROR,
"swiglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
src0_spad->size_per_thread, src0_row_size_aligned);
return;
}
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
dma_queue_push_vtcm_to_ddr(dma_queue,
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
dst_row_size, dst_row_size_aligned, 0);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
src0_row_size_aligned, src0_row_size, block_size);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
src1_row_size_aligned, src1_row_size, block_size);
}
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;
for (uint32_t ib = 0; ib < block_size; ib++) {
const float * src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float));
const float * src1_spad_ptr = src1_spad + ib * (src1_row_size_aligned / sizeof(float));
float * dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
//swiglu(x) = x1 * sigmoid(x0)
hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
hvx_mul_mul_f32_opt((const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,
(const uint8_t *) src1_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
}
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
dst_row_size_aligned, block_size);
// prefetch N+2 loop iteration if any
const uint32_t pref_block = (ir + BLOCK * 2);
if (pref_block < src0_end_row) {
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
src0_row_size_aligned, src0_row_size, pref_block_size);
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
src1_row_size_aligned, src1_row_size, pref_block_size);
}
}
dma_queue_flush(dma_queue);
t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "swiglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
const struct htp_tensor * src1,
struct htp_tensor * dst,
const int32_t * op_params,
struct htp_spad * src0_spad,
struct htp_spad * src1_spad,
struct htp_spad * dst_spad,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread,
dma_queue * dma_queue) {
htp_act_preamble3;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
size_t src0_row_size = nb01;
size_t src1_row_size = nb11;
size_t dst_row_size = nb1;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
// no work for this thread
if (src0_start_row >= src0_end_row) {
return;
} }
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
@ -127,130 +249,94 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
data_src1 += swapped ? 0 : nc_in_bytes; data_src1 += swapped ? 0 : nc_in_bytes;
} }
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size); const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size); const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN);
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size); const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN);
const bool opt_path = ((1 == is_aligned) && !(nb01 & (VLEN - 1))); uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size)); uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size));
float * restrict dst = (float *) (data_dst + (ir * dst_row_size));
if (ir + 1 < src0_end_row) { // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size); size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
} size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
if (opt_path) { const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, nc); if (BLOCK == 0) {
hvx_mul_mul_f32_opt((const uint8_t *) src0, (const uint8_t *) src0_spad_data, (const uint8_t *) src1, FARF(ERROR,
(uint8_t *) dst, nc); "swiglu-oai-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least "
} else { "%zu\n",
hvx_exp_f32((const uint8_t *) src0, src0_spad_data, nc, true); src0_spad->size_per_thread, src0_row_size_aligned);
hvx_add_scalar_f32(src0_spad_data, 1.0, src1_spad_data, nc);
hvx_inverse_f32(src1_spad_data, src0_spad_data, nc);
hvx_mul_f32((const uint8_t *) src0, src0_spad_data, dst_spad_data, nc);
hvx_mul_f32(dst_spad_data, (const uint8_t *) src1, (uint8_t *) dst, nc);
}
}
t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "swiglu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path,
ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
const struct htp_tensor * src1,
struct htp_tensor * dst,
const int32_t * op_params,
struct htp_spad * src0_spad,
struct htp_spad * src1_spad,
struct htp_spad * dst_spad,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread) {
htp_act_preamble3;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
const size_t src0_row_size = nb01;
const size_t src1_row_size = nb11;
const size_t dst_row_size = nb1;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
// no work for this thread
if (src0_start_row >= src0_end_row) {
return; return;
} }
const float alpha = ((const float *) (op_params))[2];
const float limit = ((const float *) (op_params))[3];
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) { // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
FARF(HIGH, "act-f32: unaligned addresses in activations op, possibly slower execution\n"); for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
dst_row_size, dst_row_size_aligned, 0);
dma_queue_push_ddr_to_vtcm(
dma_queue,
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
src0_row_size_aligned, src0_row_size, block_size);
dma_queue_push_ddr_to_vtcm(
dma_queue,
dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
src1_row_size_aligned, src1_row_size, block_size);
} }
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
uint8_t * restrict data_dst = (uint8_t *) dst->data;
bool src1_valid = src1->ne[0]; float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
if (!src1_valid) { float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
data_src1 = data_src0; float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;
}
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size); for (uint32_t ib = 0; ib < block_size; ib++) {
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size); const float * src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float));
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size); const float * src1_spad_ptr = src1_spad + ib * (src1_row_size_aligned / sizeof(float));
float * dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
const int32_t swapped = op_params[1]; // x (src0_spad_data) = std::min(src0_p[k], limit);
const float alpha = ((const float *) (op_params))[2]; hvx_min_scalar_f32((const uint8_t *) src0_spad_ptr, limit, (uint8_t *) src0_spad_ptr, nc);
const float limit = ((const float *) (op_params))[3]; // y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit);
hvx_clamp_scalar_f32((const uint8_t *) src1_spad_ptr, -limit, limit, (uint8_t *) src1_spad_ptr, nc);
const int nc = (src1_valid) ? ne00 : ne00 / 2; // y (src1_spad_data) = y1 + 1.f
hvx_add_scalar_f32((const uint8_t *) src1_spad_ptr, 1.0, (uint8_t *) src1_spad_ptr, nc);
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { // x1 (dst_spad_data) = alpha * (x)
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size)); hvx_mul_scalar_f32((const uint8_t *) src0_spad_ptr, alpha, (uint8_t *) dst_spad_ptr, nc);
const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size)); // x2 (dst_spad_data) = sigmoid(x1) = 1/(1+exp(-x1))
float * restrict dst = (float *) (data_dst + (ir * dst_row_size)); hvx_fast_sigmoid_f32((const uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
// out = x * sigmoid(alpha * x) * (y + 1.f)
if (ir + 1 < src0_end_row) { hvx_mul_mul_f32_opt((const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,
htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size); (const uint8_t *) src1_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
} }
if (!src1) { dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
src0 += swapped ? nc : 0; dst_row_size_aligned, block_size);
src1 += swapped ? 0 : nc;
}
// x (src0_spad_data) = std::min(src0_p[k], limit); // prefetch N+2 loop iteration if any
hvx_min_scalar_f32((const uint8_t *) src0, limit, src0_spad_data, nc); const uint32_t pref_block = (ir + BLOCK * 2);
// y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit); if (pref_block < src0_end_row) {
hvx_clamp_scalar_f32((const uint8_t *) src1, -limit, limit, src1_spad_data, nc); const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
// y (src1_spad_data) = y1 + 1.f dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
hvx_add_scalar_f32(src1_spad_data, 1.0, src1_spad_data, nc); src0_row_size_aligned, src0_row_size, pref_block_size);
// x1 (dst_spad_data) = alpha * (x) dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
hvx_mul_scalar_f32(src0_spad_data, alpha, dst_spad_data, nc); src1_row_size_aligned, src1_row_size, pref_block_size);
// x2 (dst_spad_data) = expf(-x1) }
hvx_exp_f32(dst_spad_data, dst_spad_data, nc, true);
// x3 (dst_spad_data) = x2 + 1.f
hvx_add_scalar_f32(dst_spad_data, 1.0, dst_spad_data, nc);
// x4 (dst_spad_data) = 1 / x3
hvx_inverse_f32(dst_spad_data, dst_spad_data, nc);
// out_glu(dst_spad_data) = x * x4
hvx_mul_f32(src0_spad_data, dst_spad_data, dst_spad_data, nc);
// out = out_glu * (y + 1.f);
hvx_mul_f32(dst_spad_data, src1_spad_data, (uint8_t *) dst, nc);
} }
dma_queue_flush(dma_queue);
t2 = HAP_perf_get_qtimer_count(); t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "swiglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, src0->ne[0], FARF(HIGH, "swiglu-oai-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, src0->ne[0],
src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2],
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
} }
@ -371,7 +457,8 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
struct htp_spad * dst_spad, struct htp_spad * dst_spad,
uint32_t nth, uint32_t nth,
uint32_t ith, uint32_t ith,
uint32_t src0_nrows_per_thread) { uint32_t src0_nrows_per_thread,
dma_queue * dma_queue) {
htp_act_preamble2; htp_act_preamble2;
uint64_t t1, t2; uint64_t t1, t2;
@ -379,6 +466,8 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
const size_t src0_row_size = nb01; const size_t src0_row_size = nb01;
const size_t dst_row_size = nb1; const size_t dst_row_size = nb1;
const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN);
const uint32_t src0_nrows = ne01 * ne02 * ne03; const uint32_t src0_nrows = ne01 * ne02 * ne03;
@ -390,64 +479,91 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
return; return;
} }
int is_aligned = 1; const uint8_t * data_src0 = (const uint8_t *) src0->data;
int opt_path = 0; uint8_t * data_dst = (uint8_t *) dst->data;
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
is_aligned = 0; uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
FARF(HIGH, "silu-f32: unaligned addresses in elementwise op, possibly slower execution\n"); uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
}
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
opt_path = 1; size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
if (BLOCK == 0) {
FARF(ERROR, "silu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
src0_spad->size_per_thread, src0_row_size_aligned);
return;
} }
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
uint8_t * restrict data_dst = (uint8_t *) dst->data; for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size); // Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size); dma_queue_push_vtcm_to_ddr(dma_queue,
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
dst_row_size, dst_row_size_aligned, 0);
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { dma_queue_push_ddr_to_vtcm(dma_queue,
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size)); dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
float * restrict dst = (float *) (data_dst + (ir * dst_row_size)); src0_row_size_aligned, src0_row_size, block_size);
}
if (ir + 1 < src0_end_row) { for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size); const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
float* dst_spad = (float *) dma_queue_pop(dma_queue).src;
float* src0_spad = (float *) dma_queue_pop(dma_queue).dst;
for (uint32_t ib = 0; ib < block_size; ib++) {
const float* src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float));
float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
// silu = x * sigmoid(x)
hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, ne0);
hvx_mul_f32_opt((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0);
} }
if (1 == opt_path) { dma_queue_push_vtcm_to_ddr(dma_queue,
hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, ne0); dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),
hvx_mul_f32_opt((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0); dst_row_size, dst_row_size_aligned, block_size);
} else {
hvx_exp_f32((const uint8_t *) src0, src0_spad_data, ne0, true);
hvx_add_scalar_f32(src0_spad_data, 1.0, dst_spad_data, ne0);
hvx_inverse_f32(dst_spad_data, src0_spad_data, ne0);
hvx_mul_f32((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0); // prefetch N+2 loop iteration if any
const uint32_t pref_block = (ir + BLOCK * 2);
if (pref_block < src0_end_row) {
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
src0_row_size_aligned, src0_row_size, pref_block_size);
} }
} }
dma_queue_flush(dma_queue);
t2 = HAP_perf_get_qtimer_count(); t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "silu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, ne00, ne01, ne02, FARF(HIGH, "silu-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, ne00, ne01, ne02,
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
} }
static void unary_silu_fp32(unsigned int n, unsigned int i, void * data) { static void unary_silu_fp32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data; struct htp_ops_context * octx = (struct htp_ops_context *) data;
unary_silu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, unary_silu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
octx->src0_nrows_per_thread); octx->src0_nrows_per_thread, octx->ctx->dma[i]);
} }
static void glu_swiglu_fp32(unsigned int n, unsigned int i, void * data) { static void glu_swiglu_fp32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data; struct htp_ops_context * octx = (struct htp_ops_context *) data;
glu_swiglu_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, glu_swiglu_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread); &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
} }
static void glu_swiglu_oai_fp32(unsigned int n, unsigned int i, void * data) { static void glu_swiglu_oai_fp32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data; struct htp_ops_context * octx = (struct htp_ops_context *) data;
glu_swiglu_oai_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, glu_swiglu_oai_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread); &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
} }
static int execute_op_activations_fp32(struct htp_ops_context * octx) { static int execute_op_activations_fp32(struct htp_ops_context * octx) {

View File

@ -1684,3 +1684,60 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggm
return res; return res;
} }
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->type == GGML_TYPE_I64);
char base[256];
char name[256];
snprintf(base, 256, "kernel_memset_%s", ggml_type_name(op->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_COUNT_EQUAL);
GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne);
GGML_ASSERT(op->src[0]->type == op->src[1]->type);
GGML_ASSERT(op->src[0]->type == GGML_TYPE_I32);
GGML_ASSERT(op->type == GGML_TYPE_I64);
// note: the kernel only supports i32 output due to metal atomic add only supporting atomic_int
GGML_ASSERT(ggml_nelements(op->src[0]) < (1LL << 31));
char base[256];
char name[256];
int nsg = 1;
while (32*nsg < ne00 && nsg < 32) {
nsg *= 2;
}
snprintf(base, 256, "kernel_count_equal_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s_nsg=%d", base, nsg);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, nsg, FC_COUNT_EQUAL + 0);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
}
res.smem = 32 * sizeof(int32_t);
res.nsg = nsg;
return res;
}

View File

@ -147,6 +147,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad( struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib, ggml_metal_library_t lib,

View File

@ -1023,6 +1023,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM: case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_COUNT_EQUAL:
return has_simdgroup_reduction &&
op->src[0]->type == GGML_TYPE_I32 &&
op->src[1]->type == GGML_TYPE_I32 &&
op->type == GGML_TYPE_I64;
case GGML_OP_ARGMAX: case GGML_OP_ARGMAX:
return has_simdgroup_reduction; return has_simdgroup_reduction;
case GGML_OP_NORM: case GGML_OP_NORM:

View File

@ -78,6 +78,7 @@
#define FC_MUL_MM 700 #define FC_MUL_MM 700
#define FC_ROPE 800 #define FC_ROPE 800
#define FC_SSM_CONV 900 #define FC_SSM_CONV 900
#define FC_COUNT_EQUAL 1000
// op-specific constants // op-specific constants
#define OP_FLASH_ATTN_EXT_NQPTG 8 #define OP_FLASH_ATTN_EXT_NQPTG 8
@ -894,6 +895,25 @@ typedef struct {
float step; float step;
} ggml_metal_kargs_arange; } ggml_metal_kargs_arange;
typedef struct {
int64_t val;
} ggml_metal_kargs_memset;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
} ggml_metal_kargs_count_equal;
typedef struct { typedef struct {
int32_t k0; int32_t k0;
int32_t k1; int32_t k1;

View File

@ -448,7 +448,11 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{ {
n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx); n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
} break; } break;
default: case GGML_OP_COUNT_EQUAL:
{
n_fuse = ggml_metal_op_count_equal(ctx, idx);
} break;
default:
{ {
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op)); GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
@ -2177,7 +2181,11 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
const bool has_mask = op->src[3] != nullptr; const bool has_mask = op->src[3] != nullptr;
if (ggml_metal_op_flash_attn_ext_use_vec(op)) { // note: the non-vec kernel requires more extra memory, so always reserve for it
GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG);
//if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
if (false) {
// note: always reserve the padding space to avoid graph reallocations // note: always reserve the padding space to avoid graph reallocations
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0; //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
const bool has_kvpad = true; const bool has_kvpad = true;
@ -4090,3 +4098,64 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
return 1; return 1;
} }
int ggml_metal_op_count_equal(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
{
ggml_metal_kargs_memset args = { /*.val =*/ 0 };
auto pipeline = ggml_metal_library_get_pipeline_memset(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 1);
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
}
ggml_metal_op_concurrency_reset(ctx);
{
ggml_metal_kargs_count_equal args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
};
auto pipeline = ggml_metal_library_get_pipeline_count_equal(lib, op);
const size_t smem = pipeline.smem;
const int nth = 32*pipeline.nsg;
GGML_ASSERT(nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
}
return 1;
}

View File

@ -87,6 +87,7 @@ int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx); int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_count_equal (ggml_metal_op_t ctx, int idx);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -1790,6 +1790,7 @@ kernel void kernel_op_sum_f32(
return; return;
} }
// TODO: become function constant
const uint nsg = (ntg.x + 31) / 32; const uint nsg = (ntg.x + 31) / 32;
float sumf = 0; float sumf = 0;
@ -9914,3 +9915,75 @@ kernel void kernel_opt_step_sgd_f32(
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid]; x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
} }
template<typename T>
kernel void kernel_memset(
constant ggml_metal_kargs_fill & args,
device T * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = args.val;
}
typedef decltype(kernel_memset<int64_t>) kernel_memset_t;
template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset<int64_t>;
constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]];
template<typename T>
kernel void kernel_count_equal(
constant ggml_metal_kargs_count_equal & args,
device const char * src0,
device const char * src1,
device atomic_int * dst,
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const short NSG = FC_count_equal_nsg;
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
int sum = 0;
device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;
device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13;
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
const T v0 = *(device const T *)(base0 + i0*args.nb00);
const T v1 = *(device const T *)(base1 + i0*args.nb10);
sum += (v0 == v1);
}
sum = simd_sum(sum);
if (tiisg == 0) {
shmem_i32[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
float v = 0.0f;
if (tpitg.x < NSG) {
v = shmem_i32[tpitg.x];
}
float total = simd_sum(v);
if (tpitg.x == 0) {
atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed);
}
}
}
typedef decltype(kernel_count_equal<int32_t>) kernel_count_equal_t;
template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal<int32_t>;

View File

@ -1517,10 +1517,12 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false); struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
graph->n_nodes = n_nodes; graph->n_nodes = n_nodes;
std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs; std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
tensor_ptrs.reserve(n_tensors);
for (uint32_t i = 0; i < n_tensors; i++) { for (uint32_t i = 0; i < n_tensors; i++) {
tensor_ptrs[tensors[i].id] = &tensors[i]; tensor_ptrs.emplace(tensors[i].id, &tensors[i]);
} }
std::unordered_map<uint64_t, ggml_tensor*> tensor_map; std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
tensor_map.reserve(n_nodes);
for (uint32_t i = 0; i < n_nodes; i++) { for (uint32_t i = 0; i < n_nodes; i++) {
int64_t id; int64_t id;
memcpy(&id, &nodes[i], sizeof(id)); memcpy(&id, &nodes[i], sizeof(id));

View File

@ -434,8 +434,15 @@ static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGM
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
GGML_OP_RESHAPE }; GGML_OP_RESHAPE };
static constexpr std::initializer_list<ggml_op> topk_moe_sigmoid_norm_bias{ GGML_OP_UNARY, GGML_OP_RESHAPE, GGML_OP_ADD,
GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
GGML_OP_DIV, GGML_OP_RESHAPE };
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS }; GGML_OP_VIEW, GGML_OP_GET_ROWS };
static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW, static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW,
GGML_OP_GET_ROWS, GGML_OP_RESHAPE, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
@ -464,6 +471,32 @@ static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softma
{ 9, 0, 8 }, // reshape->src[0] == div { 9, 0, 8 }, // reshape->src[0] == div
}; };
//node #436 ( UNARY): ffn_moe_probs-10 ( 256K) [Vulka ] use=2: ffn_moe_logits-10 ( 256K) [Vulka ]
//node #437 ( RESHAPE): ffn_moe_probs-10 (re ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ]
//node #438 ( ADD): ffn_moe_probs_biased ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ] blk.10.exp_probs_b.b ( 0K) [Vulka ]
//node #439 ( ARGSORT): ffn_moe_argsort-10 ( 256K) [Vulka ] use=1: ffn_moe_probs_biased ( 256K) [Vulka ]
//node #440 ( VIEW): ffn_moe_topk-10 ( 255K) [Vulka ] use=3: ffn_moe_argsort-10 ( 256K) [Vulka ]
//node #441 ( GET_ROWS): ffn_moe_weights-10 ( 12K) [Vulka ] use=1: ffn_moe_probs-10 (re ( 256K) [Vulka ] ffn_moe_topk-10 ( 255K) [Vulka ]
//node #442 ( RESHAPE): ffn_moe_weights-10 ( ( 12K) [Vulka ] use=2: ffn_moe_weights-10 ( 12K) [Vulka ]
//node #443 ( SUM_ROWS): ffn_moe_weights_sum- ( 2K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ]
//node #444 ( CLAMP): ffn_moe_weights_sum_ ( 2K) [Vulka ] use=1: ffn_moe_weights_sum- ( 2K) [Vulka ]
//node #445 ( DIV): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ] ffn_moe_weights_sum_ ( 2K) [Vulka ]
//node #446 ( RESHAPE): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights_norm ( 12K) [Vulka ]
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_sigmoid_norm_bias_edges {
{ 1, 0, 0 }, // reshape->src[0] == sigmoid
{ 2, 0, 0 }, // add->src[0] == sigmoid
{ 3, 0, 2 }, // argsort->src[0] == add
{ 4, 0, 3 }, // view->src[0] == argsort
{ 5, 0, 1 }, // get_rows->src[0] == reshape
{ 5, 1, 4 }, // get_rows->src[1] == view
{ 6, 0, 5 }, // reshape->src[0] == get_rows
{ 7, 0, 6 }, // sum_rows->src[0] == reshape
{ 8, 0, 7 }, // clamp->src[0] == sum_rows
{ 9, 0, 6 }, // div->src[0] == reshape
{ 9, 1, 8 }, // div->src[1] == clamp
{10, 0, 9 }, // reshape->src[0] == div
};
// same as early_softmax_norm but ending after the get_rows // same as early_softmax_norm but ending after the get_rows
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges { static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
{ 1, 0, 0 }, // reshape->src[0] == softmax { 1, 0, 0 }, // reshape->src[0] == softmax
@ -491,16 +524,10 @@ enum topk_moe_mode {
TOPK_MOE_EARLY_SOFTMAX, TOPK_MOE_EARLY_SOFTMAX,
TOPK_MOE_EARLY_SOFTMAX_NORM, TOPK_MOE_EARLY_SOFTMAX_NORM,
TOPK_MOE_LATE_SOFTMAX, TOPK_MOE_LATE_SOFTMAX,
TOPK_MOE_SIGMOID_NORM_BIAS,
TOPK_MOE_COUNT, TOPK_MOE_COUNT,
}; };
static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX :
TOPK_MOE_LATE_SOFTMAX;
return mode;
}
static constexpr std::initializer_list<std::array<int, 3>> rope_view_set_rows_edges { static constexpr std::initializer_list<std::array<int, 3>> rope_view_set_rows_edges {
{ 1, 0, 0 }, // view->src[0] == rope { 1, 0, 0 }, // view->src[0] == rope
{ 2, 0, 1 }, // set_rows->src[0] == view { 2, 0, 1 }, // set_rows->src[0] == view
@ -738,6 +765,9 @@ struct vk_device_struct {
vk_pipeline pipeline_topk_f32[num_topk_pipelines]; vk_pipeline pipeline_topk_f32[num_topk_pipelines];
vk_pipeline pipeline_sum_rows_f32; vk_pipeline pipeline_sum_rows_f32;
vk_pipeline pipeline_cumsum_f32; vk_pipeline pipeline_cumsum_f32;
vk_pipeline pipeline_cumsum_small_f32;
vk_pipeline pipeline_cumsum_multipass1_f32;
vk_pipeline pipeline_cumsum_multipass2_f32;
vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_argmax_f32;
vk_pipeline pipeline_count_equal_i32; vk_pipeline pipeline_count_equal_i32;
std::map<vk_solve_tri_pipeline_state, vk_pipeline> pipeline_solve_tri_f32; std::map<vk_solve_tri_pipeline_state, vk_pipeline> pipeline_solve_tri_f32;
@ -766,7 +796,7 @@ struct vk_device_struct {
vk_pipeline pipeline_count_experts; vk_pipeline pipeline_count_experts;
// [2] is for whether to take n_experts from spec constant (0) or push constant (1) // [2] is for whether to take n_experts from spec constant (0) or push constant (1)
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2]; vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
std::vector<vk_pipeline_ref> all_pipelines; std::vector<vk_pipeline_ref> all_pipelines;
@ -1181,6 +1211,11 @@ struct vk_op_topk_moe_push_constants {
uint32_t n_expert_used; uint32_t n_expert_used;
float clamp_min; float clamp_min;
float clamp_max; float clamp_max;
uint32_t gating_func;
uint32_t has_bias;
uint32_t with_norm;
float output_scale;
float output_bias;
}; };
struct vk_op_add_id_push_constants { struct vk_op_add_id_push_constants {
@ -1771,6 +1806,8 @@ struct ggml_backend_vk_context {
// Bit 'i' means nodes[start_of_fusion + i] writes to memory. // Bit 'i' means nodes[start_of_fusion + i] writes to memory.
// If there's no fusion, bit 0 is still set. // If there's no fusion, bit 0 is still set.
int fused_ops_write_mask {}; int fused_ops_write_mask {};
topk_moe_mode fused_topk_moe_mode {};
bool fused_topk_moe_scale {};
// for GGML_VK_PERF_LOGGER // for GGML_VK_PERF_LOGGER
std::unique_ptr<vk_perf_logger> perf_logger; std::unique_ptr<vk_perf_logger> perf_logger;
@ -2668,7 +2705,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
switch (src0_type) { switch (src0_type) {
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_M:
lut_size = 2*2048; lut_size = 2*2048 + 4*2048;
break; break;
case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS:
lut_size = 8*256; lut_size = 8*256;
@ -3593,6 +3630,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
uint32_t rm_kq = 2; uint32_t rm_kq = 2;
uint32_t rm_stdq_int = 1; uint32_t rm_stdq_int = 1;
uint32_t rm_kq_int = 1; uint32_t rm_kq_int = 1;
auto const &rm_iq_int = [](uint32_t i) { return i == 0 ? 8u : 4u; };
if (device->vendor_id == VK_VENDOR_ID_AMD) { if (device->vendor_id == VK_VENDOR_ID_AMD) {
if (device->architecture == AMD_GCN) { if (device->architecture == AMD_GCN) {
rm_stdq = 2; rm_stdq = 2;
@ -3696,6 +3734,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_q8_1_f32", arr_dmmv_iq1_s_q8_1_f32_len[reduc], arr_dmmv_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_q8_1_f32", arr_dmmv_iq1_m_q8_1_f32_len[reduc], arr_dmmv_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int);
} }
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
} }
@ -3742,6 +3784,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_q8_1_f32", arr_dmmv_id_iq1_s_q8_1_f32_len[reduc], arr_dmmv_id_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_q8_1_f32", arr_dmmv_id_iq1_m_q8_1_f32_len[reduc], arr_dmmv_id_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
} }
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
} }
@ -3749,6 +3794,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
#if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) #if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
GGML_UNUSED(rm_stdq_int); GGML_UNUSED(rm_stdq_int);
GGML_UNUSED(rm_kq_int); GGML_UNUSED(rm_kq_int);
GGML_UNUSED(rm_iq_int);
#endif #endif
// dequant shaders // dequant shaders
@ -4135,7 +4181,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size); const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4;
ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_cumsum_small_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size, 1 }, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass1_f32, "cumsum_multipass1_f32", cumsum_multipass1_f32_len, cumsum_multipass1_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass2_f32, "cumsum_multipass2_f32", cumsum_multipass2_f32_len, cumsum_multipass2_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
@ -4291,9 +4341,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
for (uint32_t use_push = 0; use_push < 2; ++use_push) { for (uint32_t use_push = 0; use_push < 2; ++use_push) {
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX][use_push], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0, use_push}, 1, true, true, device->subgroup_size); ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][use_push], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 4, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, use_push}, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM][use_push], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0, use_push}, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX][use_push], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1, use_push}, 1, true, true, device->subgroup_size);
} }
} }
@ -5584,6 +5632,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
case GGML_TYPE_Q4_K: case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K: case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
break; break;
default: default:
return nullptr; return nullptr;
@ -5740,6 +5790,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
case GGML_TYPE_Q4_K: case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K: case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
break; break;
default: default:
return nullptr; return nullptr;
@ -7005,7 +7057,7 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
// Quantization overhead is not worth it for small k // Quantization overhead is not worth it for small k
switch (device->vendor_id) { switch (device->vendor_id) {
case VK_VENDOR_ID_NVIDIA: case VK_VENDOR_ID_NVIDIA:
if (src0_type == GGML_TYPE_Q2_K) { if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) {
return true; return true;
} }
@ -8684,10 +8736,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
if (ctx->num_additional_fused_ops) { if (ctx->num_additional_fused_ops) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
GGML_ASSERT(idx < num_topk_moe_pipelines); GGML_ASSERT(idx < num_topk_moe_pipelines);
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
// use n_experts from push constant if it's not equal to the power of two spec constant // use n_experts from push constant if it's not equal to the power of two spec constant
bool use_push = dst->ne[0] != (1u << idx); bool use_push = dst->ne[0] != (1u << idx);
return ctx->device->pipeline_topk_moe[idx][mode][use_push]; return ctx->device->pipeline_topk_moe[idx][use_push];
} }
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
@ -8760,7 +8811,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr; return nullptr;
case GGML_OP_CUMSUM: case GGML_OP_CUMSUM:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_cumsum_f32; if (src0->ne[0] <= 512) {
return ctx->device->pipeline_cumsum_small_f32;
} else {
return ctx->device->pipeline_cumsum_f32;
}
} }
return nullptr; return nullptr;
case GGML_OP_SOLVE_TRI: case GGML_OP_SOLVE_TRI:
@ -10346,14 +10401,16 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
} }
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) { static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); topk_moe_mode mode = ctx->fused_topk_moe_mode;
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0]; ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] : ggml_tensor * bias = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 2]->src[1] : logits;
(mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] : ggml_tensor * weights = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
cgraph->nodes[node_idx + 5]; ggml_tensor * ids = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 4] :
ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3]; (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] :
cgraph->nodes[node_idx + 3];
GGML_ASSERT(logits->type == GGML_TYPE_F32); GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(bias->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32); GGML_ASSERT(weights->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32); GGML_ASSERT(ids->type == GGML_TYPE_I32);
@ -10368,6 +10425,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits); vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits);
vk_subbuffer bias_buf = ggml_vk_tensor_subbuffer(ctx, bias);
vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights); vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights);
vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids); vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids);
@ -10375,18 +10433,45 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
pc.n_rows = n_rows; pc.n_rows = n_rows;
pc.n_experts_push = n_experts; pc.n_experts_push = n_experts;
pc.n_expert_used = n_expert_used; pc.n_expert_used = n_expert_used;
pc.clamp_min = -std::numeric_limits<float>::infinity();
pc.clamp_max = std::numeric_limits<float>::infinity();
if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) { if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
ggml_tensor * clamp = cgraph->nodes[node_idx + 7]; ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
pc.clamp_min = ggml_get_op_params_f32(clamp, 0); pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
pc.clamp_max = ggml_get_op_params_f32(clamp, 1); pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
} }
if (mode == TOPK_MOE_SIGMOID_NORM_BIAS) {
ggml_tensor * clamp = cgraph->nodes[node_idx + 8];
GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
}
#define GATING_FUNC_SOFTMAX 0
#define GATING_FUNC_SIGMOID 1
#define GATING_FUNC_SOFTMAX_WEIGHT 2
pc.gating_func = mode == TOPK_MOE_SIGMOID_NORM_BIAS ? GATING_FUNC_SIGMOID :
mode == TOPK_MOE_LATE_SOFTMAX ? GATING_FUNC_SOFTMAX_WEIGHT :
GATING_FUNC_SOFTMAX;
pc.has_bias = mode == TOPK_MOE_SIGMOID_NORM_BIAS;
pc.with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM || mode == TOPK_MOE_SIGMOID_NORM_BIAS;
if (ctx->fused_topk_moe_scale) {
GGML_ASSERT(weights->op == GGML_OP_SCALE);
pc.output_scale = ggml_get_op_params_f32(weights, 0);
pc.output_bias = ggml_get_op_params_f32(weights, 1);
} else {
pc.output_scale = 1.0f;
pc.output_bias = 0.0f;
}
GGML_ASSERT(n_expert_used <= n_experts); GGML_ASSERT(n_expert_used <= n_experts);
const uint32_t rows_per_block = 4; const uint32_t rows_per_block = 4;
std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 }; std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, weights_buf, ids_buf}, pc, elements); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, bias_buf, weights_buf, ids_buf}, pc, elements);
} }
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) { static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) {
@ -10634,8 +10719,50 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons
} }
static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); vk_op_sum_rows_push_constants pc = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p); // Use the single pass shader when the rows are small or there are enough rows to fill the GPU.
// For fewer, larger rows, use the multipass shader to spread each row across SMs.
if (dst->ne[0] <= 4096 || ggml_nrows(dst) >= ctx->device->shader_core_count) {
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, pc);
return;
}
// First pass computes partial sums within a block, and stores the last partial
// to the temp buffer. Second pass sums the block partials from the temp buffer
// and adds that to the result of the first pass.
vk_pipeline pipeline1 = ctx->device->pipeline_cumsum_multipass1_f32;
vk_pipeline pipeline2 = ctx->device->pipeline_cumsum_multipass2_f32;
GGML_ASSERT(pipeline1 != nullptr && pipeline2 != nullptr);
ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);
ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);
std::array<uint32_t, 3> elements;
elements[0] = dst->ne[0];
elements[1] = (uint32_t)ggml_nrows(dst);
elements[2] = 1;
size_t temp_size = sizeof(float) * elements[0] * ggml_nrows(dst);
if (ctx->prealloc_size_split_k < temp_size) {
ctx->prealloc_size_split_k = temp_size;
ggml_vk_preallocate_buffers(ctx, subctx);
}
vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
vk_subbuffer temp_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
if (ctx->prealloc_split_k_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, {src_buf, dst_buf, temp_buf}, pc, elements);
ggml_vk_sync_buffers(ctx, subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, {src_buf, dst_buf, temp_buf}, pc, elements);
ctx->prealloc_split_k_need_sync = true;
} }
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@ -12128,6 +12255,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break; break;
case GGML_OP_UNARY: case GGML_OP_UNARY:
if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
break;
}
switch (ggml_get_unary_op(node)) { switch (ggml_get_unary_op(node)) {
case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_SILU:
@ -12175,7 +12307,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break; break;
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
if (ctx->num_additional_fused_ops) { if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx); ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
} else { } else {
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node); ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node);
@ -12195,7 +12327,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break; break;
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
if (ctx->num_additional_fused_ops) { if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx); ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
} else { } else {
ggml_vk_argsort(ctx, compute_ctx, src0, node); ggml_vk_argsort(ctx, compute_ctx, src0, node);
@ -13048,6 +13180,24 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
get_rows = cgraph->nodes[node_idx + 4]; get_rows = cgraph->nodes[node_idx + 4];
argsort = cgraph->nodes[node_idx + 2]; argsort = cgraph->nodes[node_idx + 2];
break; break;
case TOPK_MOE_SIGMOID_NORM_BIAS:
softmax = cgraph->nodes[node_idx + 0]; // really sigmoid
weights = cgraph->nodes[node_idx + 10];
get_rows = cgraph->nodes[node_idx + 5];
argsort = cgraph->nodes[node_idx + 3];
if (ggml_get_unary_op(softmax) != GGML_UNARY_OP_SIGMOID) {
return false;
}
// bias is expected to be 1D
if (ggml_nrows(cgraph->nodes[node_idx + 2]->src[1]) != 1 ||
!ggml_is_contiguous(cgraph->nodes[node_idx + 2]->src[1])) {
return false;
}
// sigmoid fusion seems to generate infinities on moltenvk
if (ctx->device->driver_id == vk::DriverId::eMoltenvk) {
return false;
}
break;
case TOPK_MOE_EARLY_SOFTMAX: case TOPK_MOE_EARLY_SOFTMAX:
softmax = cgraph->nodes[node_idx + 0]; softmax = cgraph->nodes[node_idx + 0];
weights = cgraph->nodes[node_idx + 4]; weights = cgraph->nodes[node_idx + 4];
@ -13071,26 +13221,28 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
probs = probs->src[0]; probs = probs->src[0];
ggml_tensor * selection_probs = argsort->src[0]; ggml_tensor * selection_probs = argsort->src[0];
if (probs != selection_probs) { if (probs != selection_probs && mode != TOPK_MOE_SIGMOID_NORM_BIAS) {
return false; return false;
} }
const float * op_params = (const float *)softmax->op_params;
float scale = op_params[0];
float max_bias = op_params[1];
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) { if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
return false; return false;
} }
if (scale != 1.0f || max_bias != 0.0f) { if (softmax->op == GGML_OP_SOFT_MAX) {
return false; const float * op_params = (const float *)softmax->op_params;
}
// don't fuse when masks or sinks are present float scale = op_params[0];
if (softmax->src[1] || softmax->src[2]) { float max_bias = op_params[1];
return false;
if (scale != 1.0f || max_bias != 0.0f) {
return false;
}
// don't fuse when masks or sinks are present
if (softmax->src[1] || softmax->src[2]) {
return false;
}
} }
const int n_expert = softmax->ne[0]; const int n_expert = softmax->ne[0];
@ -13363,6 +13515,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
total_mul_mat_bytes += bytes; total_mul_mat_bytes += bytes;
} }
ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
ctx->fused_topk_moe_scale = false;
const char *fusion_string {}; const char *fusion_string {};
if (!ctx->device->disable_fusion) { if (!ctx->device->disable_fusion) {
uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i); uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
@ -13408,13 +13562,23 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1; ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
// view of argsort writes to memory // view of argsort writes to memory
ctx->fused_ops_write_mask |= 1 << 3; ctx->fused_ops_write_mask |= 1 << 3;
ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM;
fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM"; fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM";
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) &&
ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) {
ctx->num_additional_fused_ops = topk_moe_sigmoid_norm_bias.size() - 1;
// view of argsort writes to memory
ctx->fused_ops_write_mask |= 1 << 4;
ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS;
fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS";
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) && } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) && ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) { ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1; ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
// view of argsort writes to memory // view of argsort writes to memory
ctx->fused_ops_write_mask |= 1 << 3; ctx->fused_ops_write_mask |= 1 << 3;
ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX;
fusion_string = "TOPK_MOE_EARLY_SOFTMAX"; fusion_string = "TOPK_MOE_EARLY_SOFTMAX";
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) && } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) && ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
@ -13422,8 +13586,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1; ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
// view of argsort writes to memory // view of argsort writes to memory
ctx->fused_ops_write_mask |= 1 << 1; ctx->fused_ops_write_mask |= 1 << 1;
ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX;
fusion_string = "TOPK_MOE_LATE_SOFTMAX"; fusion_string = "TOPK_MOE_LATE_SOFTMAX";
} }
if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
// Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano.
if (ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops - 1, { GGML_OP_DIV, GGML_OP_RESHAPE, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 }) ||
ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) {
ctx->fused_topk_moe_scale = true;
ctx->num_additional_fused_ops++;
}
}
} }
ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops; ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
@ -13602,6 +13775,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
if (keep_pattern(topk_moe_early_softmax_norm)) { if (keep_pattern(topk_moe_early_softmax_norm)) {
continue; continue;
} }
if (keep_pattern(topk_moe_sigmoid_norm_bias)) {
continue;
}
if (keep_pattern(topk_moe_early_softmax)) { if (keep_pattern(topk_moe_early_softmax)) {
continue; continue;
} }
@ -13628,6 +13804,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
} }
// Don't pull forward nodes from fusion patterns // Don't pull forward nodes from fusion patterns
if (match_pattern(topk_moe_early_softmax_norm, j) || if (match_pattern(topk_moe_early_softmax_norm, j) ||
match_pattern(topk_moe_sigmoid_norm_bias, j) ||
match_pattern(topk_moe_early_softmax, j) || match_pattern(topk_moe_early_softmax, j) ||
match_pattern(topk_moe_late_softmax, j)) { match_pattern(topk_moe_late_softmax, j)) {
continue; continue;

View File

@ -14,6 +14,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 128; layout (constant_id = 0) const uint BLOCK_SIZE = 128;
layout (constant_id = 1) const uint SUBGROUP_SIZE = 32; layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
layout (constant_id = 2) const uint ELEM_PER_THREAD = 4;
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
@ -38,32 +39,45 @@ void main() {
last_sum = 0; last_sum = 0;
} }
uint col = tid; uint col = tid * ELEM_PER_THREAD;
uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE); uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE * ELEM_PER_THREAD);
for (int i = 0; i < num_iter; ++i) { for (int i = 0; i < num_iter; ++i) {
FLOAT_TYPE v = 0; FLOAT_TYPE v[ELEM_PER_THREAD];
if (col < p.n_cols) { FLOAT_TYPE thread_sum = 0;
v = FLOAT_TYPE(data_a[src_idx + col]); [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
if (col + j < p.n_cols) {
thread_sum += FLOAT_TYPE(data_a[src_idx + col + j]);
}
v[j] = thread_sum;
} }
v = subgroupInclusiveAdd(v);
thread_sum = subgroupExclusiveAdd(thread_sum);
[[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
v[j] += thread_sum;
}
// Store the largest partial sum for each subgroup, then add the partials for all // Store the largest partial sum for each subgroup, then add the partials for all
// lower subgroups and the final partial sum from the previous iteration. // lower subgroups and the final partial sum from the previous iteration.
if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) { if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
partial[subgroup_id] = v; partial[subgroup_id] = v[ELEM_PER_THREAD - 1];
} }
barrier(); barrier();
for (int j = 0; j < subgroup_id; ++j) { for (int s = 0; s < subgroup_id; ++s) {
v += partial[j]; [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
v[j] += partial[s];
}
}
[[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
v[j] += last_sum;
} }
v += last_sum;
barrier(); barrier();
if (tid == BLOCK_SIZE - 1) { if (tid == BLOCK_SIZE - 1) {
last_sum = v; last_sum = v[ELEM_PER_THREAD - 1];
} }
if (col < p.n_cols) { [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
data_d[dst_idx + col] = D_TYPE(v); if (col + j < p.n_cols) {
data_d[dst_idx + col + j] = D_TYPE(v[j]);
}
} }
col += BLOCK_SIZE; col += BLOCK_SIZE * ELEM_PER_THREAD;
} }
} }

View File

@ -0,0 +1,60 @@
#version 450
#include "types.glsl"
#include "sum_rows.glsl"
#extension GL_EXT_control_flow_attributes : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_basic : enable
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (binding = 2) writeonly buffer T {D_TYPE data_t[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 128;
layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE];
void main() {
const uint row = gl_WorkGroupID.y;
const uint tid = gl_LocalInvocationID.x;
const uint col = gl_GlobalInvocationID.x;
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
const uint i03_offset = i03 * p.ne01*p.ne02;
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
const uint i01 = row - i03_offset - i02*p.ne01;
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
uint subgroup_id = tid / SUBGROUP_SIZE;
FLOAT_TYPE v = 0;
if (col < p.n_cols) {
v = FLOAT_TYPE(data_a[src_idx + col]);
}
v = subgroupInclusiveAdd(v);
// Store the largest partial sum for each subgroup, then add the partials for all
// lower subgroups and the final partial sum from the previous iteration.
if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
partial[subgroup_id] = v;
}
barrier();
for (int j = 0; j < subgroup_id; ++j) {
v += partial[j];
}
barrier();
if (tid == BLOCK_SIZE - 1) {
data_t[gl_WorkGroupID.x + gl_NumWorkGroups.x * row] = v;
}
if (col < p.n_cols) {
data_d[dst_idx + col] = D_TYPE(v);
}
}

View File

@ -0,0 +1,66 @@
#version 450
#include "types.glsl"
#include "sum_rows.glsl"
#extension GL_EXT_control_flow_attributes : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_basic : enable
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) buffer D {D_TYPE data_d[];};
layout (binding = 2) readonly buffer T {D_TYPE data_t[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 128;
layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
shared FLOAT_TYPE temp[BLOCK_SIZE / SUBGROUP_SIZE];
void main() {
const uint row = gl_WorkGroupID.y;
const uint tid = gl_LocalInvocationID.x;
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
const uint i03_offset = i03 * p.ne01*p.ne02;
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
const uint i01 = row - i03_offset - i02*p.ne01;
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
const uint col = gl_GlobalInvocationID.x;
float v = 0;
// prefetch value we're adding to
if (col < p.n_cols) {
v = data_d[dst_idx + col];
}
// compute the sum of all previous blocks
uint c = tid;
float sum = 0;
while (c < gl_WorkGroupID.x) {
sum += data_t[c + gl_NumWorkGroups.x * row];
c += BLOCK_SIZE;
}
sum = subgroupAdd(sum);
if (gl_SubgroupInvocationID == 0) {
temp[gl_SubgroupID] = sum;
}
barrier();
sum = 0;
[[unroll]] for (uint s = 0; s < BLOCK_SIZE / SUBGROUP_SIZE; ++s) {
sum += temp[s];
}
// Add the sum to what the first pass computed
if (col < p.n_cols) {
data_d[dst_idx + col] = v + sum;
}
}

View File

@ -14,6 +14,8 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
#define K_PER_ITER 8 #define K_PER_ITER 8
#elif defined(DATA_A_QUANT_K) #elif defined(DATA_A_QUANT_K)
#define K_PER_ITER 16 #define K_PER_ITER 16
#elif defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M)
#define K_PER_ITER 32
#else #else
#error unimplemented #error unimplemented
#endif #endif
@ -49,6 +51,15 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1]; cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1];
cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2]; cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2];
cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3]; cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3];
#elif K_PER_ITER == 32
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 ];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 1];
cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 2];
cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 3];
cache_b_qs[4] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 4];
cache_b_qs[5] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 5];
cache_b_qs[6] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 6];
cache_b_qs[7] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 7];
#else #else
#error unimplemented #error unimplemented
#endif #endif

View File

@ -377,3 +377,118 @@ FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum)); return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum));
} }
#endif #endif
#if defined(DATA_A_IQ1_S)
void repack8(uint ib, uint iqs, out i32vec4 out0, out i32vec4 out1) {
const uint ib32 = iqs / 32;
const uint qh = data_a[ib].qh[ib32];
const uint qs16_0 = data_a_packed16[ib].qs[(4 * ib32 + 0) / 2];
const uint qs16_1 = data_a_packed16[ib].qs[(4 * ib32 + 2) / 2];
const uint qs0 = qs16_0 & 0xFF;
const uint qs1 = qs16_0 >> 8;
const uint qs2 = qs16_1 & 0xFF;
const uint qs3 = qs16_1 >> 8;
const uint hi0 = bitfieldExtract(qh, 3 * int(0), 3);
const uint hi1 = bitfieldExtract(qh, 3 * int(1), 3);
const uint hi2 = bitfieldExtract(qh, 3 * int(2), 3);
const uint hi3 = bitfieldExtract(qh, 3 * int(3), 3);
const int32_t grid0 = int32_t(iq1s_grid_gpu[qs0 | (hi0 << 8)]);
const int32_t grid1 = int32_t(iq1s_grid_gpu[qs1 | (hi1 << 8)]);
const int32_t grid2 = int32_t(iq1s_grid_gpu[qs2 | (hi2 << 8)]);
const int32_t grid3 = int32_t(iq1s_grid_gpu[qs3 | (hi3 << 8)]);
out0 = i32vec4((grid0 >> 0) & 0x0F0F0F0F,
(grid0 >> 4) & 0x0F0F0F0F,
(grid1 >> 0) & 0x0F0F0F0F,
(grid1 >> 4) & 0x0F0F0F0F);
out1 = i32vec4((grid2 >> 0) & 0x0F0F0F0F,
(grid2 >> 4) & 0x0F0F0F0F,
(grid3 >> 0) & 0x0F0F0F0F,
(grid3 >> 4) & 0x0F0F0F0F);
}
vec2 get_dm(uint ib, uint iqs) {
const uint ib32 = iqs / 32;
const uint qh = data_a[ib].qh[ib32];
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
const float d = float(data_a[ib].d);
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
// the -1 cancels out the bias in iq1s_grid_gpu
return FLOAT_TYPE_VEC2(dl, dl * (delta - 1));
}
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t q_sum = 0;
const uint ib_k = ib_a / 8;
const uint iqs_k = (ib_a % 8) * 32 + iqs * 32;
i32vec4 qs_a0;
i32vec4 qs_a1;
repack8(ib_k, iqs_k, qs_a0, qs_a1);
const vec2 dm = get_dm(ib_k, iqs_k);
q_sum += dotPacked4x8EXT(qs_a0.x, cache_b_qs[0]);
q_sum += dotPacked4x8EXT(qs_a0.y, cache_b_qs[1]);
q_sum += dotPacked4x8EXT(qs_a0.z, cache_b_qs[2]);
q_sum += dotPacked4x8EXT(qs_a0.w, cache_b_qs[3]);
q_sum += dotPacked4x8EXT(qs_a1.x, cache_b_qs[4]);
q_sum += dotPacked4x8EXT(qs_a1.y, cache_b_qs[5]);
q_sum += dotPacked4x8EXT(qs_a1.z, cache_b_qs[6]);
q_sum += dotPacked4x8EXT(qs_a1.w, cache_b_qs[7]);
return FLOAT_TYPE(float(cache_b_ds.x) * float(dm.x) * float(q_sum) + float(dm.y) * float(cache_b_ds.y));
}
#endif
#if defined(DATA_A_IQ1_M)
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
const uint ib_k = ib_a / 8;
const uint iqs_k = (ib_a % 8) * 32 + iqs * 32;
const uint ib32 = iqs_k / 32;
const uint ib64 = ib32 / 2;
const uint16_t[4] scales = data_a[ib_k].scales;
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
const uint qs32 = data_a_packed32[ib_k].qs[ib32];
const uint qh16 = data_a_packed16[ib_k].qh[ib32];
float sum = 0;
const uint sc = data_a[ib_k].scales[ib64];
[[unroll]] for (int l = 0; l < 4; ++l) {
const uint ib16 = 2 * ib32 + l / 2;
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
const uint qh = qh16 >> (4 * l);
const uint qs = (qs32 >> (8 * l)) & 0xFF;
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
const int32_t grid = int32_t(iq1s_grid_gpu[qs | ((qh & 7) << 8)]);
int32_t q_sum = 0;
q_sum += dotPacked4x8EXT((grid >> 0) & 0x0F0F0F0F, cache_b_qs[2 * l + 0]);
q_sum += dotPacked4x8EXT((grid >> 4) & 0x0F0F0F0F, cache_b_qs[2 * l + 1]);
int32_t y_sum = 0;
y_sum += dotPacked4x8EXT(int(0x01010101), cache_b_qs[2 * l + 0]);
y_sum += dotPacked4x8EXT(int(0x01010101), cache_b_qs[2 * l + 1]);
// the -1 cancels out the bias in iq1s_grid_gpu
sum += dl * (q_sum + y_sum * (delta - 1));
}
sum *= float(cache_b_ds.x);
return sum;
}
#endif

View File

@ -7,6 +7,10 @@
#include "types.glsl" #include "types.glsl"
#define GATING_FUNC_SOFTMAX 0
#define GATING_FUNC_SIGMOID 1
#define GATING_FUNC_SOFTMAX_WEIGHT 2
layout (push_constant) uniform parameter layout (push_constant) uniform parameter
{ {
uint n_rows; uint n_rows;
@ -14,15 +18,18 @@ layout (push_constant) uniform parameter
uint n_expert_used; uint n_expert_used;
float clamp_min; float clamp_min;
float clamp_max; float clamp_max;
uint gating_func;
uint has_bias;
uint with_norm;
float output_scale;
float output_bias;
}; };
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
layout(constant_id = 0) const uint WARP_SIZE = 32; layout(constant_id = 0) const uint WARP_SIZE = 32;
layout(constant_id = 1) const uint n_experts_spec = 512; layout(constant_id = 1) const uint n_experts_spec = 512;
layout(constant_id = 2) const bool with_norm = true; layout(constant_id = 2) const bool nexperts_use_push = false;
layout(constant_id = 3) const bool late_softmax = false;
layout(constant_id = 4) const bool nexperts_use_push = false;
uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec; uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;
@ -31,8 +38,9 @@ uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;
const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE); const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE);
layout (binding = 0, std430) readonly buffer Logits {float logits[];}; layout (binding = 0, std430) readonly buffer Logits {float logits[];};
layout (binding = 1, std430) writeonly buffer Weights {float weights[];}; layout (binding = 1, std430) readonly buffer BiasProbs {float bias[];};
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];}; layout (binding = 2, std430) writeonly buffer Weights {float weights[];};
layout (binding = 3, std430) writeonly buffer Ids {uint ids[];};
const float INFINITY = 1.0 / 0.0; const float INFINITY = 1.0 / 0.0;
@ -87,20 +95,40 @@ void main() {
} }
const uint logits_offset = n_experts * row; const uint logits_offset = n_experts * row;
const uint bias_offset = 0; // 1D
const uint weights_offset = n_expert_used * row; const uint weights_offset = n_expert_used * row;
const uint ids_offset = n_experts * row; const uint ids_offset = n_experts * row;
const uint lane = gl_SubgroupInvocationID; const uint lane = gl_SubgroupInvocationID;
float wt[experts_per_thread]; float probs[experts_per_thread];
[[unroll]] [[unroll]]
for (uint i = 0; i < n_experts; i += WARP_SIZE) { for (uint i = 0; i < n_experts; i += WARP_SIZE) {
const uint expert = i + lane; const uint expert = i + lane;
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY; probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
} }
if (!late_softmax) { if (gating_func == GATING_FUNC_SOFTMAX) {
softmax_warp_inplace(wt, n_experts, lane, nexperts_use_push); softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push);
} else if (gating_func == GATING_FUNC_SIGMOID) {
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
probs[i] = 1.f / (1.f + exp(-probs[i]));
}
}
float selection_probs[experts_per_thread];
if (has_bias != 0) {
[[unroll]]
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
const uint expert = i + lane;
selection_probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? probs[i / WARP_SIZE] + bias[bias_offset + expert] : -INFINITY;
}
} else {
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
selection_probs[i] = probs[i];
}
} }
// at this point, each thread holds a portion of softmax, // at this point, each thread holds a portion of softmax,
@ -117,14 +145,16 @@ void main() {
} }
for (int k = 0; k < n_expert_used; k++) { for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0]; float max_val = probs[0];
float max_val_s = selection_probs[0];
uint max_expert = lane; uint max_expert = lane;
[[unroll]] [[unroll]]
for (int i = 1; i < experts_per_thread; i++) { for (int i = 1; i < experts_per_thread; i++) {
const uint expert = lane + i * WARP_SIZE; const uint expert = lane + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i] > max_val_s) {
max_val = wt[i]; max_val = probs[i];
max_val_s = selection_probs[i];
max_expert = expert; max_expert = expert;
} }
} }
@ -132,9 +162,11 @@ void main() {
[[unroll]] [[unroll]]
for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) { for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = subgroupShuffleXor(max_val, mask); const float val = subgroupShuffleXor(max_val, mask);
const float val_s = subgroupShuffleXor(max_val_s, mask);
const uint expert = subgroupShuffleXor(max_expert, mask); const uint expert = subgroupShuffleXor(max_expert, mask);
if (val > max_val || (val == max_val && expert < max_expert)) { if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
max_val = val; max_val = val;
max_val_s = val_s;
max_expert = expert; max_expert = expert;
} }
} }
@ -144,16 +176,14 @@ void main() {
} }
if ((max_expert & (WARP_SIZE - 1)) == lane) { if ((max_expert & (WARP_SIZE - 1)) == lane) {
wt[max_expert / WARP_SIZE] = -INFINITY; selection_probs[max_expert / WARP_SIZE] = -INFINITY;
ids[ids_offset + k] = max_expert; ids[ids_offset + k] = max_expert;
if (with_norm) { wt_sum += max_val;
wt_sum += max_val;
}
} }
} }
if (with_norm) { if (with_norm != 0) {
wt_sum = subgroupAdd(wt_sum); wt_sum = subgroupAdd(wt_sum);
wt_sum = clamp(wt_sum, clamp_min, clamp_max); wt_sum = clamp(wt_sum, clamp_min, clamp_max);
const float inv_sum = 1.0f / wt_sum; const float inv_sum = 1.0f / wt_sum;
@ -164,7 +194,7 @@ void main() {
} }
} }
if (late_softmax) { if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {
softmax_warp_inplace(output_weights, n_expert_used, lane, true); softmax_warp_inplace(output_weights, n_expert_used, lane, true);
} }
@ -172,7 +202,7 @@ void main() {
for (uint i = 0; i < experts_per_thread; ++i) { for (uint i = 0; i < experts_per_thread; ++i) {
uint idx = i * WARP_SIZE + lane; uint idx = i * WARP_SIZE + lane;
if (idx < n_expert_used) { if (idx < n_expert_used) {
weights[weights_offset + idx] = output_weights[i]; weights[weights_offset + idx] = output_scale * output_weights[i] + output_bias;
} }
} }
} }

View File

@ -396,6 +396,12 @@ struct block_iq1_s {
uint16_t qh[QUANT_K_IQ1_S/32]; uint16_t qh[QUANT_K_IQ1_S/32];
}; };
struct block_iq1_s_packed16 {
float16_t d;
uint16_t qs[QUANT_K_IQ1_S/8/2];
uint16_t qh[QUANT_K_IQ1_S/32];
};
#define QUANT_K_IQ1_M 256 #define QUANT_K_IQ1_M 256
#define QUANT_R_IQ1_M 1 #define QUANT_R_IQ1_M 1
@ -405,6 +411,18 @@ struct block_iq1_m {
uint16_t scales[QUANT_K_IQ1_M/64]; uint16_t scales[QUANT_K_IQ1_M/64];
}; };
struct block_iq1_m_packed16 {
uint16_t qs[QUANT_K_IQ1_M/8/2];
uint16_t qh[QUANT_K_IQ1_M/16/2];
uint16_t scales[QUANT_K_IQ1_M/64];
};
struct block_iq1_m_packed32 {
uint32_t qs[QUANT_K_IQ1_M/8/4];
uint32_t qh[QUANT_K_IQ1_M/16/4];
uint32_t scales[QUANT_K_IQ1_M/64/2];
};
struct block_iq1_m_packed64 { struct block_iq1_m_packed64 {
uint64_t qs[QUANT_K_IQ1_M/8/8]; uint64_t qs[QUANT_K_IQ1_M/8/8];
uint64_t qh[QUANT_K_IQ1_M/16/8]; uint64_t qh[QUANT_K_IQ1_M/16/8];
@ -415,12 +433,15 @@ struct block_iq1_m_packed64 {
#define QUANT_K QUANT_K_IQ1_S #define QUANT_K QUANT_K_IQ1_S
#define QUANT_R QUANT_R_IQ1_S #define QUANT_R QUANT_R_IQ1_S
#define A_TYPE block_iq1_s #define A_TYPE block_iq1_s
#define A_TYPE_PACKED16 block_iq1_s_packed16
#endif #endif
#if defined(DATA_A_IQ1_M) #if defined(DATA_A_IQ1_M)
#define QUANT_K QUANT_K_IQ1_M #define QUANT_K QUANT_K_IQ1_M
#define QUANT_R QUANT_R_IQ1_M #define QUANT_R QUANT_R_IQ1_M
#define A_TYPE block_iq1_m #define A_TYPE block_iq1_m
#define A_TYPE_PACKED16 block_iq1_m_packed16
#define A_TYPE_PACKED32 block_iq1_m_packed32
#endif #endif
#if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M) #if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M)
@ -559,7 +580,270 @@ const uint[1024] iq1s_grid_const = {
0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557
}; };
// Same content as iq1s_grid_const except each 2-bit value is expanded to 4-bit
// and has 1 added to it (allows packed values to be extracted with & 0x0F0F0F0F
// and 0xF0F0F0F0).
const uint32_t[2048] iq1s_grid_gpu_const = {
0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000,
0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101,
0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200,
0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212,
0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011,
0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111,
0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220,
0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022,
0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220,
0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101,
0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110,
0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111,
0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010,
0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210,
0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221,
0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021,
0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002,
0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101,
0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101,
0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211,
0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110,
0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022,
0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121,
0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220,
0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001,
0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101,
0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102,
0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012,
0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010,
0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111,
0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122,
0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222,
0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001,
0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102,
0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101,
0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000,
0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101,
0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112,
0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110,
0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211,
0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012,
0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111,
0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120,
0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122,
0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121,
0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221,
0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001,
0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101,
0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101,
0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011,
0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111,
0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011,
0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122,
0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121,
0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222,
0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101,
0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000,
0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200,
0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110,
0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112,
0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222,
0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021,
0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121,
0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201,
0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200,
0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101,
0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011,
0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010,
0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211,
0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121,
0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000,
0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202,
0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202,
0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211,
0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112,
0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020,
0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121,
0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222,
0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102,
0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100,
0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110,
0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011,
0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111,
0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110,
0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121,
0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222,
0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201,
0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102,
0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201,
0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012,
0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010,
0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010,
0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110,
0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011,
0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212,
0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021,
0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021,
0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021,
0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101,
0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101,
0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100,
0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010,
0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111,
0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010,
0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111,
0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120,
0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120,
0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101,
0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001,
0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201,
0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210,
0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211,
0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111,
0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112,
0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211,
0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010,
0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021,
0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122,
0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221,
0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102,
0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100,
0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101,
0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101,
0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101,
0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012,
0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110,
0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112,
0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210,
0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210,
0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210,
0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010,
0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110,
0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122,
0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020,
0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021,
0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022,
0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120,
0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222,
0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221,
0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001,
0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102,
0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201,
0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012,
0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111,
0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012,
0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110,
0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110,
0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121,
0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221,
0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220,
0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222,
0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000,
0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201,
0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012,
0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011,
0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212,
0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221,
0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121,
0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202,
0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202,
0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002,
0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101,
0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210,
0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112,
0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011,
0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011,
0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210,
0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020,
0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220,
0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222,
0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222,
0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001,
0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010,
0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111,
0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010,
0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110,
0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221,
0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122,
0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202,
0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100,
0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101,
0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112,
0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111,
0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211,
0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222,
0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221,
0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022,
0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101,
0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211,
0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111,
0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111,
0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010,
0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121,
0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222,
0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000,
0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202,
0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000,
0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202,
0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110,
0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110,
0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222,
0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120,
0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022,
0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101,
0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202,
0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110,
0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110,
0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111,
0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111,
0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120,
0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121,
0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001,
0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202,
0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001,
0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200,
0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011,
0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212,
0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012,
0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110,
0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012,
0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111,
0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020,
0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121,
0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222,
0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102,
0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102,
0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101,
0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212,
0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210,
0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111,
0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212,
0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221,
0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121,
0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002,
0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000,
0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202,
0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112,
0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111,
0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020,
0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221,
0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022,
0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100,
0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201,
0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112,
0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211,
0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012,
0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121,
0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020,
0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120,
0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200,
0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200,
0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110,
0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011,
0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222,
0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020,
0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222,
};
shared uint16_t iq1s_grid[2048]; shared uint16_t iq1s_grid[2048];
shared uint32_t iq1s_grid_gpu[2048];
#define NEEDS_INIT_IQ_SHMEM #define NEEDS_INIT_IQ_SHMEM
void init_iq_shmem(uvec3 wgsize) void init_iq_shmem(uvec3 wgsize)
@ -573,6 +857,12 @@ void init_iq_shmem(uvec3 wgsize)
iq1s_grid[2*idx+1] = g.y; iq1s_grid[2*idx+1] = g.y;
} }
} }
[[unroll]] for (uint i = 0; i < iq1s_grid_gpu_const.length(); i += wgsize.x) {
uint idx = i + gl_LocalInvocationIndex.x;
if (iq1s_grid_gpu_const.length() % wgsize.x == 0 || idx < iq1s_grid_gpu_const.length()) {
iq1s_grid_gpu[idx] = iq1s_grid_gpu_const[idx];
}
}
barrier(); barrier();
} }
#endif #endif

View File

@ -685,7 +685,7 @@ void process_shaders() {
// mul mat vec with integer dot product // mul mat vec with integer dot product
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname)) { if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname) || tname == "iq1_s" || tname == "iq1_m") {
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
@ -944,6 +944,8 @@ void process_shaders() {
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("cumsum_multipass1_f32", "cumsum_multipass1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("cumsum_multipass2_f32", "cumsum_multipass2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("count_experts", "count_experts.comp", merge_maps(base_dict, {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}})); string_to_spv("count_experts", "count_experts.comp", merge_maps(base_dict, {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}}));
@ -1123,7 +1125,7 @@ void write_output_files() {
for (const std::string& btype : btypes) { for (const std::string& btype : btypes) {
for (const auto& tname : type_names) { for (const auto& tname : type_names) {
if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname)) { if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname) && tname != "iq1_s" && tname != "iq1_m") {
continue; continue;
} }
hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n"; hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n";

View File

@ -294,7 +294,9 @@ class Keys:
USE_GELU = "clip.use_gelu" USE_GELU = "clip.use_gelu"
USE_SILU = "clip.use_silu" USE_SILU = "clip.use_silu"
N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl
WA_LAYER_INDEXES = "clip.vision.wa_layer_indexes" # used by youtuvl
IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers" IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers"
WINDOW_SIZE = "clip.vision.window_size"
class Attention: class Attention:
HEAD_COUNT = "clip.vision.attention.head_count" HEAD_COUNT = "clip.vision.attention.head_count"
@ -452,6 +454,7 @@ class MODEL_ARCH(IntEnum):
MISTRAL3 = auto() MISTRAL3 = auto()
MIMO2 = auto() MIMO2 = auto()
LLAMA_EMBED = auto() LLAMA_EMBED = auto()
MAINCODER = auto()
class VISION_PROJECTOR_TYPE(IntEnum): class VISION_PROJECTOR_TYPE(IntEnum):
@ -850,6 +853,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.MISTRAL3: "mistral3", MODEL_ARCH.MISTRAL3: "mistral3",
MODEL_ARCH.MIMO2: "mimo2", MODEL_ARCH.MIMO2: "mimo2",
MODEL_ARCH.LLAMA_EMBED: "llama-embed", MODEL_ARCH.LLAMA_EMBED: "llama-embed",
MODEL_ARCH.MAINCODER: "maincoder",
} }
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@ -3257,6 +3261,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP, MODEL_TENSOR.FFN_UP_EXP,
], ],
MODEL_ARCH.MAINCODER: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
# TODO # TODO
} }
@ -3492,7 +3512,9 @@ class VisionProjectorType:
COGVLM = "cogvlm" COGVLM = "cogvlm"
JANUS_PRO = "janus_pro" JANUS_PRO = "janus_pro"
LFM2A = "lfm2a" # audio LFM2A = "lfm2a" # audio
MUSIC_FLAMINGO = "musicflamingo" # audio
GLM4V = "glm4v" GLM4V = "glm4v"
YOUTUVL = "youtuvl"
# Items here are (block size, type size) # Items here are (block size, type size)

View File

@ -1129,11 +1129,40 @@ class GGUFWriter:
self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value) self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value)
def add_vision_n_wa_pattern(self, value: int) -> None: def add_vision_n_wa_pattern(self, value: int) -> None:
"""Add window attention pattern interval for vision models.
This defines the pattern interval for window attention vs full attention layers.
For example, if n_wa_pattern=4, then layers 3, 7, 11, ... use full attention,
while other layers use window attention.
Used by models like Qwen2.5-VL where full attention layers follow a regular pattern.
"""
self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value) self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
def add_vision_wa_layer_indexes(self, layers: Sequence[int]) -> None:
"""Add explicit layer indexes that use full attention in vision models.
This specifies the exact layer indices (0-based) that should use full attention
instead of window attention. All other layers will use window attention.
Args:
layers: List of layer indices that use full attention (e.g., [3, 7, 11, 15])
Used by models like YoutuVL where full attention layers are explicitly specified
rather than following a regular pattern.
Difference from add_vision_n_wa_pattern:
- n_wa_pattern: Defines a regular interval pattern (every Nth layer uses full attention)
- wa_layer_indexes: Explicitly lists which layers use full attention (irregular pattern)
"""
self.add_array(Keys.ClipVision.WA_LAYER_INDEXES, layers)
def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None: def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None:
self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers) self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers)
def add_vision_window_size(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.WINDOW_SIZE, value)
# audio models # audio models
def add_audio_projection_dim(self, value: int) -> None: def add_audio_projection_dim(self, value: int) -> None:

View File

@ -1221,6 +1221,7 @@ class TensorNameMap:
MODEL_TENSOR.V_MMPROJ: ( MODEL_TENSOR.V_MMPROJ: (
"multi_modal_projector.linear_{bid}", "multi_modal_projector.linear_{bid}",
"visual.merger.mlp.{bid}", # qwen2vl "visual.merger.mlp.{bid}", # qwen2vl
"merger.mlp.{bid}",
), ),
MODEL_TENSOR.V_MMPROJ_FC: ( MODEL_TENSOR.V_MMPROJ_FC: (
@ -1258,6 +1259,7 @@ class TensorNameMap:
"visual.patch_embed.proj", # qwen2vl "visual.patch_embed.proj", # qwen2vl
"vision_tower.patch_embed.proj", # kimi-vl "vision_tower.patch_embed.proj", # kimi-vl
"model.vision.patch_embedding.proj", # cogvlm "model.vision.patch_embedding.proj", # cogvlm
"siglip2.vision_model.embeddings.patch_embedding",
), ),
MODEL_TENSOR.V_ENC_EMBD_NORM: ( MODEL_TENSOR.V_ENC_EMBD_NORM: (
@ -1291,6 +1293,7 @@ class TensorNameMap:
"vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral "vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral
"visual.blocks.{bid}.attn.q", # qwen2vl, generated "visual.blocks.{bid}.attn.q", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated "vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated
"siglip2.vision_model.encoder.layers.{bid}.self_attn.q_proj", # youtuvl
), ),
MODEL_TENSOR.V_ENC_ATTN_Q_NORM: ( MODEL_TENSOR.V_ENC_ATTN_Q_NORM: (
@ -1308,6 +1311,7 @@ class TensorNameMap:
"vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral "vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral
"visual.blocks.{bid}.attn.k", # qwen2vl, generated "visual.blocks.{bid}.attn.k", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated "vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated
"siglip2.vision_model.encoder.layers.{bid}.self_attn.k_proj",
), ),
MODEL_TENSOR.V_ENC_ATTN_K_NORM: ( MODEL_TENSOR.V_ENC_ATTN_K_NORM: (
@ -1325,6 +1329,7 @@ class TensorNameMap:
"vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral "vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral
"visual.blocks.{bid}.attn.v", # qwen2vl, generated "visual.blocks.{bid}.attn.v", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated "vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated
"siglip2.vision_model.encoder.layers.{bid}.self_attn.v_proj",
), ),
MODEL_TENSOR.V_ENC_INPUT_NORM: ( MODEL_TENSOR.V_ENC_INPUT_NORM: (
@ -1339,6 +1344,7 @@ class TensorNameMap:
"visual.blocks.{bid}.norm1", # qwen2vl "visual.blocks.{bid}.norm1", # qwen2vl
"vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1) "vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1)
"model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm "model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm
"siglip2.vision_model.encoder.layers.{bid}.layer_norm1",
), ),
MODEL_TENSOR.V_ENC_ATTN_O: ( MODEL_TENSOR.V_ENC_ATTN_O: (
@ -1354,6 +1360,7 @@ class TensorNameMap:
"visual.blocks.{bid}.attn.proj", # qwen2vl "visual.blocks.{bid}.attn.proj", # qwen2vl
"vision_tower.encoder.blocks.{bid}.wo", # kimi-vl "vision_tower.encoder.blocks.{bid}.wo", # kimi-vl
"model.vision.transformer.layers.{bid}.attention.dense", # cogvlm "model.vision.transformer.layers.{bid}.attention.dense", # cogvlm
"siglip2.vision_model.encoder.layers.{bid}.self_attn.out_proj", # youtuvl
), ),
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
@ -1368,6 +1375,7 @@ class TensorNameMap:
"visual.blocks.{bid}.norm2", # qwen2vl "visual.blocks.{bid}.norm2", # qwen2vl
"vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1) "vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1)
"model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm "model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm
"siglip2.vision_model.encoder.layers.{bid}.layer_norm2",
), ),
MODEL_TENSOR.V_ENC_FFN_UP: ( MODEL_TENSOR.V_ENC_FFN_UP: (
@ -1383,6 +1391,7 @@ class TensorNameMap:
"visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl "visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl
"vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1) "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
"model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm "model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm
"siglip2.vision_model.encoder.layers.{bid}.mlp.fc1",
), ),
MODEL_TENSOR.V_ENC_FFN_GATE: ( MODEL_TENSOR.V_ENC_FFN_GATE: (
@ -1404,6 +1413,7 @@ class TensorNameMap:
"visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl "visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl
"vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1) "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
"model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm "model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm
"siglip2.vision_model.encoder.layers.{bid}.mlp.fc2",
), ),
MODEL_TENSOR.V_LAYER_SCALE_1: ( MODEL_TENSOR.V_LAYER_SCALE_1: (
@ -1430,6 +1440,7 @@ class TensorNameMap:
"visual.merger.ln_q", # qwen2vl "visual.merger.ln_q", # qwen2vl
"vision_tower.encoder.final_layernorm", # kimi-vl "vision_tower.encoder.final_layernorm", # kimi-vl
"visual.post_layernorm", # glm4v "visual.post_layernorm", # glm4v
"siglip2.vision_model.post_layernorm",
), ),
MODEL_TENSOR.V_MM_POST_NORM: ( MODEL_TENSOR.V_MM_POST_NORM: (
@ -1446,6 +1457,7 @@ class TensorNameMap:
"multi_modal_projector.pre_norm", "multi_modal_projector.pre_norm",
"pre_mm_projector_norm", "pre_mm_projector_norm",
"model.vision.linear_proj.norm1", # cogvlm "model.vision.linear_proj.norm1", # cogvlm
"merger.ln_q",
), ),
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: ( MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (

View File

@ -1 +1 @@
130bc125a88bb57664b88932c48c38a1cb316fac ebc3a0f4a56be1c9424a89fbec09962ac34fde85

View File

@ -87,6 +87,7 @@ add_library(llama
models/llada.cpp models/llada.cpp
models/llama-iswa.cpp models/llama-iswa.cpp
models/llama.cpp models/llama.cpp
models/maincoder.cpp
models/mamba.cpp models/mamba.cpp
models/mimo2-iswa.cpp models/mimo2-iswa.cpp
models/minicpm3.cpp models/minicpm3.cpp

View File

@ -118,6 +118,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MISTRAL3, "mistral3" }, { LLM_ARCH_MISTRAL3, "mistral3" },
{ LLM_ARCH_MIMO2, "mimo2" }, { LLM_ARCH_MIMO2, "mimo2" },
{ LLM_ARCH_LLAMA_EMBED, "llama-embed" }, { LLM_ARCH_LLAMA_EMBED, "llama-embed" },
{ LLM_ARCH_MAINCODER, "maincoder" },
{ LLM_ARCH_UNKNOWN, "(unknown)" }, { LLM_ARCH_UNKNOWN, "(unknown)" },
}; };
@ -2234,6 +2235,23 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
return { return {
LLM_TENSOR_TOKEN_EMBD, LLM_TENSOR_TOKEN_EMBD,
}; };
case LLM_ARCH_MAINCODER:
return {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_Q,
LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K,
LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_ATTN_V,
LLM_TENSOR_ATTN_OUT,
LLM_TENSOR_FFN_NORM,
LLM_TENSOR_FFN_GATE,
LLM_TENSOR_FFN_DOWN,
LLM_TENSOR_FFN_UP,
};
default: default:
GGML_ABORT("unknown architecture for tensor mapping"); GGML_ABORT("unknown architecture for tensor mapping");
} }

View File

@ -122,6 +122,7 @@ enum llm_arch {
LLM_ARCH_MISTRAL3, LLM_ARCH_MISTRAL3,
LLM_ARCH_MIMO2, LLM_ARCH_MIMO2,
LLM_ARCH_LLAMA_EMBED, LLM_ARCH_LLAMA_EMBED,
LLM_ARCH_MAINCODER,
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
}; };

View File

@ -74,6 +74,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS }, { "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS },
{ "grok-2", LLM_CHAT_TEMPLATE_GROK_2 }, { "grok-2", LLM_CHAT_TEMPLATE_GROK_2 },
{ "pangu-embedded", LLM_CHAT_TEMPLATE_PANGU_EMBED }, { "pangu-embedded", LLM_CHAT_TEMPLATE_PANGU_EMBED },
{ "solar-open", LLM_CHAT_TEMPLATE_SOLAR_OPEN },
}; };
llm_chat_template llm_chat_template_from_str(const std::string & name) { llm_chat_template llm_chat_template_from_str(const std::string & name) {
@ -216,6 +217,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_GROK_2; return LLM_CHAT_TEMPLATE_GROK_2;
} else if (tmpl_contains(LU8("[unused9]系统:[unused10]"))) { } else if (tmpl_contains(LU8("[unused9]系统:[unused10]"))) {
return LLM_CHAT_TEMPLATE_PANGU_EMBED; return LLM_CHAT_TEMPLATE_PANGU_EMBED;
} else if (tmpl_contains("<|begin|>") && tmpl_contains("<|end|>") && tmpl_contains("<|content|>")) {
return LLM_CHAT_TEMPLATE_SOLAR_OPEN;
} }
return LLM_CHAT_TEMPLATE_UNKNOWN; return LLM_CHAT_TEMPLATE_UNKNOWN;
} }
@ -845,6 +848,14 @@ int32_t llm_chat_apply_template(
if (add_ass) { if (add_ass) {
ss << "[unused9]助手:"; ss << "[unused9]助手:";
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_SOLAR_OPEN) {
for (auto message : chat) {
std::string role(message->role);
ss << "<|begin|>" << role << "<|content|>" << message->content << "<|end|>";
}
if (add_ass) {
ss << "<|begin|>assistant";
}
} else { } else {
// template not supported // template not supported
return -1; return -1;

View File

@ -54,6 +54,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_SEED_OSS, LLM_CHAT_TEMPLATE_SEED_OSS,
LLM_CHAT_TEMPLATE_GROK_2, LLM_CHAT_TEMPLATE_GROK_2,
LLM_CHAT_TEMPLATE_PANGU_EMBED, LLM_CHAT_TEMPLATE_PANGU_EMBED,
LLM_CHAT_TEMPLATE_SOLAR_OPEN,
LLM_CHAT_TEMPLATE_UNKNOWN, LLM_CHAT_TEMPLATE_UNKNOWN,
}; };

View File

@ -1458,7 +1458,7 @@ ggml_cgraph * llama_context::graph_reserve(
if (n_tokens % n_seqs != 0) { if (n_tokens % n_seqs != 0) {
n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
n_outputs = std::min(n_outputs, n_tokens); n_outputs = std::max(n_outputs, n_tokens);
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
} }

View File

@ -240,9 +240,10 @@ struct llama_file::impl {
throw std::runtime_error("unexpectedly reached end of file"); throw std::runtime_error("unexpectedly reached end of file");
} }
} else { } else {
bool successful = false; size_t bytes_read = 0;
while (!successful) { while (bytes_read < len) {
off_t ret = read(fd, ptr, len); const size_t to_read = len - bytes_read;
ssize_t ret = ::read(fd, reinterpret_cast<char *>(ptr) + bytes_read, to_read);
if (ret == -1) { if (ret == -1) {
if (errno == EINTR) { if (errno == EINTR) {
@ -251,10 +252,16 @@ struct llama_file::impl {
throw std::runtime_error(format("read error: %s", strerror(errno))); throw std::runtime_error(format("read error: %s", strerror(errno)));
} }
if (ret == 0) { if (ret == 0) {
// EOF: allow if this read was only pulling alignment padding past file end
off_t pos = lseek(fd, 0, SEEK_CUR);
if (pos != -1 && (size_t) pos == size) {
std::memset(reinterpret_cast<char *>(ptr) + bytes_read, 0, len - bytes_read);
return;
}
throw std::runtime_error("unexpectedly reached end of file"); throw std::runtime_error("unexpectedly reached end of file");
} }
successful = true; bytes_read += (size_t) ret;
} }
} }
} }

View File

@ -126,6 +126,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; case LLM_TYPE_31B_A3_5B: return "31B.A3.5B";
case LLM_TYPE_80B_A3B: return "80B.A3B"; case LLM_TYPE_80B_A3B: return "80B.A3B";
case LLM_TYPE_100B_A6B: return "100B.A6B"; case LLM_TYPE_100B_A6B: return "100B.A6B";
case LLM_TYPE_102B_A12B: return "102B.A12B";
case LLM_TYPE_106B_A12B: return "106B.A12B"; case LLM_TYPE_106B_A12B: return "106B.A12B";
case LLM_TYPE_230B_A10B: return "230B.A10B"; case LLM_TYPE_230B_A10B: return "230B.A10B";
case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_235B_A22B: return "235B.A22B";
@ -1109,6 +1110,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN; default: type = LLM_TYPE_UNKNOWN;
} }
} break; } break;
case LLM_ARCH_MAINCODER:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 32: type = LLM_TYPE_1B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_QWEN3VL: case LLM_ARCH_QWEN3VL:
{ {
ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false);
@ -1682,7 +1691,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
@ -1778,6 +1787,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer)
case 48: type = LLM_TYPE_102B_A12B; break; // Solar Open
case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer)
default: type = LLM_TYPE_UNKNOWN; default: type = LLM_TYPE_UNKNOWN;
} }
@ -3320,7 +3330,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, layer.ffn_gate ? n_ff : n_ff * 2}, 0);
const auto tn_ffn_up_weight = tn(LLM_TENSOR_FFN_UP, "weight", i);
ggml_tensor * t_ffn_up = ml.get_tensor_meta(tn_ffn_up_weight.str().c_str());
const int64_t n_ffn_up = t_ffn_up ? t_ffn_up->ne[1] : n_ff;
GGML_ASSERT(n_ffn_up == n_ff || n_ffn_up == n_ff * 2);
layer.ffn_up = create_tensor(tn_ffn_up_weight, {n_embd, n_ffn_up}, 0);
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ffn_up}, TENSOR_NOT_REQUIRED);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
@ -4776,7 +4793,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// output // output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); // try to load output.weight, if not found, use token_embd (tied embeddings)
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
if (!output) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
for (int i = 0; i < n_layer; ++i) { for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i]; auto & layer = layers[i];
@ -4839,7 +4860,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// output // output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); // try to load output.weight, if not found, use token_embd (tied embeddings)
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
if (!output) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
for (int i = 0; i < n_layer; ++i) { for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i]; auto & layer = layers[i];
@ -5206,9 +5231,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, flags); layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, flags);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, flags); layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, flags);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, flags); layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, flags);
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, flags); layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, TENSOR_NOT_REQUIRED | flags);
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, flags); layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, TENSOR_NOT_REQUIRED | flags);
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, flags); layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, TENSOR_NOT_REQUIRED | flags);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags);
@ -6761,6 +6786,37 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
} }
} break; } break;
case LLM_ARCH_MAINCODER:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
default: default:
throw std::runtime_error("unknown architecture"); throw std::runtime_error("unknown architecture");
} }
@ -7406,6 +7462,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{ {
llm = std::make_unique<llm_build_llama<true>>(*this, params); llm = std::make_unique<llm_build_llama<true>>(*this, params);
} break; } break;
case LLM_ARCH_MAINCODER:
{
llm = std::make_unique<llm_build_maincoder>(*this, params);
} break;
case LLM_ARCH_DECI: case LLM_ARCH_DECI:
{ {
llm = std::make_unique<llm_build_deci>(*this, params); llm = std::make_unique<llm_build_deci>(*this, params);
@ -7440,7 +7500,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
} break; } break;
case LLM_ARCH_MODERN_BERT: case LLM_ARCH_MODERN_BERT:
{ {
llm = std::make_unique<llm_build_modern_bert<true>>(*this, params); llm = std::make_unique<llm_build_modern_bert>(*this, params);
} break; } break;
case LLM_ARCH_NEO_BERT: case LLM_ARCH_NEO_BERT:
{ {
@ -8014,6 +8074,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_ERNIE4_5_MOE: case LLM_ARCH_ERNIE4_5_MOE:
case LLM_ARCH_MISTRAL3: case LLM_ARCH_MISTRAL3:
case LLM_ARCH_LLAMA_EMBED: case LLM_ARCH_LLAMA_EMBED:
case LLM_ARCH_MAINCODER:
return LLAMA_ROPE_TYPE_NORM; return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2 // the pairs of head values are offset by n_rot/2

View File

@ -119,6 +119,7 @@ enum llm_type {
LLM_TYPE_31B_A3_5B, LLM_TYPE_31B_A3_5B,
LLM_TYPE_80B_A3B, // Qwen3 Next LLM_TYPE_80B_A3B, // Qwen3 Next
LLM_TYPE_100B_A6B, LLM_TYPE_100B_A6B,
LLM_TYPE_102B_A12B, // Solar-Open
LLM_TYPE_106B_A12B, // GLM-4.5-Air LLM_TYPE_106B_A12B, // GLM-4.5-Air
LLM_TYPE_230B_A10B, // Minimax M2 LLM_TYPE_230B_A10B, // Minimax M2
LLM_TYPE_235B_A22B, LLM_TYPE_235B_A22B,

View File

@ -314,6 +314,12 @@ struct llm_tokenizer_bpe : llm_tokenizer {
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
}; };
break; break;
case LLAMA_VOCAB_PRE_TYPE_YOUTU:
regex_exprs = {
"[가-힣ㄱ-ㆎ]+|[!…“”‘’—:;,、-〿︰-]+|[ㄅ-ㄯ]+|[一-龥぀-ゟ゠-ヿ]+",
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
};
break;
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER: case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
regex_exprs = { regex_exprs = {
"[\r\n]", "[\r\n]",
@ -355,6 +361,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
case LLAMA_VOCAB_PRE_TYPE_STABLELM2: case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
case LLAMA_VOCAB_PRE_TYPE_QWEN2: case LLAMA_VOCAB_PRE_TYPE_QWEN2:
case LLAMA_VOCAB_PRE_TYPE_HUNYUAN: case LLAMA_VOCAB_PRE_TYPE_HUNYUAN:
case LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN:
regex_exprs = { regex_exprs = {
// original regex from tokenizer.json // original regex from tokenizer.json
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
@ -1860,6 +1867,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "deepseek-v3") { tokenizer_pre == "deepseek-v3") {
pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM; pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM;
clean_spaces = false; clean_spaces = false;
} else if (
tokenizer_pre == "youtu") {
pre_type = LLAMA_VOCAB_PRE_TYPE_YOUTU;
clean_spaces = false;
ignore_merges = true;
} else if ( } else if (
tokenizer_pre == "falcon") { tokenizer_pre == "falcon") {
pre_type = LLAMA_VOCAB_PRE_TYPE_FALCON; pre_type = LLAMA_VOCAB_PRE_TYPE_FALCON;
@ -2015,6 +2027,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "minimax-m2") { tokenizer_pre == "minimax-m2") {
pre_type = LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2; pre_type = LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2;
clean_spaces = false; clean_spaces = false;
} else if (
tokenizer_pre == "solar-open") {
pre_type = LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN;
clean_spaces = false;
} else { } else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
} }
@ -2187,6 +2203,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
// for now, we apply this workaround to find the tokens based on their text // for now, we apply this workaround to find the tokens based on their text
for (const auto & t : token_to_id) { for (const auto & t : token_to_id) {
auto & attr = id_to_token[t.second].attr;
// find EOT token: "<|eot_id|>", "<|im_end|>", "<end_of_turn>", etc. // find EOT token: "<|eot_id|>", "<|im_end|>", "<end_of_turn>", etc.
if (special_eot_id == LLAMA_TOKEN_NULL) { if (special_eot_id == LLAMA_TOKEN_NULL) {
if (false if (false
@ -2202,10 +2220,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|| t.first == "<end_of_utterance>" // smoldocling || t.first == "<end_of_utterance>" // smoldocling
) { ) {
special_eot_id = t.second; special_eot_id = t.second;
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
__func__, t.second, t.first.c_str()); __func__, t.second, t.first.c_str());
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
} }
} }
} }
@ -2216,10 +2234,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|| t.first == "<|eom_id|>" || t.first == "<|eom_id|>"
) { ) {
special_eom_id = t.second; special_eom_id = t.second;
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
__func__, t.second, t.first.c_str()); __func__, t.second, t.first.c_str());
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
} }
} }
} }
@ -2236,10 +2254,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|| t.first == "<|code_prefix|>" // GLM-4.5 || t.first == "<|code_prefix|>" // GLM-4.5
) { ) {
special_fim_pre_id = t.second; special_fim_pre_id = t.second;
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
__func__, t.second, t.first.c_str()); __func__, t.second, t.first.c_str());
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
} }
} }
} }
@ -2256,10 +2274,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|| t.first == "<|code_suffix|>" // GLM-4.5 || t.first == "<|code_suffix|>" // GLM-4.5
) { ) {
special_fim_suf_id = t.second; special_fim_suf_id = t.second;
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
__func__, t.second, t.first.c_str()); __func__, t.second, t.first.c_str());
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
} }
} }
} }
@ -2276,10 +2294,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|| t.first == "<|code_middle|>" // GLM-4.5 || t.first == "<|code_middle|>" // GLM-4.5
) { ) {
special_fim_mid_id = t.second; special_fim_mid_id = t.second;
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
__func__, t.second, t.first.c_str()); __func__, t.second, t.first.c_str());
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
} }
} }
} }
@ -2293,10 +2311,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|| t.first == "<PAD>" || t.first == "<PAD>"
) { ) {
special_fim_pad_id = t.second; special_fim_pad_id = t.second;
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
__func__, t.second, t.first.c_str()); __func__, t.second, t.first.c_str());
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
} }
} }
} }
@ -2311,10 +2329,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|| t.first == "<reponame>" // Granite || t.first == "<reponame>" // Granite
) { ) {
special_fim_rep_id = t.second; special_fim_rep_id = t.second;
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
__func__, t.second, t.first.c_str()); __func__, t.second, t.first.c_str());
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
} }
} }
} }
@ -2325,15 +2343,41 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|| t.first == "<|file_sep|>" // Qwen || t.first == "<|file_sep|>" // Qwen
) { ) {
special_fim_sep_id = t.second; special_fim_sep_id = t.second;
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
__func__, t.second, t.first.c_str()); __func__, t.second, t.first.c_str());
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
} }
} }
} }
} }
// auto-detect unused tokens: e.g. control tokens with the word "unused"
// ideally, these tokens should be marked as unused during conversion
{
uint32_t n_unused = 0;
for (const auto & t : token_to_id) {
auto & attr = id_to_token[t.second].attr;
if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
continue;
}
if ((attr & LLAMA_TOKEN_ATTR_UNUSED) == 0) {
if (strstr(t.first.c_str(), "unused") != NULL) {
attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_UNUSED);
}
}
if (attr & LLAMA_TOKEN_ATTR_UNUSED) {
n_unused++;
}
}
LLAMA_LOG_INFO("%s: %u unused tokens\n", __func__, n_unused);
}
// maintain a list of tokens that cause end-of-generation // maintain a list of tokens that cause end-of-generation
// this is currently determined based on the token text, which is obviously not ideal // this is currently determined based on the token text, which is obviously not ideal
// ref: https://github.com/ggerganov/llama.cpp/issues/9606 // ref: https://github.com/ggerganov/llama.cpp/issues/9606
@ -2352,12 +2396,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
} }
for (const auto & t : token_to_id) { for (const auto & t : token_to_id) {
auto & attr = id_to_token[t.second].attr;
if (false if (false
|| t.first == "<|eot_id|>" || t.first == "<|eot_id|>"
|| t.first == "<|im_end|>" || t.first == "<|im_end|>"
|| t.first == "<|end|>" || t.first == "<|end|>"
|| t.first == "<|return|>" // o200k_harmony || t.first == "<|return|>" // o200k_harmony
|| t.first == "<|call|>" // o200k_harmony || t.first == "<|call|>" // o200k_harmony
|| t.first == "<|flush|>" // solar-open
|| t.first == "<|calls|>" // solar-open
|| t.first == "<end_of_turn>" || t.first == "<end_of_turn>"
|| t.first == "<|endoftext|>" || t.first == "<|endoftext|>"
|| t.first == "<|eom_id|>" || t.first == "<|eom_id|>"
@ -2367,24 +2415,28 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|| t.first == "<end_of_utterance>" // smoldocling || t.first == "<end_of_utterance>" // smoldocling
) { ) {
special_eog_ids.insert(t.second); special_eog_ids.insert(t.second);
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
__func__, t.second, t.first.c_str()); __func__, t.second, t.first.c_str());
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
} }
} else { } else {
// token is control, but not marked as EOG -> print a debug log if (attr & LLAMA_TOKEN_ATTR_CONTROL && !(attr & LLAMA_TOKEN_ATTR_UNUSED)) {
if (id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && special_eog_ids.count(t.second) == 0) { // token is control, but not marked as EOG -> print a debug log
LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n", if (special_eog_ids.count(t.second) == 0) {
__func__, t.second, t.first.c_str()); LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n",
__func__, t.second, t.first.c_str());
}
} }
} }
} }
// @ngxson : quick hack for gpt-oss, always render these tokens // @ngxson : quick hack for gpt-oss, always render these tokens
for (const auto & t : token_to_id) { for (const auto & t : token_to_id) {
auto & attr = id_to_token[t.second].attr;
if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>" || t.first == "<|constrain|>") { if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>" || t.first == "<|constrain|>") {
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_USER_DEFINED; attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_USER_DEFINED);
} }
} }
@ -2404,34 +2456,42 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__); LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
} }
// TODO: workaround for o200k_harmony tokenizer: the "<|end|>" token should not be EOG // TODO: workaround for o200k_harmony and solar-open tokenizer: the "<|end|>" token should not be EOG
// we don't have a good way to detect this, so for now, if we have "<|return|>" and "<|call|>" tokens, // we don't have a good way to detect this, so for now, if we have "<|return|>" and "<|call|>" tokens ("<|calls|>" and "<|flush|>" for solar-open),
// we remove the "<|end|>" token from the EOG list // we remove the "<|end|>" token from the EOG list
{ {
bool has_return = false; bool has_return = false;
bool has_call = false; bool has_call = false;
bool has_end = false; bool has_end = false;
bool has_flush = false;
llama_token end_id = LLAMA_TOKEN_NULL; llama_token end_id = LLAMA_TOKEN_NULL;
LLAMA_LOG_INFO("%s: printing all EOG tokens:\n", __func__); LLAMA_LOG_INFO("%s: printing all EOG tokens:\n", __func__);
for (auto tid : special_eog_ids) { for (auto tid : special_eog_ids) {
LLAMA_LOG_INFO("%s: - %d ('%s')\n", __func__, tid, id_to_token[tid].text.c_str()); auto & text = id_to_token[tid].text;
if (id_to_token[tid].text == "<|return|>") { LLAMA_LOG_INFO("%s: - %d ('%s')\n", __func__, tid, text.c_str());
if (text == "<|return|>") {
has_return = true; has_return = true;
} else if (id_to_token[tid].text == "<|call|>") { } else if (text == "<|call|>" || text == "<|calls|>") {
has_call = true; has_call = true;
} else if (id_to_token[tid].text == "<|end|>") { } else if (text == "<|flush|>") {
has_flush = true;
} else if (text == "<|end|>") {
has_end = true; has_end = true;
end_id = tid; end_id = tid;
} }
} }
if (has_return && has_call && has_end) { if ((has_return && has_call && has_end) || (has_call && has_flush && has_end)) {
special_eog_ids.erase(end_id); special_eog_ids.erase(end_id);
id_to_token[end_id].attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>' tokens, removing '<|end|>' token from EOG list\n", __func__); auto & attr = id_to_token[end_id].attr;
attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_USER_DEFINED);
LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>', or '<|calls|>' and '<|flush|>' tokens, removing '<|end|>' token from EOG list\n", __func__);
} }
} }
} }

View File

@ -51,6 +51,8 @@ enum llama_vocab_pre_type {
LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40,
LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41, LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41,
LLAMA_VOCAB_PRE_TYPE_AFMOE = 42, LLAMA_VOCAB_PRE_TYPE_AFMOE = 42,
LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43,
LLAMA_VOCAB_PRE_TYPE_YOUTU = 44,
}; };
struct LLM_KV; struct LLM_KV;

View File

@ -142,11 +142,13 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params
LLM_FFN_GELU, LLM_FFN_SEQ, il); LLM_FFN_GELU, LLM_FFN_SEQ, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} else if (model.arch == LLM_ARCH_JINA_BERT_V2) { } else if (model.arch == LLM_ARCH_JINA_BERT_V2) {
const bool up_contains_gate = !model.layers[il].ffn_gate && model.layers[il].ffn_up->ne[1] != hparams.n_ff();
auto type_op = up_contains_gate ? LLM_FFN_GEGLU : LLM_FFN_GELU;
cur = build_ffn(cur, cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL,
model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_GEGLU, LLM_FFN_PAR, il); type_op, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} else { } else {
cur = build_ffn(cur, cur = build_ffn(cur,

View File

@ -3,12 +3,14 @@
llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_params & params) : llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) { llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_head = hparams.n_embd_head_v;
float kq_scale = 1.0f / sqrtf(float(n_embd_head)); const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot); GGML_ASSERT(n_embd_head == hparams.n_rot);
ggml_tensor *inpL, *cur; ggml_tensor * inpL;
ggml_tensor * cur;
inpL = build_inp_embd(model.tok_embd); inpL = build_inp_embd(model.tok_embd);
ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_pos = build_inp_pos();
@ -44,7 +46,7 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_pa
} }
ggml_tensor * inpSA = inpL; ggml_tensor * inpSA = inpL;
cur = build_norm(inpSA, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cur = build_norm(inpSA, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
// build self attention // build self attention
{ {

View File

@ -215,7 +215,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
model.layers[il].ffn_exp_probs_b, model.layers[il].ffn_exp_probs_b,
n_expert, n_expert_used, n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm, LLM_FFN_SILU, hparams.expert_weights_norm,
true, hparams.expert_weights_scale, hparams.expert_weights_scale, hparams.expert_weights_scale,
(llama_expert_gating_func_type) hparams.expert_gating_func, (llama_expert_gating_func_type) hparams.expert_gating_func,
il); il);
cb(moe_out, "ffn_moe_out", il); cb(moe_out, "ffn_moe_out", il);

View File

@ -1,7 +1,5 @@
#include "models.h" #include "models.h"
llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) : llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) { llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_k; const int64_t n_embd_head = hparams.n_embd_head_k;
@ -12,10 +10,8 @@ llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model,
inpL = build_inp_embd(model.tok_embd); inpL = build_inp_embd(model.tok_embd);
// important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
if (ubatch.token) { inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f);
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); cb(inpL, "inp_scaled", -1);
cb(inpL, "inp_scaled", -1);
}
// inp_pos - contains the positions // inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_pos = build_inp_pos();

View File

@ -10,10 +10,9 @@ llm_build_gemma3<iswa>::llm_build_gemma3(const llama_model & model, const llm_gr
inpL = build_inp_embd(model.tok_embd); inpL = build_inp_embd(model.tok_embd);
// important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
if (ubatch.token) { inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f);
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); cb(inpL, "inp_scaled", -1);
cb(inpL, "inp_scaled", -1);
}
// inp_pos - contains the positions // inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_pos = build_inp_pos();

View File

@ -1,7 +1,5 @@
#include "models.h" #include "models.h"
llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) : llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params), llm_graph_context(params),
model(model), model(model),
@ -15,10 +13,9 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
inpL = build_inp_embd(model.tok_embd); inpL = build_inp_embd(model.tok_embd);
// important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
if (ubatch.token) { inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f);
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); cb(inpL, "inp_scaled", -1);
cb(inpL, "inp_scaled", -1);
}
// inp_pos - contains the positions // inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_pos = build_inp_pos();
@ -248,7 +245,7 @@ ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) {
// equivalent to get_per_layer_inputs() in python code // equivalent to get_per_layer_inputs() in python code
// output shape: [n_embd_altup, n_layer, n_tokens] // output shape: [n_embd_altup, n_layer, n_tokens]
ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
auto inp = std::make_unique<llm_graph_input_embd>(); auto inp = std::make_unique<llm_graph_input_embd>();
ggml_tensor * inp_per_layer; ggml_tensor * inp_per_layer;
if (ubatch.token) { if (ubatch.token) {
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);

117
src/models/maincoder.cpp Normal file
View File

@ -0,0 +1,117 @@
#include "models.h"
llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
// norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// self-attention
{
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
cb(Kcur, "Kcur_normed", il);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}

View File

@ -312,6 +312,10 @@ struct llm_build_llama_iswa : public llm_graph_context {
llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params); llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params);
}; };
struct llm_build_maincoder : public llm_graph_context {
llm_build_maincoder(const llama_model & model, const llm_graph_params & params);
};
struct llm_build_mamba : public llm_graph_context_mamba { struct llm_build_mamba : public llm_graph_context_mamba {
llm_build_mamba(const llama_model & model, const llm_graph_params & params); llm_build_mamba(const llama_model & model, const llm_graph_params & params);
}; };
@ -332,7 +336,6 @@ struct llm_build_mistral3 : public llm_graph_context {
llm_build_mistral3(const llama_model & model, const llm_graph_params & params); llm_build_mistral3(const llama_model & model, const llm_graph_params & params);
}; };
template <bool iswa>
struct llm_build_modern_bert : public llm_graph_context { struct llm_build_modern_bert : public llm_graph_context {
llm_build_modern_bert(const llama_model & model, const llm_graph_params & params); llm_build_modern_bert(const llama_model & model, const llm_graph_params & params);
}; };

View File

@ -1,7 +1,6 @@
#include "models.h" #include "models.h"
template <bool iswa> llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
llm_build_modern_bert<iswa>::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
@ -24,13 +23,7 @@ llm_build_modern_bert<iswa>::llm_build_modern_bert(const llama_model & model, co
auto * inp_attn = build_attn_inp_no_cache(); auto * inp_attn = build_attn_inp_no_cache();
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
float freq_base_l = 0.0f; float freq_base_l = model.get_rope_freq_base(cparams, il);
if constexpr (iswa) {
freq_base_l = model.get_rope_freq_base(cparams, il);
} else {
freq_base_l = freq_base;
}
cur = inpL; cur = inpL;
@ -120,7 +113,3 @@ llm_build_modern_bert<iswa>::llm_build_modern_bert(const llama_model & model, co
res->t_embd = cur; res->t_embd = cur;
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
} }
// Explicit template instantiations
template struct llm_build_modern_bert<false>;
template struct llm_build_modern_bert<true>;

View File

@ -964,6 +964,11 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
{ "\\p{P}", unicode_cpt_flags::PUNCTUATION }, { "\\p{P}", unicode_cpt_flags::PUNCTUATION },
{ "\\p{M}", unicode_cpt_flags::ACCENT_MARK }, { "\\p{M}", unicode_cpt_flags::ACCENT_MARK },
{ "\\p{S}", unicode_cpt_flags::SYMBOL }, { "\\p{S}", unicode_cpt_flags::SYMBOL },
{ "\\p{Lu}", unicode_cpt_flags::LETTER }, // Uppercase letter
{ "\\p{Ll}", unicode_cpt_flags::LETTER }, // Lowercase letter
{ "\\p{Lt}", unicode_cpt_flags::LETTER }, // Titlecase letter
{ "\\p{Lm}", unicode_cpt_flags::LETTER }, // Modifier letter
{ "\\p{Lo}", unicode_cpt_flags::LETTER }, // Other letter
}; };
static const std::map<int, int> k_ucat_cpt = { static const std::map<int, int> k_ucat_cpt = {
@ -1074,22 +1079,26 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
continue; continue;
} }
if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() && // Match \p{...} Unicode properties of varying lengths
if (regex_expr[i + 0] == '\\' && i + 3 < regex_expr.size() &&
regex_expr[i + 1] == 'p' && regex_expr[i + 1] == 'p' &&
regex_expr[i + 2] == '{' && regex_expr[i + 2] == '{') {
regex_expr[i + 4] == '}') { // Find the closing brace
const std::string pat = regex_expr.substr(i, 5); size_t closing_brace = regex_expr.find('}', i + 3);
if (k_ucat_enum.find(pat) != k_ucat_enum.end()) { if (closing_brace != std::string::npos && closing_brace <= i + 10) { // reasonable limit
if (!inside) { const std::string pat = regex_expr.substr(i, closing_brace - i + 1);
regex_expr_collapsed += '['; if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
if (!inside) {
regex_expr_collapsed += '[';
}
regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
if (!inside) {
regex_expr_collapsed += ']';
}
i = closing_brace;
continue;
} }
regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
if (!inside) {
regex_expr_collapsed += ']';
}
i += 4;
continue;
} }
} }

View File

@ -1158,6 +1158,7 @@ struct test_case {
} }
virtual bool run_whole_graph() { return false; } virtual bool run_whole_graph() { return false; }
virtual std::vector<ggml_tensor *> fusion_test_nodes() { return {}; }
ggml_cgraph * gf = nullptr; ggml_cgraph * gf = nullptr;
ggml_cgraph * gb = nullptr; ggml_cgraph * gb = nullptr;
@ -1391,7 +1392,13 @@ struct test_case {
GGML_UNUSED(index); GGML_UNUSED(index);
}; };
const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud, run_whole_graph() ? out : nullptr); std::vector<ggml_tensor *> fused_nodes_to_verify = fusion_test_nodes();
if (fused_nodes_to_verify.size() == 0 && run_whole_graph()) {
fused_nodes_to_verify.push_back(out);
}
const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud,
run_whole_graph() ? fused_nodes_to_verify.data() : nullptr,
fused_nodes_to_verify.size());
ggml_backend_buffer_free(buf); ggml_backend_buffer_free(buf);
@ -5180,6 +5187,8 @@ struct test_topk_moe : public test_case {
const bool bias_probs; const bool bias_probs;
const MoeGatingFunc gating_func; const MoeGatingFunc gating_func;
const float scale_w; const float scale_w;
ggml_tensor * weights {};
ggml_tensor * selected_experts {};
test_topk_moe(std::array<int64_t, 4> ne = { 10, 5, 1, 1 }, test_topk_moe(std::array<int64_t, 4> ne = { 10, 5, 1, 1 },
int n_expert_used = 1, int n_expert_used = 1,
@ -5217,16 +5226,16 @@ struct test_topk_moe : public test_case {
ggml_tensor * selection_probs = probs; ggml_tensor * selection_probs = probs;
if (bias_probs) { if (bias_probs) {
ggml_tensor * exp_probs_b = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data()); ggml_tensor * exp_probs_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]);
ggml_set_name(exp_probs_b, "exp_probs_b"); ggml_set_name(exp_probs_b, "exp_probs_b");
selection_probs = ggml_add(ctx, probs, exp_probs_b); selection_probs = ggml_add(ctx, probs, exp_probs_b);
ggml_set_name(selection_probs, "selection_probs"); ggml_set_name(selection_probs, "selection_probs");
} }
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens] selected_experts = ggml_argsort_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
ggml_set_name(selected_experts, "selected_experts"); ggml_set_name(selected_experts, "selected_experts");
ggml_tensor * weights = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] weights = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
ggml_set_name(weights, "weights"); ggml_set_name(weights, "weights");
if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) { if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {
@ -5252,6 +5261,21 @@ struct test_topk_moe : public test_case {
ggml_set_name(weights, "weights"); ggml_set_name(weights, "weights");
return weights; return weights;
} }
// Verify two outputs
std::vector<ggml_tensor *> fusion_test_nodes() override { return { selected_experts, weights }; }
// allow output in arbitrary order
double err(const float * a, const float * b, size_t n) override {
std::vector<float> a2(n);
std::vector<float> b2(n);
for (size_t i = 0; i < n; ++i) {
a2[i] = a[i];
b2[i] = b[i];
}
std::sort(a2.begin(), a2.end());
std::sort(b2.begin(), b2.end());
return nmse(a2.data(), b2.data(), n);
}
}; };
struct test_mul_mat_vec_fusion : public test_case { struct test_mul_mat_vec_fusion : public test_case {

View File

@ -724,6 +724,30 @@ static void test_tools_oaicompat_json_conversion() {
"]" "]"
), ),
common_chat_tools_to_json_oaicompat<json>({special_function_tool}).dump(2)); common_chat_tools_to_json_oaicompat<json>({special_function_tool}).dump(2));
{
auto tools_no_params = common_chat_tools_parse_oaicompat(json::parse(
R"([{"type": "function", "function": {"name": "test_func", "description": "A test"}}])"));
assert_equals((size_t) 1, tools_no_params.size());
assert_equals(std::string("test_func"), tools_no_params[0].name);
assert_equals(std::string("A test"), tools_no_params[0].description);
assert_equals(std::string("{}"), tools_no_params[0].parameters);
}
{
auto tools_no_desc = common_chat_tools_parse_oaicompat(json::parse(
R"([{"type": "function", "function": {"name": "test_func", "parameters": {"type": "object"}}}])"));
assert_equals((size_t) 1, tools_no_desc.size());
assert_equals(std::string("test_func"), tools_no_desc[0].name);
assert_equals(std::string(""), tools_no_desc[0].description);
}
{
auto tools_minimal = common_chat_tools_parse_oaicompat(json::parse(
R"([{"type": "function", "function": {"name": "test_func"}}])"));
assert_equals((size_t) 1, tools_minimal.size());
assert_equals(std::string("test_func"), tools_minimal[0].name);
assert_equals(std::string(""), tools_minimal[0].description);
assert_equals(std::string("{}"), tools_minimal[0].parameters);
}
} }
static void test_template_output_parsers() { static void test_template_output_parsers() {

View File

@ -27,6 +27,7 @@ add_library(mtmd
models/qwen3vl.cpp models/qwen3vl.cpp
models/siglip.cpp models/siglip.cpp
models/whisper-enc.cpp models/whisper-enc.cpp
models/youtuvl.cpp
) )
set_target_properties(mtmd PROPERTIES set_target_properties(mtmd PROPERTIES

View File

@ -45,13 +45,14 @@
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size" #define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
#define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers" #define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers"
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" #define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern" #define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern"
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size" #define KEY_WIN_ATTN_LAYER_INDEXES "clip.vision.wa_layer_indexes"
#define KEY_MINICPMV_VERSION "clip.minicpmv_version" #define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num" #define KEY_MINICPMV_VERSION "clip.minicpmv_version"
#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num"
// audio-specific // audio-specific
#define KEY_AUDIO_PROJ_TYPE "clip.audio.projector_type" // for models with mixed modalities #define KEY_AUDIO_PROJ_TYPE "clip.audio.projector_type" // for models with mixed modalities
@ -180,6 +181,7 @@ enum projector_type {
PROJECTOR_TYPE_GLMA, PROJECTOR_TYPE_GLMA,
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
PROJECTOR_TYPE_VOXTRAL, PROJECTOR_TYPE_VOXTRAL,
PROJECTOR_TYPE_MUSIC_FLAMINGO,
PROJECTOR_TYPE_LFM2, PROJECTOR_TYPE_LFM2,
PROJECTOR_TYPE_KIMIVL, PROJECTOR_TYPE_KIMIVL,
PROJECTOR_TYPE_LIGHTONOCR, PROJECTOR_TYPE_LIGHTONOCR,
@ -187,6 +189,7 @@ enum projector_type {
PROJECTOR_TYPE_JANUS_PRO, PROJECTOR_TYPE_JANUS_PRO,
PROJECTOR_TYPE_LFM2A, PROJECTOR_TYPE_LFM2A,
PROJECTOR_TYPE_GLM4V, PROJECTOR_TYPE_GLM4V,
PROJECTOR_TYPE_YOUTUVL,
PROJECTOR_TYPE_UNKNOWN, PROJECTOR_TYPE_UNKNOWN,
}; };
@ -209,6 +212,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_GLMA, "glma"}, { PROJECTOR_TYPE_GLMA, "glma"},
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"}, { PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"}, { PROJECTOR_TYPE_VOXTRAL, "voxtral"},
{ PROJECTOR_TYPE_MUSIC_FLAMINGO, "musicflamingo"},
{ PROJECTOR_TYPE_LFM2, "lfm2"}, { PROJECTOR_TYPE_LFM2, "lfm2"},
{ PROJECTOR_TYPE_KIMIVL, "kimivl"}, { PROJECTOR_TYPE_KIMIVL, "kimivl"},
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
@ -216,6 +220,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
{ PROJECTOR_TYPE_LFM2A, "lfm2a"}, { PROJECTOR_TYPE_LFM2A, "lfm2a"},
{ PROJECTOR_TYPE_GLM4V, "glm4v"}, { PROJECTOR_TYPE_GLM4V, "glm4v"},
{ PROJECTOR_TYPE_YOUTUVL, "youtuvl"},
}; };
static projector_type clip_projector_type_from_string(const std::string & str) { static projector_type clip_projector_type_from_string(const std::string & str) {

View File

@ -61,6 +61,7 @@ struct clip_hparams {
std::unordered_set<int32_t> vision_feature_layer; std::unordered_set<int32_t> vision_feature_layer;
int32_t attn_window_size = 0; int32_t attn_window_size = 0;
int32_t n_wa_pattern = 0; int32_t n_wa_pattern = 0;
std::unordered_set<int32_t> wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL)
// audio // audio
int32_t n_mel_bins = 0; // whisper preprocessor int32_t n_mel_bins = 0; // whisper preprocessor
@ -319,7 +320,8 @@ struct clip_model {
bool audio_has_avgpool() const { bool audio_has_avgpool() const {
return proj_type == PROJECTOR_TYPE_QWEN2A return proj_type == PROJECTOR_TYPE_QWEN2A
|| proj_type == PROJECTOR_TYPE_VOXTRAL; || proj_type == PROJECTOR_TYPE_VOXTRAL
|| proj_type == PROJECTOR_TYPE_MUSIC_FLAMINGO;
} }
bool audio_has_stack_frames() const { bool audio_has_stack_frames() const {

View File

@ -818,6 +818,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_QWEN2A: case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_GLMA: case PROJECTOR_TYPE_GLMA:
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
{ {
builder = std::make_unique<clip_graph_whisper_enc>(ctx, img); builder = std::make_unique<clip_graph_whisper_enc>(ctx, img);
} break; } break;
@ -845,6 +846,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{ {
builder = std::make_unique<clip_graph_glm4v>(ctx, img); builder = std::make_unique<clip_graph_glm4v>(ctx, img);
} break; } break;
case PROJECTOR_TYPE_YOUTUVL:
{
builder = std::make_unique<clip_graph_youtuvl>(ctx, img);
} break;
default: default:
GGML_ABORT("missing cgraph builder"); GGML_ABORT("missing cgraph builder");
} }
@ -1158,6 +1163,20 @@ struct clip_model_loader {
LOG_WRN("%s: more info: https://github.com/ggml-org/llama.cpp/issues/16842\n\n", __func__); LOG_WRN("%s: more info: https://github.com/ggml-org/llama.cpp/issues/16842\n\n", __func__);
} }
} break; } break;
case PROJECTOR_TYPE_YOUTUVL:
{
hparams.n_merge = 2;
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false);
get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size, true);
std::vector<int> wa_layer_indexes_vec;
get_arr_int(KEY_WIN_ATTN_LAYER_INDEXES, wa_layer_indexes_vec, true);
for (auto & layer : wa_layer_indexes_vec) {
hparams.wa_layer_indexes.insert(layer);
}
// support max_height * max_width = 8000 * 8000. 8000/16/2 = 250 image tokens
hparams.set_limit_image_tokens(1, 62500);
hparams.set_warmup_n_tokens(16*16); // avoid OOM on warmup
} break;
case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_GLM4V:
{ {
hparams.rope_theta = 10000.0f; hparams.rope_theta = 10000.0f;
@ -1176,6 +1195,7 @@ struct clip_model_loader {
case PROJECTOR_TYPE_QWEN2A: case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_GLMA: case PROJECTOR_TYPE_GLMA:
case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
{ {
bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX || bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX ||
model.proj_type == PROJECTOR_TYPE_VOXTRAL || model.proj_type == PROJECTOR_TYPE_VOXTRAL ||
@ -1225,7 +1245,14 @@ struct clip_model_loader {
LOG_INF("%s: has_llava_proj: %d\n", __func__, hparams.has_llava_projector); LOG_INF("%s: has_llava_proj: %d\n", __func__, hparams.has_llava_projector);
LOG_INF("%s: minicpmv_version: %d\n", __func__, hparams.minicpmv_version); LOG_INF("%s: minicpmv_version: %d\n", __func__, hparams.minicpmv_version);
LOG_INF("%s: n_merge: %d\n", __func__, hparams.n_merge); LOG_INF("%s: n_merge: %d\n", __func__, hparams.n_merge);
LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern); LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
if (!hparams.wa_layer_indexes.empty()) {
LOG_INF("%s: wa_layer_indexes: ", __func__);
for (auto & layer : hparams.wa_layer_indexes) {
LOG_INF("%d ", layer);
}
LOG_INF("\n");
}
if (hparams.image_min_pixels > 0) { if (hparams.image_min_pixels > 0) {
LOG_INF("%s: image_min_pixels: %d%s\n", __func__, hparams.image_min_pixels, hparams.custom_image_min_tokens > 0 ? " (custom value)" : ""); LOG_INF("%s: image_min_pixels: %d%s\n", __func__, hparams.image_min_pixels, hparams.custom_image_min_tokens > 0 ? " (custom value)" : "");
} }
@ -1493,6 +1520,14 @@ struct clip_model_loader {
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
} break; } break;
case PROJECTOR_TYPE_YOUTUVL:
{
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM); // merger.ln_q (RMS norm)
model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); // merger.mlp.0
model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); // merger.mlp.2
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
} break;
case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_GLM4V:
{ {
model.projection = get_tensor(TN_MM_PROJECTOR); model.projection = get_tensor(TN_MM_PROJECTOR);
@ -1576,6 +1611,17 @@ struct clip_model_loader {
model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight")); model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight")); model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
} break; } break;
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
{
model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias"));
model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
model.mm_2_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "bias"));
} break;
case PROJECTOR_TYPE_INTERNVL: case PROJECTOR_TYPE_INTERNVL:
{ {
model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight")); model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
@ -2684,6 +2730,57 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
// res_imgs->data[0] = *res; // res_imgs->data[0] = *res;
res_imgs->entries.push_back(std::move(img_f32)); res_imgs->entries.push_back(std::move(img_f32));
} break; } break;
case PROJECTOR_TYPE_YOUTUVL:
{
const int patch_size = params.patch_size; // typically 16
const int merge_size = params.n_merge; // typically 2
const int align_size = patch_size * merge_size; // 32
const int max_num_patches = params.image_max_pixels > 0 ?
params.image_max_pixels / (patch_size * patch_size) : 256;
// Linear search for optimal scale to fit within max_num_patches
float scale = 1.0f;
int target_height = original_size.height;
int target_width = original_size.width;
auto get_scaled_image_size = [align_size](float scale, int size) -> int {
float scaled_size = size * scale;
// Round up to nearest multiple of align_size
int aligned = static_cast<int>(std::ceil(scaled_size / align_size)) * align_size;
// Ensure at least one patch
return std::max(align_size, aligned);
};
// Linear search with 0.02 step size
while (scale > 0.0f) {
target_height = get_scaled_image_size(scale, original_size.height);
target_width = get_scaled_image_size(scale, original_size.width);
int num_patches_h = target_height / patch_size;
int num_patches_w = target_width / patch_size;
int num_patches = num_patches_h * num_patches_w;
if (num_patches > max_num_patches) {
scale -= 0.02f;
} else {
break;
}
}
clip_image_size new_size = {target_width, target_height};
// Resize the image
clip_image_u8 resized;
img_tool::resize(*img, resized, new_size, img_tool::RESIZE_ALGO_BILINEAR, false);
// Normalize to float32
clip_image_f32_ptr img_f32(clip_image_f32_init());
normalize_image_u8_to_f32(resized, *img_f32, params.image_mean, params.image_std);
// Add to results
res_imgs->entries.push_back(std::move(img_f32));
} break;
case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_IDEFICS3:
{ {
@ -2916,6 +3013,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 *
case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN25VL:
case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_QWEN3VL:
case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_GLM4V:
case PROJECTOR_TYPE_YOUTUVL:
return (img->nx / params.patch_size) / 2; return (img->nx / params.patch_size) / 2;
default: default:
break; break;
@ -2931,6 +3029,7 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 *
case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN25VL:
case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_QWEN3VL:
case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_GLM4V:
case PROJECTOR_TYPE_YOUTUVL:
return (img->ny / params.patch_size) / 2; return (img->ny / params.patch_size) / 2;
default: default:
break; break;
@ -2991,6 +3090,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN25VL:
case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_QWEN3VL:
case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_GLM4V:
case PROJECTOR_TYPE_YOUTUVL:
{ {
// dynamic size (2 conv, so double patch size) // dynamic size (2 conv, so double patch size)
int x_patch = img->nx / (params.patch_size * 2); int x_patch = img->nx / (params.patch_size * 2);
@ -3031,6 +3131,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_QWEN2A: case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
{ {
n_patches = img->nx; n_patches = img->nx;
@ -3117,7 +3218,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
const int pos_w = image_size_width / patch_size; const int pos_w = image_size_width / patch_size;
const int pos_h = image_size_height / patch_size; const int pos_h = image_size_height / patch_size;
const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl
auto get_inp_tensor = [&gf](const char * name) { auto get_inp_tensor = [&gf](const char * name) {
ggml_tensor * inp = ggml_graph_get_tensor(gf, name); ggml_tensor * inp = ggml_graph_get_tensor(gf, name);
@ -3266,9 +3366,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
set_input_i32("positions", positions); set_input_i32("positions", positions);
} break; } break;
case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN25VL:
case PROJECTOR_TYPE_YOUTUVL:
{ {
// pw * ph = number of tokens output by ViT after apply patch merger // pw * ph = number of tokens output by ViT after apply patch merger
// ipw * ipw = number of vision token been processed inside ViT // ipw * ipw = number of vision token been processed inside ViT
const bool use_window_attn = ctx->model.proj_type == PROJECTOR_TYPE_QWEN25VL ? hparams.n_wa_pattern > 0 : !hparams.wa_layer_indexes.empty();
const int merge_ratio = 2; const int merge_ratio = 2;
const int pw = image_size_width / patch_size / merge_ratio; const int pw = image_size_width / patch_size / merge_ratio;
const int ph = image_size_height / patch_size / merge_ratio; const int ph = image_size_height / patch_size / merge_ratio;
@ -3279,7 +3381,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
std::vector<int> inv_idx(ph * pw); std::vector<int> inv_idx(ph * pw);
if (use_window_attn) { if (use_window_attn) {
const int attn_window_size = 112; const int attn_window_size = hparams.attn_window_size > 0 ? hparams.attn_window_size : 112;
const int grid_window = attn_window_size / patch_size / merge_ratio; const int grid_window = attn_window_size / patch_size / merge_ratio;
int dst = 0; int dst = 0;
// [num_vision_tokens, num_vision_tokens] attention mask tensor // [num_vision_tokens, num_vision_tokens] attention mask tensor
@ -3403,6 +3505,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
case PROJECTOR_TYPE_JANUS_PRO: case PROJECTOR_TYPE_JANUS_PRO:
case PROJECTOR_TYPE_COGVLM: case PROJECTOR_TYPE_COGVLM:
{ {
@ -3516,6 +3619,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN2VL:
case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN25VL:
case PROJECTOR_TYPE_JANUS_PRO: case PROJECTOR_TYPE_JANUS_PRO:
case PROJECTOR_TYPE_YOUTUVL:
return ctx->model.mm_1_b->ne[0]; return ctx->model.mm_1_b->ne[0];
case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_QWEN3VL:
// main path + deepstack paths // main path + deepstack paths
@ -3526,6 +3630,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->model.projection->ne[1]; return ctx->model.projection->ne[1];
case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
return ctx->model.mm_2_w->ne[1]; return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_INTERNVL: case PROJECTOR_TYPE_INTERNVL:
return ctx->model.mm_3_w->ne[1]; return ctx->model.mm_3_w->ne[1];
@ -3587,7 +3692,8 @@ bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN2A || ctx->proj_type() == PROJECTOR_TYPE_QWEN2A
|| ctx->proj_type() == PROJECTOR_TYPE_GLMA || ctx->proj_type() == PROJECTOR_TYPE_GLMA
|| ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL; || ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL
|| ctx->proj_type() == PROJECTOR_TYPE_MUSIC_FLAMINGO;
} }
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) { bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {

View File

@ -27,6 +27,11 @@ struct clip_graph_qwen3vl : clip_graph {
ggml_cgraph * build() override; ggml_cgraph * build() override;
}; };
struct clip_graph_youtuvl : clip_graph {
clip_graph_youtuvl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override;
};
struct clip_graph_minicpmv : clip_graph { struct clip_graph_minicpmv : clip_graph {
clip_graph_minicpmv(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} clip_graph_minicpmv(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override; ggml_cgraph * build() override;

View File

@ -86,6 +86,15 @@ ggml_cgraph * clip_graph_whisper_enc::build() {
FFN_GELU_ERF, FFN_GELU_ERF,
-1); -1);
} else if (proj_type == PROJECTOR_TYPE_MUSIC_FLAMINGO) {
// projector
cur = build_ffn(cur,
model.mm_1_w, model.mm_1_b,
nullptr, nullptr,
model.mm_2_w, model.mm_2_b,
FFN_GELU_ERF,
-1);
} else if (proj_type == PROJECTOR_TYPE_GLMA) { } else if (proj_type == PROJECTOR_TYPE_GLMA) {
cur = ggml_norm(ctx0, cur, hparams.eps); cur = ggml_norm(ctx0, cur, hparams.eps);
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w); cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);

View File

@ -0,0 +1,179 @@
#include "models.h"
ggml_cgraph * clip_graph_youtuvl::build() {
GGML_ASSERT(model.class_embedding == nullptr);
const int batch_size = 1;
const bool use_window_attn = !hparams.wa_layer_indexes.empty();
const int n_pos = n_patches;
const int num_position_ids = n_pos * 4;
const int m = 2;
const int Wp = n_patches_x;
const int Hp = n_patches_y;
const int Hm = Hp / m;
const int Wm = Wp / m;
norm_type norm_t = NORM_TYPE_NORMAL;
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
ggml_tensor * inp = build_inp_raw();
// change conv3d to linear
// reshape and permute to get patches, permute from (patch_size, m, Wm, patch_size, m, Hm, C) to (C, patch_size, patch_size, m, m, Wm, Hm)
{
inp = ggml_reshape_4d(
ctx0, inp,
Wm * m * patch_size, m * patch_size, Hm, 3);
inp = ggml_permute(ctx0, inp, 1, 2, 3, 0);
inp = ggml_cont_4d(
ctx0, inp,
m * patch_size * 3, Wm, m * patch_size, Hm);
inp = ggml_permute(ctx0, inp, 0, 2, 1, 3);
inp = ggml_cont_4d(
ctx0, inp,
m * patch_size * 3, patch_size, m, Hm * Wm);
inp = ggml_permute(ctx0, inp, 1, 0, 2, 3);
inp = ggml_cont_4d(
ctx0, inp,
patch_size, 3, patch_size, Hm * Wm * m * m);
inp = ggml_permute(ctx0, inp, 2, 0, 1, 3);
inp = ggml_cont_3d(
ctx0, inp,
3*patch_size* patch_size, Hm * Wm * m * m, 1);
}
inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp);
if (model.patch_bias) {
inp = ggml_add(ctx0, inp, model.patch_bias);
}
inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches);
ggml_tensor * inpL = inp;
ggml_tensor * window_mask = nullptr;
ggml_tensor * window_idx = nullptr;
ggml_tensor * inv_window_idx = nullptr;
ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
ggml_set_name(positions, "positions");
ggml_set_input(positions);
// pre-layernorm
if (model.pre_ln_w) {
inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1);
}
if (use_window_attn) {
inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4);
ggml_set_name(inv_window_idx, "inv_window_idx");
ggml_set_input(inv_window_idx);
// mask for window attention
window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos);
ggml_set_name(window_mask, "window_mask");
ggml_set_input(window_mask);
// if flash attn is used, we need to pad the mask and cast to f16
if (flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16);
}
// inpL shape: [n_embd, n_patches_x * n_patches_y, batch_size]
GGML_ASSERT(batch_size == 1);
inpL = ggml_reshape_2d(ctx0, inpL, n_embd * 4, n_patches_x * n_patches_y * batch_size / 4);
inpL = ggml_get_rows(ctx0, inpL, inv_window_idx);
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_patches_x * n_patches_y, batch_size);
}
// loop over layers
for (int il = 0; il < n_layer; il++) {
const auto & layer = model.layers[il];
const bool full_attn = use_window_attn ? hparams.wa_layer_indexes.count(il) > 0 : true;
ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
// layernorm1
cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il);
// self-attention
{
ggml_tensor * Qcur = ggml_add(ctx0,
ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b);
ggml_tensor * Kcur = ggml_add(ctx0,
ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b);
ggml_tensor * Vcur = ggml_add(ctx0,
ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b);
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches);
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_patches);
Qcur = ggml_rope_multi(
ctx0, Qcur, positions, nullptr,
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
Kcur = ggml_rope_multi(
ctx0, Kcur, positions, nullptr,
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
ggml_tensor * attn_mask = full_attn ? nullptr : window_mask;
cur = build_attn(layer.o_w, layer.o_b,
Qcur, Kcur, Vcur, attn_mask, kq_scale, il);
}
// re-add the layer input, e.g., residual
cur = ggml_add(ctx0, cur, inpL);
inpL = cur; // inpL = residual, cur = hidden_states
// layernorm2
cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il);
// ffn
cur = build_ffn(cur,
layer.ff_up_w, layer.ff_up_b,
nullptr, nullptr,
layer.ff_down_w, layer.ff_down_b,
hparams.ffn_op, il);
// residual 2
cur = ggml_add(ctx0, inpL, cur);
inpL = cur;
}
ggml_tensor * embeddings = inpL;
if (use_window_attn) {
const int spatial_merge_unit = 4;
window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / spatial_merge_unit);
ggml_set_name(window_idx, "window_idx");
ggml_set_input(window_idx);
GGML_ASSERT(batch_size == 1);
embeddings = ggml_reshape_2d(ctx0, embeddings, n_embd * spatial_merge_unit, n_patches / spatial_merge_unit);
embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd, n_patches, batch_size);
cb(embeddings, "window_order_restored", -1);
}
// post-layernorm (part of Siglip2VisionTransformer, applied after encoder)
if (model.post_ln_w) {
embeddings = build_norm(embeddings, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer);
}
// Now apply merger (VLPatchMerger):
// 1. Apply RMS norm (ln_q in VLPatchMerger)
embeddings = build_norm(embeddings, model.mm_input_norm_w, nullptr, NORM_TYPE_RMS, 1e-6, -1);
cb(embeddings, "merger_normed", -1);
// 2. First reshape for spatial merge (merge 2x2 patches)
embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size);
cb(embeddings, "merger_reshaped", -1);
embeddings = build_ffn(embeddings,
model.mm_0_w, model.mm_0_b,
nullptr, nullptr,
model.mm_1_w, model.mm_1_b,
FFN_GELU,
-1);
ggml_build_forward_expand(gf, embeddings);
return gf;
}

View File

@ -283,7 +283,7 @@ struct mtmd_context {
// https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
img_end = "[IMG_END]"; img_end = "[IMG_END]";
} else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL) { } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL || proj == PROJECTOR_TYPE_YOUTUVL) {
// <|vision_start|> ... (image embeddings) ... <|vision_end|> // <|vision_start|> ... (image embeddings) ... <|vision_end|>
img_beg = "<|vision_start|>"; img_beg = "<|vision_start|>";
img_end = "<|vision_end|>"; img_end = "<|vision_end|>";
@ -330,6 +330,7 @@ struct mtmd_context {
case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_GLMA: case PROJECTOR_TYPE_GLMA:
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
audio_preproc = std::make_unique<mtmd_audio_preprocessor_whisper>(ctx_a); audio_preproc = std::make_unique<mtmd_audio_preprocessor_whisper>(ctx_a);
break; break;
case PROJECTOR_TYPE_LFM2A: case PROJECTOR_TYPE_LFM2A:
@ -352,6 +353,9 @@ struct mtmd_context {
// [BEGIN_AUDIO] ... (embeddings) ... // [BEGIN_AUDIO] ... (embeddings) ...
aud_beg = "[BEGIN_AUDIO]"; aud_beg = "[BEGIN_AUDIO]";
} else if (proj == PROJECTOR_TYPE_MUSIC_FLAMINGO) {
// <sound> ... (embeddings) ...
aud_beg = "<sound>";
} }
} }

View File

@ -12,6 +12,7 @@
#include <cmath> #include <cmath>
#include <cctype> #include <cctype>
#include <algorithm> #include <algorithm>
#include <filesystem>
struct quant_option { struct quant_option {
std::string name; std::string name;
@ -643,6 +644,11 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
if (std::error_code ec; std::filesystem::equivalent(fname_inp, fname_out, ec)) {
fprintf(stderr, "%s: error: input and output files are the same: '%s'\n", __func__, fname_inp.c_str());
return 1;
}
print_build_info(); print_build_info();
fprintf(stderr, "%s: quantizing '%s' to '%s' as %s", __func__, fname_inp.c_str(), fname_out.c_str(), ftype_str.c_str()); fprintf(stderr, "%s: quantizing '%s' to '%s' as %s", __func__, fname_inp.c_str(), fname_out.c_str(), ftype_str.c_str());

Binary file not shown.

View File

@ -65,10 +65,7 @@ export async function copyCodeToClipboard(
successMessage = 'Code copied to clipboard', successMessage = 'Code copied to clipboard',
errorMessage = 'Failed to copy code' errorMessage = 'Failed to copy code'
): Promise<boolean> { ): Promise<boolean> {
const doc = new DOMParser().parseFromString(rawCode, 'text/html'); return copyToClipboard(rawCode, successMessage, errorMessage);
const decodedCode = doc.body.textContent ?? rawCode;
return copyToClipboard(decodedCode, successMessage, errorMessage);
} }
/** /**