fixed find_hparam calls. Fixed e_score_correction_bias to use bias instead of weight. Removed all ssm_conv bias terms.
This commit is contained in:
parent
e55caf5be0
commit
560190af97
|
|
@ -5186,21 +5186,16 @@ class KimiLinearModel(TextModel):
|
|||
assert len(_num_kv_heads) == self.hparams["num_hidden_layers"]
|
||||
self.gguf_writer.add_head_count_kv(_num_kv_heads)
|
||||
|
||||
ssm_d_conv = self.hparams.get("ssm_d_conv") or linear_attn_config.get("short_conv_kernel_size")
|
||||
if ssm_d_conv is not None:
|
||||
if (ssm_d_conv := linear_attn_config.get("short_conv_kernel_size")) is not None:
|
||||
self.gguf_writer.add_ssm_conv_kernel(ssm_d_conv)
|
||||
kda_head_dim = self.hparams.get("kda_head_dim") or linear_attn_config.get("head_dim")
|
||||
if kda_head_dim is not None:
|
||||
if (kda_head_dim := linear_attn_config.get("head_dim")) is not None:
|
||||
self.gguf_writer.add_kda_head_dim(kda_head_dim)
|
||||
|
||||
# MLA params - use add_* methods that handle arch substitution
|
||||
# Support both HuggingFace naming (q_lora_rank, kv_lora_rank) and internal naming (n_lora_q, n_lora_kv)
|
||||
q_lora_rank = self.hparams.get("q_lora_rank", self.hparams.get("n_lora_q"))
|
||||
kv_lora_rank = self.hparams.get("kv_lora_rank", self.hparams.get("n_lora_kv"))
|
||||
|
||||
if q_lora_rank is not None:
|
||||
if (q_lora_rank := self.find_hparam(["q_lora_rank", "n_lora_q"], optional=False)) is not None:
|
||||
self.gguf_writer.add_q_lora_rank(q_lora_rank)
|
||||
if kv_lora_rank is not None:
|
||||
if (kv_lora_rank := self.find_hparam(["kv_lora_rank", "n_lora_kv"], optional=False)) is not None:
|
||||
self.gguf_writer.add_kv_lora_rank(kv_lora_rank)
|
||||
|
||||
# MLA head dimensions
|
||||
|
|
@ -5226,39 +5221,32 @@ class KimiLinearModel(TextModel):
|
|||
self.gguf_writer.add_value_length_mla(v_head_dim)
|
||||
|
||||
# Rotation - use qk_rope_head_dim for Kimi
|
||||
rope_dim = self.find_hparam(["qk_rope_head_dim", "n_rot"])
|
||||
if rope_dim is not None:
|
||||
if (rope_dim := self.find_hparam(["qk_rope_head_dim", "n_rot"], optional=True)) is not None:
|
||||
self.gguf_writer.add_rope_dimension_count(rope_dim)
|
||||
else:
|
||||
# Default to head_dim
|
||||
head_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
self.gguf_writer.add_rope_dimension_count(head_dim)
|
||||
|
||||
n_experts = self.find_hparam(["num_experts"])
|
||||
if n_experts is not None:
|
||||
if (n_experts := self.find_hparam(["num_experts"], optional=False)) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
n_experts_used = self.find_hparam(["num_experts_per_token"])
|
||||
if n_experts_used is not None:
|
||||
if (n_experts_used := self.find_hparam(["num_experts_per_token"], optional=False)) is not None:
|
||||
self.gguf_writer.add_expert_used_count(n_experts_used)
|
||||
|
||||
# moe_intermediate_size (1024 for Kimi)
|
||||
moe_intermediate_size = self.find_hparam(["moe_intermediate_size"])
|
||||
if moe_intermediate_size is not None:
|
||||
if (moe_intermediate_size := self.find_hparam(["moe_intermediate_size"], optional=False)) is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||
|
||||
# num_shared_experts (1 for Kimi)
|
||||
num_shared_experts = self.find_hparam(["num_shared_experts"])
|
||||
if num_shared_experts is not None:
|
||||
if (num_shared_experts := self.find_hparam(["num_shared_experts"], optional=False)) is not None:
|
||||
self.gguf_writer.add_expert_shared_count(num_shared_experts)
|
||||
|
||||
# first_k_dense_replace (1 for Kimi - first layer uses dense MLP)
|
||||
first_k_dense_replace = self.find_hparam(["first_k_dense_replace"])
|
||||
if first_k_dense_replace is not None:
|
||||
if (first_k_dense_replace := self.find_hparam(["first_k_dense_replace"])) is not None:
|
||||
self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace)
|
||||
|
||||
# Routed scaling factor (expert_weights_scale = 2.446 for Kimi)
|
||||
routed_scaling_factor = self.find_hparam(["routed_scaling_factor"])
|
||||
if routed_scaling_factor is not None:
|
||||
if (routed_scaling_factor := self.find_hparam(["routed_scaling_factor"], optional=False)) is not None:
|
||||
self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)
|
||||
|
||||
def prepare_tensors(self):
|
||||
|
|
@ -5292,8 +5280,7 @@ class KimiLinearModel(TextModel):
|
|||
|
||||
# Kimi specific bias
|
||||
if name.endswith("e_score_correction_bias"):
|
||||
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_EXP_PROBS_B, bid)
|
||||
return [(new_name, data_torch)]
|
||||
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
|
||||
|
||||
# Handle A_log: iHF stores as [1, 1, num_heads, 1]
|
||||
# llama.cpp expects ggml ne = [1, num_heads, 1, 1]
|
||||
|
|
|
|||
|
|
@ -438,6 +438,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2
|
||||
"backbone.layers.{bid}.mixer.gate.e_score_correction", # nemotron-h-moe
|
||||
"model.layers.{bid}.mlp.e_score_correction", # exaone-moe
|
||||
"model.layers.{bid}.block_sparse_moe.gate.e_score_correction", # kimi
|
||||
),
|
||||
|
||||
# Feed-forward up
|
||||
|
|
|
|||
|
|
@ -6825,11 +6825,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head}, 0);
|
||||
}
|
||||
|
||||
// Conv bias may not exist in all models - make optional
|
||||
layer.ssm_q_conv_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "bias", i), {n_embd_head_k_kda * n_head}, TENSOR_NOT_REQUIRED);
|
||||
layer.ssm_k_conv_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "bias", i), {n_embd_head_k_kda * n_head}, TENSOR_NOT_REQUIRED);
|
||||
layer.ssm_v_conv_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "bias", i), {n_embd_head_v_kda * n_head}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
// q, k, v projections
|
||||
// Python: q_proj, k_proj, v_proj
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k_kda * n_head}, 0);
|
||||
|
|
@ -6923,7 +6918,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
// exp_probs_b (e_score_correction_bias in vLLM)
|
||||
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "weight", i), {n_expert}, 0);
|
||||
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
|
||||
if (!layer.ffn_exp_probs_b) {
|
||||
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "weight", i), {n_expert}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
}
|
||||
}
|
||||
} break;
|
||||
|
|
|
|||
|
|
@ -415,11 +415,8 @@ struct llama_layer {
|
|||
// Kimi Linear KDA (using ssm_ prefix for consistency)
|
||||
// Note: ssm_dt_b already exists above (mamba bias), reused for Kimi dt_bias
|
||||
struct ggml_tensor * ssm_q_conv = nullptr;
|
||||
struct ggml_tensor * ssm_q_conv_b = nullptr;
|
||||
struct ggml_tensor * ssm_k_conv = nullptr;
|
||||
struct ggml_tensor * ssm_k_conv_b = nullptr;
|
||||
struct ggml_tensor * ssm_v_conv = nullptr;
|
||||
struct ggml_tensor * ssm_v_conv_b = nullptr;
|
||||
struct ggml_tensor * ssm_f_a = nullptr;
|
||||
struct ggml_tensor * ssm_f_b = nullptr;
|
||||
struct ggml_tensor * ssm_beta = nullptr;
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
// Causal Conv1d function for Q,K,V
|
||||
// When qkv is 0, it is Q, 1 is K, 2 is V
|
||||
static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_tensor * conv_states_all, ggml_tensor * conv_state_all, int64_t qkv, ggml_tensor * x, ggml_tensor * proj_w, ggml_tensor * conv_w, ggml_tensor * conv_b, int64_t d_conv, int64_t head_dim, int64_t n_head, int64_t n_seq_tokens, int64_t n_seqs, int64_t n_tokens, int64_t kv_head) {
|
||||
static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_tensor * conv_states_all, ggml_tensor * conv_state_all, int64_t qkv, ggml_tensor * x, ggml_tensor * proj_w, ggml_tensor * conv_w, int64_t d_conv, int64_t head_dim, int64_t n_head, int64_t n_seq_tokens, int64_t n_seqs, int64_t n_tokens, int64_t kv_head) {
|
||||
const int64_t d_inner = head_dim * n_head;
|
||||
const int64_t conv_state_size = (d_conv - 1) * d_inner;
|
||||
const int64_t n_embd_r_total = 3 * conv_state_size; // Q + K + V
|
||||
|
|
@ -56,9 +56,6 @@ static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_t
|
|||
ggml_tensor * Xcur = ggml_ssm_conv(ctx0, conv_x, conv_weight);
|
||||
// Reshape to 2D for bias add: {d_inner, n_tokens}
|
||||
Xcur = ggml_reshape_2d(ctx0, Xcur, d_inner, n_tokens);
|
||||
if (conv_b) {
|
||||
Xcur = ggml_add(ctx0, Xcur, conv_b);
|
||||
}
|
||||
Xcur = ggml_silu(ctx0, Xcur);
|
||||
|
||||
return ggml_reshape_4d(ctx0, Xcur, head_dim, n_head, n_seq_tokens, n_seqs);
|
||||
|
|
@ -140,9 +137,9 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
|||
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
||||
cb(conv_states_all, "conv_states_all", il);
|
||||
ggml_tensor * conv_state_all = build_rs(inp_rs, conv_states_all, hparams.n_embd_r(), n_seqs);
|
||||
ggml_tensor * Qcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 0, cur, layer.wq, layer.ssm_q_conv, layer.ssm_q_conv_b, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
|
||||
ggml_tensor * Kcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 1, cur, layer.wk, layer.ssm_k_conv, layer.ssm_k_conv_b, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
|
||||
ggml_tensor * Vcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 2, cur, layer.wv, layer.ssm_v_conv, layer.ssm_v_conv_b, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
|
||||
ggml_tensor * Qcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 0, cur, layer.wq, layer.ssm_q_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
|
||||
ggml_tensor * Kcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 1, cur, layer.wk, layer.ssm_k_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
|
||||
ggml_tensor * Vcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 2, cur, layer.wv, layer.ssm_v_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
|
||||
|
||||
// g1 = -exp(A_log) * softplus(f_b(f_a(x)) + dt_bias)
|
||||
ggml_tensor * f_a = ggml_mul_mat(ctx0, layer.ssm_f_a, cur);
|
||||
|
|
|
|||
Loading…
Reference in New Issue