implement "auto" mode for clip flash attn

This commit is contained in:
Xuan Son Nguyen 2025-11-01 23:52:40 +01:00
parent 19116a4b38
commit b4955f0ae6
6 changed files with 74 additions and 10 deletions

View File

@ -4,6 +4,7 @@
// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
#include "clip.h"
#include "clip-impl.h"
#include "mtmd.h"
#include "ggml.h"
#include "ggml-cpp.h"
#include "ggml-cpu.h"
@ -427,12 +428,14 @@ struct clip_ctx {
int max_nodes = 8192;
ggml_backend_sched_ptr sched;
llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
// for debugging
bool debug_graph = false;
std::vector<ggml_tensor *> debug_print_tensors;
clip_ctx(clip_context_params & ctx_params) {
flash_attn_type = ctx_params.flash_attn_type;
debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr;
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
if (!backend_cpu) {
@ -2261,16 +2264,36 @@ private:
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
//cb(k, "k", il);
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
//cb(k, "v", il);
ggml_tensor * cur;
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
if (ctx->flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) {
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
ggml_tensor * cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f);
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f);
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
} else {
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
v = ggml_cont(ctx0, v);
const auto n_tokens = q->ne[1];
const auto n_head = q->ne[2];
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
// F32 may not needed for vision encoders?
// ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f);
ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
}
cb(cur, "kqv_out", il);
@ -3181,7 +3204,30 @@ struct clip_model_loader {
}
}
void alloc_compute_meta(clip_ctx & ctx_clip) {
void warmup(clip_ctx & ctx_clip) {
if (ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
// try to enable flash attention to see if it's supported
ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
bool supported = alloc_compute_meta(ctx_clip);
if (!supported) {
LOG_WRN("%s: flash attention not supported, memory usage will increase\n", __func__);
// TODO: maybe log more details about why flash attention is not supported
ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
alloc_compute_meta(ctx_clip);
}
} else {
bool supported = alloc_compute_meta(ctx_clip);
if (!supported) {
LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__);
}
}
LOG_INF("%s: flash attention is %s\n", __func__,
(ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
}
// return false if flash attention is not supported
bool alloc_compute_meta(clip_ctx & ctx_clip) {
const auto & hparams = ctx_clip.model.hparams;
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
@ -3217,6 +3263,17 @@ struct clip_model_loader {
const int n_nodes = ggml_graph_n_nodes(gf);
LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__, n_splits, n_nodes);
// check flash attention support
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
ggml_tensor * node = ggml_graph_node(gf, i);
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
if (!ggml_backend_supports_op(ctx_clip.backend, node)) {
return false;
}
}
}
return true;
}
void get_bool(const std::string & key, bool & output, bool required = true) {
@ -3306,14 +3363,14 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
ctx_vision = new clip_ctx(ctx_params);
loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION);
loader.load_tensors(*ctx_vision);
loader.alloc_compute_meta(*ctx_vision);
loader.warmup(*ctx_vision);
}
if (loader.has_audio) {
ctx_audio = new clip_ctx(ctx_params);
loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO);
loader.load_tensors(*ctx_audio);
loader.alloc_compute_meta(*ctx_audio);
loader.warmup(*ctx_audio);
}
} catch (const std::exception & e) {

View File

@ -1,6 +1,7 @@
#pragma once
#include "ggml.h"
#include "mtmd.h"
#include <stddef.h>
#include <stdint.h>
@ -25,6 +26,7 @@ enum clip_modality {
struct clip_context_params {
bool use_gpu;
enum ggml_log_level verbosity;
llama_flash_attn_type flash_attn_type;
};
struct clip_init_result {

View File

@ -136,6 +136,7 @@ struct mtmd_cli_context {
mparams.print_timings = true;
mparams.n_threads = params.cpuparams.n_threads;
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
mparams.flash_attn_type = params.flash_attn_type;
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
if (!ctx_vision.get()) {
LOG_ERR("Failed to load vision model from %s\n", clip_path);

View File

@ -100,6 +100,7 @@ mtmd_context_params mtmd_context_params_default() {
params.verbosity = GGML_LOG_LEVEL_INFO;
params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
params.media_marker = mtmd_default_marker();
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
return params;
}
@ -164,6 +165,7 @@ struct mtmd_context {
clip_context_params ctx_clip_params;
ctx_clip_params.use_gpu = ctx_params.use_gpu;
ctx_clip_params.verbosity = ctx_params.verbosity;
ctx_clip_params.flash_attn_type = ctx_params.flash_attn_type;
auto res = clip_init(mmproj_fname, ctx_clip_params);
ctx_v = res.ctx_v;
ctx_a = res.ctx_a;

View File

@ -82,6 +82,7 @@ struct mtmd_context_params {
enum ggml_log_level verbosity;
const char * image_marker; // deprecated, use media_marker instead
const char * media_marker;
llama_flash_attn_type flash_attn_type;
};
MTMD_API const char * mtmd_default_marker(void);

View File

@ -2456,6 +2456,7 @@ struct server_context {
mparams.print_timings = false;
mparams.n_threads = params_base.cpuparams.n_threads;
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
mparams.flash_attn_type = params_base.flash_attn_type;
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
if (mctx == nullptr) {
SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());