mtmd: build_attn modified, flash_attn on/off via ctx_params (#19729)
This commit is contained in:
parent
2bf318fd2f
commit
e6267a9359
|
|
@ -628,9 +628,6 @@ ggml_tensor * clip_graph::build_attn(
|
|||
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
|
||||
v = ggml_cont(ctx0, v);
|
||||
|
||||
const auto n_tokens = q->ne[1];
|
||||
const auto n_head = q->ne[2];
|
||||
|
||||
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
// F32 may not needed for vision encoders?
|
||||
// ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||
|
|
@ -639,7 +636,7 @@ ggml_tensor * clip_graph::build_attn(
|
|||
|
||||
ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
||||
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2] * cur->ne[3]);
|
||||
}
|
||||
|
||||
cb(cur, "kqv_out", il);
|
||||
|
|
|
|||
|
|
@ -175,7 +175,7 @@ struct mtmd_context {
|
|||
|
||||
clip_context_params ctx_clip_params {
|
||||
/* use_gpu */ ctx_params.use_gpu,
|
||||
/* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO,
|
||||
/* flash_attn_type */ mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type),
|
||||
/* image_min_tokens */ ctx_params.image_min_tokens,
|
||||
/* image_max_tokens */ ctx_params.image_max_tokens,
|
||||
/* warmup */ ctx_params.warmup,
|
||||
|
|
|
|||
|
|
@ -28,6 +28,14 @@ if [ "${1:-}" = "huge" ]; then
|
|||
echo "Include BIG and HUGE models..."
|
||||
fi
|
||||
|
||||
# Check if the second argument is "flash", then enable flash attention
|
||||
# This is useful to test if flash attention off works correctly
|
||||
FLASH_ATTN="on"
|
||||
if [ "${2:-}" = "flash_off" ] || [ "${1:-}" = "flash_off" ]; then
|
||||
FLASH_ATTN="off"
|
||||
echo "Flash attention disabled..."
|
||||
fi
|
||||
|
||||
###############
|
||||
|
||||
arr_prefix=()
|
||||
|
|
@ -143,6 +151,7 @@ for i in "${!arr_hf[@]}"; do
|
|||
-hf $(printf %q "$hf") \
|
||||
--image $(printf %q "$SCRIPT_DIR/$inp_file") \
|
||||
--temp 0 -n 128 \
|
||||
--flash-attn $(printf %q "$FLASH_ATTN") \
|
||||
${extra_args}"
|
||||
|
||||
# if extra_args does not contain -p, we add a default prompt
|
||||
|
|
|
|||
Loading…
Reference in New Issue