From ccbc84a5374bab7a01f68b129411772ddd8e7c79 Mon Sep 17 00:00:00 2001 From: Tarek Dakhran Date: Tue, 6 Jan 2026 21:00:29 +0100 Subject: [PATCH] mtmd: mtmd_audio_streaming_istft (#18645) Change is decoupled from https://github.com/ggml-org/llama.cpp/pull/18641. [LFM2.5-Audio-1.5B](https://huggingface.co/LiquidAI/LFM2.5-Audio-1.5B) needs streaming istft for generating output audio. * add streaming ISTFT class (`mtmd_audio_streaming_istft`) with overlap-add for audio reconstruction * replace global audio cache with per-instance cache, the model requires two independent caches, for preprocessing (audio input) and for istft (audio output). * unified templated FFT/IFFT implementation supporting both forward and inverse transforms --- tools/mtmd/mtmd-audio.cpp | 570 ++++++++++++++++++++++++-------------- tools/mtmd/mtmd-audio.h | 73 +++++ 2 files changed, 428 insertions(+), 215 deletions(-) diff --git a/tools/mtmd/mtmd-audio.cpp b/tools/mtmd/mtmd-audio.cpp index e99101184b..e8eef035ff 100644 --- a/tools/mtmd/mtmd-audio.cpp +++ b/tools/mtmd/mtmd-audio.cpp @@ -9,207 +9,250 @@ #include #include -// most of the code here is copied from whisper.cpp +// some of the code here is copied from whisper.cpp constexpr bool DEBUG = false; -struct mtmd_audio_mel_filters { - int32_t n_mel; - int32_t n_fft; +void mtmd_audio_cache::fill_sin_cos_table(int n) { + sin_vals.resize(n); + cos_vals.resize(n); + for (int i = 0; i < n; i++) { + double theta = (2 * M_PI * i) / n; + sin_vals[i] = sinf(theta); + cos_vals[i] = cosf(theta); + } +} - std::vector data; -}; +void mtmd_audio_cache::fill_hann_window(int length, bool periodic) { + hann_window.resize(length); + int offset = -1; + if (periodic) { + offset = 0; + } + for (int i = 0; i < length; i++) { + hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + } +} -// note: this global cache is shared among all preprocessors -// if we want to use multiple preprocessors at the same time, -// we will need to enclose it in the preprocessor class in the future -static struct mtmd_audio_global_cache { - // precomputed sin/cos table for FFT - std::vector sin_vals; - std::vector cos_vals; - - // hann window - std::vector hann_window; - - // mel filter bank - mtmd_audio_mel_filters filters; - - void fill_sin_cos_table(int n) { - sin_vals.resize(n); - cos_vals.resize(n); - for (int i = 0; i < n; i++) { - double theta = (2 * M_PI * i) / n; - sin_vals[i] = sinf(theta); - cos_vals[i] = cosf(theta); - } +void mtmd_audio_cache::fill_mel_filterbank_matrix(int n_mel, + int n_fft, + int sample_rate, + float fmin, + float fmax, + bool slaney_area_norm, + float scale) { + GGML_ASSERT(n_mel > 0 && n_fft > 1); + if (fmax <= 0.0f) { + fmax = 0.5f * sample_rate; } - void fill_hann_window(int length, bool periodic) { - hann_window.resize(length); - int offset = -1; - if (periodic) { - offset = 0; - } - for (int i = 0; i < length; i++) { - hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); - } + // Slaney scale (matches librosa default) + const double min_log_hz = 1000.0; + const double lin_slope = 3 / 200.; + const double min_log_mel = min_log_hz * lin_slope; + const double log_step = log(6.4) / 27.0; + auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double { + return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step; + }; + auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double { + return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step); + }; + + // infer N_fft from n_fft_bins + const double bin_hz_step = double(sample_rate) / double(n_fft); + + // mel grid: n_mel + 2 edges + const double m_lo = hz_to_mel(fmin); + const double m_hi = hz_to_mel(fmax); + std::vector mel_pts(n_mel + 2); + for (int i = 0; i < n_mel + 2; ++i) { + mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1)); } - // Build mel filterbank matrix [n_mel × n_fft_bins] at runtime. - // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257. - void fill_mel_filterbank_matrix( - int n_mel, - int n_fft, - int sample_rate, // e.g. 16000 - float fmin = 0.0f, // e.g. 0.0 - float fmax = -1.0f, // e.g. sr/2; pass -1 for auto - bool slaney_area_norm = true, - float scale = 1.0f // optional extra scaling; use 1.0f/1000.0f to mimic your code - ) { - GGML_ASSERT(n_mel > 0 && n_fft > 1); - if (fmax <= 0.0f) { - fmax = 0.5f * sample_rate; - } + // convert to Hz + std::vector hz_pts(n_mel + 2); + for (int i = 0; i < n_mel + 2; ++i) { + hz_pts[i] = mel_to_hz(mel_pts[i]); + } - // Slaney scale (matches librosa default) - const double min_log_hz = 1000.0; - const double lin_slope = 3 / 200.; - const double min_log_mel = min_log_hz * lin_slope; - const double log_step = log(6.4) / 27.0; - auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double { - return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step; - }; - auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double { - return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step); - }; + const int n_fft_bins = n_fft / 2 + 1; - // infer N_fft from n_fft_bins - const double bin_hz_step = double(sample_rate) / double(n_fft); + // filterbank + std::vector out(n_mel * n_fft_bins, 0); + for (int m = 0; m < n_mel; ++m) { + const double f_left = hz_pts[m]; + const double f_center = hz_pts[m + 1]; + const double f_right = hz_pts[m + 2]; - // mel grid: n_mel + 2 edges - const double m_lo = hz_to_mel(fmin); - const double m_hi = hz_to_mel(fmax); - std::vector mel_pts(n_mel + 2); - for (int i = 0; i < n_mel + 2; ++i) { - mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1)); - } + const double denom_l = std::max(1e-30, f_center - f_left); + const double denom_r = std::max(1e-30, f_right - f_center); + const double enorm = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0; - // convert to Hz - std::vector hz_pts(n_mel + 2); - for (int i = 0; i < n_mel + 2; ++i) { - hz_pts[i] = mel_to_hz(mel_pts[i]); - } - - const int n_fft_bins = n_fft / 2 + 1; - - // filterbank - std::vector out(n_mel * n_fft_bins, 0); - for (int m = 0; m < n_mel; ++m) { - const double f_left = hz_pts[m]; - const double f_center = hz_pts[m + 1]; - const double f_right = hz_pts[m + 2]; - - const double denom_l = std::max(1e-30, f_center - f_left); - const double denom_r = std::max(1e-30, f_right - f_center); - const double enorm = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0; - - for (int k = 0; k < n_fft_bins; ++k) { - const double f = k * bin_hz_step; - double w = 0.0; - if (f >= f_left && f <= f_center) { - w = (f - f_left) / denom_l; - } else if (f > f_center && f <= f_right) { - w = (f_right - f) / denom_r; - } - out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale); + for (int k = 0; k < n_fft_bins; ++k) { + const double f = k * bin_hz_step; + double w = 0.0; + if (f >= f_left && f <= f_center) { + w = (f - f_left) / denom_l; + } else if (f > f_center && f <= f_right) { + w = (f_right - f) / denom_r; } + out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale); } + } - filters.n_mel = n_mel; - filters.n_fft = n_fft; - filters.data = std::move(out); + filters.n_mel = n_mel; + filters.n_fft = n_fft; + filters.data = std::move(out); - if (DEBUG) { // debug - for (size_t i = 0; i < filters.data.size(); ++i) { - if (filters.data[i] != 0.0f) { - printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f); - } + if (DEBUG) { // debug + for (size_t i = 0; i < filters.data.size(); ++i) { + if (filters.data[i] != 0.0f) { + printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f); } } } -} g_cache; +} -// naive Discrete Fourier Transform -// input is real-valued -// output is complex-valued -static void dft(const float * in, int N, float * out) { - const int n_sin_cos_vals = g_cache.sin_vals.size(); - const int sin_cos_step = n_sin_cos_vals / N; +// Unified DFT implementation for both forward and inverse transforms +// Template parameters: +// Inverse: false = DFT with exp(-2πi·k·n/N), no scaling +// true = IDFT with exp(+2πi·k·n/N), scales by 1/N +// RealInput: true = input is real-valued (stride 1), avoids imaginary computations +// false = input is complex-valued (interleaved real/imag, stride 2) +template +static void dft_impl(const mtmd_audio_cache & cache, const float * in, int N, float * out) { + const int n_sin_cos_vals = cache.sin_vals.size(); + const int sin_cos_step = n_sin_cos_vals / N; + + constexpr float sign = Inverse ? 1.0f : -1.0f; + const float scale = Inverse ? (1.0f / N) : 1.0f; for (int k = 0; k < N; k++) { float re = 0; float im = 0; for (int n = 0; n < N; n++) { - int idx = (k * n * sin_cos_step) % (n_sin_cos_vals); // t = 2*M_PI*k*n/N - re += in[n] * g_cache.cos_vals[idx]; // cos(t) - im -= in[n] * g_cache.sin_vals[idx]; // sin(t) + int idx = (k * n * sin_cos_step) % n_sin_cos_vals; + float cos_val = cache.cos_vals[idx]; + float sin_val = cache.sin_vals[idx]; + + if constexpr (RealInput) { + // Real input: in_im = 0, simplifies to: + // re += in_re * cos_val + // im += sign * in_re * sin_val + float in_re = in[n]; + re += in_re * cos_val; + im += sign * in_re * sin_val; + } else { + float in_re = in[n * 2 + 0]; + float in_im = in[n * 2 + 1]; + // (a + bi) * (cos + sign*i*sin) = (a*cos - sign*b*sin) + (sign*a*sin + b*cos)i + re += in_re * cos_val - sign * in_im * sin_val; + im += sign * in_re * sin_val + in_im * cos_val; + } } - out[k*2 + 0] = re; - out[k*2 + 1] = im; + out[k * 2 + 0] = re * scale; + out[k * 2 + 1] = im * scale; } } -// Cooley-Tukey FFT -// poor man's implementation - use something better -// input is real-valued -// output is complex-valued -static void fft(float * in, int N, float * out) { - const int n_sin_cos_vals = g_cache.sin_vals.size(); +// Cooley-Tukey FFT/IFFT unified implementation +// Template parameters: +// Inverse: false = FFT with exp(-2πi·k/N), no scaling +// true = IFFT with exp(+2πi·k/N), scales by 0.5 at each level +// RealInput: true = input is real-valued (stride 1) +// false = input is complex-valued (interleaved real/imag, stride 2) +template +static void fft_impl(const mtmd_audio_cache & cache, float * in, int N, float * out) { + const int n_sin_cos_vals = cache.sin_vals.size(); + if (N == 1) { out[0] = in[0]; - out[1] = 0; + if constexpr (RealInput) { + out[1] = 0.0f; + } else { + out[1] = in[1]; + } return; } const int half_N = N / 2; - if (N - half_N*2 == 1) { - dft(in, N, out); + if (N - half_N * 2 == 1) { + // Odd N: fall back to DFT + dft_impl(cache, in, N, out); return; } - float* even = in + N; - for (int i = 0; i < half_N; ++i) { - even[i]= in[2*i]; - } - float* even_fft = out + 2 * N; - fft(even, half_N, even_fft); + // Split into even and odd + if constexpr (RealInput) { + // Real input: stride is 1, copy only real values + float * even = in + N; + for (int i = 0; i < half_N; ++i) { + even[i] = in[2 * i]; + } + float * even_fft = out + 2 * N; + fft_impl(cache, even, half_N, even_fft); - float* odd = even; - for (int i = 0; i < half_N; ++i) { - odd[i] = in[2*i + 1]; + float * odd = even; + for (int i = 0; i < half_N; ++i) { + odd[i] = in[2 * i + 1]; + } + float * odd_fft = even_fft + N; + fft_impl(cache, odd, half_N, odd_fft); + } else { + // Complex input: stride is 2, copy complex pairs + float * even = in + N * 2; + for (int i = 0; i < half_N; ++i) { + even[i * 2 + 0] = in[2 * i * 2 + 0]; + even[i * 2 + 1] = in[2 * i * 2 + 1]; + } + float * even_fft = out + 2 * N; + fft_impl(cache, even, half_N, even_fft); + + float * odd = even; + for (int i = 0; i < half_N; ++i) { + odd[i * 2 + 0] = in[(2 * i + 1) * 2 + 0]; + odd[i * 2 + 1] = in[(2 * i + 1) * 2 + 1]; + } + float * odd_fft = even_fft + N; + fft_impl(cache, odd, half_N, odd_fft); } - float* odd_fft = even_fft + N; - fft(odd, half_N, odd_fft); + + float * even_fft = out + 2 * N; + float * odd_fft = even_fft + N; const int sin_cos_step = n_sin_cos_vals / N; + + constexpr float sign = Inverse ? 1.0f : -1.0f; + constexpr float scale = Inverse ? 0.5f : 1.0f; + for (int k = 0; k < half_N; k++) { - int idx = k * sin_cos_step; // t = 2*M_PI*k/N - float re = g_cache.cos_vals[idx]; // cos(t) - float im = -g_cache.sin_vals[idx]; // sin(t) + int idx = k * sin_cos_step; // t = 2*M_PI*k/N + float re = cache.cos_vals[idx]; + float im = sign * cache.sin_vals[idx]; - float re_odd = odd_fft[2*k + 0]; - float im_odd = odd_fft[2*k + 1]; + float re_odd = odd_fft[2 * k + 0]; + float im_odd = odd_fft[2 * k + 1]; - out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; - out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; + out[2 * k + 0] = scale * (even_fft[2 * k + 0] + re * re_odd - im * im_odd); + out[2 * k + 1] = scale * (even_fft[2 * k + 1] + re * im_odd + im * re_odd); - out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; - out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; + out[2 * (k + half_N) + 0] = scale * (even_fft[2 * k + 0] - re * re_odd + im * im_odd); + out[2 * (k + half_N) + 1] = scale * (even_fft[2 * k + 1] - re * im_odd - im * re_odd); } } +// Forward FFT for real input (used by mel spectrogram) +static void fft(const mtmd_audio_cache & cache, float * in, int N, float * out) { + fft_impl(cache, in, N, out); +} + +// Inverse FFT for complex input +static void ifft(const mtmd_audio_cache & cache, float * in, int N, float * out) { + fft_impl(cache, in, N, out); +} + struct filter_params { int32_t n_mel; int32_t n_fft_bins; @@ -222,20 +265,27 @@ struct filter_params { bool norm_per_feature = false; }; -static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples, - int n_samples, int frame_size, int frame_step, int n_threads, - const filter_params & params, mtmd_audio_mel & out) { +static void log_mel_spectrogram_worker_thread(int ith, + const float * hann, + const std::vector & samples, + int n_samples, + int frame_size, + int frame_step, + int n_threads, + const filter_params & params, + const mtmd_audio_cache & cache, + mtmd_audio_mel & out) { std::vector fft_in(frame_size * 2, 0.0); std::vector fft_out(frame_size * 2 * 2 * 2); int n_fft_bins = params.n_fft_bins; int i = ith; - const auto & filters = g_cache.filters; + const auto & filters = cache.filters; // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist GGML_ASSERT(n_fft_bins == 1 + (frame_size / 2)); - GGML_ASSERT(g_cache.sin_vals.size() == g_cache.cos_vals.size()); + GGML_ASSERT(cache.sin_vals.size() == cache.cos_vals.size()); // calculate FFT only when fft_in are not all zero for (; i < std::min(n_samples / frame_step + 1, out.n_len); i += n_threads) { const int offset = i * frame_step; @@ -251,7 +301,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const } // FFT - fft(fft_in.data(), frame_size, fft_out.data()); + fft(cache, fft_in.data(), frame_size, fft_out.data()); // Calculate modulus^2 of complex numbers // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. @@ -298,6 +348,7 @@ static bool log_mel_spectrogram( const int n_samples_in, const int n_threads, const filter_params & params, + const mtmd_audio_cache & cache, mtmd_audio_mel & out) { //const int64_t t_start_us = ggml_time_us(); @@ -305,9 +356,9 @@ static bool log_mel_spectrogram( int n_samples = n_samples_in; // Hann window - const float * hann = g_cache.hann_window.data(); - const int frame_size = (params.n_fft_bins - 1) * 2; - const int frame_step = params.hop_length; + const float * hann = cache.hann_window.data(); + const int frame_size = (params.n_fft_bins - 1) * 2; + const int frame_step = params.hop_length; // Padding std::vector samples_padded; @@ -335,9 +386,9 @@ static bool log_mel_spectrogram( // preemphasis if (params.preemph) { - const int pad_amount = frame_size / 2; + const int pad_amount = frame_size / 2; const float preemph = 0.97f; - float prev = samples_padded[pad_amount]; + float prev = samples_padded[pad_amount]; for (int i = pad_amount + 1; i + pad_amount < n_samples; ++i) { float cur = samples_padded[i]; samples_padded[i] = cur - preemph * prev; @@ -372,14 +423,14 @@ static bool log_mel_spectrogram( { std::vector workers(n_threads - 1); for (int iw = 0; iw < n_threads - 1; ++iw) { - workers[iw] = std::thread( - log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), - n_samples, frame_size, frame_step, n_threads, - std::cref(params), std::ref(out)); + workers[iw] = + std::thread(log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), n_samples, + frame_size, frame_step, n_threads, std::cref(params), std::cref(cache), std::ref(out)); } // main thread - log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params, out); + log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params, + cache, out); for (int iw = 0; iw < n_threads - 1; ++iw) { workers[iw].join(); } @@ -404,7 +455,7 @@ static bool log_mel_spectrogram( for (int j = 0; j < effective_n_len; ++j) { auto &value = out.data[i * out.n_len + j]; - value = (value - mean) / mstd; + value = (value - mean) / mstd; } // pad the rest with zeros @@ -450,18 +501,14 @@ static bool log_mel_spectrogram( // void mtmd_audio_preprocessor_whisper::initialize() { - g_cache.fill_sin_cos_table(hparams.audio_n_fft); - g_cache.fill_hann_window(hparams.audio_window_len, true); - g_cache.fill_mel_filterbank_matrix( - hparams.n_mel_bins, - hparams.audio_n_fft, - hparams.audio_sample_rate); + cache.fill_sin_cos_table(hparams.audio_n_fft); + cache.fill_hann_window(hparams.audio_window_len, true); + cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate); } -bool mtmd_audio_preprocessor_whisper::preprocess( - const float * samples, - size_t n_samples, - std::vector & output) { +bool mtmd_audio_preprocessor_whisper::preprocess(const float * samples, + size_t n_samples, + std::vector & output) { if (n_samples == 0) { // empty audio return false; @@ -471,7 +518,7 @@ bool mtmd_audio_preprocessor_whisper::preprocess( // if input is too short, pad with zeros // this is to avoid potential issues with stage1/2 padding in log_mel_spectrogram // TODO: maybe handle this better - size_t min_samples = (size_t)hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin + size_t min_samples = (size_t) hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin if (n_samples < min_samples) { smpl.resize(min_samples, 0.0f); std::memcpy(smpl.data(), samples, n_samples * sizeof(float)); @@ -486,22 +533,19 @@ bool mtmd_audio_preprocessor_whisper::preprocess( params.hop_length = hparams.audio_hop_len; params.sample_rate = hparams.audio_sample_rate; params.center_padding = false; - params.preemph = 0.0f; // disabled + params.preemph = 0.0f; // disabled params.use_natural_log = false; params.norm_per_feature = false; - // make sure the global cache is initialized - GGML_ASSERT(!g_cache.sin_vals.empty()); - GGML_ASSERT(!g_cache.cos_vals.empty()); - GGML_ASSERT(!g_cache.filters.data.empty()); + // make sure the cache is initialized + GGML_ASSERT(!cache.sin_vals.empty()); + GGML_ASSERT(!cache.cos_vals.empty()); + GGML_ASSERT(!cache.filters.data.empty()); mtmd_audio_mel out_full; - bool ok = log_mel_spectrogram( - samples, - n_samples, - 4, // n_threads - params, - out_full); + bool ok = log_mel_spectrogram(samples, n_samples, + 4, // n_threads + params, cache, out_full); if (!ok) { return false; } @@ -512,21 +556,21 @@ bool mtmd_audio_preprocessor_whisper::preprocess( printf("output: n_mel = %d, n_len = %d\n", out_full.n_mel, out_full.n_len); } const size_t frames_per_chunk = 3000; - GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk); - for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) { - int n_len = std::min(frames_per_chunk, (size_t)out_full.n_len - off); - if ((size_t)n_len < frames_per_chunk) { - break; // last uncomplete chunk will always be a padded chunk, safe to ignore + GGML_ASSERT((size_t) out_full.n_len > frames_per_chunk); + for (size_t off = 0; off < (size_t) out_full.n_len; off += frames_per_chunk) { + int n_len = std::min(frames_per_chunk, (size_t) out_full.n_len - off); + if ((size_t) n_len < frames_per_chunk) { + break; // last uncomplete chunk will always be a padded chunk, safe to ignore } mtmd_audio_mel out_chunk; out_chunk.n_len = n_len; out_chunk.n_mel = out_full.n_mel; - out_chunk.n_len_org = out_full.n_mel; // unused + out_chunk.n_len_org = out_full.n_mel; // unused out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len); for (int i = 0; i < out_full.n_mel; i++) { - auto src = out_full.data.begin() + i*out_full.n_len + off; + auto src = out_full.data.begin() + i * out_full.n_len + off; out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk); } @@ -541,18 +585,14 @@ bool mtmd_audio_preprocessor_whisper::preprocess( // void mtmd_audio_preprocessor_conformer::initialize() { - g_cache.fill_sin_cos_table(hparams.audio_n_fft); - g_cache.fill_hann_window(hparams.audio_window_len, true); - g_cache.fill_mel_filterbank_matrix( - hparams.n_mel_bins, - hparams.audio_n_fft, - hparams.audio_sample_rate); + cache.fill_sin_cos_table(hparams.audio_n_fft); + cache.fill_hann_window(hparams.audio_window_len, true); + cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate); } -bool mtmd_audio_preprocessor_conformer::preprocess( - const float * samples, - size_t n_samples, - std::vector & output) { +bool mtmd_audio_preprocessor_conformer::preprocess(const float * samples, + size_t n_samples, + std::vector & output) { // empty audio if (n_samples == 0) { return false; @@ -569,18 +609,15 @@ bool mtmd_audio_preprocessor_conformer::preprocess( params.use_natural_log = true; params.norm_per_feature = true; - // make sure the global cache is initialized - GGML_ASSERT(!g_cache.sin_vals.empty()); - GGML_ASSERT(!g_cache.cos_vals.empty()); - GGML_ASSERT(!g_cache.filters.data.empty()); + // make sure the cache is initialized + GGML_ASSERT(!cache.sin_vals.empty()); + GGML_ASSERT(!cache.cos_vals.empty()); + GGML_ASSERT(!cache.filters.data.empty()); mtmd_audio_mel out_full; - bool ok = log_mel_spectrogram( - samples, - n_samples, - 4, // n_threads - params, - out_full); + bool ok = log_mel_spectrogram(samples, n_samples, + 4, // n_threads + params, cache, out_full); if (!ok) { return false; } @@ -588,3 +625,106 @@ bool mtmd_audio_preprocessor_conformer::preprocess( output.push_back(std::move(out_full)); return true; } + +// +// mtmd_audio_streaming_istft implementation +// + +mtmd_audio_streaming_istft::mtmd_audio_streaming_istft(int n_fft, int hop_length) : + n_fft(n_fft), + hop_length(hop_length), + n_fft_bins(n_fft / 2 + 1), + overlap_buffer(n_fft, 0.0f), + window_sum_buffer(n_fft, 0.0f), + padding_to_remove((n_fft - hop_length) / 2), + ifft_in(n_fft * 2 * 4, 0.0f), // extra space for recursive IFFT + ifft_out(n_fft * 2 * 4, 0.0f) { + cache.fill_sin_cos_table(n_fft); + cache.fill_hann_window(n_fft, true); +} + +void mtmd_audio_streaming_istft::reset() { + std::fill(overlap_buffer.begin(), overlap_buffer.end(), 0.0f); + std::fill(window_sum_buffer.begin(), window_sum_buffer.end(), 0.0f); + padding_to_remove = (n_fft - hop_length) / 2; +} + +std::vector mtmd_audio_streaming_istft::process_frame(const float * frame_spectrum) { + std::vector output(hop_length); + + // copy frequencies + for (int j = 0; j < n_fft_bins; j++) { + ifft_in[j * 2 + 0] = frame_spectrum[j * 2 + 0]; + ifft_in[j * 2 + 1] = frame_spectrum[j * 2 + 1]; + } + + // mirror negative frequencies + for (int j = 1; j < n_fft_bins - 1; j++) { + int mirror_idx = n_fft - j; + ifft_in[mirror_idx * 2 + 0] = ifft_in[j * 2 + 0]; + ifft_in[mirror_idx * 2 + 1] = -ifft_in[j * 2 + 1]; // conjugate + } + + ifft(cache, ifft_in.data(), n_fft, ifft_out.data()); + + // update window sum and overlap buffer + for (int j = 0; j < n_fft; j++) { + window_sum_buffer[j] += cache.hann_window[j] * cache.hann_window[j]; + overlap_buffer[j] += ifft_out[j * 2] * cache.hann_window[j]; + } + + // extract hop_length samples with normalization + for (int i = 0; i < hop_length; i++) { + if (window_sum_buffer[i] > 1e-8f) { + output[i] = overlap_buffer[i] / window_sum_buffer[i]; + } else { + output[i] = overlap_buffer[i]; + } + } + + // shift buffers left by hop_length + std::copy(overlap_buffer.begin() + hop_length, overlap_buffer.end(), overlap_buffer.begin()); + std::fill(overlap_buffer.end() - hop_length, overlap_buffer.end(), 0.0f); + + std::copy(window_sum_buffer.begin() + hop_length, window_sum_buffer.end(), window_sum_buffer.begin()); + std::fill(window_sum_buffer.end() - hop_length, window_sum_buffer.end(), 0.0f); + + // Remove padding if needed + int to_remove = std::min(padding_to_remove, (int) output.size()); + padding_to_remove -= to_remove; + output.erase(output.begin(), output.begin() + to_remove); + + return output; +} + +std::vector mtmd_audio_streaming_istft::flush() { + std::vector output; + + // Extract remaining samples from overlap buffer + // Continue until we've extracted all meaningful samples + int remaining = n_fft - hop_length; + while (remaining > 0) { + int chunk_size = std::min(remaining, hop_length); + + for (int i = 0; i < chunk_size; i++) { + float sample; + if (window_sum_buffer[i] > 1e-8f) { + sample = overlap_buffer[i] / window_sum_buffer[i]; + } else { + sample = overlap_buffer[i]; + } + output.push_back(sample); + } + + // Shift buffers + std::copy(overlap_buffer.begin() + chunk_size, overlap_buffer.end(), overlap_buffer.begin()); + std::fill(overlap_buffer.end() - chunk_size, overlap_buffer.end(), 0.0f); + + std::copy(window_sum_buffer.begin() + chunk_size, window_sum_buffer.end(), window_sum_buffer.begin()); + std::fill(window_sum_buffer.end() - chunk_size, window_sum_buffer.end(), 0.0f); + + remaining -= chunk_size; + } + + return output; +} diff --git a/tools/mtmd/mtmd-audio.h b/tools/mtmd/mtmd-audio.h index d484c9d030..016c7392e4 100644 --- a/tools/mtmd/mtmd-audio.h +++ b/tools/mtmd/mtmd-audio.h @@ -17,6 +17,38 @@ struct mtmd_audio_mel { std::vector data; }; +struct mtmd_audio_mel_filters { + int32_t n_mel; + int32_t n_fft; + + std::vector data; +}; + +// cache for audio processing, each processor instance owns its own cache +struct mtmd_audio_cache { + std::vector sin_vals; + std::vector cos_vals; + + std::vector hann_window; + + mtmd_audio_mel_filters filters; + + void fill_sin_cos_table(int n); + + void fill_hann_window(int length, bool periodic); + + // Build mel filterbank matrix [n_mel × n_fft_bins] at runtime. + // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257. + void fill_mel_filterbank_matrix(int n_mel, + int n_fft, + int sample_rate, // e.g. 16000 + float fmin = 0.0f, // e.g. 0.0 + float fmax = -1.0f, // e.g. sr/2; pass -1 for auto + bool slaney_area_norm = true, + float scale = 1.0f // optional extra scaling + ); +}; + struct mtmd_audio_preprocessor { const clip_hparams & hparams; @@ -31,10 +63,51 @@ struct mtmd_audio_preprocessor_whisper : mtmd_audio_preprocessor { mtmd_audio_preprocessor_whisper(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {} void initialize() override; bool preprocess(const float * samples, size_t n_samples, std::vector & output) override; + + private: + mtmd_audio_cache cache; }; struct mtmd_audio_preprocessor_conformer : mtmd_audio_preprocessor { mtmd_audio_preprocessor_conformer(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {} void initialize() override; bool preprocess(const float * samples, size_t n_samples, std::vector & output) override; + + private: + mtmd_audio_cache cache; +}; + +// +// streaming ISTFT - converts spectrogram frames back to audio one frame at a time +// +struct mtmd_audio_streaming_istft { + mtmd_audio_streaming_istft(int n_fft, int hop_length); + + // reset streaming state + void reset(); + + // process a single STFT frame (streaming) + // frame_spectrum: [n_fft_bins x 2] interleaved real/imag + // returns: up to hop_length samples + std::vector process_frame(const float * frame_spectrum); + + // flush remaining samples at end of stream + std::vector flush(); + + private: + int n_fft; + int hop_length; + int n_fft_bins; + + // Own cache for output processing + mtmd_audio_cache cache; + + // Streaming state + std::vector overlap_buffer; + std::vector window_sum_buffer; + int padding_to_remove; + + // Working buffers for IFFT + std::vector ifft_in; + std::vector ifft_out; };