diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 69abb7367d..f893b24c75 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1696,6 +1696,84 @@ class TextModel(ModelBase): if template is not None: self.gguf_writer.add_chat_template(template) + def _set_vocab_plamo(self): + # PLaMo models use a custom tokenizer with a .jsonl file + tokenizer_jsonl_path = self.dir_model / "tokenizer.jsonl" + tokenizer_config_path = self.dir_model / "tokenizer_config.json" + + if not tokenizer_jsonl_path.is_file(): + raise FileNotFoundError(f"PLaMo tokenizer file not found: {tokenizer_jsonl_path}") + + # Load tokenizer config + with open(tokenizer_config_path, "r", encoding="utf-8") as f: + tokenizer_config = json.load(f) + + # Load tokens from JSONL file (actually a list format) + tokens = [] + scores = [] + toktypes = [] + + with open(tokenizer_jsonl_path, "r", encoding="utf-8") as f: + for line_num, line in enumerate(f): + if line.strip(): + token_data = json.loads(line) + # Format: [token, score, type, ?, ?, ?, ?] + token = token_data[0].encode("utf-8") + score = float(token_data[1]) + token_type_str = token_data[2] if len(token_data) > 2 else "NORMAL" + + tokens.append(token) + scores.append(score) + + if token_type_str == "UNKNOWN": + toktypes.append(gguf.TokenType.UNKNOWN) + elif token_type_str == "CONTROL": + toktypes.append(gguf.TokenType.CONTROL) + elif token_type_str == "BYTE": + toktypes.append(gguf.TokenType.BYTE) + else: + token_str = token_data[0] + if token_str.startswith("<|plamo:") and token_str.endswith("|>"): + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.NORMAL) + + vocab_size = self.hparams["vocab_size"] + if vocab_size > len(tokens): + pad_count = vocab_size - len(tokens) + logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") + for i in range(1, pad_count + 1): + tokens.append(bytes(f"[PAD{i}]", encoding="utf-8")) + scores.append(-1000.0) + toktypes.append(gguf.TokenType.UNUSED) + + self.gguf_writer.add_tokenizer_model("plamo2") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + if "bos_token" in tokenizer_config and tokenizer_config["bos_token"] is not None: + token_id = tokens.index(tokenizer_config["bos_token"].encode("utf-8")) + self.gguf_writer.add_bos_token_id(token_id) + if "eos_token" in tokenizer_config and tokenizer_config["eos_token"] is not None: + token_id = tokens.index(tokenizer_config["eos_token"].encode("utf-8")) + self.gguf_writer.add_eos_token_id(token_id) + if "pad_token" in tokenizer_config and tokenizer_config["pad_token"] is not None: + token_id = tokens.index(tokenizer_config["pad_token"].encode("utf-8")) + self.gguf_writer.add_pad_token_id(token_id) + if "sep_token" in tokenizer_config and tokenizer_config["sep_token"] is not None: + token_id = tokens.index(tokenizer_config["sep_token"].encode("utf-8")) + self.gguf_writer.add_sep_token_id(token_id) + if "unk_token" in tokenizer_config and tokenizer_config["unk_token"] is not None: + token_id = tokens.index(tokenizer_config["unk_token"].encode("utf-8")) + self.gguf_writer.add_unk_token_id(token_id) + + # Add <|plamo:op|> as EOT to ensure appropriate end of generation + self.gguf_writer.add_eot_token_id(4) + + self.gguf_writer.add_add_space_prefix(False) + class MmprojModel(ModelBase): model_type = ModelType.MMPROJ @@ -4798,87 +4876,7 @@ class Plamo2Model(TextModel): model_arch = gguf.MODEL_ARCH.PLAMO2 def set_vocab(self): - # PLaMo 2 uses a custom tokenizer with a .jsonl file - # We need to handle this specially - tokenizer_jsonl_path = self.dir_model / "tokenizer.jsonl" - tokenizer_config_path = self.dir_model / "tokenizer_config.json" - - if not tokenizer_jsonl_path.is_file(): - raise FileNotFoundError(f"PLaMo 2 tokenizer file not found: {tokenizer_jsonl_path}") - - # Load tokenizer config - with open(tokenizer_config_path, 'r', encoding='utf-8') as f: - tokenizer_config = json.load(f) - - # Load tokens from JSONL file (actually a list format) - tokens = [] - scores = [] - toktypes = [] - - with open(tokenizer_jsonl_path, 'r', encoding='utf-8') as f: - for line_num, line in enumerate(f): - if line.strip(): - token_data = json.loads(line) - # Format: [token, score, type, ?, ?, ?, ?] - token = token_data[0].encode("utf-8") - score = float(token_data[1]) - token_type_str = token_data[2] if len(token_data) > 2 else "NORMAL" - - tokens.append(token) - scores.append(score) - - # Map token type strings to GGUF token types - if token_type_str == "UNKNOWN": - toktypes.append(gguf.TokenType.UNKNOWN) - elif token_type_str == "CONTROL": - toktypes.append(gguf.TokenType.CONTROL) - elif token_type_str == "BYTE": - toktypes.append(gguf.TokenType.BYTE) - else: - # Check for PLaMo-2 special tokens - token_str = token_data[0] - if token_str.startswith("<|plamo:") and token_str.endswith("|>"): - toktypes.append(gguf.TokenType.CONTROL) - else: - toktypes.append(gguf.TokenType.NORMAL) - - vocab_size = self.hparams["vocab_size"] - if vocab_size > len(tokens): - pad_count = vocab_size - len(tokens) - logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") - for i in range(1, pad_count + 1): - tokens.append(bytes(f"[PAD{i}]", encoding="utf-8")) - scores.append(-1000.0) - toktypes.append(gguf.TokenType.UNUSED) - - # Use "plamo2" tokenizer type for PLaMo-2's custom Aho-Corasick tokenizer - self.gguf_writer.add_tokenizer_model("plamo2") - self.gguf_writer.add_tokenizer_pre("default") - self.gguf_writer.add_token_list(tokens) - self.gguf_writer.add_token_scores(scores) - self.gguf_writer.add_token_types(toktypes) - - # Add special tokens from config - if "bos_token" in tokenizer_config and tokenizer_config["bos_token"] is not None: - token_id = tokens.index(tokenizer_config["bos_token"].encode("utf-8")) - self.gguf_writer.add_bos_token_id(token_id) - if "eos_token" in tokenizer_config and tokenizer_config["eos_token"] is not None: - token_id = tokens.index(tokenizer_config["eos_token"].encode("utf-8")) - self.gguf_writer.add_eos_token_id(token_id) - if "pad_token" in tokenizer_config and tokenizer_config["pad_token"] is not None: - token_id = tokens.index(tokenizer_config["pad_token"].encode("utf-8")) - self.gguf_writer.add_pad_token_id(token_id) - if "sep_token" in tokenizer_config and tokenizer_config["sep_token"] is not None: - token_id = tokens.index(tokenizer_config["sep_token"].encode("utf-8")) - self.gguf_writer.add_sep_token_id(token_id) - if "unk_token" in tokenizer_config and tokenizer_config["unk_token"] is not None: - token_id = tokens.index(tokenizer_config["unk_token"].encode("utf-8")) - self.gguf_writer.add_unk_token_id(token_id) - - # Add <|plamo:op|> as EOT to ensure appropriate end of generation - self.gguf_writer.add_eot_token_id(4) - - self.gguf_writer.add_add_space_prefix(False) + self._set_vocab_plamo() def set_gguf_parameters(self): hparams = self.hparams @@ -4966,6 +4964,56 @@ class Plamo2Model(TextModel): return [(new_name, data_torch)] +@ModelBase.register("Plamo3ForCausalLM", "PLaMo3ForCausalLM") +class Plamo3Model(TextModel): + model_arch = gguf.MODEL_ARCH.PLAMO3 + + def set_vocab(self): + self._set_vocab_plamo() + + tokenizer_config_path = self.dir_model / "tokenizer_config.json" + tokenizer_config = {} + + if tokenizer_config_path.is_file(): + with open(tokenizer_config_path, encoding="utf-8") as f: + tokenizer_config = json.load(f) + + chat_template = tokenizer_config.get("chat_template") + chat_template_jinja = self.dir_model / "chat_template.jinja" + + if chat_template_jinja.is_file(): + with open(chat_template_jinja, encoding="utf-8") as f: + chat_template = f.read() + + if chat_template: + self.gguf_writer.add_chat_template(chat_template) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + if (sliding_window := self.find_hparam(["window_size", "sliding_window"], optional=True)) is not None: + self.gguf_writer.add_sliding_window(sliding_window) + self.gguf_writer.add_sliding_window_pattern(self.hparams["sliding_window_pattern"]) + self.gguf_writer.add_rope_freq_base_swa(self.rope_parameters.get("sliding_attention", {"rope_theta": self.hparams.get("rope_local_theta")})["rope_theta"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + + if name.endswith(".pre_mixer_norm.weight"): + data_torch = data_torch + 1.0 + elif name.endswith(".post_mixer_norm.weight"): + data_torch = data_torch + 1.0 / 5 + elif name.endswith(".pre_mlp_norm.weight"): + data_torch = data_torch + 1.0 + elif name.endswith(".post_mlp_norm.weight"): + data_torch = data_torch + 1.0 / (5**1.5) + elif name.endswith((".mixer.q_norm.weight", ".mixer.k_norm.weight")): + data_torch = data_torch + 1.0 + elif name.endswith(".norm.weight"): + data_torch = data_torch + 1.0 + + return [(self.map_tensor_name(name), data_torch)] + + @ModelBase.register("CodeShellForCausalLM") class CodeShellModel(TextModel): model_arch = gguf.MODEL_ARCH.CODESHELL diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 27578daaf9..616b8add36 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -377,6 +377,7 @@ class MODEL_ARCH(IntEnum): PHIMOE = auto() PLAMO = auto() PLAMO2 = auto() + PLAMO3 = auto() CODESHELL = auto() ORION = auto() INTERNLM2 = auto() @@ -773,6 +774,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.PHIMOE: "phimoe", MODEL_ARCH.PLAMO: "plamo", MODEL_ARCH.PLAMO2: "plamo2", + MODEL_ARCH.PLAMO3: "plamo3", MODEL_ARCH.CODESHELL: "codeshell", MODEL_ARCH.ORION: "orion", MODEL_ARCH.INTERNLM2: "internlm2", @@ -1763,6 +1765,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.SSM_B_NORM, MODEL_TENSOR.SSM_C_NORM, ], + MODEL_ARCH.PLAMO3: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_POST_NORM, + ], MODEL_ARCH.GPT2: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.POS_EMBD, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 1690d991f2..115df6c7c3 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -595,6 +595,7 @@ class TensorNameMap: "encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2 "transformer.layers.{bid}.attn.q_norm", # openelm "model.layers.layers.{bid}.mixer.q", # plamo2 + "model.layers.layers.{bid}.mixer.q_norm", # plamo3 "layers.{bid}.self_attn.q_norm", # qwen3-embedding "model.layers.{bid}.attention.query_layernorm", # apertus ), @@ -610,6 +611,7 @@ class TensorNameMap: "encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2 "transformer.layers.{bid}.attn.k_norm", # openelm "model.layers.layers.{bid}.mixer.k", # plamo2 + "model.layers.layers.{bid}.mixer.k_norm", # plamo3 "layers.{bid}.self_attn.k_norm", # qwen3-embedding "model.layers.{bid}.attention.key_layernorm", # apertus ), diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 1e155534bd..762ea65c71 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -107,6 +107,7 @@ add_library(llama models/phi3.cpp models/plamo.cpp models/plamo2.cpp + models/plamo3.cpp models/plm.cpp models/qwen.cpp models/qwen2.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 75013d8d33..94a6807eac 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -42,6 +42,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_PHIMOE, "phimoe" }, { LLM_ARCH_PLAMO, "plamo" }, { LLM_ARCH_PLAMO2, "plamo2" }, + { LLM_ARCH_PLAMO3, "plamo3" }, { LLM_ARCH_CODESHELL, "codeshell" }, { LLM_ARCH_ORION, "orion" }, { LLM_ARCH_INTERNLM2, "internlm2" }, @@ -1077,6 +1078,22 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_ATTN_POST_NORM, LLM_TENSOR_FFN_POST_NORM, }; + case LLM_ARCH_PLAMO3: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_POST_NORM, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; case LLM_ARCH_CODESHELL: return { LLM_TENSOR_TOKEN_EMBD, diff --git a/src/llama-arch.h b/src/llama-arch.h index 27bdedc83c..714ead4025 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -46,6 +46,7 @@ enum llm_arch { LLM_ARCH_PHIMOE, LLM_ARCH_PLAMO, LLM_ARCH_PLAMO2, + LLM_ARCH_PLAMO3, LLM_ARCH_CODESHELL, LLM_ARCH_ORION, LLM_ARCH_INTERNLM2, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 1d6134ec05..5e664c8c57 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1227,6 +1227,26 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); } break; + case LLM_ARCH_PLAMO3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + uint32_t swa_period = 8; + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.rope_freq_scale_train_swa = 1.0f; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_2B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_GPT2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -3828,6 +3848,44 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); } } break; + case LLM_ARCH_PLAMO3: + { + const int64_t head_dim_q = hparams.n_embd_head_k; + const int64_t head_dim_v = hparams.n_embd_head_v; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + const int64_t num_attention_heads = hparams.n_head(i); + const int64_t num_key_value_heads = hparams.n_head_kv(i); + const int64_t q_proj_dim = num_attention_heads * head_dim_q; + const int64_t k_proj_dim = num_key_value_heads * head_dim_q; + const int64_t v_proj_dim = num_key_value_heads * head_dim_v; + const int64_t n_ff_cur = hparams.n_ff(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), + {n_embd,q_proj_dim + k_proj_dim + v_proj_dim}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim_q}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim_q}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {num_attention_heads * head_dim_v, n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur * 2}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0); + } + } break; case LLM_ARCH_GPT2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -7473,6 +7531,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_PLAMO3: + { + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + llm = std::make_unique> (*this, params); + } else { + llm = std::make_unique>(*this, params); + } + } break; case LLM_ARCH_GPT2: { llm = std::make_unique(*this, params); @@ -7977,6 +8043,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_PHIMOE: case LLM_ARCH_PLAMO: case LLM_ARCH_PLAMO2: + case LLM_ARCH_PLAMO3: case LLM_ARCH_GEMMA: case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA3: diff --git a/src/models/models.h b/src/models/models.h index dd0e286eda..e2cd4e484f 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -406,6 +406,11 @@ struct llm_build_plamo : public llm_graph_context { llm_build_plamo(const llama_model & model, const llm_graph_params & params); }; +template +struct llm_build_plamo3 : public llm_graph_context { + llm_build_plamo3(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_plm : public llm_graph_context { llm_build_plm(const llama_model & model, const llm_graph_params & params); }; diff --git a/src/models/plamo3.cpp b/src/models/plamo3.cpp new file mode 100644 index 0000000000..55c8064679 --- /dev/null +++ b/src/models/plamo3.cpp @@ -0,0 +1,128 @@ +#include "models.h" + +template +llm_build_plamo3::llm_build_plamo3(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + const int64_t head_dim_q = hparams.n_embd_head_k; + const int64_t head_dim_v = hparams.n_embd_head_v; + + ggml_tensor * cur; + ggml_tensor * inpL = build_inp_embd(model.tok_embd); + ggml_tensor * inp_pos = build_inp_pos(); + + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_iswa(); + } else { + inp_attn = build_attn_inp_kv(); + } + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * residual = inpL; + + float freq_base_l = 0.0f; + float freq_scale_l = 0.0f; + if constexpr (iswa) { + freq_base_l = model.get_rope_freq_base (cparams, il); + freq_scale_l = model.get_rope_freq_scale(cparams, il); + } else { + freq_base_l = freq_base; + freq_scale_l = freq_scale; + } + + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + const int32_t n_head = hparams.n_head(il); + const int32_t n_head_kv = hparams.n_head_kv(il); + + const int64_t q_offset = 0; + const int64_t k_offset = head_dim_q * n_head; + const int64_t v_offset = k_offset + head_dim_q * n_head_kv; + + ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, head_dim_q, n_head, n_tokens, + head_dim_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, head_dim_q, n_head_kv, n_tokens, + head_dim_q * sizeof(float), qkv->nb[1], k_offset * ggml_element_size(qkv)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, qkv, head_dim_v, n_head_kv, n_tokens, + head_dim_v * sizeof(float), qkv->nb[1], v_offset * ggml_element_size(qkv)); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "attn_q_norm", il); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "attn_k_norm", il); + + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float attn_scale = 1.0f / sqrtf(float(head_dim_q)); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, attn_scale, il); + cb(cur, "attn_out", il); + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + residual = ggml_get_rows(ctx0, residual, inp_out_ids); + } + + cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + cur = ggml_add(ctx0, cur, residual); + cb(cur, "attn_residual", il); + + residual = cur; + + cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_post_norm", il); + + cur = ggml_add(ctx0, cur, residual); + cb(cur, "ffn_residual", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +// Explicit template instantiations +template struct llm_build_plamo3; +template struct llm_build_plamo3;