From f11a0241b91a96bbab993dec7c2f3abf85fd7c9c Mon Sep 17 00:00:00 2001 From: sumitchatterjee13 Date: Mon, 9 Mar 2026 20:48:56 +1100 Subject: [PATCH] model: add sarvam_moe architecture support --- convert_hf_to_gguf.py | 109 ++++++++++++++++++++++++++++++ gguf-py/gguf/constants.py | 24 +++++++ src/CMakeLists.txt | 1 + src/llama-arch.cpp | 24 +++++++ src/llama-arch.h | 1 + src/llama-model.cpp | 78 ++++++++++++++++++++++ src/llama-vocab.cpp | 10 +++ src/llama-vocab.h | 1 + src/models/models.h | 4 ++ src/models/sarvam-moe.cpp | 135 ++++++++++++++++++++++++++++++++++++++ 10 files changed, 387 insertions(+) create mode 100644 src/models/sarvam-moe.cpp diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 083b5bca9e..e180ab80d8 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1307,6 +1307,9 @@ class TextModel(ModelBase): if chkhsh == "e4d54df1ebc1f2b91acd986c5b51aa50837d5faf7c7398e73c1f9e9ee5d19869": # ref: https://huggingface.co/kakaocorp/kanana-2-30b-a3b-instruct-2601 res = "kanana2" + if chkhsh == "62f6fb0a6fd5098caeabb19b07a5c1099cafc8b9c40eab6ea89ece4ec02fbc57": + # ref: https://huggingface.co/sarvamai/sarvam-30b + res = "sarvam-moe" if res is None: logger.warning("\n") @@ -10034,6 +10037,112 @@ class BailingMoeV2Model(TextModel): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("SarvamMoEForCausalLM", "modeling_sarvam_moe.SarvamMoEForCausalLM") +class SarvamMoEModel(TextModel): + model_arch = gguf.MODEL_ARCH.SARVAM_MOE + + @staticmethod + def _build_gpt2_byte_encoder() -> dict[int, str]: + """Build GPT-2 bytes_to_unicode mapping (cached on first call).""" + bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + cs = bs[:] + n = 0 + for b in range(256): + if b not in bs: + bs.append(b) + cs.append(256 + n) + n += 1 + return dict(zip(bs, [chr(c) for c in cs])) + + def _sp_to_gpt2(self, token: str) -> str: + """Convert SentencePiece-style token (▁ = space) to GPT-2 byte-level encoding.""" + if not hasattr(self, '_byte_encoder'): + self._byte_encoder = self._build_gpt2_byte_encoder() + token = token.replace("\u2581", " ") + return "".join(self._byte_encoder[b] for b in token.encode("utf-8")) + + def set_vocab(self): + # Sarvam uses SentencePiece-style BPE (▁ as space) but llama.cpp's BPE + # expects GPT-2 byte-level encoding. Convert tokens and merges. + tokens, toktypes, tokpre = self.get_vocab_base() + for i, toktype in enumerate(toktypes): + if toktype == gguf.TokenType.NORMAL: + tokens[i] = self._sp_to_gpt2(tokens[i]) + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + # Convert merges from SentencePiece to GPT-2 encoding + special_vocab.merges = [ + " ".join(self._sp_to_gpt2(part) for part in merge.split(" ")) + for merge in special_vocab.merges + ] + special_vocab.add_to_gguf(self.gguf_writer) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + if (rope_dim := hparams.get("head_dim")) is None: + rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] + + # Sarvam uses full rotary embedding (no partial_rotary_factor) + self.gguf_writer.add_rope_dimension_count(rope_dim) + self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) + self.gguf_writer.add_expert_shared_feed_forward_length(hparams.get("moe_shared_expert_intermediate_size", hparams["moe_intermediate_size"] * hparams["num_shared_experts"])) + self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) + self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"]) + self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if "mlp.experts" in name: + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + # merge the experts into a single 3d tensor + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + yield from super().modify_tensors(data_torch, merged_name, bid) + return + + if name.endswith(".expert_bias"): + # Zero-mean normalization for expert bias (Sarvam-specific) + data_torch = data_torch - data_torch.mean() + name = name.replace(".expert_bias", ".expert_bias.bias") + + yield from super().modify_tensors(data_torch, name, bid) + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + @ModelBase.register("GroveMoeForCausalLM", "modeling_grove_moe.GroveMoeForCausalLM") class GroveMoeModel(TextModel): model_arch = gguf.MODEL_ARCH.GROVEMOE diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 839c6e787f..a2b50c9539 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -451,6 +451,7 @@ class MODEL_ARCH(IntEnum): PLM = auto() BAILINGMOE = auto() BAILINGMOE2 = auto() + SARVAM_MOE = auto() DOTS1 = auto() ARCEE = auto() AFMOE = auto() @@ -894,6 +895,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.PLM: "plm", MODEL_ARCH.BAILINGMOE: "bailingmoe", MODEL_ARCH.BAILINGMOE2: "bailingmoe2", + MODEL_ARCH.SARVAM_MOE: "sarvam_moe", MODEL_ARCH.DOTS1: "dots1", MODEL_ARCH.ARCEE: "arcee", MODEL_ARCH.AFMOE: "afmoe", @@ -3128,6 +3130,28 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, MODEL_TENSOR.LAYER_OUT_NORM, ], + MODEL_ARCH.SARVAM_MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], MODEL_ARCH.DOTS1: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 283823fa9c..b50b175126 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -44,6 +44,7 @@ add_library(llama models/baichuan.cpp models/bailingmoe.cpp models/bailingmoe2.cpp + models/sarvam-moe.cpp models/bert.cpp models/bitnet.cpp models/bloom.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 9d8eb88d0b..773e240d58 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -100,6 +100,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_PLM, "plm" }, { LLM_ARCH_BAILINGMOE, "bailingmoe" }, { LLM_ARCH_BAILINGMOE2, "bailingmoe2" }, + { LLM_ARCH_SARVAM_MOE, "sarvam_moe" }, { LLM_ARCH_DOTS1, "dots1" }, { LLM_ARCH_ARCEE, "arcee" }, { LLM_ARCH_AFMOE, "afmoe" }, @@ -2173,6 +2174,29 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, LLM_TENSOR_LAYER_OUT_NORM, }; + case LLM_ARCH_SARVAM_MOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + }; case LLM_ARCH_DOTS1: return { LLM_TENSOR_TOKEN_EMBD, diff --git a/src/llama-arch.h b/src/llama-arch.h index 07aac40aa1..a07ad7e015 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -104,6 +104,7 @@ enum llm_arch { LLM_ARCH_PLM, LLM_ARCH_BAILINGMOE, LLM_ARCH_BAILINGMOE2, + LLM_ARCH_SARVAM_MOE, LLM_ARCH_DOTS1, LLM_ARCH_ARCEE, LLM_ARCH_AFMOE, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e18cca0524..839ccfee6b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2109,6 +2109,22 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_SARVAM_MOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + switch (hparams.n_layer) { + case 19: type = LLM_TYPE_30B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_DOTS1: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -6247,6 +6263,53 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } } break; + case LLM_ARCH_SARVAM_MOE: + { + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for sarvam_moe"); + GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for sarvam_moe"); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (static_cast(i) >= hparams.n_layer_dense_lead) { // MoE layers + const int64_t n_ff_shexp = (hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff_exp) * n_expert_shared; + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + } else { // Dense layers + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } + } break; case LLM_ARCH_DOTS1: { const int64_t n_ff_exp = hparams.n_ff_exp; @@ -7812,6 +7875,16 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); } + if (arch == LLM_ARCH_SARVAM_MOE) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + } + if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); @@ -8455,6 +8528,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_SARVAM_MOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_SEED_OSS: { llm = std::make_unique(*this, params); @@ -8799,6 +8876,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_EXAONE_MOE: case LLM_ARCH_MINICPM3: case LLM_ARCH_BAILINGMOE2: + case LLM_ARCH_SARVAM_MOE: case LLM_ARCH_DOTS1: case LLM_ARCH_HUNYUAN_MOE: case LLM_ARCH_JAIS2: diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 68ba292d42..58eaa48d0a 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -289,6 +289,12 @@ struct llm_tokenizer_bpe : llm_tokenizer { "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE: + regex_exprs = { + // Split on " " with MergedWithPrevious: spaces attach to preceding token + "[^ ]+ |[^ ]+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_JAIS2: regex_exprs = { // original regex from tokenizer.json @@ -2090,6 +2096,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "solar-open") { pre_type = LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN; clean_spaces = false; + } else if ( + tokenizer_pre == "sarvam-moe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE; + clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } diff --git a/src/llama-vocab.h b/src/llama-vocab.h index be5b08012d..b4fbf120c5 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -58,6 +58,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47, LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48, LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49, + LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE = 50, }; struct LLM_KV; diff --git a/src/models/models.h b/src/models/models.h index cf9ba04e7f..1ea4f46014 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -122,6 +122,10 @@ struct llm_build_bailingmoe : public llm_graph_context { llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_sarvam_moe : public llm_graph_context { + llm_build_sarvam_moe(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_bert : public llm_graph_context { llm_build_bert(const llama_model & model, const llm_graph_params & params); }; diff --git a/src/models/sarvam-moe.cpp b/src/models/sarvam-moe.cpp new file mode 100644 index 0000000000..b2be70d919 --- /dev/null +++ b/src/models/sarvam-moe.cpp @@ -0,0 +1,135 @@ +#include "models.h" + +llm_build_sarvam_moe::llm_build_sarvam_moe(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), + cur->nb[1], 0 * sizeof(float) * (n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), + cur->nb[1], 1 * sizeof(float) * (n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), + cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * sa_out = ggml_add(ctx0, cur, inpSA); + cb(sa_out, "sa_out", il); + + // FFN block + cur = build_norm(sa_out, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + if (static_cast(il) < hparams.n_layer_dense_lead) { + // Dense FFN for early layers + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE FFN for remaining layers + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "ffn_moe_out", il); + + // Shared expert + { + ggml_tensor * ffn_shexp = + build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + + cur = ggml_add(ctx0, cur, sa_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +}