convert : add Llama4ForCausalLM (#16042)

* convert : add Llama4ForCausalLM

* handle swa

* half working version

* fix use_kq_norm

* fix use_kq_norm
This commit is contained in:
Xuan-Son Nguyen 2025-09-18 00:18:21 +07:00 committed by GitHub
parent c959b676be
commit 8f8f2274ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 50 additions and 12 deletions

View File

@ -2393,7 +2393,10 @@ class SmolVLMModel(MmprojModel):
return [] # skip other tensors
@ModelBase.register("Llama4ForConditionalGeneration")
@ModelBase.register(
"Llama4ForConditionalGeneration",
"Llama4ForCausalLM",
)
class Llama4Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.LLAMA4
undo_permute = False
@ -2411,6 +2414,10 @@ class Llama4Model(LlamaModel):
super().set_gguf_parameters()
self.gguf_writer.add_interleave_moe_layer_step(self.hparams["interleave_moe_layer_step"])
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"])
if "layer_types" in self.hparams:
if all(lt == "full_attention" for lt in self.hparams["layer_types"]):
# all layers are full attention (for MobileLLM), disable swa
self.gguf_writer.add_sliding_window(0)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
if name.startswith("language_model."):

View File

@ -149,7 +149,7 @@ struct llama_hparams {
bool causal_attn = true;
bool use_alibi = false;
bool attn_soft_cap = false;
bool use_kq_norm = true;
bool use_kq_norm = false;
// for Classifiers
uint32_t n_cls_out = 1;

View File

@ -36,6 +36,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_80M: return "80M";
case LLM_TYPE_109M: return "109M";
case LLM_TYPE_137M: return "137M";
case LLM_TYPE_140M: return "140M";
case LLM_TYPE_160M: return "160M";
case LLM_TYPE_190M: return "190M";
case LLM_TYPE_220M: return "220M";
@ -44,6 +45,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_270M: return "270M";
case LLM_TYPE_335M: return "335M";
case LLM_TYPE_350M: return "350M";
case LLM_TYPE_360M: return "360M";
case LLM_TYPE_410M: return "410M";
case LLM_TYPE_450M: return "450M";
case LLM_TYPE_475M: return "475M";
@ -51,6 +53,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_700M: return "700M";
case LLM_TYPE_770M: return "770M";
case LLM_TYPE_780M: return "780M";
case LLM_TYPE_950M: return "950M";
case LLM_TYPE_0_3B: return "0.3B";
case LLM_TYPE_0_5B: return "0.5B";
case LLM_TYPE_0_6B: return "0.6B";
@ -622,19 +625,32 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
if (found_swa && hparams.n_swa == 0) {
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope
} else {
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
hparams.n_swa = 8192;
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
}
switch (hparams.n_expert) {
case 0: {
// MobileLLM (no MoE)
switch (hparams.n_embd) {
case 2048: type = LLM_TYPE_140M; break;
case 4096: type = LLM_TYPE_360M; break;
case 6144: type = LLM_TYPE_950M; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case 16: type = LLM_TYPE_17B_16E; break;
case 128: type = LLM_TYPE_17B_128E; break;
default: type = LLM_TYPE_UNKNOWN;
}
if (type == LLM_TYPE_17B_128E) {
hparams.use_kq_norm = false;
}
hparams.use_kq_norm = type != LLM_TYPE_17B_128E;
} break;
case LLM_ARCH_ARCEE:
{
@ -2454,9 +2470,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Llama 4 requires n_moe_layer_step > 0");
for (int i = 0; i < n_layer; ++i) {
bool is_moe_layer = (i + 1) % hparams.n_moe_layer_step == 0;
bool is_moe_layer = hparams.n_moe_layer_step > 0 && (i + 1) % hparams.n_moe_layer_step == 0;
auto & layer = layers[i];
@ -6328,6 +6343,14 @@ struct llm_build_llama : public llm_graph_context {
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
if (hparams.use_kq_norm) {
// Llama4TextL2Norm
Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps);
Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps);
cb(Qcur, "Qcur_normed", il);
cb(Kcur, "Kcur_normed", il);
}
cur = build_attn(inp_attn,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
@ -6435,7 +6458,8 @@ struct llm_build_llama_iswa : public llm_graph_context {
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0;
const bool use_rope = hparams.n_no_rope_layer_step > 0 &&
(il + 1) % hparams.n_no_rope_layer_step != 0;
// norm
cur = build_norm(inpL,
@ -18981,7 +19005,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
} break;
case LLM_ARCH_LLAMA4:
{
llm = std::make_unique<llm_build_llama_iswa>(*this, params);
if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) {
llm = std::make_unique<llm_build_llama>(*this, params);
} else {
llm = std::make_unique<llm_build_llama_iswa>(*this, params);
}
} break;
case LLM_ARCH_DECI:
{

View File

@ -28,6 +28,7 @@ enum llm_type {
LLM_TYPE_80M,
LLM_TYPE_109M,
LLM_TYPE_137M,
LLM_TYPE_140M,
LLM_TYPE_160M,
LLM_TYPE_190M,
LLM_TYPE_220M,
@ -36,6 +37,7 @@ enum llm_type {
LLM_TYPE_270M,
LLM_TYPE_335M,
LLM_TYPE_350M,
LLM_TYPE_360M,
LLM_TYPE_410M,
LLM_TYPE_450M,
LLM_TYPE_475M,
@ -43,6 +45,7 @@ enum llm_type {
LLM_TYPE_700M,
LLM_TYPE_770M,
LLM_TYPE_780M,
LLM_TYPE_950M,
LLM_TYPE_0_3B,
LLM_TYPE_0_5B,
LLM_TYPE_0_6B,