fixed typo and split wkv_b into wk_b and wv_b

This commit is contained in:
Yee Man Chan 2026-01-10 22:08:38 +08:00
parent d26fe50178
commit dce064c0a3
3 changed files with 27 additions and 4 deletions

View File

@ -5275,7 +5275,8 @@ class KimiLinearModel(TextModel):
# Kimi specific bias
if name.endswith("e_score_correction_bias"):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_EXP_PROBS_B, bid)
return [(new_name, data_torch)]
# process the experts separately
if name.find("block_sparse_moe.experts") != -1:
@ -5305,7 +5306,27 @@ class KimiLinearModel(TextModel):
tensors.append((new_name, data_torch))
return tensors
return []
# note: MLA with the absorption optimization, needs these two split and k_b_proj transposed
if name.endswith("kv_b_proj.weight"):
name_kb = name.replace("kv_b_proj", "k_b_proj")
name_vb = name.replace("kv_b_proj", "v_b_proj")
n_head_kv = self.hparams["num_key_value_heads"]
v_head_dim = self.hparams["v_head_dim"]
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)
kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
k_b = k_b.transpose(1, 2)
return [
(self.map_tensor_name(name_kb), k_b),
(self.map_tensor_name(name_vb), v_b)
]
mapped_name = self.map_tensor_name(name)
logger.info(f"Returning {mapped_name}: shape after = {tuple(data_torch.shape)}")
return [(mapped_name, data_torch)]

View File

@ -3317,6 +3317,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_B,
MODEL_TENSOR.ATTN_K_B,
MODEL_TENSOR.ATTN_V_B,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.FFN_NORM,

View File

@ -403,7 +403,7 @@ class TensorNameMap:
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
"model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2
"backbone.layers.{bid}.mixer.gate.e_score_correction" # nemotron-h-moe
"model.layers.{bid}.block_sparse_moe.gate.e_score_correction", # kimi
"model.layers.{bid}.block_sparse_moe.gate.e_score_correction_bias", # kimi
),
# Feed-forward up
@ -812,7 +812,7 @@ class TensorNameMap:
),
MODEL_TENSOR.SSM_DT_B: (
"model.layers.{bid}.self_attn.dt_bias",
),
MODEL_TENSOR.TIME_MIX_W0: (
"model.layers.{bid}.attention.w0", # rwkv7
),