model : support Rnj-1 (#17811)
* add support for rnj1 * refactor gemma3 to support rnj-1 * address review comments
This commit is contained in:
parent
c8554b66e0
commit
1d2a1ab73d
|
|
@ -5825,9 +5825,11 @@ class Gemma3Model(TextModel):
|
||||||
norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value
|
norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
self._set_vocab_sentencepiece()
|
if (self.dir_model / "tokenizer.model").is_file():
|
||||||
|
self._set_vocab_sentencepiece()
|
||||||
self.gguf_writer.add_add_space_prefix(False)
|
self.gguf_writer.add_add_space_prefix(False)
|
||||||
|
else:
|
||||||
|
self._set_vocab_gpt2()
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
hparams = self.hparams
|
hparams = self.hparams
|
||||||
|
|
@ -5845,13 +5847,24 @@ class Gemma3Model(TextModel):
|
||||||
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
|
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
|
||||||
# attn_logit_softcapping is removed in Gemma3
|
# attn_logit_softcapping is removed in Gemma3
|
||||||
assert hparams.get("attn_logit_softcapping") is None
|
assert hparams.get("attn_logit_softcapping") is None
|
||||||
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
|
if (final_logit_softcap := hparams.get("final_logit_softcapping")):
|
||||||
|
self.gguf_writer.add_final_logit_softcapping(final_logit_softcap)
|
||||||
|
if hparams.get("sliding_window_pattern") != 1:
|
||||||
|
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
|
||||||
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
|
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
|
||||||
if hparams.get("rope_scaling") is not None:
|
if hparams.get("rope_scaling") is not None:
|
||||||
assert hparams["rope_scaling"]["rope_type"] == "linear"
|
rope_scaling = hparams["rope_scaling"]
|
||||||
# important: this rope_scaling is only applied for global layers, and not used by 1B model
|
if rope_scaling["rope_type"] == "linear":
|
||||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
# important: this rope_scaling is only applied for global layers, and not used by 1B model
|
||||||
self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
|
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||||
|
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
|
||||||
|
elif rope_scaling["rope_type"] == "yarn":
|
||||||
|
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||||
|
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
|
||||||
|
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
|
||||||
|
self.gguf_writer.add_rope_scaling_yarn_ext_factor(rope_scaling["extrapolation_factor"])
|
||||||
|
self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_scaling["beta_fast"])
|
||||||
|
self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_scaling["beta_slow"])
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
del bid # unused
|
del bid # unused
|
||||||
|
|
@ -5865,8 +5878,10 @@ class Gemma3Model(TextModel):
|
||||||
|
|
||||||
# remove OOV (out-of-vocabulary) rows in token_embd
|
# remove OOV (out-of-vocabulary) rows in token_embd
|
||||||
if "embed_tokens.weight" in name:
|
if "embed_tokens.weight" in name:
|
||||||
vocab = self._create_vocab_sentencepiece()
|
if (self.dir_model / "tokenizer.model").is_file():
|
||||||
tokens = vocab[0]
|
tokens = self._create_vocab_sentencepiece()[0]
|
||||||
|
else:
|
||||||
|
tokens = self.get_vocab_base()[0]
|
||||||
data_torch = data_torch[:len(tokens)]
|
data_torch = data_torch[:len(tokens)]
|
||||||
|
|
||||||
# ref code in Gemma3RMSNorm
|
# ref code in Gemma3RMSNorm
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,7 @@ add_library(llama
|
||||||
models/gemma-embedding.cpp
|
models/gemma-embedding.cpp
|
||||||
models/gemma.cpp
|
models/gemma.cpp
|
||||||
models/gemma2-iswa.cpp
|
models/gemma2-iswa.cpp
|
||||||
models/gemma3-iswa.cpp
|
models/gemma3.cpp
|
||||||
models/gemma3n-iswa.cpp
|
models/gemma3n-iswa.cpp
|
||||||
models/glm4-moe.cpp
|
models/glm4-moe.cpp
|
||||||
models/glm4.cpp
|
models/glm4.cpp
|
||||||
|
|
|
||||||
|
|
@ -1264,18 +1264,25 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_GEMMA3:
|
case LLM_ARCH_GEMMA3:
|
||||||
{
|
{
|
||||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
||||||
hparams.set_swa_pattern(6);
|
if (found_swa && hparams.n_swa > 0) {
|
||||||
|
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||||
|
hparams.set_swa_pattern(6);
|
||||||
|
|
||||||
hparams.rope_freq_base_train_swa = 10000.0f;
|
hparams.rope_freq_base_train_swa = 10000.0f;
|
||||||
hparams.rope_freq_scale_train_swa = 1.0f;
|
hparams.rope_freq_scale_train_swa = 1.0f;
|
||||||
|
} else {
|
||||||
|
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
||||||
|
}
|
||||||
|
|
||||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
hparams.f_final_logit_softcapping = 0.0f;
|
||||||
|
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
case 18: type = LLM_TYPE_270M; break;
|
case 18: type = LLM_TYPE_270M; break;
|
||||||
case 26: type = LLM_TYPE_1B; break;
|
case 26: type = LLM_TYPE_1B; break;
|
||||||
|
case 32: type = LLM_TYPE_8B; break; // Rnj-1
|
||||||
case 34: type = LLM_TYPE_4B; break;
|
case 34: type = LLM_TYPE_4B; break;
|
||||||
case 48: type = LLM_TYPE_12B; break;
|
case 48: type = LLM_TYPE_12B; break;
|
||||||
case 62: type = LLM_TYPE_27B; break;
|
case 62: type = LLM_TYPE_27B; break;
|
||||||
|
|
@ -7304,7 +7311,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_GEMMA3:
|
case LLM_ARCH_GEMMA3:
|
||||||
{
|
{
|
||||||
llm = std::make_unique<llm_build_gemma3_iswa>(*this, params);
|
if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) {
|
||||||
|
llm = std::make_unique<llm_build_gemma3<true>>(*this, params);
|
||||||
|
} else {
|
||||||
|
llm = std::make_unique<llm_build_gemma3<false>>(*this, params);
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_GEMMA3N:
|
case LLM_ARCH_GEMMA3N:
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
#include "models.h"
|
#include "models.h"
|
||||||
|
|
||||||
llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
template <bool iswa>
|
||||||
|
llm_build_gemma3<iswa>::llm_build_gemma3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||||
const int64_t n_embd_head = hparams.n_embd_head_k;
|
const int64_t n_embd_head = hparams.n_embd_head_k;
|
||||||
|
|
||||||
ggml_tensor * cur;
|
ggml_tensor * cur;
|
||||||
|
|
@ -17,13 +18,28 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll
|
||||||
ggml_tensor * inp_pos = build_inp_pos();
|
ggml_tensor * inp_pos = build_inp_pos();
|
||||||
|
|
||||||
// TODO: is causal == true correct? might need some changes
|
// TODO: is causal == true correct? might need some changes
|
||||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
|
||||||
|
inp_attn_type * inp_attn = nullptr;
|
||||||
|
|
||||||
|
if constexpr (iswa) {
|
||||||
|
inp_attn = build_attn_inp_kv_iswa();
|
||||||
|
} else {
|
||||||
|
inp_attn = build_attn_inp_kv();
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
const float freq_base_l = model.get_rope_freq_base (cparams, il);
|
float freq_base_l = 0.0f;
|
||||||
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
float freq_scale_l = 0.0f;
|
||||||
|
|
||||||
|
if constexpr (iswa) {
|
||||||
|
freq_base_l = model.get_rope_freq_base (cparams, il);
|
||||||
|
freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
||||||
|
} else {
|
||||||
|
freq_base_l = freq_base;
|
||||||
|
freq_scale_l = freq_scale;
|
||||||
|
}
|
||||||
|
|
||||||
// norm
|
// norm
|
||||||
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
||||||
|
|
@ -102,7 +118,7 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll
|
||||||
cur = build_norm(cur,
|
cur = build_norm(cur,
|
||||||
model.layers[il].ffn_post_norm, NULL,
|
model.layers[il].ffn_post_norm, NULL,
|
||||||
LLM_NORM_RMS, -1);
|
LLM_NORM_RMS, -1);
|
||||||
cb(cur, "ffn_post_norm", -1);
|
cb(cur, "ffn_post_norm", il);
|
||||||
|
|
||||||
cur = ggml_add(ctx0, cur, sa_out);
|
cur = ggml_add(ctx0, cur, sa_out);
|
||||||
|
|
||||||
|
|
@ -124,8 +140,17 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll
|
||||||
// lm_head
|
// lm_head
|
||||||
cur = build_lora_mm(model.output, cur);
|
cur = build_lora_mm(model.output, cur);
|
||||||
|
|
||||||
|
if (hparams.f_final_logit_softcapping) {
|
||||||
|
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
|
||||||
|
cur = ggml_tanh(ctx0, cur);
|
||||||
|
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
|
||||||
|
}
|
||||||
|
|
||||||
cb(cur, "result_output", -1);
|
cb(cur, "result_output", -1);
|
||||||
res->t_logits = cur;
|
res->t_logits = cur;
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, cur);
|
ggml_build_forward_expand(gf, cur);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template struct llm_build_gemma3<false>;
|
||||||
|
template struct llm_build_gemma3<true>;
|
||||||
|
|
@ -179,8 +179,9 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
|
||||||
llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params);
|
llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llm_build_gemma3_iswa : public llm_graph_context {
|
template <bool iswa>
|
||||||
llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params);
|
struct llm_build_gemma3 : public llm_graph_context {
|
||||||
|
llm_build_gemma3(const llama_model & model, const llm_graph_params & params);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llm_build_gemma3n_iswa : public llm_graph_context {
|
struct llm_build_gemma3n_iswa : public llm_graph_context {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue