llama : support T5 models with unequal number of encoder-decoder layers (#15909)

* Extend the support of T5 models with different encoder-decoder layers

Signed-off-by: Jie Fu <jiefu@tencent.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update gguf-py/gguf/constants.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update gguf-py/gguf/gguf_writer.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-arch.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-arch.h

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-hparams.h

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Rename n_dec_layer --> dec_n_layer

Signed-off-by: Jie Fu <jiefu@tencent.com>

* Adapt to cases when dec_n_layer > n_layer

Signed-off-by: Jie Fu <jiefu@tencent.com>

---------

Signed-off-by: Jie Fu <jiefu@tencent.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
Jie Fu (傅杰) 2025-09-11 02:51:51 +08:00 committed by GitHub
parent 6ab397e12b
commit 4f658855fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 31 additions and 4 deletions

View File

@ -6701,6 +6701,8 @@ class T5Model(TextModel):
self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_embedding_length(self.hparams["d_model"])
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"]) self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
self.gguf_writer.add_block_count(self.hparams["num_layers"]) self.gguf_writer.add_block_count(self.hparams["num_layers"])
if (dec_n_layer := self.hparams.get("num_decoder_layers")) is not None:
self.gguf_writer.add_decoder_block_count(dec_n_layer)
self.gguf_writer.add_head_count(self.hparams["num_heads"]) self.gguf_writer.add_head_count(self.hparams["num_heads"])
self.gguf_writer.add_key_length(self.hparams["d_kv"]) self.gguf_writer.add_key_length(self.hparams["d_kv"])
self.gguf_writer.add_value_length(self.hparams["d_kv"]) self.gguf_writer.add_value_length(self.hparams["d_kv"])

View File

@ -109,6 +109,7 @@ class Keys:
POOLING_TYPE = "{arch}.pooling_type" POOLING_TYPE = "{arch}.pooling_type"
LOGIT_SCALE = "{arch}.logit_scale" LOGIT_SCALE = "{arch}.logit_scale"
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
DECODER_BLOCK_COUNT = "{arch}.decoder_block_count"
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping" ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping" FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
SWIN_NORM = "{arch}.swin_norm" SWIN_NORM = "{arch}.swin_norm"

View File

@ -676,6 +676,9 @@ class GGUFWriter:
def add_decoder_start_token_id(self, id: int) -> None: def add_decoder_start_token_id(self, id: int) -> None:
self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id) self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id)
def add_decoder_block_count(self, value: int) -> None:
self.add_uint32(Keys.LLM.DECODER_BLOCK_COUNT.format(arch=self.arch), value)
def add_embedding_length_per_layer_input(self, value: int) -> None: def add_embedding_length_per_layer_input(self, value: int) -> None:
self.add_uint32(Keys.LLM.EMBD_LENGTH_PER_LAYER_INP.format(arch=self.arch), value) self.add_uint32(Keys.LLM.EMBD_LENGTH_PER_LAYER_INP.format(arch=self.arch), value)

View File

@ -137,6 +137,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" }, { LLM_KV_POOLING_TYPE, "%s.pooling_type" },
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
{ LLM_KV_DECODER_BLOCK_COUNT, "%s.decoder_block_count" },
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
{ LLM_KV_SWIN_NORM, "%s.swin_norm" }, { LLM_KV_SWIN_NORM, "%s.swin_norm" },

View File

@ -141,6 +141,7 @@ enum llm_kv {
LLM_KV_POOLING_TYPE, LLM_KV_POOLING_TYPE,
LLM_KV_LOGIT_SCALE, LLM_KV_LOGIT_SCALE,
LLM_KV_DECODER_START_TOKEN_ID, LLM_KV_DECODER_START_TOKEN_ID,
LLM_KV_DECODER_BLOCK_COUNT,
LLM_KV_ATTN_LOGIT_SOFTCAPPING, LLM_KV_ATTN_LOGIT_SOFTCAPPING,
LLM_KV_FINAL_LOGIT_SOFTCAPPING, LLM_KV_FINAL_LOGIT_SOFTCAPPING,
LLM_KV_SWIN_NORM, LLM_KV_SWIN_NORM,

View File

@ -159,6 +159,7 @@ struct llama_hparams {
// needed by encoder-decoder models (e.g. T5, FLAN-T5) // needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141 // ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = LLAMA_TOKEN_NULL; llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
uint32_t dec_n_layer = 0;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;

View File

@ -1542,6 +1542,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.dec_start_token_id = dec_start_token_id; hparams.dec_start_token_id = dec_start_token_id;
} }
hparams.dec_n_layer = hparams.n_layer;
ml.get_key(LLM_KV_DECODER_BLOCK_COUNT, hparams.dec_n_layer, false);
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 6: type = LLM_TYPE_60M; break; // t5-small case 6: type = LLM_TYPE_60M; break; // t5-small
case 8: type = LLM_TYPE_80M; break; // flan-t5-small case 8: type = LLM_TYPE_80M; break; // flan-t5-small
@ -4414,6 +4417,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
} }
// n_layer: number of encoder_layers
// dec_n_layer: number of decoder_layers
const int dec_n_layer = hparams.dec_n_layer;
if (dec_n_layer > n_layer) {
layers.resize(dec_n_layer);
}
// load encoder layers
for (int i = 0; i < n_layer; ++i) { for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i]; auto & layer = layers[i];
@ -4429,6 +4440,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
// load decoder layers
for (int i = 0; i < dec_n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_norm = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED);
@ -13509,7 +13525,9 @@ struct llm_build_t5_dec : public llm_graph_context {
ggml_tensor * inp_out_ids = build_inp_out_ids(); ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) { const int64_t dec_n_layer = hparams.dec_n_layer;
for (int il = 0; il < dec_n_layer; ++il) {
ggml_tensor * inpSA = inpL; ggml_tensor * inpSA = inpL;
// norm // norm
@ -13600,7 +13618,7 @@ struct llm_build_t5_dec : public llm_graph_context {
//cb(cur, "kqv_out", il); //cb(cur, "kqv_out", il);
} }
if (il == n_layer - 1 && inp_out_ids) { if (il == dec_n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids); cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids); inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
} }
@ -13621,8 +13639,8 @@ struct llm_build_t5_dec : public llm_graph_context {
model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL, model.layers[il].ffn_down, NULL, NULL,
NULL, NULL,
model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_RELU,
model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, model.layers[il].ffn_gate ? LLM_FFN_PAR : LLM_FFN_SEQ,
il); il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }