model : Plamo3 support (#17304)
* plamo3 * fix plamo3 * clean code * clean up the code * fix diff * clean up the code * clean up the code * clean up the code * clean up the code * clean up the code * clean up the code * add chat_template if exist * clean up the code * fix cpu-backend * chore: whitespace trim fix + typo fix * Fix: address review feedback * restore `FREQ_BASE_SWA` constant * Fix: address review feedback2 * Fix:typecheck * Fix: address review feedback3 * final cleanup --------- Co-authored-by: mmngays <146910567+mmngays@users.noreply.github.com> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
parent
07a0c4ba92
commit
9c675c7140
|
|
@ -1696,6 +1696,84 @@ class TextModel(ModelBase):
|
|||
if template is not None:
|
||||
self.gguf_writer.add_chat_template(template)
|
||||
|
||||
def _set_vocab_plamo(self):
|
||||
# PLaMo models use a custom tokenizer with a .jsonl file
|
||||
tokenizer_jsonl_path = self.dir_model / "tokenizer.jsonl"
|
||||
tokenizer_config_path = self.dir_model / "tokenizer_config.json"
|
||||
|
||||
if not tokenizer_jsonl_path.is_file():
|
||||
raise FileNotFoundError(f"PLaMo tokenizer file not found: {tokenizer_jsonl_path}")
|
||||
|
||||
# Load tokenizer config
|
||||
with open(tokenizer_config_path, "r", encoding="utf-8") as f:
|
||||
tokenizer_config = json.load(f)
|
||||
|
||||
# Load tokens from JSONL file (actually a list format)
|
||||
tokens = []
|
||||
scores = []
|
||||
toktypes = []
|
||||
|
||||
with open(tokenizer_jsonl_path, "r", encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f):
|
||||
if line.strip():
|
||||
token_data = json.loads(line)
|
||||
# Format: [token, score, type, ?, ?, ?, ?]
|
||||
token = token_data[0].encode("utf-8")
|
||||
score = float(token_data[1])
|
||||
token_type_str = token_data[2] if len(token_data) > 2 else "NORMAL"
|
||||
|
||||
tokens.append(token)
|
||||
scores.append(score)
|
||||
|
||||
if token_type_str == "UNKNOWN":
|
||||
toktypes.append(gguf.TokenType.UNKNOWN)
|
||||
elif token_type_str == "CONTROL":
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
elif token_type_str == "BYTE":
|
||||
toktypes.append(gguf.TokenType.BYTE)
|
||||
else:
|
||||
token_str = token_data[0]
|
||||
if token_str.startswith("<|plamo:") and token_str.endswith("|>"):
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
|
||||
vocab_size = self.hparams["vocab_size"]
|
||||
if vocab_size > len(tokens):
|
||||
pad_count = vocab_size - len(tokens)
|
||||
logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
|
||||
for i in range(1, pad_count + 1):
|
||||
tokens.append(bytes(f"[PAD{i}]", encoding="utf-8"))
|
||||
scores.append(-1000.0)
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("plamo2")
|
||||
self.gguf_writer.add_tokenizer_pre("default")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
if "bos_token" in tokenizer_config and tokenizer_config["bos_token"] is not None:
|
||||
token_id = tokens.index(tokenizer_config["bos_token"].encode("utf-8"))
|
||||
self.gguf_writer.add_bos_token_id(token_id)
|
||||
if "eos_token" in tokenizer_config and tokenizer_config["eos_token"] is not None:
|
||||
token_id = tokens.index(tokenizer_config["eos_token"].encode("utf-8"))
|
||||
self.gguf_writer.add_eos_token_id(token_id)
|
||||
if "pad_token" in tokenizer_config and tokenizer_config["pad_token"] is not None:
|
||||
token_id = tokens.index(tokenizer_config["pad_token"].encode("utf-8"))
|
||||
self.gguf_writer.add_pad_token_id(token_id)
|
||||
if "sep_token" in tokenizer_config and tokenizer_config["sep_token"] is not None:
|
||||
token_id = tokens.index(tokenizer_config["sep_token"].encode("utf-8"))
|
||||
self.gguf_writer.add_sep_token_id(token_id)
|
||||
if "unk_token" in tokenizer_config and tokenizer_config["unk_token"] is not None:
|
||||
token_id = tokens.index(tokenizer_config["unk_token"].encode("utf-8"))
|
||||
self.gguf_writer.add_unk_token_id(token_id)
|
||||
|
||||
# Add <|plamo:op|> as EOT to ensure appropriate end of generation
|
||||
self.gguf_writer.add_eot_token_id(4)
|
||||
|
||||
self.gguf_writer.add_add_space_prefix(False)
|
||||
|
||||
|
||||
class MmprojModel(ModelBase):
|
||||
model_type = ModelType.MMPROJ
|
||||
|
|
@ -4798,87 +4876,7 @@ class Plamo2Model(TextModel):
|
|||
model_arch = gguf.MODEL_ARCH.PLAMO2
|
||||
|
||||
def set_vocab(self):
|
||||
# PLaMo 2 uses a custom tokenizer with a .jsonl file
|
||||
# We need to handle this specially
|
||||
tokenizer_jsonl_path = self.dir_model / "tokenizer.jsonl"
|
||||
tokenizer_config_path = self.dir_model / "tokenizer_config.json"
|
||||
|
||||
if not tokenizer_jsonl_path.is_file():
|
||||
raise FileNotFoundError(f"PLaMo 2 tokenizer file not found: {tokenizer_jsonl_path}")
|
||||
|
||||
# Load tokenizer config
|
||||
with open(tokenizer_config_path, 'r', encoding='utf-8') as f:
|
||||
tokenizer_config = json.load(f)
|
||||
|
||||
# Load tokens from JSONL file (actually a list format)
|
||||
tokens = []
|
||||
scores = []
|
||||
toktypes = []
|
||||
|
||||
with open(tokenizer_jsonl_path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(f):
|
||||
if line.strip():
|
||||
token_data = json.loads(line)
|
||||
# Format: [token, score, type, ?, ?, ?, ?]
|
||||
token = token_data[0].encode("utf-8")
|
||||
score = float(token_data[1])
|
||||
token_type_str = token_data[2] if len(token_data) > 2 else "NORMAL"
|
||||
|
||||
tokens.append(token)
|
||||
scores.append(score)
|
||||
|
||||
# Map token type strings to GGUF token types
|
||||
if token_type_str == "UNKNOWN":
|
||||
toktypes.append(gguf.TokenType.UNKNOWN)
|
||||
elif token_type_str == "CONTROL":
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
elif token_type_str == "BYTE":
|
||||
toktypes.append(gguf.TokenType.BYTE)
|
||||
else:
|
||||
# Check for PLaMo-2 special tokens
|
||||
token_str = token_data[0]
|
||||
if token_str.startswith("<|plamo:") and token_str.endswith("|>"):
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
|
||||
vocab_size = self.hparams["vocab_size"]
|
||||
if vocab_size > len(tokens):
|
||||
pad_count = vocab_size - len(tokens)
|
||||
logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
|
||||
for i in range(1, pad_count + 1):
|
||||
tokens.append(bytes(f"[PAD{i}]", encoding="utf-8"))
|
||||
scores.append(-1000.0)
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
|
||||
# Use "plamo2" tokenizer type for PLaMo-2's custom Aho-Corasick tokenizer
|
||||
self.gguf_writer.add_tokenizer_model("plamo2")
|
||||
self.gguf_writer.add_tokenizer_pre("default")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
# Add special tokens from config
|
||||
if "bos_token" in tokenizer_config and tokenizer_config["bos_token"] is not None:
|
||||
token_id = tokens.index(tokenizer_config["bos_token"].encode("utf-8"))
|
||||
self.gguf_writer.add_bos_token_id(token_id)
|
||||
if "eos_token" in tokenizer_config and tokenizer_config["eos_token"] is not None:
|
||||
token_id = tokens.index(tokenizer_config["eos_token"].encode("utf-8"))
|
||||
self.gguf_writer.add_eos_token_id(token_id)
|
||||
if "pad_token" in tokenizer_config and tokenizer_config["pad_token"] is not None:
|
||||
token_id = tokens.index(tokenizer_config["pad_token"].encode("utf-8"))
|
||||
self.gguf_writer.add_pad_token_id(token_id)
|
||||
if "sep_token" in tokenizer_config and tokenizer_config["sep_token"] is not None:
|
||||
token_id = tokens.index(tokenizer_config["sep_token"].encode("utf-8"))
|
||||
self.gguf_writer.add_sep_token_id(token_id)
|
||||
if "unk_token" in tokenizer_config and tokenizer_config["unk_token"] is not None:
|
||||
token_id = tokens.index(tokenizer_config["unk_token"].encode("utf-8"))
|
||||
self.gguf_writer.add_unk_token_id(token_id)
|
||||
|
||||
# Add <|plamo:op|> as EOT to ensure appropriate end of generation
|
||||
self.gguf_writer.add_eot_token_id(4)
|
||||
|
||||
self.gguf_writer.add_add_space_prefix(False)
|
||||
self._set_vocab_plamo()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
|
|
@ -4966,6 +4964,56 @@ class Plamo2Model(TextModel):
|
|||
return [(new_name, data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register("Plamo3ForCausalLM", "PLaMo3ForCausalLM")
|
||||
class Plamo3Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.PLAMO3
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_plamo()
|
||||
|
||||
tokenizer_config_path = self.dir_model / "tokenizer_config.json"
|
||||
tokenizer_config = {}
|
||||
|
||||
if tokenizer_config_path.is_file():
|
||||
with open(tokenizer_config_path, encoding="utf-8") as f:
|
||||
tokenizer_config = json.load(f)
|
||||
|
||||
chat_template = tokenizer_config.get("chat_template")
|
||||
chat_template_jinja = self.dir_model / "chat_template.jinja"
|
||||
|
||||
if chat_template_jinja.is_file():
|
||||
with open(chat_template_jinja, encoding="utf-8") as f:
|
||||
chat_template = f.read()
|
||||
|
||||
if chat_template:
|
||||
self.gguf_writer.add_chat_template(chat_template)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
|
||||
if (sliding_window := self.find_hparam(["window_size", "sliding_window"], optional=True)) is not None:
|
||||
self.gguf_writer.add_sliding_window(sliding_window)
|
||||
self.gguf_writer.add_sliding_window_pattern(self.hparams["sliding_window_pattern"])
|
||||
self.gguf_writer.add_rope_freq_base_swa(self.rope_parameters.get("sliding_attention", {"rope_theta": self.hparams.get("rope_local_theta")})["rope_theta"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
|
||||
if name.endswith(".pre_mixer_norm.weight"):
|
||||
data_torch = data_torch + 1.0
|
||||
elif name.endswith(".post_mixer_norm.weight"):
|
||||
data_torch = data_torch + 1.0 / 5
|
||||
elif name.endswith(".pre_mlp_norm.weight"):
|
||||
data_torch = data_torch + 1.0
|
||||
elif name.endswith(".post_mlp_norm.weight"):
|
||||
data_torch = data_torch + 1.0 / (5**1.5)
|
||||
elif name.endswith((".mixer.q_norm.weight", ".mixer.k_norm.weight")):
|
||||
data_torch = data_torch + 1.0
|
||||
elif name.endswith(".norm.weight"):
|
||||
data_torch = data_torch + 1.0
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register("CodeShellForCausalLM")
|
||||
class CodeShellModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.CODESHELL
|
||||
|
|
|
|||
|
|
@ -377,6 +377,7 @@ class MODEL_ARCH(IntEnum):
|
|||
PHIMOE = auto()
|
||||
PLAMO = auto()
|
||||
PLAMO2 = auto()
|
||||
PLAMO3 = auto()
|
||||
CODESHELL = auto()
|
||||
ORION = auto()
|
||||
INTERNLM2 = auto()
|
||||
|
|
@ -773,6 +774,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.PHIMOE: "phimoe",
|
||||
MODEL_ARCH.PLAMO: "plamo",
|
||||
MODEL_ARCH.PLAMO2: "plamo2",
|
||||
MODEL_ARCH.PLAMO3: "plamo3",
|
||||
MODEL_ARCH.CODESHELL: "codeshell",
|
||||
MODEL_ARCH.ORION: "orion",
|
||||
MODEL_ARCH.INTERNLM2: "internlm2",
|
||||
|
|
@ -1763,6 +1765,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.SSM_B_NORM,
|
||||
MODEL_TENSOR.SSM_C_NORM,
|
||||
],
|
||||
MODEL_ARCH.PLAMO3: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_POST_NORM,
|
||||
],
|
||||
MODEL_ARCH.GPT2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.POS_EMBD,
|
||||
|
|
|
|||
|
|
@ -595,6 +595,7 @@ class TensorNameMap:
|
|||
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
|
||||
"transformer.layers.{bid}.attn.q_norm", # openelm
|
||||
"model.layers.layers.{bid}.mixer.q", # plamo2
|
||||
"model.layers.layers.{bid}.mixer.q_norm", # plamo3
|
||||
"layers.{bid}.self_attn.q_norm", # qwen3-embedding
|
||||
"model.layers.{bid}.attention.query_layernorm", # apertus
|
||||
),
|
||||
|
|
@ -610,6 +611,7 @@ class TensorNameMap:
|
|||
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
|
||||
"transformer.layers.{bid}.attn.k_norm", # openelm
|
||||
"model.layers.layers.{bid}.mixer.k", # plamo2
|
||||
"model.layers.layers.{bid}.mixer.k_norm", # plamo3
|
||||
"layers.{bid}.self_attn.k_norm", # qwen3-embedding
|
||||
"model.layers.{bid}.attention.key_layernorm", # apertus
|
||||
),
|
||||
|
|
|
|||
|
|
@ -107,6 +107,7 @@ add_library(llama
|
|||
models/phi3.cpp
|
||||
models/plamo.cpp
|
||||
models/plamo2.cpp
|
||||
models/plamo3.cpp
|
||||
models/plm.cpp
|
||||
models/qwen.cpp
|
||||
models/qwen2.cpp
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_PHIMOE, "phimoe" },
|
||||
{ LLM_ARCH_PLAMO, "plamo" },
|
||||
{ LLM_ARCH_PLAMO2, "plamo2" },
|
||||
{ LLM_ARCH_PLAMO3, "plamo3" },
|
||||
{ LLM_ARCH_CODESHELL, "codeshell" },
|
||||
{ LLM_ARCH_ORION, "orion" },
|
||||
{ LLM_ARCH_INTERNLM2, "internlm2" },
|
||||
|
|
@ -1077,6 +1078,22 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
|||
LLM_TENSOR_ATTN_POST_NORM,
|
||||
LLM_TENSOR_FFN_POST_NORM,
|
||||
};
|
||||
case LLM_ARCH_PLAMO3:
|
||||
return {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
LLM_TENSOR_OUTPUT,
|
||||
LLM_TENSOR_ATTN_NORM,
|
||||
LLM_TENSOR_ATTN_QKV,
|
||||
LLM_TENSOR_ATTN_Q_NORM,
|
||||
LLM_TENSOR_ATTN_K_NORM,
|
||||
LLM_TENSOR_ATTN_OUT,
|
||||
LLM_TENSOR_ATTN_POST_NORM,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
LLM_TENSOR_FFN_POST_NORM,
|
||||
LLM_TENSOR_FFN_DOWN,
|
||||
LLM_TENSOR_FFN_UP,
|
||||
};
|
||||
case LLM_ARCH_CODESHELL:
|
||||
return {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ enum llm_arch {
|
|||
LLM_ARCH_PHIMOE,
|
||||
LLM_ARCH_PLAMO,
|
||||
LLM_ARCH_PLAMO2,
|
||||
LLM_ARCH_PLAMO3,
|
||||
LLM_ARCH_CODESHELL,
|
||||
LLM_ARCH_ORION,
|
||||
LLM_ARCH_INTERNLM2,
|
||||
|
|
|
|||
|
|
@ -1227,6 +1227,26 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
|
||||
} break;
|
||||
case LLM_ARCH_PLAMO3:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
||||
if (found_swa && hparams.n_swa > 0) {
|
||||
uint32_t swa_period = 8;
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
hparams.rope_freq_scale_train_swa = 1.0f;
|
||||
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa);
|
||||
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false);
|
||||
hparams.set_swa_pattern(swa_period);
|
||||
} else {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
}
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 24: type = LLM_TYPE_2B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GPT2:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
|
|
@ -3828,6 +3848,44 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_PLAMO3:
|
||||
{
|
||||
const int64_t head_dim_q = hparams.n_embd_head_k;
|
||||
const int64_t head_dim_v = hparams.n_embd_head_v;
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
const int64_t num_attention_heads = hparams.n_head(i);
|
||||
const int64_t num_key_value_heads = hparams.n_head_kv(i);
|
||||
const int64_t q_proj_dim = num_attention_heads * head_dim_q;
|
||||
const int64_t k_proj_dim = num_key_value_heads * head_dim_q;
|
||||
const int64_t v_proj_dim = num_key_value_heads * head_dim_v;
|
||||
const int64_t n_ff_cur = hparams.n_ff(i);
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i),
|
||||
{n_embd,q_proj_dim + k_proj_dim + v_proj_dim}, 0);
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim_q}, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim_q}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {num_attention_heads * head_dim_v, n_embd}, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur * 2}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GPT2:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
|
@ -7473,6 +7531,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
{
|
||||
llm = std::make_unique<llm_build_plamo2>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_PLAMO3:
|
||||
{
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
llm = std::make_unique<llm_build_plamo3<true>> (*this, params);
|
||||
} else {
|
||||
llm = std::make_unique<llm_build_plamo3<false>>(*this, params);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GPT2:
|
||||
{
|
||||
llm = std::make_unique<llm_build_gpt2>(*this, params);
|
||||
|
|
@ -7977,6 +8043,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|||
case LLM_ARCH_PHIMOE:
|
||||
case LLM_ARCH_PLAMO:
|
||||
case LLM_ARCH_PLAMO2:
|
||||
case LLM_ARCH_PLAMO3:
|
||||
case LLM_ARCH_GEMMA:
|
||||
case LLM_ARCH_GEMMA2:
|
||||
case LLM_ARCH_GEMMA3:
|
||||
|
|
|
|||
|
|
@ -406,6 +406,11 @@ struct llm_build_plamo : public llm_graph_context {
|
|||
llm_build_plamo(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
template <bool iswa>
|
||||
struct llm_build_plamo3 : public llm_graph_context {
|
||||
llm_build_plamo3(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_plm : public llm_graph_context {
|
||||
llm_build_plm(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -0,0 +1,128 @@
|
|||
#include "models.h"
|
||||
|
||||
template <bool iswa>
|
||||
llm_build_plamo3<iswa>::llm_build_plamo3(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t head_dim_q = hparams.n_embd_head_k;
|
||||
const int64_t head_dim_v = hparams.n_embd_head_v;
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL = build_inp_embd(model.tok_embd);
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
|
||||
inp_attn_type * inp_attn = nullptr;
|
||||
|
||||
if constexpr (iswa) {
|
||||
inp_attn = build_attn_inp_kv_iswa();
|
||||
} else {
|
||||
inp_attn = build_attn_inp_kv();
|
||||
}
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * residual = inpL;
|
||||
|
||||
float freq_base_l = 0.0f;
|
||||
float freq_scale_l = 0.0f;
|
||||
if constexpr (iswa) {
|
||||
freq_base_l = model.get_rope_freq_base (cparams, il);
|
||||
freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
||||
} else {
|
||||
freq_base_l = freq_base;
|
||||
freq_scale_l = freq_scale;
|
||||
}
|
||||
|
||||
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur);
|
||||
cb(cur, "wqkv", il);
|
||||
|
||||
const int32_t n_head = hparams.n_head(il);
|
||||
const int32_t n_head_kv = hparams.n_head_kv(il);
|
||||
|
||||
const int64_t q_offset = 0;
|
||||
const int64_t k_offset = head_dim_q * n_head;
|
||||
const int64_t v_offset = k_offset + head_dim_q * n_head_kv;
|
||||
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, head_dim_q, n_head, n_tokens,
|
||||
head_dim_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv));
|
||||
ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, head_dim_q, n_head_kv, n_tokens,
|
||||
head_dim_q * sizeof(float), qkv->nb[1], k_offset * ggml_element_size(qkv));
|
||||
ggml_tensor * Vcur = ggml_view_3d(ctx0, qkv, head_dim_v, n_head_kv, n_tokens,
|
||||
head_dim_v * sizeof(float), qkv->nb[1], v_offset * ggml_element_size(qkv));
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "attn_q_norm", il);
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Kcur, "attn_k_norm", il);
|
||||
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
const float attn_scale = 1.0f / sqrtf(float(head_dim_q));
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, attn_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
residual = ggml_get_rows(ctx0, residual, inp_out_ids);
|
||||
}
|
||||
|
||||
cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_post_norm", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, residual);
|
||||
cb(cur, "attn_residual", il);
|
||||
|
||||
residual = cur;
|
||||
|
||||
cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
NULL, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_post_norm", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, residual);
|
||||
cb(cur, "ffn_residual", il);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
// Explicit template instantiations
|
||||
template struct llm_build_plamo3<false>;
|
||||
template struct llm_build_plamo3<true>;
|
||||
Loading…
Reference in New Issue