refactor: simplify partial RoPE with weight reordering

This commit is contained in:
suhyun-hwang 2026-01-10 20:56:26 +09:00
parent db84faff3a
commit 96294c6ad9
3 changed files with 51 additions and 58 deletions

View File

@ -7843,6 +7843,15 @@ class VaetkiModel(TextModel):
elif name.startswith("language_model."):
name = name.replace("language_model.", "model.")
if name.endswith("q_b_proj.weight"):
n_head = self.hparams["num_attention_heads"]
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
qk_rope_head_dim = self.hparams["qk_rope_head_dim"]
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
data_torch = data_torch.view(n_head, qk_head_dim, -1)
data_torch = torch.cat([data_torch[:, qk_nope_head_dim:, :], data_torch[:, :qk_nope_head_dim, :]], dim=1)
data_torch = data_torch.reshape(n_head * qk_head_dim, -1)
# VAETKI WBLRMSNorm: add 1 to weights for standard RMSNorm compatibility
norm_weight_patterns = [
"input_layernorm.weight",

View File

@ -1153,15 +1153,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
}
{
uint32_t n_swa_temp = 0;
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, n_swa_temp, false);
if (n_swa_temp > 0) {
hparams.n_swa = n_swa_temp;
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer);
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
}
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
if (hparams.n_swa > 0) {
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer);
}
switch (hparams.n_layer) {

View File

@ -20,7 +20,7 @@ llm_build_vaetki::llm_build_vaetki(const llama_model & model, const llm_graph_pa
ggml_tensor * inp_pos = build_inp_pos();
llm_graph_input_attn_kv_iswa * inp_attn = build_attn_inp_kv_iswa();
auto * inp_attn = build_attn_inp_kv_iswa();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@ -44,76 +44,64 @@ llm_build_vaetki::llm_build_vaetki(const llama_model & model, const llm_graph_pa
q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
cb(q, "q", il);
ggml_tensor * q_nope =
ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(q->type, n_embd_head_k_mla),
ggml_row_size(q->type, n_embd_head_k_mla) * n_head, 0);
cb(q_nope, "q_nope", il);
ggml_tensor * q_pe = ggml_view_3d(
ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(q->type, n_embd_head_k_mla),
ggml_row_size(q->type, n_embd_head_k_mla) * n_head, ggml_row_size(q->type, n_embd_head_qk_nope));
cb(q_pe, "q_pe", il);
// q is now [rope | nope] after weight reordering in conversion
// reshape to {n_embd_head_k_mla, n_head, n_tokens}
q = ggml_reshape_3d(ctx0, q, n_embd_head_k_mla, n_head, n_tokens);
ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
cb(kv_cmpr_pe, "kv_cmpr_pe", il);
ggml_tensor * kv_cmpr =
ggml_view_2d(ctx0, kv_cmpr_pe, kv_lora_rank, n_tokens,
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0);
// {kv_lora_rank, n_tokens}
ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe,
kv_lora_rank, n_tokens,
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0);
cb(kv_cmpr, "kv_cmpr", il);
ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, n_embd_head_qk_rope, 1, n_tokens,
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank));
// {n_embd_head_qk_rope, 1, n_tokens}
ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe,
n_embd_head_qk_rope, 1, n_tokens,
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank));
cb(k_pe, "k_pe", il);
q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(q_pe, "q_pe_rope", il);
// apply rope - rotates first n_rot dims, copies rest unchanged
ggml_tensor * Qcur = ggml_rope_ext(ctx0, q, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow);
cb(k_pe, "k_pe_rope", il);
// convert interleaved RoPE to split format
q_pe = ggml_reshape_4d(ctx0, q_pe, 2, n_embd_head_qk_rope/2, n_head, n_tokens);
q_pe = ggml_permute(ctx0, q_pe, 1, 0, 2, 3);
q_pe = ggml_cont(ctx0, q_pe);
q_pe = ggml_reshape_3d(ctx0, q_pe, n_embd_head_qk_rope, n_head, n_tokens);
cb(q_pe, "q_pe_split", il);
k_pe = ggml_reshape_4d(ctx0, k_pe, 2, n_embd_head_qk_rope/2, 1, n_tokens);
k_pe = ggml_permute(ctx0, k_pe, 1, 0, 2, 3);
k_pe = ggml_cont(ctx0, k_pe);
k_pe = ggml_reshape_3d(ctx0, k_pe, n_embd_head_qk_rope, 1, n_tokens);
cb(k_pe, "k_pe_split", il);
kv_cmpr = build_norm(kv_cmpr, model.layers[il].attn_kv_a_norm, nullptr, LLM_NORM_RMS, il);
cb(kv_cmpr, "kv_cmpr_norm", il);
ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr);
cb(kv, "kv", il);
ggml_tensor * k_nope =
ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v_mla),
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v_mla) * n_head, 0);
cb(k_nope, "k_nope_view", il);
// {n_embd_head_qk_nope, n_head, n_tokens}
ggml_tensor * k_nope = ggml_view_3d(ctx0, kv,
n_embd_head_qk_nope, n_head, n_tokens,
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v_mla),
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v_mla) * n_head, 0);
cb(k_nope, "k_nope", il);
ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, n_embd_head_v_mla, n_head, n_tokens,
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v_mla),
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v_mla) * n_head,
ggml_row_size(kv->type, n_embd_head_qk_nope));
cb(Vcur, "Vcur_view", il);
// {n_embd_head_v_mla, n_head, n_tokens}
ggml_tensor * Vcur = ggml_view_3d(ctx0, kv,
n_embd_head_v_mla, n_head, n_tokens,
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v_mla),
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v_mla) * n_head,
ggml_row_size(kv->type, n_embd_head_qk_nope));
cb(Vcur, "Vcur", il);
Vcur = ggml_cont(ctx0, Vcur);
cb(Vcur, "Vcur_cont", il);
ggml_tensor * Qcur = ggml_concat(ctx0, q_nope, q_pe, 0);
cb(Qcur, "Qcur", il);
ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
ggml_tensor * q_pe_ref = ggml_view_3d(ctx0, Qcur,
n_embd_head_qk_rope, n_head, n_tokens,
Qcur->nb[1], Qcur->nb[2], 0);
ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe_ref), k_nope, 0);
cb(Kcur, "Kcur", il);
cur = build_attn(inp_attn,