implement "auto" mode for clip flash attn
This commit is contained in:
parent
19116a4b38
commit
b4955f0ae6
|
|
@ -4,6 +4,7 @@
|
||||||
// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
|
// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
|
||||||
#include "clip.h"
|
#include "clip.h"
|
||||||
#include "clip-impl.h"
|
#include "clip-impl.h"
|
||||||
|
#include "mtmd.h"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "ggml-cpp.h"
|
#include "ggml-cpp.h"
|
||||||
#include "ggml-cpu.h"
|
#include "ggml-cpu.h"
|
||||||
|
|
@ -427,12 +428,14 @@ struct clip_ctx {
|
||||||
|
|
||||||
int max_nodes = 8192;
|
int max_nodes = 8192;
|
||||||
ggml_backend_sched_ptr sched;
|
ggml_backend_sched_ptr sched;
|
||||||
|
llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
|
||||||
|
|
||||||
// for debugging
|
// for debugging
|
||||||
bool debug_graph = false;
|
bool debug_graph = false;
|
||||||
std::vector<ggml_tensor *> debug_print_tensors;
|
std::vector<ggml_tensor *> debug_print_tensors;
|
||||||
|
|
||||||
clip_ctx(clip_context_params & ctx_params) {
|
clip_ctx(clip_context_params & ctx_params) {
|
||||||
|
flash_attn_type = ctx_params.flash_attn_type;
|
||||||
debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr;
|
debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr;
|
||||||
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
||||||
if (!backend_cpu) {
|
if (!backend_cpu) {
|
||||||
|
|
@ -2261,17 +2264,37 @@ private:
|
||||||
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
|
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
|
||||||
//cb(k, "k", il);
|
//cb(k, "k", il);
|
||||||
|
|
||||||
|
ggml_tensor * cur;
|
||||||
|
|
||||||
|
if (ctx->flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) {
|
||||||
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
||||||
//cb(k, "v", il);
|
|
||||||
|
|
||||||
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
|
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
|
||||||
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
|
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
|
||||||
|
|
||||||
ggml_tensor * cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f);
|
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f);
|
||||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||||
|
|
||||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
|
||||||
|
v = ggml_cont(ctx0, v);
|
||||||
|
|
||||||
|
const auto n_tokens = q->ne[1];
|
||||||
|
const auto n_head = q->ne[2];
|
||||||
|
|
||||||
|
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||||
|
// F32 may not needed for vision encoders?
|
||||||
|
// ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||||
|
|
||||||
|
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f);
|
||||||
|
|
||||||
|
ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
||||||
|
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||||
|
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
|
||||||
if (wo) {
|
if (wo) {
|
||||||
|
|
@ -3181,7 +3204,30 @@ struct clip_model_loader {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void alloc_compute_meta(clip_ctx & ctx_clip) {
|
void warmup(clip_ctx & ctx_clip) {
|
||||||
|
if (ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
|
||||||
|
// try to enable flash attention to see if it's supported
|
||||||
|
ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
|
||||||
|
bool supported = alloc_compute_meta(ctx_clip);
|
||||||
|
if (!supported) {
|
||||||
|
LOG_WRN("%s: flash attention not supported, memory usage will increase\n", __func__);
|
||||||
|
// TODO: maybe log more details about why flash attention is not supported
|
||||||
|
ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
||||||
|
alloc_compute_meta(ctx_clip);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bool supported = alloc_compute_meta(ctx_clip);
|
||||||
|
if (!supported) {
|
||||||
|
LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INF("%s: flash attention is %s\n", __func__,
|
||||||
|
(ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
|
||||||
|
}
|
||||||
|
|
||||||
|
// return false if flash attention is not supported
|
||||||
|
bool alloc_compute_meta(clip_ctx & ctx_clip) {
|
||||||
const auto & hparams = ctx_clip.model.hparams;
|
const auto & hparams = ctx_clip.model.hparams;
|
||||||
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
|
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
|
||||||
|
|
||||||
|
|
@ -3217,6 +3263,17 @@ struct clip_model_loader {
|
||||||
const int n_nodes = ggml_graph_n_nodes(gf);
|
const int n_nodes = ggml_graph_n_nodes(gf);
|
||||||
|
|
||||||
LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__, n_splits, n_nodes);
|
LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__, n_splits, n_nodes);
|
||||||
|
|
||||||
|
// check flash attention support
|
||||||
|
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
||||||
|
ggml_tensor * node = ggml_graph_node(gf, i);
|
||||||
|
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||||
|
if (!ggml_backend_supports_op(ctx_clip.backend, node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_bool(const std::string & key, bool & output, bool required = true) {
|
void get_bool(const std::string & key, bool & output, bool required = true) {
|
||||||
|
|
@ -3306,14 +3363,14 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
|
||||||
ctx_vision = new clip_ctx(ctx_params);
|
ctx_vision = new clip_ctx(ctx_params);
|
||||||
loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION);
|
loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION);
|
||||||
loader.load_tensors(*ctx_vision);
|
loader.load_tensors(*ctx_vision);
|
||||||
loader.alloc_compute_meta(*ctx_vision);
|
loader.warmup(*ctx_vision);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (loader.has_audio) {
|
if (loader.has_audio) {
|
||||||
ctx_audio = new clip_ctx(ctx_params);
|
ctx_audio = new clip_ctx(ctx_params);
|
||||||
loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO);
|
loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO);
|
||||||
loader.load_tensors(*ctx_audio);
|
loader.load_tensors(*ctx_audio);
|
||||||
loader.alloc_compute_meta(*ctx_audio);
|
loader.warmup(*ctx_audio);
|
||||||
}
|
}
|
||||||
|
|
||||||
} catch (const std::exception & e) {
|
} catch (const std::exception & e) {
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
#include "mtmd.h"
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
|
|
@ -25,6 +26,7 @@ enum clip_modality {
|
||||||
struct clip_context_params {
|
struct clip_context_params {
|
||||||
bool use_gpu;
|
bool use_gpu;
|
||||||
enum ggml_log_level verbosity;
|
enum ggml_log_level verbosity;
|
||||||
|
llama_flash_attn_type flash_attn_type;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct clip_init_result {
|
struct clip_init_result {
|
||||||
|
|
|
||||||
|
|
@ -136,6 +136,7 @@ struct mtmd_cli_context {
|
||||||
mparams.print_timings = true;
|
mparams.print_timings = true;
|
||||||
mparams.n_threads = params.cpuparams.n_threads;
|
mparams.n_threads = params.cpuparams.n_threads;
|
||||||
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
||||||
|
mparams.flash_attn_type = params.flash_attn_type;
|
||||||
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
|
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
|
||||||
if (!ctx_vision.get()) {
|
if (!ctx_vision.get()) {
|
||||||
LOG_ERR("Failed to load vision model from %s\n", clip_path);
|
LOG_ERR("Failed to load vision model from %s\n", clip_path);
|
||||||
|
|
|
||||||
|
|
@ -100,6 +100,7 @@ mtmd_context_params mtmd_context_params_default() {
|
||||||
params.verbosity = GGML_LOG_LEVEL_INFO;
|
params.verbosity = GGML_LOG_LEVEL_INFO;
|
||||||
params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
|
params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
|
||||||
params.media_marker = mtmd_default_marker();
|
params.media_marker = mtmd_default_marker();
|
||||||
|
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -164,6 +165,7 @@ struct mtmd_context {
|
||||||
clip_context_params ctx_clip_params;
|
clip_context_params ctx_clip_params;
|
||||||
ctx_clip_params.use_gpu = ctx_params.use_gpu;
|
ctx_clip_params.use_gpu = ctx_params.use_gpu;
|
||||||
ctx_clip_params.verbosity = ctx_params.verbosity;
|
ctx_clip_params.verbosity = ctx_params.verbosity;
|
||||||
|
ctx_clip_params.flash_attn_type = ctx_params.flash_attn_type;
|
||||||
auto res = clip_init(mmproj_fname, ctx_clip_params);
|
auto res = clip_init(mmproj_fname, ctx_clip_params);
|
||||||
ctx_v = res.ctx_v;
|
ctx_v = res.ctx_v;
|
||||||
ctx_a = res.ctx_a;
|
ctx_a = res.ctx_a;
|
||||||
|
|
|
||||||
|
|
@ -82,6 +82,7 @@ struct mtmd_context_params {
|
||||||
enum ggml_log_level verbosity;
|
enum ggml_log_level verbosity;
|
||||||
const char * image_marker; // deprecated, use media_marker instead
|
const char * image_marker; // deprecated, use media_marker instead
|
||||||
const char * media_marker;
|
const char * media_marker;
|
||||||
|
llama_flash_attn_type flash_attn_type;
|
||||||
};
|
};
|
||||||
|
|
||||||
MTMD_API const char * mtmd_default_marker(void);
|
MTMD_API const char * mtmd_default_marker(void);
|
||||||
|
|
|
||||||
|
|
@ -2456,6 +2456,7 @@ struct server_context {
|
||||||
mparams.print_timings = false;
|
mparams.print_timings = false;
|
||||||
mparams.n_threads = params_base.cpuparams.n_threads;
|
mparams.n_threads = params_base.cpuparams.n_threads;
|
||||||
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
||||||
|
mparams.flash_attn_type = params_base.flash_attn_type;
|
||||||
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
|
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
|
||||||
if (mctx == nullptr) {
|
if (mctx == nullptr) {
|
||||||
SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
|
SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue