model : support Rnj-1 (#17811)

* add support for rnj1

* refactor gemma3 to support rnj-1

* address review comments
This commit is contained in:
philip-essential 2025-12-08 19:49:03 -08:00 committed by GitHub
parent c8554b66e0
commit 1d2a1ab73d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 76 additions and 24 deletions

View File

@ -5825,9 +5825,11 @@ class Gemma3Model(TextModel):
norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value
def set_vocab(self): def set_vocab(self):
if (self.dir_model / "tokenizer.model").is_file():
self._set_vocab_sentencepiece() self._set_vocab_sentencepiece()
self.gguf_writer.add_add_space_prefix(False) self.gguf_writer.add_add_space_prefix(False)
else:
self._set_vocab_gpt2()
def set_gguf_parameters(self): def set_gguf_parameters(self):
hparams = self.hparams hparams = self.hparams
@ -5845,13 +5847,24 @@ class Gemma3Model(TextModel):
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
# attn_logit_softcapping is removed in Gemma3 # attn_logit_softcapping is removed in Gemma3
assert hparams.get("attn_logit_softcapping") is None assert hparams.get("attn_logit_softcapping") is None
if (final_logit_softcap := hparams.get("final_logit_softcapping")):
self.gguf_writer.add_final_logit_softcapping(final_logit_softcap)
if hparams.get("sliding_window_pattern") != 1:
self.gguf_writer.add_sliding_window(hparams["sliding_window"]) self.gguf_writer.add_sliding_window(hparams["sliding_window"])
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4)) self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
if hparams.get("rope_scaling") is not None: if hparams.get("rope_scaling") is not None:
assert hparams["rope_scaling"]["rope_type"] == "linear" rope_scaling = hparams["rope_scaling"]
if rope_scaling["rope_type"] == "linear":
# important: this rope_scaling is only applied for global layers, and not used by 1B model # important: this rope_scaling is only applied for global layers, and not used by 1B model
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"]) self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
elif rope_scaling["rope_type"] == "yarn":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
self.gguf_writer.add_rope_scaling_yarn_ext_factor(rope_scaling["extrapolation_factor"])
self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_scaling["beta_fast"])
self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_scaling["beta_slow"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused del bid # unused
@ -5865,8 +5878,10 @@ class Gemma3Model(TextModel):
# remove OOV (out-of-vocabulary) rows in token_embd # remove OOV (out-of-vocabulary) rows in token_embd
if "embed_tokens.weight" in name: if "embed_tokens.weight" in name:
vocab = self._create_vocab_sentencepiece() if (self.dir_model / "tokenizer.model").is_file():
tokens = vocab[0] tokens = self._create_vocab_sentencepiece()[0]
else:
tokens = self.get_vocab_base()[0]
data_torch = data_torch[:len(tokens)] data_torch = data_torch[:len(tokens)]
# ref code in Gemma3RMSNorm # ref code in Gemma3RMSNorm

View File

@ -67,7 +67,7 @@ add_library(llama
models/gemma-embedding.cpp models/gemma-embedding.cpp
models/gemma.cpp models/gemma.cpp
models/gemma2-iswa.cpp models/gemma2-iswa.cpp
models/gemma3-iswa.cpp models/gemma3.cpp
models/gemma3n-iswa.cpp models/gemma3n-iswa.cpp
models/glm4-moe.cpp models/glm4-moe.cpp
models/glm4.cpp models/glm4.cpp

View File

@ -1264,18 +1264,25 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} break; } break;
case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3:
{ {
const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
if (found_swa && hparams.n_swa > 0) {
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.set_swa_pattern(6); hparams.set_swa_pattern(6);
hparams.rope_freq_base_train_swa = 10000.0f; hparams.rope_freq_base_train_swa = 10000.0f;
hparams.rope_freq_scale_train_swa = 1.0f; hparams.rope_freq_scale_train_swa = 1.0f;
} else {
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
}
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); hparams.f_final_logit_softcapping = 0.0f;
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 18: type = LLM_TYPE_270M; break; case 18: type = LLM_TYPE_270M; break;
case 26: type = LLM_TYPE_1B; break; case 26: type = LLM_TYPE_1B; break;
case 32: type = LLM_TYPE_8B; break; // Rnj-1
case 34: type = LLM_TYPE_4B; break; case 34: type = LLM_TYPE_4B; break;
case 48: type = LLM_TYPE_12B; break; case 48: type = LLM_TYPE_12B; break;
case 62: type = LLM_TYPE_27B; break; case 62: type = LLM_TYPE_27B; break;
@ -7304,7 +7311,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
} break; } break;
case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3:
{ {
llm = std::make_unique<llm_build_gemma3_iswa>(*this, params); if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) {
llm = std::make_unique<llm_build_gemma3<true>>(*this, params);
} else {
llm = std::make_unique<llm_build_gemma3<false>>(*this, params);
}
} break; } break;
case LLM_ARCH_GEMMA3N: case LLM_ARCH_GEMMA3N:
{ {

View File

@ -1,6 +1,7 @@
#include "models.h" #include "models.h"
llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { template <bool iswa>
llm_build_gemma3<iswa>::llm_build_gemma3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_k; const int64_t n_embd_head = hparams.n_embd_head_k;
ggml_tensor * cur; ggml_tensor * cur;
@ -17,13 +18,28 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll
ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_pos = build_inp_pos();
// TODO: is causal == true correct? might need some changes // TODO: is causal == true correct? might need some changes
auto * inp_attn = build_attn_inp_kv_iswa(); using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
inp_attn_type * inp_attn = nullptr;
if constexpr (iswa) {
inp_attn = build_attn_inp_kv_iswa();
} else {
inp_attn = build_attn_inp_kv();
}
ggml_tensor * inp_out_ids = build_inp_out_ids(); ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
const float freq_base_l = model.get_rope_freq_base (cparams, il); float freq_base_l = 0.0f;
const float freq_scale_l = model.get_rope_freq_scale(cparams, il); float freq_scale_l = 0.0f;
if constexpr (iswa) {
freq_base_l = model.get_rope_freq_base (cparams, il);
freq_scale_l = model.get_rope_freq_scale(cparams, il);
} else {
freq_base_l = freq_base;
freq_scale_l = freq_scale;
}
// norm // norm
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
@ -102,7 +118,7 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll
cur = build_norm(cur, cur = build_norm(cur,
model.layers[il].ffn_post_norm, NULL, model.layers[il].ffn_post_norm, NULL,
LLM_NORM_RMS, -1); LLM_NORM_RMS, -1);
cb(cur, "ffn_post_norm", -1); cb(cur, "ffn_post_norm", il);
cur = ggml_add(ctx0, cur, sa_out); cur = ggml_add(ctx0, cur, sa_out);
@ -124,8 +140,17 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll
// lm_head // lm_head
cur = build_lora_mm(model.output, cur); cur = build_lora_mm(model.output, cur);
if (hparams.f_final_logit_softcapping) {
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
cur = ggml_tanh(ctx0, cur);
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
}
cb(cur, "result_output", -1); cb(cur, "result_output", -1);
res->t_logits = cur; res->t_logits = cur;
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
} }
template struct llm_build_gemma3<false>;
template struct llm_build_gemma3<true>;

View File

@ -179,8 +179,9 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params); llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params);
}; };
struct llm_build_gemma3_iswa : public llm_graph_context { template <bool iswa>
llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params); struct llm_build_gemma3 : public llm_graph_context {
llm_build_gemma3(const llama_model & model, const llm_graph_params & params);
}; };
struct llm_build_gemma3n_iswa : public llm_graph_context { struct llm_build_gemma3n_iswa : public llm_graph_context {