diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 1fec31b832..c928bc39ce 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -20,7 +20,7 @@ If AI is used to generate any portion of the code, contributors must adhere to t
1. Explicitly disclose the manner in which AI was employed.
2. Perform a comprehensive manual review prior to submitting the pull request.
3. Be prepared to explain every line of code they submitted when asked about it by a maintainer.
-4. Using AI to respond to human reviewers is strictly prohibited.
+4. Using AI to write pull request descriptions or to respond to human reviewers is strictly prohibited.
For more info, please refer to the [AGENTS.md](AGENTS.md) file.
diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp
index 23e23ca8c7..2f073512e0 100644
--- a/common/chat-parser.cpp
+++ b/common/chat-parser.cpp
@@ -1403,6 +1403,118 @@ static void common_chat_parse_solar_open(common_chat_msg_parser & builder) {
builder.add_content(builder.consume_rest());
}
+static void common_chat_parse_exaone_moe_content(common_chat_msg_parser & builder) {
+ // 1) { "name": "...", "arguments": {...} }
+ // 2) { "id": "...", "type": "function", "function": { "name": "...", "arguments": {...} } }
+ static const common_regex tool_call_open(R"(]*>)");
+
+ if (!builder.syntax().parse_tool_calls) {
+ LOG_DBG("%s: not parse_tool_calls\n", __func__);
+ builder.add_content(builder.consume_rest());
+ return;
+ }
+
+ LOG_DBG("%s: parse_tool_calls\n", __func__);
+
+ // Find all blocks
+ while (auto first = builder.try_find_regex(tool_call_open, std::string::npos, /* add_prelude_to_content= */ true)) {
+ builder.move_to(first->groups[0].end);
+ builder.consume_spaces();
+
+ builder.try_consume_literal("```json");
+ builder.try_consume_literal("```");
+ builder.consume_spaces();
+
+ // Consume JSON object
+ auto data = builder.consume_json();
+
+ builder.consume_spaces();
+ builder.try_consume_literal("```");
+ builder.consume_spaces();
+
+ if (!builder.try_consume_literal("")) {
+ throw common_chat_msg_partial_exception("incomplete tool call");
+ }
+ builder.consume_spaces();
+
+ // Extract name and arguments
+ std::string name;
+ std::string id;
+ nlohmann::ordered_json arguments;
+
+ const auto extract_args = [&](const nlohmann::ordered_json & obj) -> bool {
+ if (!obj.contains("name") || !obj.contains("arguments")) {
+ return false;
+ }
+ name = obj.at("name").get();
+ arguments = obj.at("arguments");
+ if (obj.contains("id") && obj.at("id").is_string()) {
+ id = obj.at("id").get();
+ }
+ return true;
+ };
+
+ if (!extract_args(data.json)) {
+ if (data.json.contains("function") && data.json.at("function").is_object()) {
+ auto fn = data.json.at("function");
+ extract_args(fn);
+ if (id.empty() && data.json.contains("id") && data.json.at("id").is_string()) {
+ id = data.json.at("id").get();
+ }
+ }
+ }
+
+ // If name is empty, treat the JSON object as content
+ if (name.empty()) {
+ LOG_DBG("%s: tool call missing name, treating as content\n", __func__);
+ builder.add_content(data.json.dump());
+ continue;
+ }
+
+ std::string args_str = arguments.dump();
+ if (!builder.add_tool_call(name, id, args_str)) {
+ throw common_chat_msg_partial_exception("incomplete tool call");
+ }
+ }
+
+ builder.add_content(builder.consume_rest());
+}
+
+static void common_chat_parse_exaone_moe(common_chat_msg_parser & builder) {
+ LOG_DBG("%s: parsing exaone_moe\n", __func__);
+ // EXAONE MoE outputs reasoning content between "" and "" tags, followed by regular content
+ // First try to parse using the standard reasoning parsing method
+ LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str());
+
+ auto start_pos = builder.pos();
+ auto found_end_think = builder.try_find_literal("");
+ builder.move_to(start_pos);
+
+ if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) {
+ LOG_DBG("%s: no end_think, not partial, adding content\n", __func__);
+ common_chat_parse_exaone_moe_content(builder);
+ } else if (builder.try_parse_reasoning("", "")) {
+ // If reasoning was parsed successfully, the remaining content is regular content
+ LOG_DBG("%s: parsed reasoning, adding content\n", __func__);
+ common_chat_parse_exaone_moe_content(builder);
+ } else {
+ if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) {
+ LOG_DBG("%s: reasoning_format none, adding content\n", __func__);
+ common_chat_parse_exaone_moe_content(builder);
+ return;
+ }
+ // If no reasoning tags found, check if we should treat everything as reasoning
+ if (builder.syntax().thinking_forced_open) {
+ // If thinking is forced open but no tags found, treat everything as reasoning
+ LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__);
+ builder.add_reasoning_content(builder.consume_rest());
+ } else {
+ LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__);
+ common_chat_parse_exaone_moe_content(builder);
+ }
+ }
+}
+
static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("", "");
builder.add_content(builder.consume_rest());
@@ -1490,6 +1602,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_SOLAR_OPEN:
common_chat_parse_solar_open(builder);
break;
+ case COMMON_CHAT_FORMAT_EXAONE_MOE:
+ common_chat_parse_exaone_moe(builder);
+ break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
diff --git a/common/chat.cpp b/common/chat.cpp
index 22e527bab8..d531388bcb 100644
--- a/common/chat.cpp
+++ b/common/chat.cpp
@@ -670,6 +670,7 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5";
case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo";
case COMMON_CHAT_FORMAT_SOLAR_OPEN: return "Solar Open";
+ case COMMON_CHAT_FORMAT_EXAONE_MOE: return "EXAONE MoE";
case COMMON_CHAT_FORMAT_PEG_SIMPLE: return "peg-simple";
case COMMON_CHAT_FORMAT_PEG_NATIVE: return "peg-native";
case COMMON_CHAT_FORMAT_PEG_CONSTRUCTED: return "peg-constructed";
@@ -2539,6 +2540,65 @@ static common_chat_params common_chat_params_init_solar_open(const common_chat_t
return data;
}
+static common_chat_params common_chat_params_init_exaone_moe(const common_chat_template & tmpl, const struct templates_params & inputs) {
+ common_chat_params data;
+
+ data.prompt = apply(tmpl, inputs);
+ data.format = COMMON_CHAT_FORMAT_EXAONE_MOE;
+ if (string_ends_with(data.prompt, "\n")) {
+ if (!inputs.enable_thinking) {
+ data.prompt += "\n\n";
+ } else {
+ data.thinking_forced_open = true;
+ }
+ }
+
+ if (inputs.tools.is_array() && !inputs.tools.empty()) {
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
+ std::vector tool_rules;
+ foreach_function(inputs.tools, [&](const json & tool) {
+ const auto & function = tool.at("function");
+ std::string name = function.at("name");
+ auto parameters = function.at("parameters");
+ builder.resolve_refs(parameters);
+ // Expect: {"name": "", "arguments": {...}}
+ tool_rules.push_back(builder.add_rule(
+ name + "-call",
+ "\"\" space " +
+ builder.add_schema(name + "-obj", json{
+ {"type", "object"},
+ {"properties", {
+ {"name", json{{"const", name}}},
+ {"arguments", parameters},
+ }},
+ {"required", json::array({"name", "arguments"})},
+ }) +
+ " space \"\" space"));
+ });
+
+ auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | "));
+ builder.add_rule("root",
+ std::string(data.thinking_forced_open ? "( \"\" space )? " : "") +
+ (inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call));
+
+ data.grammar_triggers.push_back({
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
+ std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)?" : "") +
+ "()[\\s\\S]*"
+ });
+ data.preserved_tokens = {
+ "",
+ "",
+ "",
+ "",
+ };
+ });
+ }
+
+ return data;
+}
+
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.prompt = apply(tmpl, inputs);
@@ -2709,6 +2769,13 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_xiaomi_mimo(tmpl, params);
}
+ // EXAONE MoE format detection
+ if (src.find("") != std::string::npos &&
+ src.find("") != std::string::npos &&
+ src.find("<|tool_declare|>") != std::string::npos) {
+ return common_chat_params_init_exaone_moe(tmpl, params);
+ }
+
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
if (src.find("") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_hermes_2_pro(tmpl, params);
diff --git a/common/chat.h b/common/chat.h
index 8bd4a325ff..454085e90e 100644
--- a/common/chat.h
+++ b/common/chat.h
@@ -125,6 +125,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_APRIEL_1_5,
COMMON_CHAT_FORMAT_XIAOMI_MIMO,
COMMON_CHAT_FORMAT_SOLAR_OPEN,
+ COMMON_CHAT_FORMAT_EXAONE_MOE,
// These are intended to be parsed by the PEG parser
COMMON_CHAT_FORMAT_PEG_SIMPLE,
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 5735595e32..fba1334527 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -1252,6 +1252,9 @@ class TextModel(ModelBase):
if chkhsh == "16389f0a1f51ee53e562ffd51c371dc508639ab0e4261502071836e50e223e91":
# ref: https://huggingface.co/upstage/Solar-Open-100B
res = "solar-open"
+ if chkhsh == "6c81ce329e0802883b22eabab0d3fa48357337ef1ecb45443828bf1f6254833f":
+ # ref: https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B
+ res = "exaone-moe"
if res is None:
logger.warning("\n")
@@ -8802,6 +8805,106 @@ class Exaone4Model(TextModel):
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))
+@ModelBase.register("ExaoneMoEForCausalLM")
+class ExaoneMoEModel(Exaone4Model):
+ model_arch = gguf.MODEL_ARCH.EXAONE_MOE
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0)
+ self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ self.gguf_writer.add_expert_count(self.hparams["num_experts"])
+ moe_intermediate_size = self.hparams["moe_intermediate_size"]
+ num_shared_experts = self.hparams["num_shared_experts"]
+ self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
+ self.gguf_writer.add_expert_shared_count(num_shared_experts)
+ self.gguf_writer.add_expert_shared_feed_forward_length(moe_intermediate_size * num_shared_experts)
+ self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"])
+ self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"])
+ n_dense_layer = self.hparams.get("first_k_dense_replace", self.hparams.get("first_last_k_dense_replace", 0))
+ self.gguf_writer.add_leading_dense_block_count(n_dense_layer)
+ # For here, we hard-code the number of NextN/MTP layers to 1 for K-EXAONE,
+ # so that we can convert MTP weights to GGUF format for speculative decoding.
+ # This is because HF config of K-EXAONE does not have `num_nextn_predict_layers` at now.
+ # Will be updated when HF config is updated.
+ self.gguf_writer.add_nextn_predict_layers(self.hparams.get("num_nextn_predict_layers", 1))
+
+ self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
+
+ _experts: list[dict[str, Tensor]] | None = None
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ if name.startswith("mtp."):
+ if name.find("layers.") != -1:
+ # `mtp.layers.0.[module_name]` format
+ name = name.replace(f"mtp.layers.{bid}", f"model.layers.{bid + self.hparams['num_hidden_layers']}")
+ else:
+ # mtp fc/norm weights
+ remapper = {
+ "mtp.fc": "model.layers.{bid}.eh_proj",
+ "mtp.pre_fc_norm_embedding": "model.layers.{bid}.enorm",
+ "mtp.pre_fc_norm_hidden": "model.layers.{bid}.hnorm",
+ "mtp.norm": "model.layers.{bid}.shared_head.norm",
+ }
+ _n = Path(name)
+ new_name = remapper[_n.stem] + _n.suffix
+
+ # set shared weights for all NextN/MTP layers
+ tensors = []
+ for bid in range(self.hparams['num_hidden_layers'], self.block_count):
+ new_name = new_name.format(bid=bid)
+ tensors.append((self.map_tensor_name(new_name), data_torch))
+ return tensors
+
+ if name.endswith("e_score_correction_bias"):
+ name = name.replace("e_score_correction_bias", "e_score_correction.bias")
+
+ if name.find("mlp.experts") != -1:
+ n_experts = self.hparams["num_experts"]
+ assert bid is not None
+
+ if self._experts is None:
+ self._experts = [{} for _ in range(self.block_count)]
+
+ self._experts[bid][name] = data_torch
+
+ if len(self._experts[bid]) >= n_experts * 3:
+ tensors: list[tuple[str, Tensor]] = []
+
+ # merge the experts into a single 3d tensor
+ for w_name in ["down_proj", "gate_proj", "up_proj"]:
+ datas: list[Tensor] = []
+
+ for xid in range(n_experts):
+ ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
+ datas.append(self._experts[bid][ename])
+ del self._experts[bid][ename]
+
+ data_torch = torch.stack(datas, dim=0)
+
+ merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
+
+ new_name = self.map_tensor_name(merged_name)
+
+ tensors.append((new_name, data_torch))
+ return tensors
+ else:
+ return []
+
+ return [(self.map_tensor_name(name), data_torch)]
+
+ def prepare_tensors(self):
+ super().prepare_tensors()
+ if self._experts is not None:
+ # flatten `list[dict[str, Tensor]]` into `list[str]`
+ experts = [k for d in self._experts for k in d.keys()]
+ if len(experts) > 0:
+ raise ValueError(f"Unprocessed experts: {experts}")
+
+
@ModelBase.register("GraniteForCausalLM")
class GraniteModel(LlamaModel):
"""Conversion for IBM's GraniteForCausalLM"""
diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py
index 74c67e6a9c..aa9843ea17 100755
--- a/convert_hf_to_gguf_update.py
+++ b/convert_hf_to_gguf_update.py
@@ -147,6 +147,7 @@ models = [
{"name": "kormo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/KORMo-Team/KORMo-tokenizer", },
{"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", },
{"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", },
+ {"name": "exaone-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B", },
]
# some models are known to be broken upstream, so we will skip them as exceptions
diff --git a/examples/model-conversion/scripts/causal/modelcard.template b/examples/model-conversion/scripts/causal/modelcard.template
index cfa8e6b433..a045950324 100644
--- a/examples/model-conversion/scripts/causal/modelcard.template
+++ b/examples/model-conversion/scripts/causal/modelcard.template
@@ -7,7 +7,7 @@ base_model:
Recommended way to run this model:
```sh
-llama-server -hf {namespace}/{model_name}-GGUF -c 0
+llama-server -hf {namespace}/{model_name}-GGUF
```
Then, access http://localhost:8080
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 9516d8ec8f..90794ff264 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -262,6 +262,10 @@ static const char * cu_get_error_str(CUresult err) {
#define FLASH_ATTN_AVAILABLE
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
+#if defined(TURING_MMA_AVAILABLE)
+#define LDMATRIX_TRANS_AVAILABLE
+#endif // defined(TURING_MMA_AVAILABLE)
+
static bool fp16_available(const int cc) {
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh
index 3144678728..6b55f784f3 100644
--- a/ggml/src/ggml-cuda/fattn-common.cuh
+++ b/ggml/src/ggml-cuda/fattn-common.cuh
@@ -914,7 +914,7 @@ void launch_fattn(
const int nblocks_stream_k = max_blocks;
- const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
+ const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75;
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
blocks_num.y = 1;
diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
index 856291dc3c..e53bbc0502 100644
--- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
@@ -98,6 +98,19 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
}
+static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) {
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
+
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
+
+ // TODO tune specifically for RDNA
+ return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
+}
+
static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
if (ampere_mma_available(cc)) {
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
@@ -105,6 +118,9 @@ static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, c
if (turing_mma_available(cc)) {
return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
}
+ if (amd_wmma_available(cc)) {
+ return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
+ }
GGML_ASSERT(volta_mma_available(cc));
return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
}
@@ -116,6 +132,8 @@ static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(cons
return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
#elif defined(VOLTA_MMA_AVAILABLE)
return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
+#elif defined(AMD_WMMA_AVAILABLE)
+ return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
#else
GGML_UNUSED_VARS(DKQ, DV, ncols);
return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
@@ -186,6 +204,23 @@ static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ,
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg;
}
+static constexpr __device__ int get_cols_per_thread() {
+#if defined(AMD_WMMA_AVAILABLE)
+ return 1; // RDNA has a single column.
+#else
+ return 2; // This is specifically KQ columns, Volta only has a single VKQ column.
+#endif // defined(AMD_WMMA_AVAILABLE)
+}
+
+static __host__ int get_cols_per_warp(const int cc) {
+ if (turing_mma_available(cc) || amd_wmma_available(cc)) {
+ return 16;
+ } else {
+ // Volta
+ return 32;
+ }
+}
+
// ------------------------------------------------------------------------------------------------------------------
static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) {
@@ -393,10 +428,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const int jt,
const int kb0,
const int k_VKQ_sup) {
-#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
constexpr int ncols = ncols1 * ncols2;
constexpr int cols_per_warp = T_B_KQ::I;
- constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
+ constexpr int cols_per_thread = get_cols_per_thread();
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
@@ -413,6 +448,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const int k_VKQ_0 = kb0 * nbatch_fa;
#if defined(TURING_MMA_AVAILABLE)
T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
+#elif defined(AMD_WMMA_AVAILABLE)
+ T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
#else // Volta
T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
#endif // defined(TURING_MMA_AVAILABLE)
@@ -461,8 +498,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
if constexpr (cols_per_warp == 8) {
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
} else {
- // Wide version of KQ_C is column-major => swap A and B.
+ // Wide version of KQ_C is column-major
+#if defined(AMD_WMMA_AVAILABLE)
+ // RDNA matrix C is column-major.
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
+#else
+ // swap A and B for CUDA.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
+#endif // defined(AMD_WMMA_AVAILABLE)
}
}
}
@@ -479,8 +522,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
T_A_KQ K_A;
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
- // Wide version of KQ_C is column-major => swap A and B.
+ // Wide version of KQ_C is column-major
+#if defined(AMD_WMMA_AVAILABLE)
+ // RDNA matrix C is column-major.
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
+#else
+ // swap A and B for CUDA.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
+#endif // defined(AMD_WMMA_AVAILABLE)
}
}
}
@@ -532,7 +581,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
#pragma unroll
for (int l = 0; l < T_C_KQ::ne; ++l) {
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
- KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
+#if defined(AMD_WMMA_AVAILABLE)
+ constexpr int KQ_idx = 0;
+#else
+ // Turing + Volta:
+ const int KQ_idx = l % 2;
+#endif // defined(AMD_WMMA_AVAILABLE)
+ KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
}
}
}
@@ -552,8 +607,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
#pragma unroll
for (int l = 0; l < T_C_KQ::ne; ++l) {
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
- KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[l % 2]);
- KQ_rowsum_add[l % 2] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
+#if defined(AMD_WMMA_AVAILABLE)
+ constexpr int KQ_idx = 0;
+#else
+ // Turing + Volta:
+ const int KQ_idx = l % 2;
+#endif // defined(AMD_WMMA_AVAILABLE)
+ KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]);
+ KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
} else {
KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f;
}
@@ -584,8 +645,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
#pragma unroll
for (int l = 0; l < T_C_KQ::ne; ++l) {
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
+#if defined(AMD_WMMA_AVAILABLE)
+ constexpr int KQ_idx = 0;
+#else
// Turing + Volta:
- KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
+ const int KQ_idx = (l/2) % 2;
+#endif // defined(AMD_WMMA_AVAILABLE)
+ KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
}
}
}
@@ -596,7 +662,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
// Values per KQ column are spread across 4 threads:
constexpr int offset_first = 2;
constexpr int offset_last = 1;
-#else
+#elif defined(AMD_WMMA_AVAILABLE)
+ // Values per KQ column are spread across 2 threads:
+ constexpr int offset_first = 16;
+ constexpr int offset_last = 16;
+#else // Volta
// Values per KQ column are spread across 2 threads:
constexpr int offset_first = 2;
constexpr int offset_last = 2;
@@ -612,10 +682,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
#pragma unroll
for (int l = 0; l < T_C_KQ::ne; ++l) {
- // Turing + Volta:
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
- KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[(l/2) % 2]);
- KQ_rowsum_add[(l/2) % 2] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
+#if defined(AMD_WMMA_AVAILABLE)
+ constexpr int KQ_idx = 0;
+#else
+ // Turing + Volta:
+ const int KQ_idx = (l/2) % 2;
+#endif // defined(AMD_WMMA_AVAILABLE)
+ KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]);
+ KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
} else {
KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f;
}
@@ -639,7 +714,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
#if defined(TURING_MMA_AVAILABLE)
if constexpr (cols_per_warp == 8) {
- const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
#pragma unroll
for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
#pragma unroll
@@ -660,6 +735,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
}
}
+#elif defined(AMD_WMMA_AVAILABLE)
+ const half2 KQ_max_scale_h2 = make_half2(
+ KQ_max_scale[0], KQ_max_scale[0]);
+#pragma unroll
+ for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
+#pragma unroll
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
+ VKQ_C[i].x[l] *= KQ_max_scale_h2;
+ }
+ }
#else // Volta
const half2 KQ_max_scale_h2 = make_half2(
KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]);
@@ -707,6 +792,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
// Therefore, iterate over V in reverse and re-use the data if possible.
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
+#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
+ T_A_VKQ A_identity;
+ make_identity_mat(A_identity);
+#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
// Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
#pragma unroll
@@ -727,7 +816,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
-#if defined(TURING_MMA_AVAILABLE)
+#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
#pragma unroll
for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
@@ -737,12 +826,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J;
T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
+#if defined(LDMATRIX_TRANS_AVAILABLE)
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
+#else
+ // TODO: Try to transpose tile_V when loading gmem to smem.
+ // Use mma to transpose T_A_VKQ for RDNA.
+ T_A_VKQ A_trans;
+ load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
+ mma(A, A_trans, A_identity);
+#endif // defined(TURING_MMA_AVAILABLE)
if constexpr (T_B_KQ::I == 8) {
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
} else {
- // Wide version of VKQ_C is column-major => swap A and B.
+ // Wide version of VKQ_C is column-major.
+#if defined(AMD_WMMA_AVAILABLE)
+ // RDNA matrix C is column-major.
+ mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
+#else
+ // swap A and B for CUDA.
mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
+#endif // defined(AMD_WMMA_AVAILABLE)
}
}
}
@@ -761,7 +864,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
}
}
-#endif // defined(TURING_MMA_AVAILABLE)
+#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
if constexpr (nstages <= 1) {
__syncthreads(); // Only needed if tile_K == tile_V.
@@ -774,7 +877,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
tile_Q, tile_K, tile_V, tile_mask,
Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
NO_DEVICE_CODE;
-#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
}
#if defined(TURING_MMA_AVAILABLE)
@@ -794,6 +897,15 @@ template<> struct mma_tile_sizes<8> {
using T_B_VKQ = tile< 8, 8, half2>; // column-major
using T_C_VKQ = tile<16, 4, half2>; // row-major
};
+#elif defined(AMD_WMMA_AVAILABLE)
+template struct mma_tile_sizes {
+ using T_A_KQ = tile<16, 8, half2>; // row-major
+ using T_B_KQ = tile<16, 8, half2>; // column-major
+ using T_C_KQ = tile<16, 16, float>; // column-major
+ using T_A_VKQ = tile<16, 8, half2>; // row-major
+ using T_B_VKQ = tile<16, 8, half2>; // column-major
+ using T_C_VKQ = tile<16, 8, half2>; // column-major
+};
#else // Volta
template struct mma_tile_sizes {
using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
@@ -828,7 +940,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int jt,
const int kb0_start,
const int kb0_stop) {
-#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
constexpr int ncols = ncols1 * ncols2;
@@ -840,7 +952,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
using T_C_VKQ = typename mma_tile_sizes::T_C_VKQ;
constexpr int cols_per_warp = T_B_KQ::I;
- constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
+ constexpr int cols_per_thread = get_cols_per_thread();
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
@@ -871,6 +983,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
#if defined(TURING_MMA_AVAILABLE)
T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
+#elif defined(AMD_WMMA_AVAILABLE)
+ T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
#else // Volta
T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
#endif // defined(TURING_MMA_AVAILABLE)
@@ -1010,6 +1124,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
// The partial sums are spread across 8/4 threads.
constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
constexpr int offset_last = cols_per_warp == 8 ? 4 : 1;
+#elif defined(AMD_WMMA_AVAILABLE)
+ // The partial sums are spread across 2 threads.
+ constexpr int offset_first = 16;
+ constexpr int offset_last = 16;
#else // Volta
// The partial sums are spread across 2 threads.
constexpr int offset_first = 2;
@@ -1047,7 +1165,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
#if defined(TURING_MMA_AVAILABLE)
if constexpr (cols_per_warp == 8) {
- const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
#pragma unroll
for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
#pragma unroll
@@ -1068,6 +1186,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
}
}
}
+#elif defined(AMD_WMMA_AVAILABLE)
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
+#pragma unroll
+ for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
+#pragma unroll
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
+ VKQ_C[i].x[l] *= KQ_max_scale_h2;
+ }
+ }
#else // Volta
const int col = (threadIdx.x / 2) % 2;
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
@@ -1119,6 +1246,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
+#elif defined(AMD_WMMA_AVAILABLE)
+ const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0);
+ const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]);
+ const bool thread_should_write = threadIdx.x / 16 < cols_per_thread;
#else // Volta
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2);
const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]);
@@ -1319,7 +1450,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
jt, kb0_start, kb0_stop);
NO_DEVICE_CODE;
-#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
}
template
@@ -1346,7 +1477,7 @@ static __global__ void flash_attn_ext_f16(
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
-#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
+#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
@@ -1360,6 +1491,13 @@ static __global__ void flash_attn_ext_f16(
}
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
+#if defined(AMD_WMMA_AVAILABLE)
+ if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) {
+ NO_DEVICE_CODE;
+ return;
+ }
+#endif // defined(AMD_WMMA_AVAILABLE)
+
static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
constexpr int ncols = ncols1 * ncols2;
@@ -1473,7 +1611,7 @@ static __global__ void flash_attn_ext_f16(
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
-#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
+#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
}
template
@@ -1492,7 +1630,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
const bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols, cc);
const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc);
- const int cols_per_warp = std::min(ncols, turing_mma_available(cc) ? 16 : 32);
+ const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
const int nwarps = nthreads / WARP_SIZE;
constexpr bool mla = DKQ == 576;
@@ -1512,29 +1650,34 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+#if defined(GGML_USE_HIP)
+ using fattn_kernel_ptr_t = const void*;
+#else
+ using fattn_kernel_ptr_t = fattn_kernel_t;
+#endif // defined(GGML_USE_HIP)
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16;
-#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shared_memory_limit_raised[id]) {
- CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
+ CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
shared_memory_limit_raised[id] = true;
}
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#endif // !defined(GGML_USE_MUSA)
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16;
-#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shared_memory_limit_raised[id]) {
- CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
+ CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
shared_memory_limit_raised[id] = true;
}
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#endif // !defined(GGML_USE_MUSA)
}
launch_fattn
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
index 0155406665..598cda7daa 100644
--- a/ggml/src/ggml-cuda/fattn.cu
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -18,12 +18,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
}
}
- if (turing_mma_available(cc) && Q->ne[1] <= 16/ncols2) {
+ if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst);
return;
}
- if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) {
+ if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst);
return;
}
@@ -230,7 +230,18 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
// The effective batch size for the kernel can be increased by gqa_ratio.
// The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
- const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
+ bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
+ for (const ggml_tensor * t : {Q, K, V, mask}) {
+ if (t == nullptr) {
+ continue;
+ }
+ for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
+ if (t->nb[i] % 16 != 0) {
+ gqa_opt_applies = false;
+ break;
+ }
+ }
+ }
const int cc = ggml_cuda_info().devices[device].cc;
@@ -337,6 +348,31 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
return BEST_FATTN_KERNEL_WMMA_F16;
}
+ if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) {
+ if (can_use_vector_kernel) {
+ if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
+ if (Q->ne[1] == 1) {
+ if (!gqa_opt_applies) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ }
+ } else {
+ if (Q->ne[1] <= 2) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ }
+ }
+ int gqa_ratio_eff = 1;
+ const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
+ while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
+ gqa_ratio_eff *= 2;
+ }
+ if (Q->ne[1] * gqa_ratio_eff <= 8) {
+ return BEST_FATTN_KERNEL_TILE; // AMD WMMA is only faster if the full tile width of 16 can be utilized.
+ }
+ return BEST_FATTN_KERNEL_MMA_F16;
+ }
+
// If there are no tensor cores available, use the generic tile kernel:
if (can_use_vector_kernel) {
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index f021de1d74..c3ee2ea066 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -3737,6 +3737,7 @@ static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) {
return cuda_ctx->cuda_graph->is_enabled();
#else
+ GGML_UNUSED(cuda_ctx);
return false;
#endif // USE_CUDA_GRAPH
}
diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh
index df9eed7117..42085d1002 100644
--- a/ggml/src/ggml-cuda/mma.cuh
+++ b/ggml/src/ggml-cuda/mma.cuh
@@ -206,10 +206,16 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 16) {
- // matrix C
#if defined(RDNA3)
- return 2 * l + (threadIdx.x / 16);
+ if constexpr (std::is_same_v || std::is_same_v) {
+ // matrix C
+ return 2 * l + (threadIdx.x / 16);
+ } else {
+ // matrix A&B
+ return l;
+ }
#else
+ // matrix C is the transposed matrix A&B on RDNA4
return ne * (threadIdx.x / 16) + l;
#endif // defined(RDNA3)
} else if constexpr (I == 16 && J == 8) {
@@ -621,6 +627,21 @@ namespace ggml_cuda_mma {
return ret;
}
+#elif defined(AMD_WMMA_AVAILABLE)
+ template
+ static __device__ __forceinline__ tile get_half2(const tile & tile_float) {
+ tile ret;
+#pragma unroll
+ for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
+ ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
+ }
+ return ret;
+ }
+
+ static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
+ NO_DEVICE_CODE;
+ return tile<8, 8, half2>{};
+ }
#else // Volta
template
static __device__ __forceinline__ tile get_half2(const tile & tile_float) {
@@ -639,6 +660,19 @@ namespace ggml_cuda_mma {
}
#endif // defined(TURING_MMA_AVAILABLE)
+ static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) {
+#if defined(RDNA4)
+ const int row = t.get_i(0);
+ const int left_right = t.get_j(0) / 4;
+ const int up_down = row / 8;
+ const int idx = row % 8;
+ reinterpret_cast(t.x)[idx] = left_right == up_down ? 1.0f : 0.0f;
+#else
+ GGML_UNUSED_VARS(t);
+ NO_DEVICE_CODE;
+#endif // defined(RDNA4)
+ }
+
template
static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) {
#if defined(AMD_MFMA_AVAILABLE)
@@ -878,6 +912,17 @@ namespace ggml_cuda_mma {
: "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#elif defined(AMD_WMMA_AVAILABLE)
+#if defined(RDNA4)
+ using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
+ halfx8_t& acc_frag = reinterpret_cast(D.x[0]);
+ const halfx8_t& a_frag = reinterpret_cast(A.x[0]);
+ const halfx8_t& b_frag = reinterpret_cast(B.x[0]);
+ acc_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
+#else
+ GGML_UNUSED_VARS(D, A, B);
+ NO_DEVICE_CODE;
+#endif // defined(RDNA4)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index 016b04e5a0..5cc1b54319 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -138,6 +138,8 @@
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
+#define cudaFuncSetAttribute hipFuncSetAttribute
+#define cudaFuncAttributeMaxDynamicSharedMemorySize hipFuncAttributeMaxDynamicSharedMemorySize
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index 0c0ee268ac..fd910cf996 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -426,6 +426,7 @@ class MODEL_ARCH(IntEnum):
NEMOTRON_H_MOE = auto()
EXAONE = auto()
EXAONE4 = auto()
+ EXAONE_MOE = auto()
GRANITE = auto()
GRANITE_MOE = auto()
GRANITE_HYBRID = auto()
@@ -846,6 +847,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.NEMOTRON_H_MOE: "nemotron_h_moe",
MODEL_ARCH.EXAONE: "exaone",
MODEL_ARCH.EXAONE4: "exaone4",
+ MODEL_ARCH.EXAONE_MOE: "exaone-moe",
MODEL_ARCH.GRANITE: "granite",
MODEL_ARCH.GRANITE_MOE: "granitemoe",
MODEL_ARCH.GRANITE_HYBRID: "granitehybrid",
@@ -2758,6 +2760,38 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_POST_NORM,
],
+ MODEL_ARCH.EXAONE_MOE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_GATE_SHEXP,
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
+ MODEL_TENSOR.FFN_UP_SHEXP,
+ MODEL_TENSOR.FFN_EXP_PROBS_B,
+ # NextN/MTP tensors - preserved but unused
+ MODEL_TENSOR.NEXTN_EH_PROJ,
+ MODEL_TENSOR.NEXTN_EMBED_TOKENS,
+ MODEL_TENSOR.NEXTN_ENORM,
+ MODEL_TENSOR.NEXTN_HNORM,
+ MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
+ MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
+ ],
MODEL_ARCH.GRANITE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index 0fd51c3e87..bdd0af6800 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -436,7 +436,8 @@ class TensorNameMap:
"model.layers.{bid}.mlp.expert_bias", # afmoe
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
"model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2
- "backbone.layers.{bid}.mixer.gate.e_score_correction" # nemotron-h-moe
+ "backbone.layers.{bid}.mixer.gate.e_score_correction", # nemotron-h-moe
+ "model.layers.{bid}.mlp.e_score_correction", # exaone-moe
),
# Feed-forward up
@@ -1797,7 +1798,7 @@ class TensorNameMap:
"model.embed_audio.soft_embedding_norm", # gemma3n
),
- # NextN/MTP tensors for GLM4_MOE
+ # NextN/MTP tensors
MODEL_TENSOR.NEXTN_EH_PROJ: (
"model.layers.{bid}.eh_proj",
),
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 0104512661..87b322e4b6 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -63,6 +63,7 @@ add_library(llama
models/paddleocr.cpp
models/exaone.cpp
models/exaone4.cpp
+ models/exaone-moe.cpp
models/falcon-h1.cpp
models/falcon.cpp
models/gemma-embedding.cpp
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
index 437a978bbb..b016788ee9 100644
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
@@ -81,6 +81,7 @@ static const std::map LLM_ARCH_NAMES = {
{ LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" },
{ LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_EXAONE4, "exaone4" },
+ { LLM_ARCH_EXAONE_MOE, "exaone-moe" },
{ LLM_ARCH_RWKV6, "rwkv6" },
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
{ LLM_ARCH_RWKV7, "rwkv7" },
@@ -1730,6 +1731,38 @@ static std::set llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_FFN_UP,
LLM_TENSOR_FFN_POST_NORM,
};
+ case LLM_ARCH_EXAONE_MOE:
+ return {
+ LLM_TENSOR_TOKEN_EMBD,
+ LLM_TENSOR_OUTPUT_NORM,
+ LLM_TENSOR_OUTPUT,
+ LLM_TENSOR_ROPE_FREQS,
+ LLM_TENSOR_ATTN_NORM,
+ LLM_TENSOR_ATTN_Q,
+ LLM_TENSOR_ATTN_Q_NORM,
+ LLM_TENSOR_ATTN_K,
+ LLM_TENSOR_ATTN_K_NORM,
+ LLM_TENSOR_ATTN_V,
+ LLM_TENSOR_ATTN_OUT,
+ LLM_TENSOR_FFN_NORM,
+ LLM_TENSOR_FFN_GATE,
+ LLM_TENSOR_FFN_DOWN,
+ LLM_TENSOR_FFN_UP,
+ LLM_TENSOR_FFN_GATE_INP,
+ LLM_TENSOR_FFN_GATE_EXPS,
+ LLM_TENSOR_FFN_DOWN_EXPS,
+ LLM_TENSOR_FFN_UP_EXPS,
+ LLM_TENSOR_FFN_GATE_SHEXP,
+ LLM_TENSOR_FFN_UP_SHEXP,
+ LLM_TENSOR_FFN_DOWN_SHEXP,
+ LLM_TENSOR_FFN_EXP_PROBS_B,
+ LLM_TENSOR_NEXTN_EH_PROJ,
+ LLM_TENSOR_NEXTN_EMBED_TOKENS,
+ LLM_TENSOR_NEXTN_ENORM,
+ LLM_TENSOR_NEXTN_HNORM,
+ LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
+ LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
+ };
case LLM_ARCH_RWKV6:
return {
LLM_TENSOR_TOKEN_EMBD,
diff --git a/src/llama-arch.h b/src/llama-arch.h
index c306ddb08a..658785929d 100644
--- a/src/llama-arch.h
+++ b/src/llama-arch.h
@@ -85,6 +85,7 @@ enum llm_arch {
LLM_ARCH_NEMOTRON_H_MOE,
LLM_ARCH_EXAONE,
LLM_ARCH_EXAONE4,
+ LLM_ARCH_EXAONE_MOE,
LLM_ARCH_RWKV6,
LLM_ARCH_RWKV6QWEN2,
LLM_ARCH_RWKV7,
diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp
index b54ebbd155..3c7e0afdae 100644
--- a/src/llama-chat.cpp
+++ b/src/llama-chat.cpp
@@ -57,6 +57,7 @@ static const std::map LLM_CHAT_TEMPLATES = {
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
{ "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 },
+ { "exaone-moe", LLM_CHAT_TEMPLATE_EXAONE_MOE },
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
@@ -137,6 +138,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
} else if (tmpl_contains("[gMASK]")) {
return LLM_CHAT_TEMPLATE_CHATGLM_4;
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
+ if (tmpl_contains("<|tool_declare|>")) {
+ return LLM_CHAT_TEMPLATE_EXAONE_MOE;
+ }
return tmpl_contains("") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE;
} else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) {
return LLM_CHAT_TEMPLATE_GLMEDGE;
@@ -576,6 +580,22 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "[|assistant|]";
}
+ } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_MOE) {
+ for (auto message : chat) {
+ std::string role(message->role);
+ if (role == "system") {
+ ss << "<|system|>\n" << trim(message->content) << "<|endofturn|>\n";
+ } else if (role == "user") {
+ ss << "<|user|>\n" << trim(message->content) << "<|endofturn|>\n";
+ } else if (role == "assistant") {
+ ss << "<|assistant|>\n" << trim(message->content) << "<|endofturn|>\n";
+ } else if (role == "tool") {
+ ss << "<|tool|>\n" << trim(message->content) << "<|endofturn|>\n";
+ }
+ }
+ if (add_ass) {
+ ss << "<|assistant|>\n";
+ }
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
// this template requires the model to have "\n\n" as EOT token
for (size_t i = 0; i < chat.size(); i++) {
diff --git a/src/llama-chat.h b/src/llama-chat.h
index e1f795249c..9ed1db128e 100644
--- a/src/llama-chat.h
+++ b/src/llama-chat.h
@@ -36,6 +36,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_MINICPM,
LLM_CHAT_TEMPLATE_EXAONE_3,
LLM_CHAT_TEMPLATE_EXAONE_4,
+ LLM_CHAT_TEMPLATE_EXAONE_MOE,
LLM_CHAT_TEMPLATE_RWKV_WORLD,
LLM_CHAT_TEMPLATE_GRANITE,
LLM_CHAT_TEMPLATE_GIGACHAT,
diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
index 374ff1ebf3..944c7e53bd 100644
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
@@ -96,11 +96,9 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
int32_t * data = (int32_t *) pos_bucket->data;
- for (int h = 0; h < 1; ++h) {
- for (int j = 0; j < n_tokens; ++j) {
- for (int i = 0; i < n_tokens; ++i) {
- data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
- }
+ for (int j = 0; j < n_tokens; ++j) {
+ for (int i = 0; i < n_tokens; ++i) {
+ data[j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
}
}
}
@@ -323,34 +321,32 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
const int64_t n_tokens = ubatch->n_tokens;
const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
- for (int h = 0; h < 1; ++h) {
- for (int i1 = 0; i1 < n_tokens; ++i1) {
- const llama_seq_id s1 = ubatch->seq_id[i1][0];
- const llama_pos p1 = ubatch->pos[i1];
+ for (int i1 = 0; i1 < n_tokens; ++i1) {
+ const llama_seq_id s1 = ubatch->seq_id[i1][0];
+ const llama_pos p1 = ubatch->pos[i1];
- const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
+ const uint64_t idst = i1*n_kv;
- for (int i0 = 0; i0 < n_tokens; ++i0) {
- const llama_seq_id s0 = ubatch->seq_id[i0][0];
- const llama_pos p0 = ubatch->pos[i0];
+ for (int i0 = 0; i0 < n_tokens; ++i0) {
+ const llama_seq_id s0 = ubatch->seq_id[i0][0];
+ const llama_pos p0 = ubatch->pos[i0];
- // mask different sequences
- if (s0 != s1) {
- continue;
- }
-
- // mask future tokens
- if (cparams.causal_attn && p0 > p1) {
- continue;
- }
-
- // apply SWA if any
- if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
- continue;
- }
-
- data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
+ // mask different sequences
+ if (s0 != s1) {
+ continue;
}
+
+ // mask future tokens
+ if (cparams.causal_attn && p0 > p1) {
+ continue;
+ }
+
+ // apply SWA if any
+ if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
+ continue;
+ }
+
+ data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
}
}
};
@@ -454,27 +450,19 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
float * data = (float *) cross_kq_mask->data;
- for (int h = 0; h < 1; ++h) {
- for (int i = 0; i < n_tokens; ++i) {
- for (int j = 0; j < n_enc; ++j) {
- float f = -INFINITY;
+ for (int i = 0; i < n_tokens; ++i) {
+ for (int j = 0; j < n_enc; ++j) {
+ float f = -INFINITY;
- for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
- const llama_seq_id seq_id = ubatch->seq_id[i][s];
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
- if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
- f = 0.0f;
- }
+ if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
+ f = 0.0f;
}
-
- data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
}
- }
- for (int i = n_tokens; i < n_tokens; ++i) {
- for (int j = 0; j < n_enc; ++j) {
- data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
- }
+ data[i*n_enc + j] = f;
}
}
}
diff --git a/src/llama-mmap.cpp b/src/llama-mmap.cpp
index 2da857b3aa..0c43495b11 100644
--- a/src/llama-mmap.cpp
+++ b/src/llama-mmap.cpp
@@ -244,11 +244,14 @@ struct llama_file::impl {
}
errno = 0;
if (fd == -1) {
- std::size_t ret = std::fread(ptr, len, 1, fp);
+ const size_t curr_off = tell();
+ const size_t to_read = std::min(len, size - curr_off);
+
+ std::size_t ret = std::fread(ptr, to_read, 1, fp);
if (ferror(fp)) {
throw std::runtime_error(format("read error: %s", strerror(errno)));
}
- if (ret != 1) {
+ if (to_read > 0 && ret != 1) {
throw std::runtime_error("unexpectedly reached end of file");
}
} else {
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index a51f693150..e06b1a8e26 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -1933,6 +1933,38 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
+ case LLM_ARCH_EXAONE_MOE:
+ {
+ hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
+ hparams.n_swa = 128;
+ hparams.set_swa_pattern(4);
+ hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
+ hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
+
+ ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, true);
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert);
+ ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used);
+ ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false);
+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
+ ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
+ ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups, false);
+ ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false);
+ ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
+ ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false);
+ ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
+ ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
+
+ ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
+
+ switch (hparams.n_layer) {
+ case 32: type = LLM_TYPE_30B_A3B; break;
+ case 48:
+ case 49: type = LLM_TYPE_235B_A22B; break;
+ default: type = LLM_TYPE_UNKNOWN;
+ }
+ } break;
case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2:
{
@@ -5520,6 +5552,84 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
}
} break;
+ case LLM_ARCH_EXAONE_MOE:
+ {
+ const int64_t n_ff_exp = hparams.n_ff_exp;
+ const int64_t n_expert = hparams.n_expert;
+ const int64_t n_expert_used = hparams.n_expert_used;
+ const int64_t n_ff_shexp = hparams.n_ff_shexp;
+ const int64_t head_dim = hparams.n_embd_head_k;
+ const int64_t n_qo_dim = n_head * head_dim;
+ const int64_t n_kv_dim = n_head_kv * head_dim;
+
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+ // output
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+
+ if (output == NULL) {
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+ }
+
+ for (int i = 0; i < n_layer; ++i) {
+ int flags = 0;
+ if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) {
+ // skip all tensors in the NextN layers
+ flags |= TENSOR_SKIP;
+ }
+
+ auto & layer = layers[i];
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}, flags);
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}, flags);
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_kv_dim}, flags);
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, flags);
+
+ layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0) | flags);
+
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags);
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags);
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags);
+
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags);
+
+ // dense layers for first n_layer_dense_lead layers or nextn_predict_layers layers at the end
+ if (i < (int) hparams.n_layer_dense_lead || (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers)) {
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags);
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, flags);
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags);
+ } else {
+ layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags);
+ layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags);
+
+ if (n_expert == 0) {
+ throw std::runtime_error("n_expert must be > 0");
+ }
+ if (n_expert_used == 0) {
+ throw std::runtime_error("n_expert_used must be > 0");
+ }
+
+ layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, flags);
+ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags);
+ layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, flags);
+
+ layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags);
+ layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags);
+ layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags);
+ }
+
+ // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers
+ if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) {
+ layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, flags);
+ layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, flags);
+ layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, flags);
+
+ layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), {n_embd}, flags | TENSOR_NOT_REQUIRED);
+ layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED);
+ layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED);
+ }
+ }
+ } break;
case LLM_ARCH_RWKV6:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -7106,59 +7216,59 @@ void llama_model::print_info() const {
};
// hparams
- LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str());
- LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only);
- LLAMA_LOG_INFO("%s: no_alloc = %d\n", __func__, hparams.no_alloc);
+ LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str());
+ LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only);
+ LLAMA_LOG_INFO("%s: no_alloc = %d\n", __func__, hparams.no_alloc);
if (!hparams.vocab_only) {
- LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
- LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd);
- LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp());
- LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
- LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str());
- LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
- LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
- LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
- LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any());
- LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
- LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
- LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
- LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str());
- LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str());
- LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps);
- LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
- LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
- LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
- LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale);
- LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale);
- LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str());
- LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
- LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
- LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups);
- LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used);
- LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn);
- LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type);
- LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type);
- LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str());
- LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
- LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
+ LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
+ LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd);
+ LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp());
+ LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
+ LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str());
+ LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
+ LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
+ LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
+ LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any());
+ LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
+ LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
+ LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
+ LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str());
+ LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str());
+ LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps);
+ LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
+ LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
+ LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
+ LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale);
+ LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale);
+ LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str());
+ LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
+ LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
+ LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups);
+ LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used);
+ LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn);
+ LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type);
+ LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type);
+ LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str());
+ LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
+ LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
- LLAMA_LOG_INFO("%s: freq_base_swa = %.1f\n", __func__, hparams.rope_freq_base_train_swa);
- LLAMA_LOG_INFO("%s: freq_scale_swa = %g\n", __func__, hparams.rope_freq_scale_train_swa);
+ LLAMA_LOG_INFO("%s: freq_base_swa = %.1f\n", __func__, hparams.rope_freq_base_train_swa);
+ LLAMA_LOG_INFO("%s: freq_scale_swa = %g\n", __func__, hparams.rope_freq_scale_train_swa);
}
- LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
- LLAMA_LOG_INFO("%s: rope_yarn_log_mul= %.4f\n", __func__, hparams.rope_yarn_log_mul);
- LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
+ LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
+ LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul);
+ LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
// MRoPE (Multi-axis Rotary Position Embedding) sections
if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) {
- LLAMA_LOG_INFO("%s: mrope sections = [%d, %d, %d, %d]\n", __func__, s[0], s[1], s[2], s[3]);
+ LLAMA_LOG_INFO("%s: mrope sections = [%d, %d, %d, %d]\n", __func__, s[0], s[1], s[2], s[3]);
}
if (!classifier_labels.empty()) {
- LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out);
+ LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out);
size_t i = 0;
for (auto label : classifier_labels) {
- LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str());
+ LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str());
}
}
}
@@ -7172,55 +7282,55 @@ void llama_model::print_info() const {
arch == LLM_ARCH_QWEN3NEXT ||
arch == LLM_ARCH_NEMOTRON_H ||
arch == LLM_ARCH_NEMOTRON_H_MOE) {
- LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
- LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
- LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
- LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
- LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group);
- LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
+ LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
+ LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
+ LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
+ LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
+ LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group);
+ LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
}
- LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str());
+ LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str());
if (pimpl->n_elements >= 1e12) {
- LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12);
+ LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12);
} else if (pimpl->n_elements >= 1e9) {
- LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9);
+ LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9);
} else if (pimpl->n_elements >= 1e6) {
- LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6);
+ LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6);
} else {
- LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3);
+ LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3);
}
// general kv
- LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str());
+ LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str());
if (arch == LLM_ARCH_DEEPSEEK) {
- LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
- LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
- LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
- LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
+ LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
+ LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
+ LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
+ LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
}
if (arch == LLM_ARCH_DEEPSEEK2) {
- LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
- LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q);
- LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv);
- LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla);
- LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla);
- LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
- LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
- LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
- LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
- LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
+ LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
+ LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q);
+ LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv);
+ LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla);
+ LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla);
+ LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
+ LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
+ LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
+ LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
+ LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
}
if (arch == LLM_ARCH_QWEN2MOE) {
- LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
- LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
+ LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
+ LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
}
if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) {
- LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
+ LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
}
if (arch == LLM_ARCH_MINICPM ||
@@ -7228,41 +7338,41 @@ void llama_model::print_info() const {
arch == LLM_ARCH_GRANITE_MOE ||
arch == LLM_ARCH_GRANITE_HYBRID ||
arch == LLM_ARCH_NEMOTRON_H_MOE) {
- LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
- LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
- LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
- LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
+ LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
+ LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
+ LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
+ LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
}
if (arch == LLM_ARCH_BAILINGMOE) {
- LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
- LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
- LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
- LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
- LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
+ LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
+ LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
+ LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
+ LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
+ LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
}
if (arch == LLM_ARCH_BAILINGMOE2) {
- LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
- LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
- LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
- LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
- LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
- LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
- LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
- LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers);
+ LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
+ LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
+ LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
+ LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
+ LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
+ LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
+ LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
+ LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers);
}
if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) {
- LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
- LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
+ LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
+ LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
}
if (arch == LLM_ARCH_GROVEMOE) {
- LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
- LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp);
- LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts);
- LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale);
+ LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
+ LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp);
+ LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts);
+ LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale);
}
vocab.print_info();
@@ -7816,6 +7926,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
llm = std::make_unique>(*this, params);
}
} break;
+ case LLM_ARCH_EXAONE_MOE:
+ {
+ llm = std::make_unique(*this, params);
+ } break;
case LLM_ARCH_RWKV6:
{
llm = std::make_unique(*this, params);
@@ -8180,6 +8294,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_NEMOTRON:
case LLM_ARCH_EXAONE:
case LLM_ARCH_EXAONE4:
+ case LLM_ARCH_EXAONE_MOE:
case LLM_ARCH_MINICPM3:
case LLM_ARCH_BAILINGMOE2:
case LLM_ARCH_DOTS1:
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
index 1a2dff8058..6063cc5290 100644
--- a/src/llama-vocab.cpp
+++ b/src/llama-vocab.cpp
@@ -461,6 +461,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\\r\\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
};
break;
+ case LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE:
+ regex_exprs = {
+ // original regex from tokenizer.json
+ // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+"
+ "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
+ };
+ break;
default:
// default regex for BPE tokenization pre-processing
regex_exprs = {
@@ -1965,6 +1972,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
} else if (
tokenizer_pre == "exaone4") {
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
+ } else if (
+ tokenizer_pre == "exaone-moe") {
+ pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE;
} else if (
tokenizer_pre == "chameleon") {
pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON;
@@ -2437,7 +2447,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
auto & attr = id_to_token[t.second].attr;
if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>" || t.first == "<|constrain|>") {
- attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_USER_DEFINED);
+ LLAMA_LOG_WARN("%s: setting token '%s' (%d) attribute to USER_DEFINED (%u), old attributes: %u\n",
+ __func__, t.first.c_str(), t.second, LLAMA_TOKEN_ATTR_USER_DEFINED, attr);
+
+ attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
}
}
@@ -2490,7 +2503,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
special_eog_ids.erase(end_id);
auto & attr = id_to_token[end_id].attr;
- attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_USER_DEFINED);
+ attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>', or '<|calls|>' and '<|flush|>' tokens, removing '<|end|>' token from EOG list\n", __func__);
}
@@ -3290,34 +3303,34 @@ int32_t llama_vocab::impl::detokenize(
}
void llama_vocab::impl::print_info() const {
- LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, type_name().c_str());
- LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, vocab.n_tokens());
- LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (uint32_t) bpe_ranks.size());
+ LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, type_name().c_str());
+ LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, vocab.n_tokens());
+ LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (uint32_t) bpe_ranks.size());
// special tokens
- if (special_bos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, special_bos_id, id_to_token.at(special_bos_id).text.c_str() ); }
- if (special_eos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, special_eos_id, id_to_token.at(special_eos_id).text.c_str() ); }
- if (special_eot_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, special_eot_id, id_to_token.at(special_eot_id).text.c_str() ); }
- if (special_eom_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOM token = %d '%s'\n", __func__, special_eom_id, id_to_token.at(special_eom_id).text.c_str() ); }
- if (special_unk_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, special_unk_id, id_to_token.at(special_unk_id).text.c_str() ); }
- if (special_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, special_sep_id, id_to_token.at(special_sep_id).text.c_str() ); }
- if (special_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, special_pad_id, id_to_token.at(special_pad_id).text.c_str() ); }
- if (special_mask_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, special_mask_id, id_to_token.at(special_mask_id).text.c_str() ); }
+ if (special_bos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, special_bos_id, id_to_token.at(special_bos_id).text.c_str() ); }
+ if (special_eos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, special_eos_id, id_to_token.at(special_eos_id).text.c_str() ); }
+ if (special_eot_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, special_eot_id, id_to_token.at(special_eot_id).text.c_str() ); }
+ if (special_eom_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOM token = %d '%s'\n", __func__, special_eom_id, id_to_token.at(special_eom_id).text.c_str() ); }
+ if (special_unk_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, special_unk_id, id_to_token.at(special_unk_id).text.c_str() ); }
+ if (special_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, special_sep_id, id_to_token.at(special_sep_id).text.c_str() ); }
+ if (special_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, special_pad_id, id_to_token.at(special_pad_id).text.c_str() ); }
+ if (special_mask_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, special_mask_id, id_to_token.at(special_mask_id).text.c_str() ); }
- if (linefeed_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, linefeed_id, id_to_token.at(linefeed_id).text.c_str() ); }
+ if (linefeed_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, linefeed_id, id_to_token.at(linefeed_id).text.c_str() ); }
- if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token = %d '%s'\n", __func__, special_fim_pre_id, id_to_token.at(special_fim_pre_id).text.c_str() ); }
- if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token = %d '%s'\n", __func__, special_fim_suf_id, id_to_token.at(special_fim_suf_id).text.c_str() ); }
- if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token = %d '%s'\n", __func__, special_fim_mid_id, id_to_token.at(special_fim_mid_id).text.c_str() ); }
- if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token = %d '%s'\n", __func__, special_fim_pad_id, id_to_token.at(special_fim_pad_id).text.c_str() ); }
- if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token = %d '%s'\n", __func__, special_fim_rep_id, id_to_token.at(special_fim_rep_id).text.c_str() ); }
- if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token = %d '%s'\n", __func__, special_fim_sep_id, id_to_token.at(special_fim_sep_id).text.c_str() ); }
+ if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token = %d '%s'\n", __func__, special_fim_pre_id, id_to_token.at(special_fim_pre_id).text.c_str() ); }
+ if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token = %d '%s'\n", __func__, special_fim_suf_id, id_to_token.at(special_fim_suf_id).text.c_str() ); }
+ if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token = %d '%s'\n", __func__, special_fim_mid_id, id_to_token.at(special_fim_mid_id).text.c_str() ); }
+ if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token = %d '%s'\n", __func__, special_fim_pad_id, id_to_token.at(special_fim_pad_id).text.c_str() ); }
+ if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token = %d '%s'\n", __func__, special_fim_rep_id, id_to_token.at(special_fim_rep_id).text.c_str() ); }
+ if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token = %d '%s'\n", __func__, special_fim_sep_id, id_to_token.at(special_fim_sep_id).text.c_str() ); }
for (const auto & id : special_eog_ids) {
- LLAMA_LOG_INFO( "%s: EOG token = %d '%s'\n", __func__, id, id_to_token.at(id).text.c_str() );
+ LLAMA_LOG_INFO( "%s: EOG token = %d '%s'\n", __func__, id, id_to_token.at(id).text.c_str() );
}
- LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len);
+ LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len);
}
llama_vocab::llama_vocab() : pimpl(new impl(*this)) {
diff --git a/src/llama-vocab.h b/src/llama-vocab.h
index 2b240a5491..28c3a82b91 100644
--- a/src/llama-vocab.h
+++ b/src/llama-vocab.h
@@ -53,6 +53,7 @@ enum llama_vocab_pre_type {
LLAMA_VOCAB_PRE_TYPE_AFMOE = 42,
LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43,
LLAMA_VOCAB_PRE_TYPE_YOUTU = 44,
+ LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45,
};
struct LLM_KV;
diff --git a/src/models/exaone-moe.cpp b/src/models/exaone-moe.cpp
new file mode 100644
index 0000000000..bef5b2ad35
--- /dev/null
+++ b/src/models/exaone-moe.cpp
@@ -0,0 +1,146 @@
+#include "models.h"
+
+
+llm_build_exaone_moe::llm_build_exaone_moe(const llama_model & model, const llm_graph_params & params) :
+ llm_graph_context(params) {
+ const int64_t n_embd_head = hparams.n_embd_head_k;
+
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_v);
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+ ggml_tensor * cur;
+ ggml_tensor * inpL;
+
+ inpL = build_inp_embd(model.tok_embd);
+
+ // inp_pos - contains the positions
+ ggml_tensor * inp_pos = build_inp_pos();
+
+ auto * inp_attn_iswa = build_attn_inp_kv_iswa();
+
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+ const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
+ for (int il = 0; il < n_transformer_layers; ++il) {
+ ggml_tensor * inpSA = inpL;
+
+ // use RoPE for SWA layers
+ const bool is_local_layer = hparams.is_swa(il);
+
+ // norm
+ cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+ cb(cur, "attn_norm", il);
+
+ // self-attention
+ {
+ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
+
+ // compute Q and K and RoPE them
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+ cb(Qcur, "Qcur_normed", il);
+ cb(Kcur, "Kcur_normed", il);
+
+ if (is_local_layer) {
+ Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base,
+ freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+
+ Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base,
+ freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+ }
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
+
+ cur = build_attn(inp_attn_iswa,
+ model.layers[il].wo, NULL,
+ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
+ cb(cur, "attn_out", il);
+ }
+ if (il == n_transformer_layers - 1 && inp_out_ids) {
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
+
+ // norm
+ cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
+ cb(cur, "ffn_norm", il);
+
+ // feed-forward network
+ if (model.layers[il].ffn_gate_inp == nullptr) {
+ // dense branch
+ cur = build_ffn(cur,
+ model.layers[il].ffn_up, NULL, NULL,
+ model.layers[il].ffn_gate, NULL, NULL,
+ model.layers[il].ffn_down, NULL, NULL, NULL,
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
+ cb(cur, "ffn_out", il);
+ } else {
+ // MoE branch
+ ggml_tensor * moe_out = build_moe_ffn(cur,
+ model.layers[il].ffn_gate_inp,
+ model.layers[il].ffn_up_exps,
+ model.layers[il].ffn_gate_exps,
+ model.layers[il].ffn_down_exps,
+ model.layers[il].ffn_exp_probs_b,
+ n_expert, n_expert_used,
+ LLM_FFN_SILU, hparams.expert_weights_norm,
+ true, hparams.expert_weights_scale,
+ (llama_expert_gating_func_type) hparams.expert_gating_func,
+ il);
+ cb(moe_out, "ffn_moe_out", il);
+
+ // FFN shared expert
+ {
+ ggml_tensor * ffn_shexp =
+ build_ffn(cur,
+ model.layers[il].ffn_up_shexp, NULL, NULL,
+ model.layers[il].ffn_gate_shexp, NULL, NULL,
+ model.layers[il].ffn_down_shexp, NULL, NULL,
+ NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+ cb(ffn_shexp, "ffn_shexp", il);
+
+ cur = ggml_add(ctx0, moe_out, ffn_shexp);
+ cb(cur, "ffn_out", il);
+ }
+ }
+
+ cur = ggml_add(ctx0, cur, ffn_inp);
+
+ cur = build_cvec(cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+ cur = inpL;
+
+ // final norm
+ cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+
+ cb(cur, "result_norm", -1);
+ res->t_embd = cur;
+
+ // lm_head
+ cur = build_lora_mm(model.output, cur);
+
+ cb(cur, "result_output", -1);
+ res->t_logits = cur;
+
+ ggml_build_forward_expand(gf, cur);
+}
diff --git a/src/models/gemma3n-iswa.cpp b/src/models/gemma3n-iswa.cpp
index 93defbeef9..51acab1490 100644
--- a/src/models/gemma3n-iswa.cpp
+++ b/src/models/gemma3n-iswa.cpp
@@ -258,12 +258,12 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
res->add_input(std::move(inp));
} else {
// Vision embedding path: use padding token (ID=0) embedding
+ // TODO: verify if this is the correct behavior in transformers implementation
const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_altup * n_layer
- // Extract and dequantize padding token embedding (column 0)
- ggml_tensor * padding_q = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0);
- ggml_tensor * padding_f32 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, embd_size);
- inp_per_layer = ggml_cpy(ctx0, padding_q, padding_f32);
+ // Extract and dequantize padding token embedding (row 0)
+ ggml_tensor * padding = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0);
+ inp_per_layer = ggml_cast(ctx0, padding, GGML_TYPE_F32);
// Reshape to [n_embd_altup, n_layer, 1]
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, 1);
diff --git a/src/models/models.h b/src/models/models.h
index 2cb6942415..eedbee2c46 100644
--- a/src/models/models.h
+++ b/src/models/models.h
@@ -171,6 +171,10 @@ struct llm_build_exaone : public llm_graph_context {
llm_build_exaone(const llama_model & model, const llm_graph_params & params);
};
+struct llm_build_exaone_moe : public llm_graph_context {
+ llm_build_exaone_moe(const llama_model & model, const llm_graph_params & params);
+};
+
struct llm_build_falcon : public llm_graph_context {
llm_build_falcon(const llama_model & model, const llm_graph_params & params);
};
diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp
index a98ede0a57..aed97e77e5 100644
--- a/tools/llama-bench/llama-bench.cpp
+++ b/tools/llama-bench/llama-bench.cpp
@@ -334,6 +334,7 @@ struct cmd_params {
std::vector> tensor_split;
std::vector> tensor_buft_overrides;
std::vector use_mmap;
+ std::vector use_direct_io;
std::vector embeddings;
std::vector no_op_offload;
std::vector no_host;
@@ -372,6 +373,7 @@ static const cmd_params cmd_params_defaults = {
/* tensor_split */ { std::vector(llama_max_devices(), 0.0f) },
/* tensor_buft_overrides*/ { std::vector{ { nullptr, nullptr } } },
/* use_mmap */ { true },
+ /* use_direct_io */ { true },
/* embeddings */ { false },
/* no_op_offload */ { false },
/* no_host */ { false },
@@ -449,6 +451,8 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -dev, --device (default: auto)\n");
printf(" -mmp, --mmap <0|1> (default: %s)\n",
join(cmd_params_defaults.use_mmap, ",").c_str());
+ printf(" -dio, --direct-io <0|1> (default: %s)\n",
+ join(cmd_params_defaults.use_direct_io, ",").c_str());
printf(" -embd, --embeddings <0|1> (default: %s)\n",
join(cmd_params_defaults.embeddings, ",").c_str());
printf(" -ts, --tensor-split (default: 0)\n");
@@ -772,6 +776,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
}
auto p = string_split(argv[i], split_delim);
params.use_mmap.insert(params.use_mmap.end(), p.begin(), p.end());
+ } else if (arg == "-dio" || arg == "--direct-io") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ auto p = string_split(argv[i], split_delim);
+ params.use_direct_io.insert(params.use_direct_io.end(), p.begin(), p.end());
} else if (arg == "-embd" || arg == "--embeddings") {
if (++i >= argc) {
invalid_param = true;
@@ -1008,6 +1019,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.use_mmap.empty()) {
params.use_mmap = cmd_params_defaults.use_mmap;
}
+ if (params.use_direct_io.empty()) {
+ params.use_direct_io = cmd_params_defaults.use_direct_io;
+ }
if (params.embeddings.empty()) {
params.embeddings = cmd_params_defaults.embeddings;
}
@@ -1056,6 +1070,7 @@ struct cmd_params_instance {
std::vector tensor_split;
std::vector tensor_buft_overrides;
bool use_mmap;
+ bool use_direct_io;
bool embeddings;
bool no_op_offload;
bool no_host;
@@ -1067,11 +1082,12 @@ struct cmd_params_instance {
if (!devices.empty()) {
mparams.devices = const_cast(devices.data());
}
- mparams.split_mode = split_mode;
- mparams.main_gpu = main_gpu;
- mparams.tensor_split = tensor_split.data();
- mparams.use_mmap = use_mmap;
- mparams.no_host = no_host;
+ mparams.split_mode = split_mode;
+ mparams.main_gpu = main_gpu;
+ mparams.tensor_split = tensor_split.data();
+ mparams.use_mmap = use_mmap;
+ mparams.use_direct_io = use_direct_io;
+ mparams.no_host = no_host;
if (n_cpu_moe <= 0) {
if (tensor_buft_overrides.empty()) {
@@ -1115,7 +1131,8 @@ struct cmd_params_instance {
bool equal_mparams(const cmd_params_instance & other) const {
return model == other.model && n_gpu_layers == other.n_gpu_layers && n_cpu_moe == other.n_cpu_moe &&
split_mode == other.split_mode &&
- main_gpu == other.main_gpu && use_mmap == other.use_mmap && tensor_split == other.tensor_split &&
+ main_gpu == other.main_gpu && tensor_split == other.tensor_split &&
+ use_mmap == other.use_mmap && use_direct_io == other.use_direct_io &&
devices == other.devices &&
no_host == other.no_host &&
vec_tensor_buft_override_equal(tensor_buft_overrides, other.tensor_buft_overrides);
@@ -1153,6 +1170,7 @@ static std::vector get_cmd_params_instances(const cmd_param
for (const auto & ts : params.tensor_split)
for (const auto & ot : params.tensor_buft_overrides)
for (const auto & mmp : params.use_mmap)
+ for (const auto & dio : params.use_direct_io)
for (const auto & noh : params.no_host)
for (const auto & embd : params.embeddings)
for (const auto & nopo : params.no_op_offload)
@@ -1194,6 +1212,7 @@ static std::vector get_cmd_params_instances(const cmd_param
/* .tensor_split = */ ts,
/* .tensor_buft_overrides = */ ot,
/* .use_mmap = */ mmp,
+ /* .use_direct_io= */ dio,
/* .embeddings = */ embd,
/* .no_op_offload= */ nopo,
/* .no_host = */ noh,
@@ -1228,6 +1247,7 @@ static std::vector get_cmd_params_instances(const cmd_param
/* .tensor_split = */ ts,
/* .tensor_buft_overrides = */ ot,
/* .use_mmap = */ mmp,
+ /* .use_direct_io= */ dio,
/* .embeddings = */ embd,
/* .no_op_offload= */ nopo,
/* .no_host = */ noh,
@@ -1262,6 +1282,7 @@ static std::vector get_cmd_params_instances(const cmd_param
/* .tensor_split = */ ts,
/* .tensor_buft_overrides = */ ot,
/* .use_mmap = */ mmp,
+ /* .use_direct_io= */ dio,
/* .embeddings = */ embd,
/* .no_op_offload= */ nopo,
/* .no_host = */ noh,
@@ -1301,6 +1322,7 @@ struct test {
std::vector tensor_split;
std::vector tensor_buft_overrides;
bool use_mmap;
+ bool use_direct_io;
bool embeddings;
bool no_op_offload;
bool no_host;
@@ -1338,6 +1360,7 @@ struct test {
tensor_split = inst.tensor_split;
tensor_buft_overrides = inst.tensor_buft_overrides;
use_mmap = inst.use_mmap;
+ use_direct_io = inst.use_direct_io;
embeddings = inst.embeddings;
no_op_offload = inst.no_op_offload;
no_host = inst.no_host;
@@ -1397,9 +1420,9 @@ struct test {
"n_ubatch", "n_threads", "cpu_mask", "cpu_strict", "poll",
"type_k", "type_v", "n_gpu_layers", "n_cpu_moe", "split_mode",
"main_gpu", "no_kv_offload", "flash_attn", "devices", "tensor_split",
- "tensor_buft_overrides", "use_mmap", "embeddings", "no_op_offload",
- "no_host", "n_prompt", "n_gen", "n_depth", "test_time",
- "avg_ns", "stddev_ns", "avg_ts", "stddev_ts"
+ "tensor_buft_overrides", "use_mmap", "use_direct_io", "embeddings",
+ "no_op_offload", "no_host", "n_prompt", "n_gen", "n_depth",
+ "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts"
};
return fields;
}
@@ -1414,7 +1437,7 @@ struct test {
return INT;
}
if (field == "f16_kv" || field == "no_kv_offload" || field == "cpu_strict" || field == "flash_attn" ||
- field == "use_mmap" || field == "embeddings" || field == "no_host") {
+ field == "use_mmap" || field == "use_direct_io" || field == "embeddings" || field == "no_host") {
return BOOL;
}
if (field == "avg_ts" || field == "stddev_ts") {
@@ -1487,6 +1510,7 @@ struct test {
tensor_split_str,
tensor_buft_overrides_str,
std::to_string(use_mmap),
+ std::to_string(use_direct_io),
std::to_string(embeddings),
std::to_string(no_op_offload),
std::to_string(no_host),
@@ -1672,6 +1696,9 @@ struct markdown_printer : public printer {
if (field == "use_mmap") {
return 4;
}
+ if (field == "use_direct_io") {
+ return 3;
+ }
if (field == "test") {
return 15;
}
@@ -1709,6 +1736,9 @@ struct markdown_printer : public printer {
if (field == "use_mmap") {
return "mmap";
}
+ if (field == "use_direct_io") {
+ return "dio";
+ }
if (field == "embeddings") {
return "embd";
}
@@ -1793,6 +1823,9 @@ struct markdown_printer : public printer {
if (params.use_mmap.size() > 1 || params.use_mmap != cmd_params_defaults.use_mmap) {
fields.emplace_back("use_mmap");
}
+ if (params.use_direct_io.size() > 1 || params.use_direct_io != cmd_params_defaults.use_direct_io) {
+ fields.emplace_back("use_direct_io");
+ }
if (params.embeddings.size() > 1 || params.embeddings != cmd_params_defaults.embeddings) {
fields.emplace_back("embeddings");
}
diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp
index b163078a6c..930b2061f1 100644
--- a/tools/mtmd/clip.cpp
+++ b/tools/mtmd/clip.cpp
@@ -3854,19 +3854,6 @@ bool clip_is_glm(const struct clip_ctx * ctx) {
return ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE;
}
-bool clip_is_mrope(const struct clip_ctx * ctx) {
- switch (ctx->proj_type()) {
- case PROJECTOR_TYPE_QWEN2VL:
- case PROJECTOR_TYPE_QWEN25VL:
- case PROJECTOR_TYPE_QWEN3VL:
- case PROJECTOR_TYPE_GLM4V:
- case PROJECTOR_TYPE_PADDLEOCR:
- return true;
- default:
- return false;
- }
-}
-
bool clip_is_llava(const struct clip_ctx * ctx) {
return ctx->model.hparams.has_llava_projector;
}
diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h
index 79df0136ba..27ee020182 100644
--- a/tools/mtmd/clip.h
+++ b/tools/mtmd/clip.h
@@ -104,7 +104,6 @@ bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct
int clip_is_minicpmv(const struct clip_ctx * ctx);
bool clip_is_glm(const struct clip_ctx * ctx);
-bool clip_is_mrope(const struct clip_ctx * ctx);
bool clip_is_llava(const struct clip_ctx * ctx);
// note for contributor: this clip_is_(model) pattern is deprecated
// do NOT add new functions like this
diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp
index 8fb6fc49c7..6610fc5df7 100644
--- a/tools/mtmd/mtmd.cpp
+++ b/tools/mtmd/mtmd.cpp
@@ -146,8 +146,6 @@ struct mtmd_context {
bool tok_row_end_trail = false;
bool ov_img_first = false;
- bool use_mrope = false; // for Qwen2VL, we need to use M-RoPE
-
// string template for slice image delimiters with row/col (idefics3)
std::string sli_img_start_tmpl;
@@ -217,7 +215,6 @@ struct mtmd_context {
void init_vision() {
GGML_ASSERT(ctx_v != nullptr);
- use_mrope = clip_is_mrope(ctx_v);
projector_type proj = clip_get_projector_type(ctx_v);
int minicpmv_version = clip_is_minicpmv(ctx_v);
@@ -631,7 +628,7 @@ struct mtmd_tokenizer {
}
mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
- if (ctx->use_mrope) {
+ if (mtmd_decode_use_mrope(ctx)) {
// for Qwen2VL, we need this information for M-RoPE decoding positions
image_tokens->nx = clip_n_output_tokens_x(ctx->ctx_v, batch_f32.entries[0].get());
image_tokens->ny = clip_n_output_tokens_y(ctx->ctx_v, batch_f32.entries[0].get());
@@ -867,10 +864,7 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
switch (ctx->proj_type_v()) {
- case PROJECTOR_TYPE_QWEN2VL:
- case PROJECTOR_TYPE_QWEN25VL:
- case PROJECTOR_TYPE_QWEN3VL:
- case PROJECTOR_TYPE_YOUTUVL:
+ case PROJECTOR_TYPE_GEMMA3:
return true;
default:
return false;
@@ -878,7 +872,15 @@ bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
}
bool mtmd_decode_use_mrope(mtmd_context * ctx) {
- return ctx->use_mrope;
+ switch (ctx->proj_type_v()) {
+ case PROJECTOR_TYPE_QWEN2VL:
+ case PROJECTOR_TYPE_QWEN25VL:
+ case PROJECTOR_TYPE_QWEN3VL:
+ case PROJECTOR_TYPE_GLM4V:
+ return true;
+ default:
+ return false;
+ }
}
bool mtmd_support_vision(mtmd_context * ctx) {