diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index e12c8b9250..ab09bb7eb7 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -112,6 +112,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_MISTRAL3, "mistral3" }, + { LLM_ARCH_KIMI_LINEAR, "kimi-linear" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -2540,6 +2541,54 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_KIMI_LINEAR, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + // Dense FFN (layer 0 only) + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + // MoE FFN (layers 1+) + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + // Shared experts + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + // KDA (using SSM_ enum prefix, keeping GGUF names for backward compat) + { LLM_TENSOR_SSM_CONV1D_Q, "blk.%d.ssm_conv1d_q" }, + { LLM_TENSOR_SSM_CONV1D_K, "blk.%d.ssm_conv1d_k" }, + { LLM_TENSOR_SSM_CONV1D_V, "blk.%d.ssm_conv1d_v" }, + { LLM_TENSOR_SSM_F_A, "blk.%d.ssm_f_a" }, + { LLM_TENSOR_SSM_F_B, "blk.%d.ssm_f_b" }, + { LLM_TENSOR_SSM_BETA, "blk.%d.ssm_beta" }, + { LLM_TENSOR_SSM_A_LOG, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_DT_B, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_G_A, "blk.%d.ssm_g_a" }, + { LLM_TENSOR_SSM_G_B, "blk.%d.ssm_g_b" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + // MLA + { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, + { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, + { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, + { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, + { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -2644,6 +2693,17 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + // Kimi KDA - Conv tensors are 4D [d_conv, 1, d_inner, 1], reshaped to 2D at runtime + {LLM_TENSOR_SSM_CONV1D_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_CONV1D_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_CONV1D_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_F_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_F_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_BETA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_A_LOG, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_DT_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_SSM_G_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_G_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -2801,6 +2861,7 @@ bool llm_arch_is_recurrent(const llm_arch & arch) { case LLM_ARCH_RWKV6QWEN2: case LLM_ARCH_RWKV7: case LLM_ARCH_ARWKV7: + case LLM_ARCH_KIMI_LINEAR: // KDA layers use delta attention with recurrent state return true; default: return false; @@ -2817,6 +2878,9 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_LFM2MOE: case LLM_ARCH_NEMOTRON_H: case LLM_ARCH_QWEN3NEXT: + // Kimi: Currently using recurrent-only mode since MLA doesn't use KV cache + // TODO: Enable hybrid when MLA KV caching is implemented + // case LLM_ARCH_KIMI_LINEAR: return true; default: return false; diff --git a/src/llama-arch.h b/src/llama-arch.h index 438963cef0..2b965850c5 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -116,6 +116,7 @@ enum llm_arch { LLM_ARCH_RND1, LLM_ARCH_PANGU_EMBED, LLM_ARCH_MISTRAL3, + LLM_ARCH_KIMI_LINEAR, LLM_ARCH_UNKNOWN, }; @@ -385,6 +386,17 @@ enum llm_tensor { LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next + // Kimi Linear KDA (using SSM_ prefix for consistency) + LLM_TENSOR_SSM_CONV1D_Q, // kimi: Q conv1d weight + LLM_TENSOR_SSM_CONV1D_K, // kimi: K conv1d weight + LLM_TENSOR_SSM_CONV1D_V, // kimi: V conv1d weight + LLM_TENSOR_SSM_F_A, // kimi: forget gate projection A + LLM_TENSOR_SSM_F_B, // kimi: forget gate projection B + LLM_TENSOR_SSM_BETA, // kimi: beta mixing coefficient + LLM_TENSOR_SSM_A_LOG, // kimi: A_log (pre-converted in GGUF) + LLM_TENSOR_SSM_DT_B, // kimi: dt bias + LLM_TENSOR_SSM_G_A, // kimi: output gate projection A + LLM_TENSOR_SSM_G_B, // kimi: output gate projection B LLM_TENSOR_TIME_MIX_W0, LLM_TENSOR_TIME_MIX_W1, LLM_TENSOR_TIME_MIX_W2, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index e04f0fc4f9..3278cf2ef8 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1387,7 +1387,7 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes() const { - if (model.arch == LLM_ARCH_QWEN3NEXT) { + if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR) { return std::max(8192u, 32u*model.n_tensors()); } return std::max(1024u, 8u*model.n_tensors()); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 42ccb5b76a..e41d65398f 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1816,11 +1816,14 @@ ggml_tensor * llm_graph_context::build_rs( ggml_build_forward_expand(gf, output_states); // copy extra states which won't be changed further (between n_seqs and n_rs) - ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra); - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, - states_extra, - ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s)))); + // Skip if there are no extra states to copy (n_rs == n_seqs) + if (arch != LLM_ARCH_KIMI_LINEAR || n_rs > n_seqs) { // arch check for backward compat + ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra); + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + states_extra, + ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s)))); + } return output_states; } diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 8cdbaf69fc..88d266b8da 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -133,6 +133,13 @@ uint32_t llama_hparams::n_embd_r() const { return n_embd * (n_shortconv_l_cache - 1); } + if (kda_head_dim != 0) { + // for Kimi KDA layers + // Conv state for Q, K, V: 3 * (d_conv - 1) * n_head * head_dim + const uint32_t d_inner = n_head() * kda_head_dim; // 32 * 128 = 4096 + return 3 * (kda_d_conv > 0 ? kda_d_conv - 1 : 3) * d_inner; + } + // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed // Corresponds to Mamba's conv_states size @@ -145,6 +152,13 @@ uint32_t llama_hparams::n_embd_s() const { return n_embd * wkv_head_size; } + if (kda_head_dim != 0) { + // for Kimi KDA layers + // Full recurrent state: head_dim * head_dim * n_head + // h tensor shape for delta attention: [head_dim, head_dim, n_head] + return kda_head_dim * kda_head_dim * n_head(); // 128 * 128 * 32 = 524288 + } + // corresponds to Mamba's ssm_states size return ssm_d_state * ssm_d_inner; } diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 6eff334a5f..80170650eb 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -133,6 +133,10 @@ struct llama_hparams { uint32_t ssm_dt_rank = 0; uint32_t ssm_n_group = 0; + // for Kimi Delta Attention (KDA) + uint32_t kda_head_dim = 0; // head_dim for KDA layers (128 for Kimi) + uint32_t kda_d_conv = 0; // conv kernel size for KDA (4 for Kimi) + // for hybrid state space models std::array recurrent_layer_arr; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 584efbf3c8..763f0dfecb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2283,6 +2283,54 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_KIMI_LINEAR: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv, false); + ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); + + // KDA (Delta Attention) parameters + hparams.kda_head_dim = 128; // linear_attn_config.head_dim + hparams.kda_d_conv = 4; // linear_attn_config.short_conv_kernel_size + + // MLA qk_rope_head_dim (for reference) + // qk_rope_head_dim = 64, qk_nope_head_dim = 128, qk_head_dim = 192 + + // Mark KDA layers as recurrent using n_head_kv pattern (like Jamba) + // MLA layers are at: 3, 7, 11, 15, 19, 23, 26 (7 MLA layers total) + // KDA layers are all others: 0, 1, 2, 4, 5, 6, 8, 9, 10, 12, 13, 14, 16, 17, 18, 20, 21, 22, 24, 25 (20 KDA layers) + // Set n_head_kv = 0 for KDA layers (recurrent), n_head_kv = n_head for MLA layers (attention) + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + bool is_mla = (i == 3 || i == 7 || i == 11 || i == 15 || i == 19 || i == 23 || i == 26); + hparams.n_head_kv_arr[i] = is_mla ? hparams.n_head() : 0; + hparams.recurrent_layer_arr[i] = !is_mla; // KDA layers are recurrent + } + + // MoE parameters - Kimi uses moe_intermediate_size = 1024 + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + + // Default values if not in GGUF + if (hparams.n_ff_exp == 0) hparams.n_ff_exp = 1024; // moe_intermediate_size + if (hparams.n_ff_shexp == 0) hparams.n_ff_shexp = 9216; // shared_expert_intermediate_size = intermediate_size + if (hparams.n_expert_shared == 0) hparams.n_expert_shared = 1; // num_shared_experts + if (hparams.n_layer_dense_lead == 0) hparams.n_layer_dense_lead = 1; // first_k_dense_replace + if (hparams.expert_weights_scale == 0.0f) hparams.expert_weights_scale = 2.446f; // routed_scaling_factor + + // MoE gating function - Kimi uses sigmoid (moe_router_activation_func: sigmoid) + if (hparams.expert_gating_func == 0) hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + + switch (hparams.n_layer) { + case 27: type = LLM_TYPE_48B; break; // Kimi-Linear-48B-A3B + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -6395,6 +6443,148 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); } } break; + case LLM_ARCH_KIMI_LINEAR: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // Check for KDA specific tensors to determine layer type or if it's a mixed model + // Assuming KDA layer if KDA tensors are present + + // KDA uses head_dim = 128 (from linear_attn_config.head_dim) + const int64_t n_embd_head_k_kda = 128; + const int64_t n_embd_head_v_kda = 128; + const int64_t ssm_d_conv = hparams.ssm_d_conv > 0 ? hparams.ssm_d_conv : 4; + + // Try loading KDA specific tensors (using SSM_ prefix) + // Conv1d weights: try 4D first, then 3D (quantization may remove trailing 1) + // 4D: [d_conv, 1, d_inner, 1], 3D: [d_conv, 1, d_inner] + layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_q_conv) { + layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, TENSOR_NOT_REQUIRED); + } + + if (layer.ssm_q_conv) { + // KDA Layer - Conv1d weights may be 3D or 4D + layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_k_conv) { + layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, 0); + } + layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_v_conv) { + layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head}, 0); + } + + // Conv bias may not exist in all models - make optional + layer.ssm_q_conv_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "bias", i), {n_embd_head_k_kda * n_head}, TENSOR_NOT_REQUIRED); + layer.ssm_k_conv_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "bias", i), {n_embd_head_k_kda * n_head}, TENSOR_NOT_REQUIRED); + layer.ssm_v_conv_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "bias", i), {n_embd_head_v_kda * n_head}, TENSOR_NOT_REQUIRED); + + // q, k, v projections + // Python: q_proj, k_proj, v_proj + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k_kda * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_head_k_kda * n_head}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_head_v_kda * n_head}, 0); + + // KDA specific projections + // f_a_proj, f_b_proj + layer.ssm_f_a = create_tensor(tn(LLM_TENSOR_SSM_F_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); // head_dim + layer.ssm_f_b = create_tensor(tn(LLM_TENSOR_SSM_F_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); // projection_size + + // b_proj (beta mixing coefficient) + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), {n_embd, n_head}, 0); + + // A_log - Shape in GGUF: [1, num_heads, 1, 1] (4D) or [1, num_heads] (2D after quantization) + layer.ssm_a_log = create_tensor(tn(LLM_TENSOR_SSM_A_LOG, i), {1, n_head, 1, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_a_log) { + layer.ssm_a_log = create_tensor(tn(LLM_TENSOR_SSM_A_LOG, i), {1, n_head}, 0); + } + + // dt_bias - shape [n_embd_head_k_kda * n_head] = [4096] + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT_B, i), {n_embd_head_k_kda * n_head}, 0); + + // g_a_proj, g_b_proj (output gate) + layer.ssm_g_a = create_tensor(tn(LLM_TENSOR_SSM_G_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); + layer.ssm_g_b = create_tensor(tn(LLM_TENSOR_SSM_G_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); + + // o_norm (reusing SSM_NORM) + layer.ssm_o_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {n_embd_head_k_kda}, 0); // FusedRMSNormGated + layer.ssm_o_norm_b = create_tensor(tn(LLM_TENSOR_SSM_NORM, "bias", i), {n_embd_head_k_kda}, TENSOR_NOT_REQUIRED); + + // o_proj + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v_kda * n_head, n_embd}, 0); + + } else { + // MLA Layer - use MLA-specific head dimensions + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla > 0 ? hparams.n_embd_head_k_mla : 192; + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla > 0 ? hparams.n_embd_head_v_mla : 128; + + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, TENSOR_NOT_REQUIRED); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + if (layer.attn_q_a_norm) { + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); + } else { + // Kimi MLA without Q compression: wq = [n_embd, n_head * n_embd_head_k_mla] + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); + } + + // Kimi: qk_rope_head_dim = 64 (actual RoPE dimension for MLA) + // Note: hparams.n_rot may be 72 (from conversion) but actual is 64 + const int64_t qk_rope_head_dim = 64; // From config: qk_rope_head_dim + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}, 0); + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_k_mla - qk_rope_head_dim + n_embd_head_v_mla)}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + } + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // MoE intermediate size (different from dense FFN) + const int64_t n_ff_exp = hparams.n_ff_exp > 0 ? hparams.n_ff_exp : 1024; + + // Kimi uses n_layer_dense_lead to determine which layers use dense FFN vs MoE + // first_k_dense_replace = 1 means layer 0 uses dense FFN, layers 1+ use MoE + if (i < (int) hparams.n_layer_dense_lead) { + // Dense FFN layer - use normal n_ff + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + // MoE layer - use n_ff_exp (1024) instead of n_ff (9216) + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared experts use moe_intermediate_size * num_shared_experts + // Kimi: shared_expert_intermediate_size = 1024 * 1 = 1024 + // Tensors are 2D: [n_embd, n_ff_shexp] or [n_ff_shexp, n_embd] + const int64_t n_ff_shexp_actual = n_ff_exp * (hparams.n_expert_shared > 0 ? hparams.n_expert_shared : 1); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp_actual, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); + + // exp_probs_b (e_score_correction_bias in vLLM) + // Try "bias" first (standard), then "weight" (for compatibility) + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + if (!layer.ffn_exp_probs_b) { + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "weight", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + } + } + } break; case LLM_ARCH_COGVLM: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -7563,6 +7753,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_KIMI_LINEAR: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -7718,6 +7912,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ARCTIC: case LLM_ARCH_DEEPSEEK: case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_KIMI_LINEAR: case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: case LLM_ARCH_GLM4: diff --git a/src/llama-model.h b/src/llama-model.h index f8342cf2cb..b067b686d2 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -84,6 +84,7 @@ enum llm_type { LLM_TYPE_35B, LLM_TYPE_36B, LLM_TYPE_40B, + LLM_TYPE_48B, LLM_TYPE_65B, LLM_TYPE_70B, LLM_TYPE_120B, @@ -404,6 +405,23 @@ struct llama_layer { struct ggml_tensor * ffn_act_beta = nullptr; struct ggml_tensor * ffn_act_eps = nullptr; + // Kimi Linear KDA (using ssm_ prefix for consistency) + // Note: ssm_dt_b already exists above (mamba bias), reused for Kimi dt_bias + struct ggml_tensor * ssm_q_conv = nullptr; + struct ggml_tensor * ssm_q_conv_b = nullptr; + struct ggml_tensor * ssm_k_conv = nullptr; + struct ggml_tensor * ssm_k_conv_b = nullptr; + struct ggml_tensor * ssm_v_conv = nullptr; + struct ggml_tensor * ssm_v_conv_b = nullptr; + struct ggml_tensor * ssm_f_a = nullptr; + struct ggml_tensor * ssm_f_b = nullptr; + struct ggml_tensor * ssm_beta = nullptr; + struct ggml_tensor * ssm_a_log = nullptr; + struct ggml_tensor * ssm_g_a = nullptr; + struct ggml_tensor * ssm_g_b = nullptr; + struct ggml_tensor * ssm_o_norm = nullptr; + struct ggml_tensor * ssm_o_norm_b = nullptr; + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 0b23eaef3a..7b8bf6e524 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -724,7 +724,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer; // sanity checks for models that have attention layers - if (qs.n_attention_wv != 0 && !is_clip_model) + // Skip this check for Kimi models which have hybrid KDA+MLA architecture + // (only MLA layers have attn_kv_b weights, KDA layers don't) + if (qs.n_attention_wv != 0 && !is_clip_model && model.arch != LLM_ARCH_KIMI_LINEAR) { const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin(); // attention layers have a non-zero number of kv heads diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index a73c4c448b..7af74b0218 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1738,26 +1738,33 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // read bpe merges and populate bpe ranks const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); + // Kimi-K2 uses custom tokenization without traditional BPE merges + const bool is_kimi_k2 = (tokenizer_pre == "kimi-k2"); + if (merges_keyidx == -1) { - throw std::runtime_error("cannot find tokenizer merges in model file\n"); - } - - const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); - for (int i = 0; i < n_merges; i++) { - const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); - //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); - - std::string first; - std::string second; - - const size_t pos = word.find(' ', 1); - - if (pos != std::string::npos) { - first = word.substr(0, pos); - second = word.substr(pos + 1); + if (!is_kimi_k2) { + throw std::runtime_error("cannot find tokenizer merges in model file\n"); } + // Kimi-K2 doesn't need merges, skip + LLAMA_LOG_INFO("%s: Kimi-K2 tokenizer detected, skipping BPE merges\n", __func__); + } else { + const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); + for (int i = 0; i < n_merges; i++) { + const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); + //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); - bpe_ranks.emplace(std::make_pair(first, second), i); + std::string first; + std::string second; + + const size_t pos = word.find(' ', 1); + + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); + } + + bpe_ranks.emplace(std::make_pair(first, second), i); + } } // default special tokens