refactor: simplify partial RoPE with weight reordering
This commit is contained in:
parent
db84faff3a
commit
96294c6ad9
|
|
@ -7843,6 +7843,15 @@ class VaetkiModel(TextModel):
|
|||
elif name.startswith("language_model."):
|
||||
name = name.replace("language_model.", "model.")
|
||||
|
||||
if name.endswith("q_b_proj.weight"):
|
||||
n_head = self.hparams["num_attention_heads"]
|
||||
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
|
||||
qk_rope_head_dim = self.hparams["qk_rope_head_dim"]
|
||||
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
data_torch = data_torch.view(n_head, qk_head_dim, -1)
|
||||
data_torch = torch.cat([data_torch[:, qk_nope_head_dim:, :], data_torch[:, :qk_nope_head_dim, :]], dim=1)
|
||||
data_torch = data_torch.reshape(n_head * qk_head_dim, -1)
|
||||
|
||||
# VAETKI WBLRMSNorm: add 1 to weights for standard RMSNorm compatibility
|
||||
norm_weight_patterns = [
|
||||
"input_layernorm.weight",
|
||||
|
|
|
|||
|
|
@ -1153,15 +1153,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
|
||||
}
|
||||
|
||||
{
|
||||
uint32_t n_swa_temp = 0;
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, n_swa_temp, false);
|
||||
if (n_swa_temp > 0) {
|
||||
hparams.n_swa = n_swa_temp;
|
||||
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
|
||||
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer);
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
}
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
||||
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
|
||||
if (hparams.n_swa > 0) {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer);
|
||||
}
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ llm_build_vaetki::llm_build_vaetki(const llama_model & model, const llm_graph_pa
|
|||
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
llm_graph_input_attn_kv_iswa * inp_attn = build_attn_inp_kv_iswa();
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
|
|
@ -44,76 +44,64 @@ llm_build_vaetki::llm_build_vaetki(const llama_model & model, const llm_graph_pa
|
|||
q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
|
||||
cb(q, "q", il);
|
||||
|
||||
ggml_tensor * q_nope =
|
||||
ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(q->type, n_embd_head_k_mla),
|
||||
ggml_row_size(q->type, n_embd_head_k_mla) * n_head, 0);
|
||||
cb(q_nope, "q_nope", il);
|
||||
|
||||
ggml_tensor * q_pe = ggml_view_3d(
|
||||
ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(q->type, n_embd_head_k_mla),
|
||||
ggml_row_size(q->type, n_embd_head_k_mla) * n_head, ggml_row_size(q->type, n_embd_head_qk_nope));
|
||||
cb(q_pe, "q_pe", il);
|
||||
// q is now [rope | nope] after weight reordering in conversion
|
||||
// reshape to {n_embd_head_k_mla, n_head, n_tokens}
|
||||
q = ggml_reshape_3d(ctx0, q, n_embd_head_k_mla, n_head, n_tokens);
|
||||
|
||||
ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
|
||||
cb(kv_cmpr_pe, "kv_cmpr_pe", il);
|
||||
|
||||
ggml_tensor * kv_cmpr =
|
||||
ggml_view_2d(ctx0, kv_cmpr_pe, kv_lora_rank, n_tokens,
|
||||
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0);
|
||||
// {kv_lora_rank, n_tokens}
|
||||
ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe,
|
||||
kv_lora_rank, n_tokens,
|
||||
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0);
|
||||
cb(kv_cmpr, "kv_cmpr", il);
|
||||
|
||||
ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, n_embd_head_qk_rope, 1, n_tokens,
|
||||
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
|
||||
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
|
||||
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank));
|
||||
// {n_embd_head_qk_rope, 1, n_tokens}
|
||||
ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe,
|
||||
n_embd_head_qk_rope, 1, n_tokens,
|
||||
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
|
||||
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
|
||||
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank));
|
||||
cb(k_pe, "k_pe", il);
|
||||
|
||||
q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(q_pe, "q_pe_rope", il);
|
||||
// apply rope - rotates first n_rot dims, copies rest unchanged
|
||||
ggml_tensor * Qcur = ggml_rope_ext(ctx0, q, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
|
||||
freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
|
||||
freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(k_pe, "k_pe_rope", il);
|
||||
|
||||
// convert interleaved RoPE to split format
|
||||
q_pe = ggml_reshape_4d(ctx0, q_pe, 2, n_embd_head_qk_rope/2, n_head, n_tokens);
|
||||
q_pe = ggml_permute(ctx0, q_pe, 1, 0, 2, 3);
|
||||
q_pe = ggml_cont(ctx0, q_pe);
|
||||
q_pe = ggml_reshape_3d(ctx0, q_pe, n_embd_head_qk_rope, n_head, n_tokens);
|
||||
cb(q_pe, "q_pe_split", il);
|
||||
|
||||
k_pe = ggml_reshape_4d(ctx0, k_pe, 2, n_embd_head_qk_rope/2, 1, n_tokens);
|
||||
k_pe = ggml_permute(ctx0, k_pe, 1, 0, 2, 3);
|
||||
k_pe = ggml_cont(ctx0, k_pe);
|
||||
k_pe = ggml_reshape_3d(ctx0, k_pe, n_embd_head_qk_rope, 1, n_tokens);
|
||||
cb(k_pe, "k_pe_split", il);
|
||||
|
||||
kv_cmpr = build_norm(kv_cmpr, model.layers[il].attn_kv_a_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(kv_cmpr, "kv_cmpr_norm", il);
|
||||
|
||||
ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr);
|
||||
cb(kv, "kv", il);
|
||||
|
||||
ggml_tensor * k_nope =
|
||||
ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
|
||||
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v_mla),
|
||||
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v_mla) * n_head, 0);
|
||||
cb(k_nope, "k_nope_view", il);
|
||||
// {n_embd_head_qk_nope, n_head, n_tokens}
|
||||
ggml_tensor * k_nope = ggml_view_3d(ctx0, kv,
|
||||
n_embd_head_qk_nope, n_head, n_tokens,
|
||||
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v_mla),
|
||||
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v_mla) * n_head, 0);
|
||||
cb(k_nope, "k_nope", il);
|
||||
|
||||
ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, n_embd_head_v_mla, n_head, n_tokens,
|
||||
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v_mla),
|
||||
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v_mla) * n_head,
|
||||
ggml_row_size(kv->type, n_embd_head_qk_nope));
|
||||
cb(Vcur, "Vcur_view", il);
|
||||
// {n_embd_head_v_mla, n_head, n_tokens}
|
||||
ggml_tensor * Vcur = ggml_view_3d(ctx0, kv,
|
||||
n_embd_head_v_mla, n_head, n_tokens,
|
||||
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v_mla),
|
||||
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v_mla) * n_head,
|
||||
ggml_row_size(kv->type, n_embd_head_qk_nope));
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Vcur = ggml_cont(ctx0, Vcur);
|
||||
cb(Vcur, "Vcur_cont", il);
|
||||
|
||||
ggml_tensor * Qcur = ggml_concat(ctx0, q_nope, q_pe, 0);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
|
||||
ggml_tensor * q_pe_ref = ggml_view_3d(ctx0, Qcur,
|
||||
n_embd_head_qk_rope, n_head, n_tokens,
|
||||
Qcur->nb[1], Qcur->nb[2], 0);
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe_ref), k_nope, 0);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
|
|
|
|||
Loading…
Reference in New Issue