Merge pull request #9 from sfallah/sf/deepseek-ocr-attn
using common build_attn in sam
This commit is contained in:
commit
6687b4e746
|
|
@ -2590,10 +2590,7 @@ private:
|
|||
} else {
|
||||
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
|
||||
v = ggml_cont(ctx0, v);
|
||||
|
||||
const auto n_tokens = q->ne[1];
|
||||
const auto n_head = q->ne[2];
|
||||
|
||||
|
||||
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
// F32 may not needed for vision encoders?
|
||||
// ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||
|
|
@ -2602,7 +2599,8 @@ private:
|
|||
|
||||
ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
||||
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur), cur->ne[0] * cur->ne[1], cur->ne[2] * cur->ne[3]);
|
||||
|
||||
}
|
||||
|
||||
cb(cur, "kqv_out", il);
|
||||
|
|
@ -2789,15 +2787,12 @@ private:
|
|||
|
||||
Q = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 0*cur->nb[1]);
|
||||
Q = ggml_reshape_4d(ctx0, ggml_cont(ctx0, Q), d_heads, n_heads, W*H, B);
|
||||
Q = ggml_cont (ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads]
|
||||
|
||||
K = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 1*cur->nb[1]);
|
||||
K = ggml_reshape_4d(ctx0, ggml_cont(ctx0, K), d_heads, n_heads, W*H, B);
|
||||
K = ggml_cont (ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads]
|
||||
|
||||
V = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 2*cur->nb[1]);
|
||||
V = ggml_reshape_4d(ctx0, ggml_cont(ctx0, V), d_heads, n_heads, W*H, B);
|
||||
V = ggml_cont (ctx0, ggml_permute(ctx0, V, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads]
|
||||
|
||||
ggml_tensor * mask;
|
||||
ggml_tensor * rw;
|
||||
|
|
@ -2806,7 +2801,8 @@ private:
|
|||
|
||||
rw = get_rel_pos(ctx0, layer.rel_pos_w, W, W); // [W, W, C]
|
||||
rh = get_rel_pos(ctx0, layer.rel_pos_h, H, H); // [H, H, C]
|
||||
qr = ggml_reshape_4d(ctx0, Q, d_heads, W, H, B*n_heads);
|
||||
qr = ggml_permute(ctx0, Q, 0, 2, 1, 3);
|
||||
qr = ggml_reshape_4d(ctx0, ggml_cont(ctx0, qr), d_heads, W, H, B * n_heads);
|
||||
|
||||
const int WH_pad = GGML_PAD(W*H, GGML_KQ_MASK_PAD) - W*H;
|
||||
|
||||
|
|
@ -2822,11 +2818,10 @@ private:
|
|||
mask = ggml_cast (ctx0, mask, GGML_TYPE_F16);
|
||||
|
||||
float scale = 1.0f / sqrtf((float)d_heads);
|
||||
cur = ggml_flash_attn_ext(ctx0, Q, K, V, mask, scale, 0.0f, 0.0f); // [B, H*W, n_heads, d_heads]
|
||||
|
||||
cur = build_attn(layer.o_w, layer.o_b, Q, K, V, mask, scale,
|
||||
il); // [B, H*W, n_embd]
|
||||
cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), n_embd, W, H, B);
|
||||
cur = ggml_mul_mat(ctx0, layer.o_w, cur);
|
||||
cur = ggml_add_inplace(ctx0, cur, layer.o_b);
|
||||
}
|
||||
|
||||
if (hparams.is_global_attn(il) == false) {
|
||||
|
|
|
|||
|
|
@ -175,7 +175,7 @@ struct mtmd_context {
|
|||
|
||||
clip_context_params ctx_clip_params {
|
||||
/* use_gpu */ ctx_params.use_gpu,
|
||||
/* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO,
|
||||
/* flash_attn_type */ mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type),
|
||||
/* image_min_tokens */ ctx_params.image_min_tokens,
|
||||
/* image_max_tokens */ ctx_params.image_max_tokens,
|
||||
/* warmup */ ctx_params.warmup,
|
||||
|
|
|
|||
Loading…
Reference in New Issue