more cleaning on python code

This commit is contained in:
younesbelkada 2025-07-03 18:09:30 +04:00
parent fdd5cff4ba
commit 14c37ec047
5 changed files with 204 additions and 0 deletions

View File

@ -6535,6 +6535,144 @@ class UltravoxWhisperEncoderModel(WhisperEncoderModel):
super().set_gguf_parameters()
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
@ModelBase.register("FalconH1ForCausalLM")
class FalconH1Model(Mamba2Model):
model_arch = gguf.MODEL_ARCH.FALCON_H1
def __init__(self, *args, **kwargs):
# Set the hparam prefixes for Falcon Mamba2
self.hparam_prefixes = ["mamba"]
# Initialize the base Mamba2Model
super().__init__(*args, **kwargs)
# Use Llama conversion for attention
self._transformer_model_class = LlamaModel
# n_group and d_inner are used during reshape_tensors for mamaba2
self.d_model = self.find_hparam(["hidden_size", "d_model"])
self.n_group = self.find_hparam(["n_groups"])
self.d_inner = self.find_hparam(["expand"]) * self.d_model
# Initialize any Falcon Mamba2 specific attributes
self.has_attention = True # Falcon Mamba2 has attention components
# Load Falcon-H1 multipliers from hyperparameters
self.attention_in_multiplier = self.find_hparam(["attention_in_multiplier"], optional=True)
self.attention_out_multiplier = self.find_hparam(["attention_out_multiplier"], optional=True)
self.ssm_in_multiplier = self.find_hparam(["ssm_in_multiplier"], optional=True)
self.ssm_out_multiplier = self.find_hparam(["ssm_out_multiplier"], optional=True)
self.mlp_multipliers = self.find_hparam(["mlp_multipliers"], optional=True)
self.ssm_multipliers = self.find_hparam(["ssm_multipliers"], optional=True)
self.intermediate_size = self.find_hparam(["intermediate_size"])
def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
prefixed = []
for pfx in self.hparam_prefixes:
prefixed.extend(
"_".join([pfx, k])
for k in keys
)
keys = list(keys) + prefixed
return super().find_hparam(keys, *args, **kwargs)
def _generate_mup_vector(self, block_id: int) -> torch.Tensor:
zxbcdt_multipliers = self.hparams["ssm_multipliers"]
intermediate_size = self.hparams["mamba_d_ssm"]
groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"]
vector_shape = (2 * intermediate_size + 2 * groups_time_state_size + self.hparams["mamba_n_heads"])
mup_vector = torch.ones(1, 1, vector_shape)
mup_vector[:, :, :intermediate_size] *= zxbcdt_multipliers[0]
mup_vector[:, :, intermediate_size:2 * intermediate_size] *= zxbcdt_multipliers[1]
mup_vector[:, :, 2 * intermediate_size:2 * intermediate_size + groups_time_state_size] *= zxbcdt_multipliers[2]
mup_vector[:, :, 2 * intermediate_size + groups_time_state_size:2 * intermediate_size + 2 * groups_time_state_size] *= zxbcdt_multipliers[3]
mup_vector[:, :, 2 * intermediate_size + 2 * groups_time_state_size:] *= zxbcdt_multipliers[4]
return mup_vector
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
for name, tensor in super().get_tensors():
if name.startswith("model.backbone") or name.startswith("model.lm_head"):
name = name.removeprefix("model.")
yield name, tensor
if self.ssm_multipliers is not None:
# Insert MUP vector after mamba.dt_bias
if "mamba.dt_bias" in name:
block_match = re.search(r"(?:model\.layers\.)?(\d+)\.mamba\.dt_bias", name)
if block_match:
block_id = int(block_match.group(1))
# Generate MUP vector with correct name format
mup_tensor = self._generate_mup_vector(block_id)
mup_name = f"blk.{block_id}.ssm_mup_vec"
logger.debug(f"Inserting MUP vector for block {block_id}: {mup_name}")
yield mup_name, mup_tensor
def set_gguf_parameters(self):
super().set_gguf_parameters()
## General Params ##
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
## Mamba mixer params ##
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
self.gguf_writer.add_ssm_group_count(self.n_group)
self.gguf_writer.add_ssm_inner_size(self.d_inner)
self.gguf_writer.add_ssm_head_dim(d_head := self.find_hparam(["d_head"]))
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
## Attention params ##
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in self.hparams else self.hparams["num_attention_heads"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_key_length(self.hparams["head_dim"])
self.gguf_writer.add_value_length(self.hparams["head_dim"])
self.gguf_writer.add_float32("falcon-h1.key_multiplier", self.hparams["key_multiplier"])
## Other params
self.gguf_writer.add_float32("falcon-h1.lm_head_multiplier", self.hparams["lm_head_multiplier"])
self.gguf_writer.add_float32("falcon-h1.embedding_multiplier", self.hparams["embedding_multiplier"])
## Validation ##
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}"
# Add Falcon Mamba2 specific configuration
self.gguf_writer.add_uint32("falcon-h1.ssm.mamba_chunk_size", self.hparams["mamba_chunk_size"])
self.gguf_writer.add_uint32("falcon-h1.attention.head_dim", self.hparams["head_dim"])
self.gguf_writer.add_uint32("falcon-h1.ssm.mamba_d_ssm", self.hparams["mamba_d_ssm"])
self.gguf_writer.add_uint32("falcon-h1.num_attention_heads", self.find_hparam(["num_attention_heads"]))
self.gguf_writer.add_uint32("falcon-h1.num_key_value_heads",
self.find_hparam(["num_key_value_heads"], optional=True) or
self.find_hparam(["num_attention_heads"]))
# Add multipliers as metadata instead of tensors
self.gguf_writer.add_float32("falcon-h1.attention_in_multiplier", self.attention_in_multiplier)
self.gguf_writer.add_float32("falcon-h1.attention_out_multiplier", self.attention_out_multiplier)
self.gguf_writer.add_float32("falcon-h1.ssm_in_multiplier", self.ssm_in_multiplier)
self.gguf_writer.add_float32("falcon-h1.ssm_out_multiplier", self.ssm_out_multiplier)
# Add MLP multipliers
if isinstance(self.mlp_multipliers, (list, tuple)) and len(self.mlp_multipliers) == 2:
self.gguf_writer.add_float32("falcon-h1.mlp_gate_multiplier", self.mlp_multipliers[0])
self.gguf_writer.add_float32("falcon-h1.mlp_down_multiplier", self.mlp_multipliers[1])
# Add has MuP flag if SSM multipliers are present
if self.ssm_multipliers is not None:
self.gguf_writer.add_bool("falcon-h1.ssm.has_mup", True)
# Add any other Falcon Mamba2 specific configuration
self.gguf_writer.add_bool("falcon-h1.mamba_use_mlp", self.find_hparam(["mamba_use_mlp"], optional=True))
self.gguf_writer.add_bool("falcon-h1.mamba_norm_before_gate", self.find_hparam(["mamba_norm_before_gate"], optional=True))
self.gguf_writer.add_bool("falcon-h1.mamba_rms_norm", self.find_hparam(["mamba_rms_norm"], optional=True))
self.gguf_writer.add_float32("falcon-h1.rope_theta", self.find_hparam(["rope_theta"], optional=True))
###### CONVERSION LOGIC ######

