mtmd: refactor audio preprocessing (#17978)

* mtmd: refactor audio preprocessing

* refactor

Co-authored-by: Tarek <tdakhran@users.noreply.github.com>

* wip

* wip (2)

* improve constructor

* fix use_natural_log

* fix padding for short input

* clean up

* remove need_chunking

---------

Co-authored-by: Tarek <tdakhran@users.noreply.github.com>
This commit is contained in:
Xuan-Son Nguyen 2025-12-15 14:16:52 +01:00 committed by GitHub
parent 4a4f7e6550
commit 96a181a933
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 379 additions and 596 deletions

View File

@ -65,6 +65,13 @@ struct clip_hparams {
int32_t n_mel_bins = 0; // whisper preprocessor
int32_t proj_stack_factor = 0; // ultravox
// audio-to-mel preprocessor params
int32_t audio_chunk_len = -1; // in seconds
int32_t audio_sample_rate = -1;
int32_t audio_n_fft = -1;
int32_t audio_window_len = -1;
int32_t audio_hop_len = -1;
// legacy
bool has_llava_projector = false;
int minicpmv_version = 0;
@ -278,3 +285,5 @@ struct clip_model {
|| proj_type == PROJECTOR_TYPE_VOXTRAL;
}
};
const clip_hparams * clip_get_hparams(const struct clip_ctx * ctx);

View File

@ -1170,11 +1170,15 @@ struct clip_model_loader {
model.proj_type == PROJECTOR_TYPE_VOXTRAL ||
model.proj_type == PROJECTOR_TYPE_GLMA;
get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack);
if (hparams.n_mel_bins != 128) {
throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__));
}
hparams.ffn_op = FFN_GELU_ERF;
log_ffn_op = "gelu_erf"; // temporary solution for logging
// audio preprocessing params
hparams.audio_chunk_len = 30; // in seconds
hparams.audio_sample_rate = 16000;
hparams.audio_n_fft = 400;
hparams.audio_window_len = 400;
hparams.audio_hop_len = 160;
} break;
default:
break;
@ -1212,6 +1216,11 @@ struct clip_model_loader {
LOG_INF("\n--- audio hparams ---\n");
LOG_INF("%s: n_mel_bins: %d\n", __func__, hparams.n_mel_bins);
LOG_INF("%s: proj_stack_factor: %d\n", __func__, hparams.proj_stack_factor);
LOG_INF("%s: audio_chunk_len: %d\n", __func__, hparams.audio_chunk_len);
LOG_INF("%s: audio_sample_rate: %d\n", __func__, hparams.audio_sample_rate);
LOG_INF("%s: audio_n_fft: %d\n", __func__, hparams.audio_n_fft);
LOG_INF("%s: audio_window_len: %d\n", __func__, hparams.audio_window_len);
LOG_INF("%s: audio_hop_len: %d\n", __func__, hparams.audio_hop_len);
}
LOG_INF("\n");
LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
@ -3478,3 +3487,7 @@ void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel
batch->entries.push_back(clip_image_f32_ptr(audio));
batch->is_audio = true;
}
const clip_hparams * clip_get_hparams(const struct clip_ctx * ctx) {
return &ctx->model.hparams;
}

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,7 @@
#pragma once
#include "ggml.h"
#include "clip-model.h"
#include <cstdint>
#include <vector>
@ -8,18 +9,7 @@
#define MTMD_INTERNAL_HEADER
#define WHISPER_ASSERT GGML_ASSERT
#define WHISPER_SAMPLE_RATE 16000
#define WHISPER_N_FFT 400
#define WHISPER_HOP_LENGTH 160
#define WHISPER_CHUNK_SIZE 30
#define COMMON_SAMPLE_RATE 16000
namespace whisper_preprocessor {
struct whisper_mel {
struct mtmd_audio_mel {
int n_len;
int n_len_org;
int n_mel;
@ -27,23 +17,18 @@ struct whisper_mel {
std::vector<float> data;
};
struct whisper_filters {
int32_t n_mel;
int32_t n_fft;
struct mtmd_audio_preprocessor {
const clip_hparams & hparams;
std::vector<float> data;
mtmd_audio_preprocessor(const clip_ctx * ctx): hparams(*clip_get_hparams(ctx)) {}
virtual ~mtmd_audio_preprocessor() = default;
virtual void initialize() = 0; // NOT thread-safe
virtual bool preprocess(const float * samples, size_t n_samples, std::vector<mtmd_audio_mel> & output) = 0;
};
bool preprocess_audio(
const float * samples,
size_t n_samples,
const whisper_filters & filters,
std::vector<whisper_mel> & output);
} // namespace whisper_preprocessor
namespace whisper_precalc_filters {
whisper_preprocessor::whisper_filters get_128_bins();
} // namespace whisper_precalc_filters
struct mtmd_audio_preprocessor_whisper : mtmd_audio_preprocessor {
mtmd_audio_preprocessor_whisper(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {}
void initialize() override;
bool preprocess(const float * samples, size_t n_samples, std::vector<mtmd_audio_mel> & output) override;
};

View File

@ -151,8 +151,7 @@ struct mtmd_context {
// string template for slice image delimiters with row/col (idefics3)
std::string sli_img_start_tmpl;
// for whisper, we pre-calculate the mel filter bank
whisper_preprocessor::whisper_filters w_filters;
std::unique_ptr<mtmd_audio_preprocessor> audio_preproc;
// TODO @ngxson : add timings
@ -317,14 +316,25 @@ struct mtmd_context {
GGML_ASSERT(ctx_a != nullptr);
projector_type proj = clip_get_projector_type(ctx_a);
if (clip_has_whisper_encoder(ctx_a)) {
// TODO @ngxson : check if model n_mel is 128 or 80
w_filters = whisper_precalc_filters::get_128_bins();
}
LOG_WRN("%s: audio input is in experimental stage and may have reduced quality:\n"
" https://github.com/ggml-org/llama.cpp/discussions/13759\n", __func__);
// set preprocessor
switch (proj) {
case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_QWEN25O:
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_VOXTRAL:
audio_preproc = std::make_unique<mtmd_audio_preprocessor_whisper>(ctx_a);
break;
default:
GGML_ABORT("unsupported audio projector type");
}
// initialize audio preprocessor
audio_preproc->initialize();
// set special tokens
if (proj == PROJECTOR_TYPE_QWEN2A) {
// <|audio_bos|> ... (embeddings) ... <|audio_eos|>
aud_beg = "<|audio_bos|>";
@ -653,11 +663,10 @@ struct mtmd_tokenizer {
}
// preprocess audio
GGML_ASSERT(ctx->w_filters.n_mel); // make sure we have filter preloaded
std::vector<whisper_preprocessor::whisper_mel> mel_spec_chunks;
std::vector<mtmd_audio_mel> mel_spec_chunks;
const float * samples = (const float *)bitmap->data.data();
size_t n_samples = bitmap->data.size() / sizeof(float);
bool ok = whisper_preprocessor::preprocess_audio(samples, n_samples, ctx->w_filters, mel_spec_chunks);
bool ok = ctx->audio_preproc->preprocess(samples, n_samples, mel_spec_chunks);
if (!ok) {
LOG_ERR("Unable to preprocess audio\n");
return 2;
@ -863,8 +872,7 @@ int mtmd_get_audio_bitrate(mtmd_context * ctx) {
if (!ctx->ctx_a) {
return -1;
}
// for now, we assume that all audio models have the same bitrate
return 16000; // 16kHz
return clip_get_hparams(ctx->ctx_a)->audio_sample_rate;
}
//