diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 70b93e3425..74ba6c2705 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -4,6 +4,7 @@ // Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch #include "clip.h" #include "clip-impl.h" +#include "mtmd.h" #include "ggml.h" #include "ggml-cpp.h" #include "ggml-cpu.h" @@ -427,12 +428,14 @@ struct clip_ctx { int max_nodes = 8192; ggml_backend_sched_ptr sched; + llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; // for debugging bool debug_graph = false; std::vector debug_print_tensors; clip_ctx(clip_context_params & ctx_params) { + flash_attn_type = ctx_params.flash_attn_type; debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr; backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); if (!backend_cpu) { @@ -2261,16 +2264,36 @@ private: ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3); //cb(k, "k", il); - ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); - //cb(k, "v", il); + ggml_tensor * cur; - k = ggml_cast(ctx0, k, GGML_TYPE_F16); - v = ggml_cast(ctx0, v, GGML_TYPE_F16); + if (ctx->flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) { + ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); - ggml_tensor * cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f); - ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + k = ggml_cast(ctx0, k, GGML_TYPE_F16); + v = ggml_cast(ctx0, v, GGML_TYPE_F16); - cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); + cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f); + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); + + } else { + ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3); + v = ggml_cont(ctx0, v); + + const auto n_tokens = q->ne[1]; + const auto n_head = q->ne[2]; + + ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + // F32 may not needed for vision encoders? + // ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + + kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f); + + ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + } cb(cur, "kqv_out", il); @@ -3181,7 +3204,30 @@ struct clip_model_loader { } } - void alloc_compute_meta(clip_ctx & ctx_clip) { + void warmup(clip_ctx & ctx_clip) { + if (ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { + // try to enable flash attention to see if it's supported + ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED; + bool supported = alloc_compute_meta(ctx_clip); + if (!supported) { + LOG_WRN("%s: flash attention not supported, memory usage will increase\n", __func__); + // TODO: maybe log more details about why flash attention is not supported + ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; + alloc_compute_meta(ctx_clip); + } + } else { + bool supported = alloc_compute_meta(ctx_clip); + if (!supported) { + LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__); + } + } + + LOG_INF("%s: flash attention is %s\n", __func__, + (ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled"); + } + + // return false if flash attention is not supported + bool alloc_compute_meta(clip_ctx & ctx_clip) { const auto & hparams = ctx_clip.model.hparams; ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead()); @@ -3217,6 +3263,17 @@ struct clip_model_loader { const int n_nodes = ggml_graph_n_nodes(gf); LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__, n_splits, n_nodes); + + // check flash attention support + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * node = ggml_graph_node(gf, i); + if (node->op == GGML_OP_FLASH_ATTN_EXT) { + if (!ggml_backend_supports_op(ctx_clip.backend, node)) { + return false; + } + } + } + return true; } void get_bool(const std::string & key, bool & output, bool required = true) { @@ -3306,14 +3363,14 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_vision = new clip_ctx(ctx_params); loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION); loader.load_tensors(*ctx_vision); - loader.alloc_compute_meta(*ctx_vision); + loader.warmup(*ctx_vision); } if (loader.has_audio) { ctx_audio = new clip_ctx(ctx_params); loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO); loader.load_tensors(*ctx_audio); - loader.alloc_compute_meta(*ctx_audio); + loader.warmup(*ctx_audio); } } catch (const std::exception & e) { diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index 3387cdbd36..afef15205e 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -1,6 +1,7 @@ #pragma once #include "ggml.h" +#include "mtmd.h" #include #include @@ -25,6 +26,7 @@ enum clip_modality { struct clip_context_params { bool use_gpu; enum ggml_log_level verbosity; + llama_flash_attn_type flash_attn_type; }; struct clip_init_result { diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index fd1fb6581b..17aea1472b 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -136,6 +136,7 @@ struct mtmd_cli_context { mparams.print_timings = true; mparams.n_threads = params.cpuparams.n_threads; mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO; + mparams.flash_attn_type = params.flash_attn_type; ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams)); if (!ctx_vision.get()) { LOG_ERR("Failed to load vision model from %s\n", clip_path); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 196641dd95..88930bbb5e 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -100,6 +100,7 @@ mtmd_context_params mtmd_context_params_default() { params.verbosity = GGML_LOG_LEVEL_INFO; params.image_marker = MTMD_DEFAULT_IMAGE_MARKER; params.media_marker = mtmd_default_marker(); + params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; return params; } @@ -164,6 +165,7 @@ struct mtmd_context { clip_context_params ctx_clip_params; ctx_clip_params.use_gpu = ctx_params.use_gpu; ctx_clip_params.verbosity = ctx_params.verbosity; + ctx_clip_params.flash_attn_type = ctx_params.flash_attn_type; auto res = clip_init(mmproj_fname, ctx_clip_params); ctx_v = res.ctx_v; ctx_a = res.ctx_a; diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index 0b5d2ba0c7..91f4f0e4d7 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -82,6 +82,7 @@ struct mtmd_context_params { enum ggml_log_level verbosity; const char * image_marker; // deprecated, use media_marker instead const char * media_marker; + llama_flash_attn_type flash_attn_type; }; MTMD_API const char * mtmd_default_marker(void); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 92d30664e4..fa45531eae 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2456,6 +2456,7 @@ struct server_context { mparams.print_timings = false; mparams.n_threads = params_base.cpuparams.n_threads; mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO; + mparams.flash_attn_type = params_base.flash_attn_type; mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); if (mctx == nullptr) { SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());