Refactor and optimize

This commit is contained in:
Piotr Wilkin 2026-01-13 01:45:08 +01:00
parent 34e1ed9093
commit f98f285620
1 changed files with 10 additions and 32 deletions

View File

@ -118,16 +118,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_graph_context_delta::build_delta_net
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
// Permute g based on mode
if (is_kda) {
// KDA: g [S_k, H_v, n_tokens, n_seqs] -> [S_k, n_tokens, H_k, n_seqs]
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
} else {
// GDA: g [H_v, n_tokens, n_seqs] -> [n_tokens, 1, H_k, n_seqs]
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
}
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 0, 2, 1, 3), is_kda ? S_k : 1, n_tokens, H_k, n_seqs);
beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
cb(q, "q_perm", il);
@ -145,14 +136,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_graph_context_delta::build_delta_net
k = ggml_pad(ctx0, k, 0, pad, 0, 0);
v = ggml_pad(ctx0, v, 0, pad, 0, 0);
beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
g = ggml_pad(ctx0, g, 0, pad, 0, 0);
if (is_kda) {
// KDA: g shape [S_k, n_tokens, H_k, n_seqs] -> pad along dim 1
g = ggml_pad(ctx0, g, 0, pad, 0, 0);
} else {
// GDA: g shape [n_tokens, 1, H_k, n_seqs] -> pad along dim 0
g = ggml_pad(ctx0, g, pad, 0, 0, 0);
}
cb(q, "q_pad", il);
cb(k, "k_pad", il);
@ -176,18 +161,20 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_graph_context_delta::build_delta_net
// Reshape g for chunks
ggml_tensor * g_cumsum;
ggml_tensor * g_cumsum_t;
if (is_kda) {
// KDA: g [S_k, n_tokens+pad, H_k, n_seqs] -> [S_k, chunk_size, n_chunks, H_k * n_seqs]
g = ggml_reshape_4d(ctx0, g, S_k, chunk_size, n_chunks, H_k * n_seqs);
// Cumsum along chunk_size dimension (ne[1])
// GGML cumsum operates on ne[0], so we need to transpose, cumsum, transpose back
g = ggml_cont(ctx0, ggml_transpose(ctx0, g)); // [chunk_size, S_k, n_chunks, H_k * n_seqs]
g_cumsum = ggml_cumsum(ctx0, g);
g_cumsum = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); // [S_k, chunk_size, n_chunks, H_k * n_seqs]
g_cumsum_t = ggml_cumsum(ctx0, g);
g_cumsum = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum_t)); // [S_k, chunk_size, n_chunks, H_k * n_seqs]
} else {
// GDA: g [n_tokens+pad, 1, H_k, n_seqs] -> [chunk_size, 1, n_chunks, H_k * n_seqs]
g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
g_cumsum = ggml_cumsum(ctx0, g);
g_cumsum_t = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_k * n_seqs);
}
cb(g_cumsum, "g_cumsum", il);
@ -217,7 +204,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_graph_context_delta::build_delta_net
// GDA: Use decay mask approach (g broadcasts over K dimension)
// g_cumsum [chunk_size, 1, n_chunks, H_v * n_seqs]
ggml_tensor * gcs_i = g_cumsum;
ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
ggml_tensor * gcs_j = g_cumsum_t;
g_exp_pos = ggml_exp(ctx0, g_cumsum_t);
ggml_tensor * gcs_j_broadcast = ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
@ -247,17 +235,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_graph_context_delta::build_delta_net
// Compute u = A @ v and w = A @ (g.exp() * k)
v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
ggml_tensor * gexp;
if (is_kda) {
// KDA: Reuse g_exp_pos computed earlier
gexp = g_exp_pos;
} else {
// GDA: g_cumsum [chunk_size, 1, n_chunks, H_k * n_seqs]
ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
gexp = ggml_exp(ctx0, g_cumsum_t);
}
ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, g_exp_pos);
cb(kbeta_gexp, "kbeta_gexp", il);
ggml_tensor * k_cumdecay = ggml_cont(ctx0, ggml_transpose(ctx0,
@ -330,7 +308,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_graph_context_delta::build_delta_net
ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk);
ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk);
ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk);
ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk);
ggml_tensor * gexp_chunk = get_slice_2d(ctx0, g_exp_pos, chunk);
cb(attn_chunk, "attn_chunk", il);