new class llm_graph_input_mem_hybrid_k to get around the new MLA change. switch the concat order of ggml_concat calls in kimi-linear.cpp to accommodate MLA changes. Removed support for exp_probs_b.weight
This commit is contained in:
parent
bb02b5d515
commit
f1525b3695
|
|
@ -533,6 +533,47 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
|
|||
return res;
|
||||
}
|
||||
|
||||
void llm_graph_input_mem_hybrid_k::set_input(const llama_ubatch * ubatch) {
|
||||
mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
|
||||
|
||||
mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
|
||||
|
||||
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
||||
|
||||
if (inp_rs->s_copy) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
|
||||
int32_t * data = (int32_t *) inp_rs->s_copy->data;
|
||||
|
||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||
for (uint32_t i = 0; i < n_rs; ++i) {
|
||||
data[i] = mctx->get_recr()->s_copy(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) {
|
||||
const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
|
||||
|
||||
this->mctx = mctx;
|
||||
|
||||
bool res = true;
|
||||
|
||||
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||
|
||||
res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
|
||||
res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||
|
||||
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
||||
|
||||
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
|
||||
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
|
||||
|
||||
res &= inp_rs->head == mctx->get_recr()->get_head();
|
||||
res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
const auto * attn_ctx = mctx->get_attn();
|
||||
|
||||
|
|
@ -2272,6 +2313,17 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
|||
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
llm_graph_input_mem_hybrid_k * llm_graph_context::build_inp_mem_hybrid_k() const {
|
||||
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
||||
|
||||
auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
|
||||
auto inp_attn = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_mem_hybrid_k>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
||||
|
||||
return (llm_graph_input_mem_hybrid_k *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const {
|
||||
const auto * mctx_cur = static_cast<const llama_memory_hybrid_iswa_context *>(mctx);
|
||||
|
||||
|
|
|
|||
|
|
@ -433,6 +433,34 @@ public:
|
|||
const llama_memory_hybrid_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_mem_hybrid_k : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_mem_hybrid_k(
|
||||
const llama_cparams & cparams,
|
||||
std::unique_ptr<llm_graph_input_attn_k> inp_attn,
|
||||
std::unique_ptr<llm_graph_input_rs> inp_rs,
|
||||
const llama_memory_hybrid_context * mctx) :
|
||||
inp_attn(std::move(inp_attn)),
|
||||
inp_rs(std::move(inp_rs)),
|
||||
cparams(cparams),
|
||||
mctx(mctx) { }
|
||||
virtual ~llm_graph_input_mem_hybrid_k() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
bool can_reuse(const llm_graph_params & params) override;
|
||||
|
||||
std::unique_ptr<llm_graph_input_attn_k> inp_attn;
|
||||
std::unique_ptr<llm_graph_input_rs> inp_rs;
|
||||
|
||||
llm_graph_input_attn_k * get_attn() const { return inp_attn.get(); }
|
||||
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
|
||||
|
||||
const llama_cparams cparams;
|
||||
|
||||
const llama_memory_hybrid_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_mem_hybrid_iswa(
|
||||
|
|
@ -960,6 +988,7 @@ struct llm_graph_context {
|
|||
//
|
||||
|
||||
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
|
||||
llm_graph_input_mem_hybrid_k * build_inp_mem_hybrid_k() const;
|
||||
|
||||
llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const;
|
||||
|
||||
|
|
|
|||
|
|
@ -2454,12 +2454,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
case LLM_ARCH_KIMI_LINEAR:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv, false);
|
||||
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
|
||||
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv, false);
|
||||
ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.kda_head_dim, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl);
|
||||
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl);
|
||||
ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
|
||||
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot);
|
||||
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
|
||||
ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.kda_head_dim);
|
||||
|
||||
// MLA qk_rope_head_dim (for reference)
|
||||
// qk_rope_head_dim = 64, qk_nope_head_dim = 128, qk_head_dim = 192
|
||||
|
|
@ -2471,11 +2471,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
}
|
||||
|
||||
// MoE parameters - Kimi uses moe_intermediate_size = 1024
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false);
|
||||
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false);
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
|
||||
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
|
||||
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
|
|
@ -6863,8 +6862,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
// MLA Layer - use MLA-specific head dimensions
|
||||
const int64_t q_lora_rank = hparams.n_lora_q;
|
||||
const int64_t kv_lora_rank = hparams.n_lora_kv;
|
||||
const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla;
|
||||
const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla;
|
||||
const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla();
|
||||
const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla();
|
||||
|
||||
layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, TENSOR_NOT_REQUIRED);
|
||||
layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0);
|
||||
|
|
@ -6917,10 +6916,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp_actual, n_embd}, TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
|
||||
if (!layer.ffn_exp_probs_b) {
|
||||
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "weight", i), {n_expert}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
|||
// Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM)
|
||||
// So we don't need inp_pos
|
||||
|
||||
auto * inp = build_inp_mem_hybrid();
|
||||
auto * inp = build_inp_mem_hybrid_k();
|
||||
auto * inp_rs = inp->get_recr();
|
||||
auto * inp_attn = inp->get_attn();
|
||||
|
||||
|
|
@ -104,8 +104,8 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
|||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||
|
||||
// MLA params
|
||||
const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla;
|
||||
const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla;
|
||||
const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla();
|
||||
const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla();
|
||||
const int64_t kv_lora_rank = hparams.n_lora_kv;
|
||||
// qk_rope_head_dim = 64 (from Kimi config) which is hparams.n_rot
|
||||
// Confirmed from tensor shape: wkv_a_mqa [2304, 576] = [n_embd, kv_lora_rank + qk_rope_head_dim]
|
||||
|
|
@ -258,14 +258,14 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
|||
|
||||
// {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
|
||||
// note: rope must go first for in-place context shifting in build_rope_shift()
|
||||
Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0);
|
||||
Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
|
||||
cb(kv_cmpr, "kv_cmpr_reshape", il);
|
||||
|
||||
// {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0);
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
// {kv_lora_rank, 1, n_tokens}
|
||||
|
|
@ -299,7 +299,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
|||
// Need to broadcast k_pe from [qk_rope, 1, n_tokens] to [qk_rope, n_head, n_tokens]
|
||||
ggml_tensor * k_pe_target = ggml_new_tensor_3d(ctx0, k_pe->type, n_embd_head_qk_rope, n_head, n_tokens);
|
||||
ggml_tensor * k_pe_repeated = ggml_repeat(ctx0, k_pe, k_pe_target);
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, k_pe_repeated, 0);
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, k_pe_repeated, k_nope, 0);
|
||||
cb(Kcur, "mla_K", il);
|
||||
|
||||
// Direct softmax attention (with MHA KV cache)
|
||||
|
|
|
|||
Loading…
Reference in New Issue