corrected code-branch when flash-attn disabled
enabling usage of --flash-attn option
This commit is contained in:
parent
5381b9cf63
commit
076138a428
|
|
@ -2591,9 +2591,6 @@ private:
|
|||
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
|
||||
v = ggml_cont(ctx0, v);
|
||||
|
||||
const auto n_tokens = q->ne[1];
|
||||
const auto n_head = q->ne[2];
|
||||
|
||||
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
// F32 may not needed for vision encoders?
|
||||
// ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||
|
|
@ -2601,8 +2598,9 @@ private:
|
|||
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f);
|
||||
|
||||
ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
||||
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
cur = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3));
|
||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
||||
|
||||
}
|
||||
|
||||
cb(cur, "kqv_out", il);
|
||||
|
|
|
|||
|
|
@ -175,7 +175,7 @@ struct mtmd_context {
|
|||
|
||||
clip_context_params ctx_clip_params {
|
||||
/* use_gpu */ ctx_params.use_gpu,
|
||||
/* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO,
|
||||
/* flash_attn_type */ mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type),
|
||||
/* image_min_tokens */ ctx_params.image_min_tokens,
|
||||
/* image_max_tokens */ ctx_params.image_max_tokens,
|
||||
/* warmup */ ctx_params.warmup,
|
||||
|
|
|
|||
Loading…
Reference in New Issue