llama : fix shapes for bert/mpt q/k norm (#16409)

This commit is contained in:
Sigbjørn Skjæret 2025-10-03 14:40:25 +02:00 committed by GitHub
parent 638d330246
commit 946f71ed9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 7 additions and 0 deletions

View File

@ -7843,6 +7843,8 @@ struct llm_build_bert : public llm_graph_context {
}
if (model.layers[il].attn_q_norm) {
Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head*n_head, n_tokens);
Qcur = build_norm(Qcur,
model.layers[il].attn_q_norm,
model.layers[il].attn_q_norm_b,
@ -7852,6 +7854,8 @@ struct llm_build_bert : public llm_graph_context {
}
if (model.layers[il].attn_k_norm) {
Kcur = ggml_reshape_2d(ctx0, Kcur, n_embd_head*n_head_kv, n_tokens);
Kcur = build_norm(Kcur,
model.layers[il].attn_k_norm,
model.layers[il].attn_k_norm_b,
@ -8234,6 +8238,9 @@ struct llm_build_mpt : public llm_graph_context {
// Q/K Layernorm
if (model.layers[il].attn_q_norm) {
Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head*n_head, n_tokens);
Kcur = ggml_reshape_2d(ctx0, Kcur, n_embd_head*n_head_kv, n_tokens);
Qcur = build_norm(Qcur,
model.layers[il].attn_q_norm,
model.layers[il].attn_q_norm_b,