using common build_attn in sam
This commit is contained in:
parent
4d7d9945f8
commit
5381b9cf63
|
|
@ -2789,15 +2789,12 @@ private:
|
|||
|
||||
Q = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 0*cur->nb[1]);
|
||||
Q = ggml_reshape_4d(ctx0, ggml_cont(ctx0, Q), d_heads, n_heads, W*H, B);
|
||||
Q = ggml_cont (ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads]
|
||||
|
||||
K = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 1*cur->nb[1]);
|
||||
K = ggml_reshape_4d(ctx0, ggml_cont(ctx0, K), d_heads, n_heads, W*H, B);
|
||||
K = ggml_cont (ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads]
|
||||
|
||||
V = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 2*cur->nb[1]);
|
||||
V = ggml_reshape_4d(ctx0, ggml_cont(ctx0, V), d_heads, n_heads, W*H, B);
|
||||
V = ggml_cont (ctx0, ggml_permute(ctx0, V, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads]
|
||||
|
||||
ggml_tensor * mask;
|
||||
ggml_tensor * rw;
|
||||
|
|
@ -2806,7 +2803,7 @@ private:
|
|||
|
||||
rw = get_rel_pos(ctx0, layer.rel_pos_w, W, W); // [W, W, C]
|
||||
rh = get_rel_pos(ctx0, layer.rel_pos_h, H, H); // [H, H, C]
|
||||
qr = ggml_reshape_4d(ctx0, Q, d_heads, W, H, B*n_heads);
|
||||
qr = ggml_reshape_4d(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)), d_heads, W, H, B*n_heads);
|
||||
|
||||
const int WH_pad = GGML_PAD(W*H, GGML_KQ_MASK_PAD) - W*H;
|
||||
|
||||
|
|
@ -2822,11 +2819,18 @@ private:
|
|||
mask = ggml_cast (ctx0, mask, GGML_TYPE_F16);
|
||||
|
||||
float scale = 1.0f / sqrtf((float)d_heads);
|
||||
cur = ggml_flash_attn_ext(ctx0, Q, K, V, mask, scale, 0.0f, 0.0f); // [B, H*W, n_heads, d_heads]
|
||||
|
||||
cur = build_attn(
|
||||
layer.o_w,
|
||||
layer.o_b,
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
mask,
|
||||
scale,
|
||||
il
|
||||
); // [B, H*W, n_embd]
|
||||
cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), n_embd, W, H, B);
|
||||
cur = ggml_mul_mat(ctx0, layer.o_w, cur);
|
||||
cur = ggml_add_inplace(ctx0, cur, layer.o_b);
|
||||
}
|
||||
|
||||
if (hparams.is_global_attn(il) == false) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue