mtmd: support combined QKV projection in buid_vit

This commit is contained in:
bluebread 2025-12-04 17:57:43 +00:00
parent 2dd9924076
commit fc3f625fef
1 changed files with 36 additions and 11 deletions

View File

@ -2152,19 +2152,44 @@ private:
// self-attention
{
ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
if (layer.q_b) {
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
}
ggml_tensor * Qcur;
ggml_tensor * Kcur;
ggml_tensor * Vcur;
if (layer.qkv_w) {
ggml_tensor * QKV;
ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
if (layer.k_b) {
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
}
QKV = ggml_mul_mat(ctx0, layer.qkv_w, cur);
if (layer.qkv_b) {
QKV = ggml_add(ctx0, QKV, layer.qkv_b);
}
QKV = ggml_reshape_4d(ctx0, QKV, cur->ne[0], 3, cur->ne[1]*cur->ne[2], cur->ne[3]);
ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
if (layer.v_b) {
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
const int ne0 = QKV->ne[0];
const int ne2 = QKV->ne[2];
const int ne3 = QKV->ne[3];
const int nb1 = QKV->nb[1];
const int nb2 = QKV->nb[2];
const int nb3 = QKV->nb[3];
Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, QKV, ne0, ne2, ne3, nb2, nb3, 0*nb1));
Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, QKV, ne0, ne2, ne3, nb2, nb3, 1*nb1));
Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, QKV, ne0, ne2, ne3, nb2, nb3, 2*nb1));
} else {
Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
if (layer.q_b) {
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
}
Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
if (layer.k_b) {
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
}
Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
if (layer.v_b) {
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
}
}
if (layer.q_norm) {