diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index d1bed23d03..b9bcfafa1c 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2789,15 +2789,12 @@ private: Q = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 0*cur->nb[1]); Q = ggml_reshape_4d(ctx0, ggml_cont(ctx0, Q), d_heads, n_heads, W*H, B); - Q = ggml_cont (ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads] K = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 1*cur->nb[1]); K = ggml_reshape_4d(ctx0, ggml_cont(ctx0, K), d_heads, n_heads, W*H, B); - K = ggml_cont (ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads] V = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 2*cur->nb[1]); V = ggml_reshape_4d(ctx0, ggml_cont(ctx0, V), d_heads, n_heads, W*H, B); - V = ggml_cont (ctx0, ggml_permute(ctx0, V, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads] ggml_tensor * mask; ggml_tensor * rw; @@ -2806,7 +2803,7 @@ private: rw = get_rel_pos(ctx0, layer.rel_pos_w, W, W); // [W, W, C] rh = get_rel_pos(ctx0, layer.rel_pos_h, H, H); // [H, H, C] - qr = ggml_reshape_4d(ctx0, Q, d_heads, W, H, B*n_heads); + qr = ggml_reshape_4d(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)), d_heads, W, H, B*n_heads); const int WH_pad = GGML_PAD(W*H, GGML_KQ_MASK_PAD) - W*H; @@ -2822,11 +2819,18 @@ private: mask = ggml_cast (ctx0, mask, GGML_TYPE_F16); float scale = 1.0f / sqrtf((float)d_heads); - cur = ggml_flash_attn_ext(ctx0, Q, K, V, mask, scale, 0.0f, 0.0f); // [B, H*W, n_heads, d_heads] + cur = build_attn( + layer.o_w, + layer.o_b, + Q, + K, + V, + mask, + scale, + il + ); // [B, H*W, n_embd] cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), n_embd, W, H, B); - cur = ggml_mul_mat(ctx0, layer.o_w, cur); - cur = ggml_add_inplace(ctx0, cur, layer.o_b); } if (hparams.is_global_attn(il) == false) {