Refactor and optimize
This commit is contained in:
parent
34e1ed9093
commit
f98f285620
|
|
@ -118,16 +118,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_graph_context_delta::build_delta_net
|
|||
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
|
||||
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
|
||||
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
|
||||
|
||||
// Permute g based on mode
|
||||
if (is_kda) {
|
||||
// KDA: g [S_k, H_v, n_tokens, n_seqs] -> [S_k, n_tokens, H_k, n_seqs]
|
||||
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
|
||||
} else {
|
||||
// GDA: g [H_v, n_tokens, n_seqs] -> [n_tokens, 1, H_k, n_seqs]
|
||||
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
|
||||
}
|
||||
|
||||
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 0, 2, 1, 3), is_kda ? S_k : 1, n_tokens, H_k, n_seqs);
|
||||
beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
|
||||
|
||||
cb(q, "q_perm", il);
|
||||
|
|
@ -145,14 +136,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_graph_context_delta::build_delta_net
|
|||
k = ggml_pad(ctx0, k, 0, pad, 0, 0);
|
||||
v = ggml_pad(ctx0, v, 0, pad, 0, 0);
|
||||
beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
|
||||
g = ggml_pad(ctx0, g, 0, pad, 0, 0);
|
||||
|
||||
if (is_kda) {
|
||||
// KDA: g shape [S_k, n_tokens, H_k, n_seqs] -> pad along dim 1
|
||||
g = ggml_pad(ctx0, g, 0, pad, 0, 0);
|
||||
} else {
|
||||
// GDA: g shape [n_tokens, 1, H_k, n_seqs] -> pad along dim 0
|
||||
g = ggml_pad(ctx0, g, pad, 0, 0, 0);
|
||||
}
|
||||
|
||||
cb(q, "q_pad", il);
|
||||
cb(k, "k_pad", il);
|
||||
|
|
@ -176,18 +161,20 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_graph_context_delta::build_delta_net
|
|||
|
||||
// Reshape g for chunks
|
||||
ggml_tensor * g_cumsum;
|
||||
ggml_tensor * g_cumsum_t;
|
||||
if (is_kda) {
|
||||
// KDA: g [S_k, n_tokens+pad, H_k, n_seqs] -> [S_k, chunk_size, n_chunks, H_k * n_seqs]
|
||||
g = ggml_reshape_4d(ctx0, g, S_k, chunk_size, n_chunks, H_k * n_seqs);
|
||||
// Cumsum along chunk_size dimension (ne[1])
|
||||
// GGML cumsum operates on ne[0], so we need to transpose, cumsum, transpose back
|
||||
g = ggml_cont(ctx0, ggml_transpose(ctx0, g)); // [chunk_size, S_k, n_chunks, H_k * n_seqs]
|
||||
g_cumsum = ggml_cumsum(ctx0, g);
|
||||
g_cumsum = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); // [S_k, chunk_size, n_chunks, H_k * n_seqs]
|
||||
g_cumsum_t = ggml_cumsum(ctx0, g);
|
||||
g_cumsum = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum_t)); // [S_k, chunk_size, n_chunks, H_k * n_seqs]
|
||||
} else {
|
||||
// GDA: g [n_tokens+pad, 1, H_k, n_seqs] -> [chunk_size, 1, n_chunks, H_k * n_seqs]
|
||||
g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
|
||||
g_cumsum = ggml_cumsum(ctx0, g);
|
||||
g_cumsum_t = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_k * n_seqs);
|
||||
}
|
||||
|
||||
cb(g_cumsum, "g_cumsum", il);
|
||||
|
|
@ -217,7 +204,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_graph_context_delta::build_delta_net
|
|||
// GDA: Use decay mask approach (g broadcasts over K dimension)
|
||||
// g_cumsum [chunk_size, 1, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * gcs_i = g_cumsum;
|
||||
ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
|
||||
ggml_tensor * gcs_j = g_cumsum_t;
|
||||
g_exp_pos = ggml_exp(ctx0, g_cumsum_t);
|
||||
ggml_tensor * gcs_j_broadcast = ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
|
||||
decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
|
||||
|
||||
|
|
@ -247,17 +235,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_graph_context_delta::build_delta_net
|
|||
// Compute u = A @ v and w = A @ (g.exp() * k)
|
||||
v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
|
||||
|
||||
ggml_tensor * gexp;
|
||||
if (is_kda) {
|
||||
// KDA: Reuse g_exp_pos computed earlier
|
||||
gexp = g_exp_pos;
|
||||
} else {
|
||||
// GDA: g_cumsum [chunk_size, 1, n_chunks, H_k * n_seqs]
|
||||
ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
|
||||
gexp = ggml_exp(ctx0, g_cumsum_t);
|
||||
}
|
||||
|
||||
ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
|
||||
ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, g_exp_pos);
|
||||
cb(kbeta_gexp, "kbeta_gexp", il);
|
||||
|
||||
ggml_tensor * k_cumdecay = ggml_cont(ctx0, ggml_transpose(ctx0,
|
||||
|
|
@ -330,7 +308,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_graph_context_delta::build_delta_net
|
|||
ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk);
|
||||
ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk);
|
||||
ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk);
|
||||
ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk);
|
||||
ggml_tensor * gexp_chunk = get_slice_2d(ctx0, g_exp_pos, chunk);
|
||||
|
||||
cb(attn_chunk, "attn_chunk", il);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue