mtmd: support combined QKV projection in buid_vit
This commit is contained in:
parent
2dd9924076
commit
fc3f625fef
|
|
@ -2152,19 +2152,44 @@ private:
|
||||||
|
|
||||||
// self-attention
|
// self-attention
|
||||||
{
|
{
|
||||||
ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
|
ggml_tensor * Qcur;
|
||||||
if (layer.q_b) {
|
ggml_tensor * Kcur;
|
||||||
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
|
ggml_tensor * Vcur;
|
||||||
}
|
|
||||||
|
if (layer.qkv_w) {
|
||||||
|
ggml_tensor * QKV;
|
||||||
|
|
||||||
ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
|
QKV = ggml_mul_mat(ctx0, layer.qkv_w, cur);
|
||||||
if (layer.k_b) {
|
if (layer.qkv_b) {
|
||||||
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
|
QKV = ggml_add(ctx0, QKV, layer.qkv_b);
|
||||||
}
|
}
|
||||||
|
QKV = ggml_reshape_4d(ctx0, QKV, cur->ne[0], 3, cur->ne[1]*cur->ne[2], cur->ne[3]);
|
||||||
|
|
||||||
ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
|
const int ne0 = QKV->ne[0];
|
||||||
if (layer.v_b) {
|
const int ne2 = QKV->ne[2];
|
||||||
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
|
const int ne3 = QKV->ne[3];
|
||||||
|
const int nb1 = QKV->nb[1];
|
||||||
|
const int nb2 = QKV->nb[2];
|
||||||
|
const int nb3 = QKV->nb[3];
|
||||||
|
|
||||||
|
Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, QKV, ne0, ne2, ne3, nb2, nb3, 0*nb1));
|
||||||
|
Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, QKV, ne0, ne2, ne3, nb2, nb3, 1*nb1));
|
||||||
|
Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, QKV, ne0, ne2, ne3, nb2, nb3, 2*nb1));
|
||||||
|
} else {
|
||||||
|
Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
|
||||||
|
if (layer.q_b) {
|
||||||
|
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
|
||||||
|
}
|
||||||
|
|
||||||
|
Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
|
||||||
|
if (layer.k_b) {
|
||||||
|
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
|
||||||
|
if (layer.v_b) {
|
||||||
|
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (layer.q_norm) {
|
if (layer.q_norm) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue