diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index bd017dfec4..0e1b6aae99 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5159,17 +5159,14 @@ class KimiLinearModel(TextModel): super().set_gguf_parameters() self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) - self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) - # Use find_hparam for context length - # Kimi uses model_max_length - n_ctx = self.find_hparam(["max_position_embeddings", "model_max_length", "n_ctx", "n_positions"], optional=True) - if n_ctx is not None: - self.gguf_writer.add_context_length(n_ctx) - else: - # Default to 4096 if not found - logger.warning("No context length found in config, defaulting to 4096") - self.gguf_writer.add_context_length(4096) + if (score_func := self.find_hparam(["moe_router_activation_func"], optional=True)) is not None: + if score_func == "sigmoid": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + elif score_func == "softmax": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) + else: + raise ValueError(f"Unsupported expert score gating function value: {score_func}") # KDA & MLA params # Get ssm_d_conv from linear_attn_config.short_conv_kernel_size or ssm_d_conv @@ -5226,7 +5223,7 @@ class KimiLinearModel(TextModel): self.gguf_writer.add_value_length_mla(v_head_dim) # Rotation - use qk_rope_head_dim for Kimi - rope_dim = self.hparams.get("qk_rope_head_dim") or self.hparams.get("n_rot") + rope_dim = self.find_hparam(["qk_rope_head_dim", "n_rot"]) if rope_dim is not None: self.gguf_writer.add_rope_dimension_count(rope_dim) else: @@ -5234,41 +5231,30 @@ class KimiLinearModel(TextModel): head_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] self.gguf_writer.add_rope_dimension_count(head_dim) - # Copied from Qwen2Moe as this model inherits parts of it - # YaRN is not enabled by default - # To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts - rope_scaling = self.hparams.get("rope_scaling") or {} - if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) - self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) - self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) - - # MoE params - n_experts = self.hparams.get("num_local_experts", self.hparams.get("num_experts")) + n_experts = self.find_hparam(["num_experts"]) if n_experts is not None: self.gguf_writer.add_expert_count(n_experts) - # Support both num_experts_per_tok and num_experts_per_token - n_experts_used = self.hparams.get("num_experts_per_tok", self.hparams.get("num_experts_per_token")) + n_experts_used = self.find_hparam(["num_experts_per_token"]) if n_experts_used is not None: self.gguf_writer.add_expert_used_count(n_experts_used) # moe_intermediate_size (1024 for Kimi) - moe_intermediate_size = self.hparams.get("moe_intermediate_size") + moe_intermediate_size = self.find_hparam(["moe_intermediate_size"]) if moe_intermediate_size is not None: self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) # num_shared_experts (1 for Kimi) - num_shared_experts = self.hparams.get("num_shared_experts") + num_shared_experts = self.find_hparam(["num_shared_experts"]) if num_shared_experts is not None: self.gguf_writer.add_expert_shared_count(num_shared_experts) # first_k_dense_replace (1 for Kimi - first layer uses dense MLP) - first_k_dense_replace = self.hparams.get("first_k_dense_replace") + first_k_dense_replace = self.find_hparam(["first_k_dense_replace"]) if first_k_dense_replace is not None: self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace) # Routed scaling factor (expert_weights_scale = 2.446 for Kimi) - routed_scaling_factor = self.hparams.get("routed_scaling_factor") + routed_scaling_factor = self.find_hparam(["routed_scaling_factor"]) if routed_scaling_factor is not None: self.gguf_writer.add_expert_weights_scale(routed_scaling_factor) @@ -5301,19 +5287,20 @@ class KimiLinearModel(TextModel): data_torch = data_torch.reshape(1, d_inner, 1, d_conv) logger.info(f"Reshaped conv1d weight {name}: [d_inner={d_inner}, 1, d_conv={d_conv}] -> numpy {tuple(data_torch.shape)} -> ggml ne=[{d_conv}, 1, {d_inner}, 1]") - # Handle A_log: HF stores as [1, 1, num_heads, 1] - # llama.cpp expects ggml ne = [1, num_heads, 1, 1] - # GGUF reverses numpy shape: numpy (1, 1, num_heads, 1) -> ggml ne = [1, num_heads, 1, 1] - # So no transformation needed! The shapes already match after GGUF reversal. - if name.endswith(".A_log"): - if data_torch.ndim == 4: - logger.info(f"A_log {name}: numpy {tuple(data_torch.shape)} -> ggml ne={list(reversed(data_torch.shape))}") - # Kimi specific bias if name.endswith("e_score_correction_bias"): new_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_EXP_PROBS_B, bid) return [(new_name, data_torch)] + # Handle A_log: iHF stores as [1, 1, num_heads, 1] + # llama.cpp expects ggml ne = [1, num_heads, 1, 1] + # GGUF reverses numpy shape: numpy (1, 1, num_heads, 1) -> ggml ne = [1, num_heads, 1, 1] + if name.endswith(".A_log"): + data_torch = -torch.exp(data_torch) + if name.endswith(".dt_bias"): + name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias" + logger.info("Changed dt_bias to dt_proj.bias") + # process the experts separately if name.find("block_sparse_moe.experts") != -1: n_experts = self.hparams.get("num_local_experts", self.hparams.get("num_experts")) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 43ea4eec0c..73e7bae6e1 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -462,7 +462,7 @@ class MODEL_ARCH(IntEnum): MIMO2 = auto() LLAMA_EMBED = auto() MAINCODER = auto() - KIMI_LINEAR = auto() # Kimi-Linear (hybrid MLA+KDA) + KIMI_LINEAR = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -559,10 +559,9 @@ class MODEL_TENSOR(IntEnum): SSM_F_A = auto() # Kimi Linear SSM_F_B = auto() # Kimi Linear SSM_BETA = auto() # Kimi Linear - SSM_A_LOG = auto() # Kimi Linear + SSM_DT_B = auto() # Kimi Linear SSM_G_A = auto() # Kimi Linear SSM_G_B = auto() # Kimi Linear - SSM_DT_B = auto() # Kimi Linear TIME_MIX_W0 = auto() TIME_MIX_W1 = auto() TIME_MIX_W2 = auto() @@ -894,7 +893,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.MIMO2: "mimo2", MODEL_ARCH.LLAMA_EMBED: "llama-embed", MODEL_ARCH.MAINCODER: "maincoder", - MODEL_ARCH.KIMI_LINEAR: "kimi-linear", + MODEL_ARCH.KIMI_LINEAR: "kimi-linear", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -988,10 +987,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.SSM_F_A: "blk.{bid}.ssm_f_a", # Kimi Linear MODEL_TENSOR.SSM_F_B: "blk.{bid}.ssm_f_b", # Kimi Linear MODEL_TENSOR.SSM_BETA: "blk.{bid}.ssm_beta", # Kimi Linear - MODEL_TENSOR.SSM_A_LOG: "blk.{bid}.ssm_a", # Kimi Linear MODEL_TENSOR.SSM_G_A: "blk.{bid}.ssm_g_a", # Kimi Linear MODEL_TENSOR.SSM_G_B: "blk.{bid}.ssm_g_b", # Kimi Linear - MODEL_TENSOR.SSM_DT_B: "blk.{bid}.ssm_dt", # Kimi Linear MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0", MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2", @@ -3433,11 +3430,11 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.SSM_F_A, MODEL_TENSOR.SSM_F_B, MODEL_TENSOR.SSM_BETA, - MODEL_TENSOR.SSM_A_LOG, + MODEL_TENSOR.SSM_A, MODEL_TENSOR.SSM_G_A, MODEL_TENSOR.SSM_G_B, + MODEL_TENSOR.SSM_DT, MODEL_TENSOR.SSM_NORM, - MODEL_TENSOR.SSM_DT_B, MODEL_TENSOR.FFN_EXP_PROBS_B, MODEL_TENSOR.FFN_GATE_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 99da6891f8..d96119ebe9 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -438,7 +438,6 @@ class TensorNameMap: "model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2 "backbone.layers.{bid}.mixer.gate.e_score_correction", # nemotron-h-moe "model.layers.{bid}.mlp.e_score_correction", # exaone-moe - "model.layers.{bid}.block_sparse_moe.gate.e_score_correction_bias", # kimi ), # Feed-forward up @@ -556,7 +555,6 @@ class TensorNameMap: MODEL_TENSOR.FFN_GATE_CHEXP: ( "model.layers.{bid}.mlp.chunk_experts.gate_proj", # grovemoe - "model.layers.{bid}.block_sparse_moe.shared_experts.gate_proj", # kimi ), # Feed-forward down @@ -764,6 +762,7 @@ class TensorNameMap: "model.layers.layers.{bid}.mixer.dt_proj", # plamo2 "model.layers.{bid}.linear_attn.dt_proj", # qwen3next "backbone.layers.{bid}.mixer.dt", # nemotron-h-moe + "model.layers.{bid}.self_attn.dt_proj", # kimi ), MODEL_TENSOR.SSM_DT_NORM: ( @@ -777,6 +776,7 @@ class TensorNameMap: "model.layers.{bid}.mamba.A_log", # jamba falcon-h1 granite-hybrid "model.layers.layers.{bid}.mixer.A_log", # plamo2 "model.layers.{bid}.linear_attn.A_log", # qwen3next + "model.layers.{bid}.self_attn.A_log", # kimi ), MODEL_TENSOR.SSM_B_NORM: ( @@ -836,18 +836,12 @@ class TensorNameMap: MODEL_TENSOR.SSM_BETA: ( "model.layers.{bid}.self_attn.b_proj", ), - MODEL_TENSOR.SSM_A_LOG: ( - "model.layers.{bid}.self_attn.A_log", - ), MODEL_TENSOR.SSM_G_A: ( "model.layers.{bid}.self_attn.g_a_proj", ), MODEL_TENSOR.SSM_G_B: ( "model.layers.{bid}.self_attn.g_b_proj", ), - MODEL_TENSOR.SSM_DT_B: ( - "model.layers.{bid}.self_attn.dt_bias", - ), MODEL_TENSOR.TIME_MIX_W0: ( "model.layers.{bid}.attention.w0", # rwkv7 ), diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 654276542d..a8bf1c9b80 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -380,8 +380,6 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_SSM_F_A, "blk.%d.ssm_f_a" }, { LLM_TENSOR_SSM_F_B, "blk.%d.ssm_f_b" }, { LLM_TENSOR_SSM_BETA, "blk.%d.ssm_beta" }, - { LLM_TENSOR_SSM_A_LOG, "blk.%d.ssm_a" }, - { LLM_TENSOR_SSM_DT_B, "blk.%d.ssm_dt" }, { LLM_TENSOR_SSM_G_A, "blk.%d.ssm_g_a" }, { LLM_TENSOR_SSM_G_B, "blk.%d.ssm_g_b" }, { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, @@ -2336,10 +2334,10 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_SSM_F_A, LLM_TENSOR_SSM_F_B, LLM_TENSOR_SSM_BETA, - LLM_TENSOR_SSM_A_LOG, - LLM_TENSOR_SSM_DT_B, + LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_G_A, LLM_TENSOR_SSM_G_B, + LLM_TENSOR_SSM_DT, LLM_TENSOR_SSM_NORM, // MLA LLM_TENSOR_ATTN_Q_A, @@ -2461,8 +2459,6 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_F_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_F_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_BETA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_SSM_A_LOG, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, - {LLM_TENSOR_SSM_DT_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_SSM_G_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_G_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index e5816acee1..f092f72834 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -408,8 +408,6 @@ enum llm_tensor { LLM_TENSOR_SSM_F_A, // kimi: forget gate projection A LLM_TENSOR_SSM_F_B, // kimi: forget gate projection B LLM_TENSOR_SSM_BETA, // kimi: beta mixing coefficient - LLM_TENSOR_SSM_A_LOG, // kimi: A_log (pre-converted in GGUF) - LLM_TENSOR_SSM_DT_B, // kimi: dt bias LLM_TENSOR_SSM_G_A, // kimi: output gate projection A LLM_TENSOR_SSM_G_B, // kimi: output gate projection B LLM_TENSOR_TIME_MIX_W0, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 9b796b3675..53f9f389e4 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2468,7 +2468,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); switch (hparams.n_layer) { case 27: type = LLM_TYPE_48B_A3B; break; // Kimi-Linear-48B-A3B @@ -6839,14 +6839,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // b_proj (beta mixing coefficient) layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), {n_embd, n_head}, 0); - // A_log - Shape in GGUF: [1, num_heads, 1, 1] (4D) or [1, num_heads] (2D after quantization) - layer.ssm_a_log = create_tensor(tn(LLM_TENSOR_SSM_A_LOG, i), {1, n_head, 1, 1}, TENSOR_NOT_REQUIRED); - if (!layer.ssm_a_log) { - layer.ssm_a_log = create_tensor(tn(LLM_TENSOR_SSM_A_LOG, i), {1, n_head}, 0); + // A_log - Shape in GGUF: [1, num_heads, 1, 1] (4D) or [1, num_heads] (2D after quantization) Note: -exp(A_log) is applied in convert_hf_to_gguf.py + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head, 1, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_a) { + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); } // dt_bias - shape [n_embd_head_k_kda * n_head] = [4096] - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT_B, i), {n_embd_head_k_kda * n_head}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_embd_head_k_kda * n_head}, 0); // g_a_proj, g_b_proj (output gate) layer.ssm_g_a = create_tensor(tn(LLM_TENSOR_SSM_G_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); @@ -6918,11 +6918,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); // exp_probs_b (e_score_correction_bias in vLLM) - // Try "bias" first (standard), then "weight" (for compatibility) - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); - if (!layer.ffn_exp_probs_b) { - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "weight", i), {n_expert}, TENSOR_NOT_REQUIRED); - } + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "weight", i), {n_expert}, 0); } } } break; diff --git a/src/llama-model.h b/src/llama-model.h index 40078dbdbd..a4900b093e 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -422,7 +422,6 @@ struct llama_layer { struct ggml_tensor * ssm_f_a = nullptr; struct ggml_tensor * ssm_f_b = nullptr; struct ggml_tensor * ssm_beta = nullptr; - struct ggml_tensor * ssm_a_log = nullptr; struct ggml_tensor * ssm_g_a = nullptr; struct ggml_tensor * ssm_g_b = nullptr; struct ggml_tensor * ssm_o_norm = nullptr; diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp index 6db782641d..6013cd0b77 100644 --- a/src/models/kimi-linear.cpp +++ b/src/models/kimi-linear.cpp @@ -127,7 +127,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // Check layer type by checking which tensors exist // KDA layers have ssm_a_log tensor, MLA layers have wkv_a_mqa tensor - bool is_kda = (layer.ssm_a_log != nullptr); + bool is_kda = (layer.ssm_a != nullptr); bool is_mla = (layer.wkv_a_mqa != nullptr); if (is_kda) { @@ -152,12 +152,10 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll g1 = ggml_softplus(ctx0, g1); g1 = ggml_reshape_3d(ctx0, g1, head_dim, n_head, n_tokens); - // A_log shape is [1, n_head] or [1, n_head, 1, 1], need to broadcast to [head_dim, n_head, n_tokens] - // First compute -exp(A_log), then reshape for broadcasting - ggml_tensor * A_neg_exp = ggml_neg(ctx0, ggml_exp(ctx0, layer.ssm_a_log)); + // A_log shape is [1, n_head] or [1, n_head, 1, 1], need to broadcast to [head_dim, n_head, n_tokens]. No need to -exp(a_log) because it was done in convert_hf_to_gguf.py // Reshape to [1, n_head, 1] for broadcasting with g1 [head_dim, n_head, n_tokens] - A_neg_exp = ggml_reshape_3d(ctx0, A_neg_exp, 1, n_head, 1); - g1 = ggml_mul(ctx0, g1, A_neg_exp); + ggml_tensor * A = ggml_reshape_3d(ctx0, layer.ssm_a, 1, n_head, 1); + g1 = ggml_mul(ctx0, g1, A); cb(g1, "kda_g1", il); // Compute beta (mixing coefficient)