mtmd : add MERaLiON-2 multimodal audio support (#21756)
* mtmd : add MERaLiON-2 multimodal audio support Adds support for A*STAR's MERaLiON-2 audio-language model (3B and 10B) to the multimodal framework. Architecture: - Whisper large-v2 encoder for audio feature extraction - Gated MLP adaptor: ln_speech -> frame stack (x15) -> Linear+SiLU -> GLU -> out_proj - Gemma2 3B / 27B decoder The mmproj GGUF is generated via convert_hf_to_gguf.py --mmproj on the full MERaLiON-2 model directory (architecture: MERaLiON2ForConditionalGeneration). The decoder is converted separately as a standard Gemma2 model after stripping the text_decoder. weight prefix. New projector type: PROJECTOR_TYPE_MERALION Supports tasks: speech transcription (EN/ZH/MS/TA), translation, spoken QA. Model: https://huggingface.co/MERaLiON/MERaLiON-2-3B https://huggingface.co/MERaLiON/MERaLiON-2-10B * simplify comments in meralion adaptor * meralion: use format_tensor_name, ascii arrows in comments
This commit is contained in:
parent
af1127d3c4
commit
073bb2c20b
|
|
@ -11279,6 +11279,48 @@ class UltravoxWhisperEncoderModel(WhisperEncoderModel):
|
|||
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
|
||||
|
||||
|
||||
@ModelBase.register("MERaLiON2ForConditionalGeneration")
|
||||
class MERaLiONWhisperEncoderModel(WhisperEncoderModel):
|
||||
has_vision_encoder = False
|
||||
has_audio_encoder = True
|
||||
|
||||
def get_audio_config(self) -> dict[str, Any] | None:
|
||||
return self.global_config.get("speech_config")
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.MERALION)
|
||||
self.gguf_writer.add_audio_stack_factor(self.global_config.get("speech_mlp_scale_factor", 15))
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name.startswith("text_decoder."):
|
||||
return
|
||||
|
||||
if name.startswith("speech_encoder."):
|
||||
name = name.replace("speech_encoder.", "audio_tower.")
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
return
|
||||
|
||||
suffix = "." + name.rsplit(".", 1)[-1]
|
||||
|
||||
if name.startswith("ln_speech."):
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.A_MM_NORM_PRE, suffix=suffix), data_torch)
|
||||
return
|
||||
|
||||
if name.startswith("speech_audio_adapter."):
|
||||
if ".mlp_adapter.0." in name:
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.A_MMPROJ, 0, suffix=suffix), data_torch)
|
||||
elif ".gate_proj." in name:
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.A_MMPROJ, 1, suffix=suffix), data_torch)
|
||||
elif ".pool_proj." in name:
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.A_MMPROJ, 2, suffix=suffix), data_torch)
|
||||
elif ".out_proj." in name:
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.A_MMPROJ, 3, suffix=suffix), data_torch)
|
||||
return
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("VoxtralForConditionalGeneration")
|
||||
class VoxtralWhisperEncoderModel(WhisperEncoderModel):
|
||||
has_vision_encoder = False # no vision encoder
|
||||
|
|
|
|||
|
|
@ -4115,6 +4115,7 @@ class VisionProjectorType:
|
|||
GLMA = "glma" # audio
|
||||
QWEN25O = "qwen2.5o" # omni
|
||||
VOXTRAL = "voxtral"
|
||||
MERALION = "meralion" # audio: Whisper + gated MLP adaptor
|
||||
LFM2 = "lfm2"
|
||||
KIMIVL = "kimivl"
|
||||
PADDLEOCR = "paddleocr"
|
||||
|
|
|
|||
|
|
@ -2041,7 +2041,7 @@ class TensorNameMap:
|
|||
# this prefix is added in the conversion code in modify_tensors()
|
||||
|
||||
MODEL_TENSOR.A_MMPROJ: (
|
||||
"audio.multi_modal_projector.linear_{bid}", # ultravox
|
||||
"audio.multi_modal_projector.linear_{bid}", # ultravox, meralion
|
||||
"audio_adapter.model.{bid}" # lfm2
|
||||
),
|
||||
|
||||
|
|
|
|||
|
|
@ -259,6 +259,7 @@ enum projector_type {
|
|||
PROJECTOR_TYPE_GLMA,
|
||||
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
|
||||
PROJECTOR_TYPE_VOXTRAL,
|
||||
PROJECTOR_TYPE_MERALION,
|
||||
PROJECTOR_TYPE_MUSIC_FLAMINGO,
|
||||
PROJECTOR_TYPE_LFM2,
|
||||
PROJECTOR_TYPE_KIMIVL,
|
||||
|
|
@ -302,6 +303,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
|||
{ PROJECTOR_TYPE_GLMA, "glma"},
|
||||
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
|
||||
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
|
||||
{ PROJECTOR_TYPE_MERALION, "meralion"},
|
||||
{ PROJECTOR_TYPE_MUSIC_FLAMINGO, "musicflamingo"},
|
||||
{ PROJECTOR_TYPE_LFM2, "lfm2"},
|
||||
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
|
||||
|
|
|
|||
|
|
@ -467,7 +467,8 @@ struct clip_model {
|
|||
|
||||
bool audio_has_stack_frames() const {
|
||||
return proj_type == PROJECTOR_TYPE_ULTRAVOX
|
||||
|| proj_type == PROJECTOR_TYPE_VOXTRAL;
|
||||
|| proj_type == PROJECTOR_TYPE_VOXTRAL
|
||||
|| proj_type == PROJECTOR_TYPE_MERALION;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -890,6 +890,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
case PROJECTOR_TYPE_MERALION:
|
||||
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_whisper_enc>(ctx, img);
|
||||
|
|
@ -1399,10 +1400,12 @@ struct clip_model_loader {
|
|||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_MERALION:
|
||||
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
|
||||
{
|
||||
bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX ||
|
||||
model.proj_type == PROJECTOR_TYPE_VOXTRAL ||
|
||||
model.proj_type == PROJECTOR_TYPE_MERALION ||
|
||||
model.proj_type == PROJECTOR_TYPE_GLMA;
|
||||
get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack);
|
||||
hparams.ffn_op = FFN_GELU_ERF;
|
||||
|
|
@ -2017,6 +2020,30 @@ struct clip_model_loader {
|
|||
model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
|
||||
model.mm_norm_mid_w = get_tensor(string_format(TN_MM_NORM_MID, "weight"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_MERALION:
|
||||
{
|
||||
// Whisper encoder conv layers
|
||||
model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
|
||||
model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
|
||||
model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
|
||||
model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
|
||||
// MERaLiON adaptor: 4 linear layers + ln_pre
|
||||
// linear_0 = frame compression (19200->6400) + SiLU
|
||||
// linear_1 = gate_proj (6400->6400) for GLU
|
||||
// linear_2 = pool_proj (6400->6400) for GLU
|
||||
// linear_3 = out_proj (6400->3584)
|
||||
model.mm_0_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 0, "weight"));
|
||||
model.mm_0_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 0, "bias"));
|
||||
model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
|
||||
model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias"));
|
||||
model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
|
||||
model.mm_2_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "bias"));
|
||||
model.mm_3_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 3, "weight"));
|
||||
model.mm_3_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 3, "bias"));
|
||||
// ln_speech (LayerNorm before adaptor)
|
||||
model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
|
||||
model.mm_norm_pre_b = get_tensor(string_format(TN_MM_NORM_PRE, "bias"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
{
|
||||
model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
|
||||
|
|
@ -2809,6 +2836,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
|||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_MERALION:
|
||||
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
|
||||
{
|
||||
n_patches = img->nx;
|
||||
|
|
@ -3298,6 +3326,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_LFM2:
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_MERALION:
|
||||
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
|
||||
case PROJECTOR_TYPE_JANUS_PRO:
|
||||
case PROJECTOR_TYPE_PHI4:
|
||||
|
|
@ -3463,6 +3492,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
|||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
|
||||
return ctx->model.mm_2_w->ne[1];
|
||||
case PROJECTOR_TYPE_MERALION:
|
||||
return ctx->model.mm_3_w->ne[1]; // out_proj output dim
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
case PROJECTOR_TYPE_NEMOTRON_V2_VL:
|
||||
return ctx->model.mm_3_w->ne[1];
|
||||
|
|
@ -3523,6 +3554,7 @@ bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
|
|||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_MERALION:
|
||||
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
|
||||
return true;
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -95,6 +95,28 @@ ggml_cgraph * clip_graph_whisper_enc::build() {
|
|||
FFN_GELU_ERF,
|
||||
-1);
|
||||
|
||||
} else if (proj_type == PROJECTOR_TYPE_MERALION) {
|
||||
// stack (above) -> ln -> linear0+silu -> GLU -> out
|
||||
cur = ggml_norm(ctx0, cur, hparams.eps);
|
||||
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
|
||||
cur = ggml_add(ctx0, cur, model.mm_norm_pre_b);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.mm_0_w, cur);
|
||||
cur = ggml_add(ctx0, cur, model.mm_0_b);
|
||||
cur = ggml_silu(ctx0, cur);
|
||||
|
||||
ggml_tensor * gate = ggml_mul_mat(ctx0, model.mm_1_w, cur);
|
||||
gate = ggml_add(ctx0, gate, model.mm_1_b);
|
||||
gate = ggml_silu(ctx0, gate);
|
||||
|
||||
ggml_tensor * pool = ggml_mul_mat(ctx0, model.mm_2_w, cur);
|
||||
pool = ggml_add(ctx0, pool, model.mm_2_b);
|
||||
|
||||
cur = ggml_mul(ctx0, gate, pool);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.mm_3_w, cur);
|
||||
cur = ggml_add(ctx0, cur, model.mm_3_b);
|
||||
|
||||
} else if (proj_type == PROJECTOR_TYPE_GLMA) {
|
||||
cur = ggml_norm(ctx0, cur, hparams.eps);
|
||||
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
|
||||
|
|
|
|||
|
|
@ -476,6 +476,7 @@ struct mtmd_context {
|
|||
} break;
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
case PROJECTOR_TYPE_MERALION:
|
||||
{
|
||||
audio_preproc = std::make_unique<mtmd_audio_preprocessor_whisper>(ctx_a);
|
||||
} break;
|
||||
|
|
|
|||
Loading…
Reference in New Issue