Change to decay mask approach

This commit is contained in:
Piotr Wilkin 2026-01-13 14:52:05 +01:00
parent 08b1ed8633
commit e9ad184926
1 changed files with 70 additions and 19 deletions

View File

@ -181,25 +181,54 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_graph_context_delta::build_delta_net
// Build attention matrix A for the WY representation solve
// For GDA: A[j,i] = sum_k(k[j,k] * exp(g[j] - g[i]) * k[i,k]) = (k @ k^T) * exp(g[j] - g[i])
// For KDA: A[j,i] = sum_k(k[j,k] * exp(g[j,k] - g[i,k]) * k[i,k])
// = sum_k(k[j,k] * exp(g[j,k]) * k[i,k] * exp(-g[i,k]))
// = (k * exp(g))^T @ (k * exp(-g))
// The KDA formulation factorizes into a matmul, avoiding the need for a 5D decay mask!
// For KDA: A[j,i] = sum_k(k_beta[j,k] * exp(g[j,k] - g[i,k]) * k[i,k])
// KDA uses decay mask with S_k packed into batch to compute exp(g[j,k] - g[i,k]) per-key
ggml_tensor * k_decay;
ggml_tensor * decay_mask = nullptr;
ggml_tensor * g_exp_pos = nullptr; // For KDA: exp(g_cumsum)
ggml_tensor * g_exp_neg = nullptr; // For KDA: exp(-g_cumsum)
ggml_tensor * g_exp_pos = nullptr;
if (is_kda) {
// KDA: Exact computation using factorization
// k_pos_beta = k_beta * exp(g_cumsum), k_neg = k * exp(-g_cumsum)
// A = k_pos_beta^T @ k_neg (via mul_mat)
// KDA: Use decay mask with S_k in leading dimension for efficient mul_mat reduction
// A[j,i] = sum_k(k_beta[j,k] * exp(g[j,k] - g[i,k]) * k[i,k])
// By putting S_k in dim 0, mul_mat implicitly sums over it
const int64_t CHB = n_chunks * H_k * n_seqs;
// g_cumsum_t is [chunk_size, S_k, n_chunks, H_k * n_seqs]
// Reshape to [chunk_size, S_k, CHB] then build decay mask
ggml_tensor * gcs = ggml_reshape_3d(ctx0, g_cumsum_t, chunk_size, S_k, CHB);
ggml_tensor * gcs_i = ggml_reshape_4d(ctx0, gcs, chunk_size, 1, S_k, CHB);
ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, gcs, 1, chunk_size, S_k, CHB);
// Build decay mask: [chunk_size, chunk_size, S_k, CHB]
ggml_tensor * gcs_j_bc = ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, S_k, CHB);
decay_mask = ggml_sub(ctx0, gcs_j_bc, gcs_i);
cb(decay_mask, "decay_mask_kda", il);
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
decay_mask = ggml_exp(ctx0, decay_mask);
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
// Permute to [S_k, chunk_size_j, chunk_size_i, CHB] for mul_mat reduction over S_k
decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB);
// Reshape k and k_beta for broadcasting with decay_mask
// k_i: indexed at position i (dim 2 of decay_mask)
// k_beta_j: indexed at position j (dim 1 of decay_mask)
ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB);
ggml_tensor * k_beta_j = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, 1, CHB);
// decay_k_beta_j[s,j,i,b] = decay[s,j,i,b] * k_beta[s,j,b]
ggml_tensor * decay_k_beta_j = ggml_mul(ctx0, decay_mask, k_beta_j);
// mul_mat sums over S_k: result[j,1,i,CHB] = sum_s decay_k_beta_j[s,j,i,b] * k_i[s,1,i,b]
k_decay = ggml_mul_mat(ctx0, decay_k_beta_j, k_i);
k_decay = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, k_decay, chunk_size, chunk_size, n_chunks, H_k * n_seqs)));
// g_exp_pos is still needed for later (kbeta_gexp, etc.)
g_exp_pos = ggml_exp(ctx0, g_cumsum);
g_exp_neg = ggml_exp(ctx0, ggml_neg(ctx0, g_cumsum));
ggml_tensor * k_pos_beta = ggml_mul(ctx0, k_beta, g_exp_pos);
ggml_tensor * k_neg = ggml_mul(ctx0, k, g_exp_neg);
k_decay = ggml_mul_mat(ctx0, k_pos_beta, k_neg);
} else {
// GDA: Use decay mask approach (g broadcasts over K dimension)
// g_cumsum [chunk_size, 1, n_chunks, H_v * n_seqs]
@ -245,19 +274,41 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_graph_context_delta::build_delta_net
// Attention scores q @ k^T with decay
// For GDA: attn_kq[j,i] = sum_k(q[j,k] * exp(g[j] - g[i]) * k[i,k])
// For KDA: attn_kq[j,i] = sum_k(q[j,k] * exp(g[j,k] - g[i,k]) * k[i,k])
// = (q * exp(g))^T @ (k * exp(-g))
ggml_tensor * attn_kq;
if (is_kda) {
// KDA: Use factorization
ggml_tensor * q_exp_pos = ggml_mul(ctx0, q, g_exp_pos);
ggml_tensor * k_exp_neg = ggml_mul(ctx0, k, g_exp_neg);
attn_kq = ggml_mul_mat(ctx0, q_exp_pos, k_exp_neg);
// KDA: Same approach as k_decay - use decay_mask with S_k in leading dim
const int64_t CHB = n_chunks * H_k * n_seqs;
// Rebuild decay mask (same structure as k_decay)
ggml_tensor * gcs = ggml_reshape_3d(ctx0, g_cumsum_t, chunk_size, S_k, CHB);
ggml_tensor * gcs_i = ggml_reshape_4d(ctx0, gcs, chunk_size, 1, S_k, CHB);
ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, gcs, 1, chunk_size, S_k, CHB);
ggml_tensor * gcs_j_bc = ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, S_k, CHB);
ggml_tensor * decay_mask_kq = ggml_sub(ctx0, gcs_j_bc, gcs_i);
decay_mask_kq = ggml_mul(ctx0, decay_mask_kq, diag_mask);
decay_mask_kq = ggml_exp(ctx0, decay_mask_kq);
decay_mask_kq = ggml_mul(ctx0, decay_mask_kq, diag_mask);
// Permute to [S_k, chunk_size_j, chunk_size_i, CHB]
decay_mask_kq = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask_kq, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB);
// q_j: indexed at position j, k_i: indexed at position i
ggml_tensor * q_j = ggml_reshape_4d(ctx0, q, S_k, chunk_size, 1, CHB);
ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB);
// decay_q_j[s,j,i,b] = decay[s,j,i,b] * q[s,j,b]
ggml_tensor * decay_q_j = ggml_mul(ctx0, decay_mask_kq, q_j);
// mul_mat sums over S_k
attn_kq = ggml_mul_mat(ctx0, decay_q_j, k_i);
attn_kq = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, attn_kq, chunk_size, chunk_size, n_chunks, H_k * n_seqs)));
} else {
// GDA: Use decay mask
attn_kq = ggml_mul_mat(ctx0, k, q);
attn_kq = ggml_mul(ctx0, attn_kq, decay_mask);
attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
}
attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
cb(attn_kq, "attn_kq", il);
// Compute g_last and g_diff for state updates