View File

@ -172,6 +172,7 @@ class Keys:
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
GROUP_COUNT = "{arch}.ssm.group_count"
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
HEAD_DIM = "{arch}.ssm.head_dim"
class WKV:
HEAD_SIZE = "{arch}.wkv.head_size"
@ -288,6 +289,7 @@ class MODEL_ARCH(IntEnum):
LLAMA4 = auto()
DECI = auto()
FALCON = auto()
FALCON_H1 = auto()
BAICHUAN = auto()
GROK = auto()
GPT2 = auto()
@ -525,6 +527,7 @@ class MODEL_TENSOR(IntEnum):
POSNET_ATTN_K = auto()
POSNET_ATTN_V = auto()
POSNET_ATTN_OUT = auto()
SSM_MUP_VEC = auto()
# vision
V_MMPROJ = auto()
V_MMPROJ_FC = auto()
@ -660,6 +663,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.DOTS1: "dots1",
MODEL_ARCH.ARCEE: "arcee",
MODEL_ARCH.ERNIE4_5: "ernie4_5",
MODEL_ARCH.FALCON_H1: "falcon_h1",
}
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@ -736,6 +740,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.SSM_MUP_VEC: "blk.{bid}.ssm_mup_vec",
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
@ -2211,6 +2216,41 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.FALCON_H1: [
# Token embedding
MODEL_TENSOR.TOKEN_EMBD,
# Input layernorm
MODEL_TENSOR.ATTN_NORM,
# Attention components
MODEL_TENSOR.ATTN_Q, # Query projection
MODEL_TENSOR.ATTN_K, # Key projection
MODEL_TENSOR.ATTN_V, # Value projection
MODEL_TENSOR.ATTN_OUT, # Output projection
# SSM components (Mamba2 specific)
MODEL_TENSOR.SSM_MUP_VEC, # Mup vector
MODEL_TENSOR.SSM_IN, # Input projection for SSM
MODEL_TENSOR.SSM_CONV1D, # Convolution layer
MODEL_TENSOR.SSM_DT, # Delta time projection
MODEL_TENSOR.SSM_A, # A parameter (log form)
MODEL_TENSOR.SSM_D, # D parameter
MODEL_TENSOR.SSM_NORM, # Normalization in SSM
MODEL_TENSOR.SSM_OUT, # Output projection
# Pre-feedforward layernorm
MODEL_TENSOR.FFN_PRE_NORM,
# Feed-forward network components
MODEL_TENSOR.FFN_GATE, # Gate projection (SwiGLU)
MODEL_TENSOR.FFN_DOWN, # Down projection
MODEL_TENSOR.FFN_UP, # Up projection
# Post-feedforward layernorm
MODEL_TENSOR.OUTPUT_NORM, # Final layer norm
MODEL_TENSOR.OUTPUT, # Output projection (lm_head)
],
# TODO
}

