implement "auto" mode for clip flash attn
This commit is contained in:
parent
19116a4b38
commit
b4955f0ae6
|
|
@ -4,6 +4,7 @@
|
|||
// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
|
||||
#include "clip.h"
|
||||
#include "clip-impl.h"
|
||||
#include "mtmd.h"
|
||||
#include "ggml.h"
|
||||
#include "ggml-cpp.h"
|
||||
#include "ggml-cpu.h"
|
||||
|
|
@ -427,12 +428,14 @@ struct clip_ctx {
|
|||
|
||||
int max_nodes = 8192;
|
||||
ggml_backend_sched_ptr sched;
|
||||
llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
|
||||
|
||||
// for debugging
|
||||
bool debug_graph = false;
|
||||
std::vector<ggml_tensor *> debug_print_tensors;
|
||||
|
||||
clip_ctx(clip_context_params & ctx_params) {
|
||||
flash_attn_type = ctx_params.flash_attn_type;
|
||||
debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr;
|
||||
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
||||
if (!backend_cpu) {
|
||||
|
|
@ -2261,17 +2264,37 @@ private:
|
|||
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
|
||||
//cb(k, "k", il);
|
||||
|
||||
ggml_tensor * cur;
|
||||
|
||||
if (ctx->flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) {
|
||||
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
||||
//cb(k, "v", il);
|
||||
|
||||
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
|
||||
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
|
||||
|
||||
ggml_tensor * cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f);
|
||||
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f);
|
||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
||||
|
||||
} else {
|
||||
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
|
||||
v = ggml_cont(ctx0, v);
|
||||
|
||||
const auto n_tokens = q->ne[1];
|
||||
const auto n_head = q->ne[2];
|
||||
|
||||
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
// F32 may not needed for vision encoders?
|
||||
// ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||
|
||||
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f);
|
||||
|
||||
ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
||||
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
}
|
||||
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
|
|
@ -3181,7 +3204,30 @@ struct clip_model_loader {
|
|||
}
|
||||
}
|
||||
|
||||
void alloc_compute_meta(clip_ctx & ctx_clip) {
|
||||
void warmup(clip_ctx & ctx_clip) {
|
||||
if (ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
|
||||
// try to enable flash attention to see if it's supported
|
||||
ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
|
||||
bool supported = alloc_compute_meta(ctx_clip);
|
||||
if (!supported) {
|
||||
LOG_WRN("%s: flash attention not supported, memory usage will increase\n", __func__);
|
||||
// TODO: maybe log more details about why flash attention is not supported
|
||||
ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
||||
alloc_compute_meta(ctx_clip);
|
||||
}
|
||||
} else {
|
||||
bool supported = alloc_compute_meta(ctx_clip);
|
||||
if (!supported) {
|
||||
LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
LOG_INF("%s: flash attention is %s\n", __func__,
|
||||
(ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
|
||||
}
|
||||
|
||||
// return false if flash attention is not supported
|
||||
bool alloc_compute_meta(clip_ctx & ctx_clip) {
|
||||
const auto & hparams = ctx_clip.model.hparams;
|
||||
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
|
||||
|
||||
|
|
@ -3217,6 +3263,17 @@ struct clip_model_loader {
|
|||
const int n_nodes = ggml_graph_n_nodes(gf);
|
||||
|
||||
LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__, n_splits, n_nodes);
|
||||
|
||||
// check flash attention support
|
||||
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
||||
ggml_tensor * node = ggml_graph_node(gf, i);
|
||||
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||
if (!ggml_backend_supports_op(ctx_clip.backend, node)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void get_bool(const std::string & key, bool & output, bool required = true) {
|
||||
|
|
@ -3306,14 +3363,14 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
|
|||
ctx_vision = new clip_ctx(ctx_params);
|
||||
loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION);
|
||||
loader.load_tensors(*ctx_vision);
|
||||
loader.alloc_compute_meta(*ctx_vision);
|
||||
loader.warmup(*ctx_vision);
|
||||
}
|
||||
|
||||
if (loader.has_audio) {
|
||||
ctx_audio = new clip_ctx(ctx_params);
|
||||
loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO);
|
||||
loader.load_tensors(*ctx_audio);
|
||||
loader.alloc_compute_meta(*ctx_audio);
|
||||
loader.warmup(*ctx_audio);
|
||||
}
|
||||
|
||||
} catch (const std::exception & e) {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include "mtmd.h"
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
|
|
@ -25,6 +26,7 @@ enum clip_modality {
|
|||
struct clip_context_params {
|
||||
bool use_gpu;
|
||||
enum ggml_log_level verbosity;
|
||||
llama_flash_attn_type flash_attn_type;
|
||||
};
|
||||
|
||||
struct clip_init_result {
|
||||
|
|
|
|||
|
|
@ -136,6 +136,7 @@ struct mtmd_cli_context {
|
|||
mparams.print_timings = true;
|
||||
mparams.n_threads = params.cpuparams.n_threads;
|
||||
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
||||
mparams.flash_attn_type = params.flash_attn_type;
|
||||
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
|
||||
if (!ctx_vision.get()) {
|
||||
LOG_ERR("Failed to load vision model from %s\n", clip_path);
|
||||
|
|
|
|||
|
|
@ -100,6 +100,7 @@ mtmd_context_params mtmd_context_params_default() {
|
|||
params.verbosity = GGML_LOG_LEVEL_INFO;
|
||||
params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
|
||||
params.media_marker = mtmd_default_marker();
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
|
||||
return params;
|
||||
}
|
||||
|
||||
|
|
@ -164,6 +165,7 @@ struct mtmd_context {
|
|||
clip_context_params ctx_clip_params;
|
||||
ctx_clip_params.use_gpu = ctx_params.use_gpu;
|
||||
ctx_clip_params.verbosity = ctx_params.verbosity;
|
||||
ctx_clip_params.flash_attn_type = ctx_params.flash_attn_type;
|
||||
auto res = clip_init(mmproj_fname, ctx_clip_params);
|
||||
ctx_v = res.ctx_v;
|
||||
ctx_a = res.ctx_a;
|
||||
|
|
|
|||
|
|
@ -82,6 +82,7 @@ struct mtmd_context_params {
|
|||
enum ggml_log_level verbosity;
|
||||
const char * image_marker; // deprecated, use media_marker instead
|
||||
const char * media_marker;
|
||||
llama_flash_attn_type flash_attn_type;
|
||||
};
|
||||
|
||||
MTMD_API const char * mtmd_default_marker(void);
|
||||
|
|
|
|||
|
|
@ -2456,6 +2456,7 @@ struct server_context {
|
|||
mparams.print_timings = false;
|
||||
mparams.n_threads = params_base.cpuparams.n_threads;
|
||||
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
||||
mparams.flash_attn_type = params_base.flash_attn_type;
|
||||
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
|
||||
if (mctx == nullptr) {
|
||||
SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
|
||||
|
|
|
|||
Loading…
Reference in New Issue