mtmd: fix --no-warmup (#17695)

This commit is contained in:
Xuan-Son Nguyen 2025-12-02 22:48:08 +01:00 committed by GitHub
parent 4eba8d9451
commit a96283adc4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 20 deletions

View File

@ -1226,7 +1226,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) { [](common_params & params) {
params.warmup = false; params.warmup = false;
} }
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY})); ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY}));
add_opt(common_arg( add_opt(common_arg(
{"--spm-infill"}, {"--spm-infill"},
string_format( string_format(

View File

@ -428,6 +428,7 @@ struct clip_ctx {
int max_nodes = 8192; int max_nodes = 8192;
ggml_backend_sched_ptr sched; ggml_backend_sched_ptr sched;
clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO; clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO;
bool is_allocated = false;
// for debugging // for debugging
bool debug_graph = false; bool debug_graph = false;
@ -3305,12 +3306,30 @@ struct clip_model_loader {
}; };
static void warmup(clip_ctx & ctx_clip) { static void warmup(clip_ctx & ctx_clip) {
// create a fake batch
const auto & hparams = ctx_clip.model.hparams;
clip_image_f32_batch batch;
clip_image_f32_ptr img(clip_image_f32_init());
if (ctx_clip.model.modality == CLIP_MODALITY_VISION) {
img->nx = hparams.warmup_image_size;
img->ny = hparams.warmup_image_size;
LOG_INF("%s: warmup with image size = %d x %d\n", __func__, img->nx, img->ny);
} else {
img->nx = hparams.warmup_audio_size;
img->ny = hparams.n_mel_bins;
LOG_INF("%s: warmup with audio size = %d\n", __func__, img->nx);
}
batch.entries.push_back(std::move(img));
warmup(ctx_clip, batch);
}
static void warmup(clip_ctx & ctx_clip, const clip_image_f32_batch & batch) {
support_info_graph info; support_info_graph info;
if (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_AUTO) { if (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_AUTO) {
// try to enable flash attention to see if it's supported // try to enable flash attention to see if it's supported
ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_ENABLED; ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_ENABLED;
info = alloc_compute_meta(ctx_clip); info = alloc_compute_meta(ctx_clip, batch);
if (!info.fattn && info.fattn_op) { if (!info.fattn && info.fattn_op) {
auto op = info.fattn_op; auto op = info.fattn_op;
LOG_WRN("%s: *****************************************************************\n", __func__); LOG_WRN("%s: *****************************************************************\n", __func__);
@ -3329,15 +3348,17 @@ struct clip_model_loader {
LOG_WRN("%s: please report this on github as an issue\n", __func__); LOG_WRN("%s: please report this on github as an issue\n", __func__);
LOG_WRN("%s: *****************************************************************\n", __func__); LOG_WRN("%s: *****************************************************************\n", __func__);
ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_DISABLED; ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_DISABLED;
alloc_compute_meta(ctx_clip); alloc_compute_meta(ctx_clip, batch);
} }
} else { } else {
info = alloc_compute_meta(ctx_clip); info = alloc_compute_meta(ctx_clip, batch);
if (!info.fattn && ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { if (!info.fattn && ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__); LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__);
} }
} }
ctx_clip.is_allocated = true; // mark buffers as allocated
LOG_INF("%s: flash attention is %s\n", __func__, LOG_INF("%s: flash attention is %s\n", __func__,
(ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled"); (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
@ -3369,24 +3390,9 @@ struct clip_model_loader {
} }
} }
static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip) { static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip, const clip_image_f32_batch & batch) {
const auto & hparams = ctx_clip.model.hparams;
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead()); ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
// create a fake batch
clip_image_f32_batch batch;
clip_image_f32_ptr img(clip_image_f32_init());
if (ctx_clip.model.modality == CLIP_MODALITY_VISION) {
img->nx = hparams.warmup_image_size;
img->ny = hparams.warmup_image_size;
LOG_INF("%s: warmup with image size = %d x %d\n", __func__, img->nx, img->ny);
} else {
img->nx = hparams.warmup_audio_size;
img->ny = hparams.n_mel_bins;
LOG_INF("%s: warmup with audio size = %d\n", __func__, img->nx);
}
batch.entries.push_back(std::move(img));
ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch); ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch);
ggml_backend_sched_reserve(ctx_clip.sched.get(), gf); ggml_backend_sched_reserve(ctx_clip.sched.get(), gf);
@ -4630,6 +4636,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
return false; // only support batch size of 1 return false; // only support batch size of 1
} }
// if buffers are not allocated, we need to do a warmup run to allocate them
if (!ctx->is_allocated) {
clip_model_loader::warmup(*ctx, *imgs_c_ptr);
}
// build the inference graph // build the inference graph
ctx->debug_print_tensors.clear(); ctx->debug_print_tensors.clear();
ggml_backend_sched_reset(ctx->sched.get()); ggml_backend_sched_reset(ctx->sched.get());