View File

@ -867,6 +867,15 @@ class GGUFWriter:
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
def add_ssm_head_dim(self, value: int) -> None:
self.add_uint32(Keys.SSM.HEAD_DIM.format(arch=self.arch), value)
def add_attn_head_count(self, count: int) -> None:
self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
def add_key_value_head_count(self, count: int) -> None:
self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
def add_tokenizer_model(self, model: str) -> None:
self.add_string(Keys.Tokenizer.MODEL, model)

View File

@ -286,12 +286,14 @@ class TensorNameMap:
# Post feed-forward norm
MODEL_TENSOR.FFN_PRE_NORM: (
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
"model.layers.{bid}.pre_ff_layernorm.weight",
),
# Post feed-forward norm
MODEL_TENSOR.FFN_POST_NORM: (
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
"model.layers.{bid}.feed_forward.up_proj",
),
MODEL_TENSOR.FFN_GATE_INP: (
@ -362,6 +364,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
"model.layers.{bid}.feed_forward.down_proj",
),
# AWQ-activation gate
@ -547,11 +550,13 @@ class TensorNameMap:
MODEL_TENSOR.SSM_IN: (
"model.layers.{bid}.in_proj",
"backbone.layers.{bid}.mixer.in_proj",
"model.layers.{bid}.mamba.in_proj",
),
MODEL_TENSOR.SSM_CONV1D: (
"model.layers.{bid}.conv1d",
"backbone.layers.{bid}.mixer.conv1d",
"model.layers.{bid}.mamba.conv1d",
),
MODEL_TENSOR.SSM_X: (
@ -562,16 +567,19 @@ class TensorNameMap:
MODEL_TENSOR.SSM_DT: (
"model.layers.{bid}.dt_proj",
"backbone.layers.{bid}.mixer.dt_proj",
"model.layers.{bid}.mamba.dt_proj",
),
MODEL_TENSOR.SSM_A: (
"model.layers.{bid}.A_log",
"backbone.layers.{bid}.mixer.A_log",
"model.layers.{bid}.mamba.A_log",
),
MODEL_TENSOR.SSM_D: (
"model.layers.{bid}.D",
"backbone.layers.{bid}.mixer.D",
"model.layers.{bid}.mamba.D",
),
MODEL_TENSOR.SSM_NORM: (
@ -1168,6 +1176,14 @@ class TensorNameMap:
"resampler.attn.out_proj",
),
MODEL_TENSOR.SSM_MUP_VEC: (
"model.layers.{bid}.mamba.mup_vector", # falcon-h1
),
MODEL_TENSOR.SSM_NORM: (
"model.layers.{bid}.mamba.norm",
),
MODEL_TENSOR.V_RESMPL_KV: (
"resampler.kv_proj",
),

View File

@ -14921,6 +14921,7 @@ struct llm_build_falcon_h1 : public llm_graph_context {
y = ggml_reshape_4d(ctx0, y, d_ssm / n_group, n_group, n_seq_tokens, n_seqs);
y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
}
y = ggml_reshape_3d(ctx0, y, d_ssm, n_seq_tokens, n_seqs);
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}