diff --git a/src/models/delta.cpp b/src/models/delta.cpp index d3c5cce3e1..533d5ff3e8 100644 --- a/src/models/delta.cpp +++ b/src/models/delta.cpp @@ -181,25 +181,54 @@ std::pair llm_graph_context_delta::build_delta_net // Build attention matrix A for the WY representation solve // For GDA: A[j,i] = sum_k(k[j,k] * exp(g[j] - g[i]) * k[i,k]) = (k @ k^T) * exp(g[j] - g[i]) - // For KDA: A[j,i] = sum_k(k[j,k] * exp(g[j,k] - g[i,k]) * k[i,k]) - // = sum_k(k[j,k] * exp(g[j,k]) * k[i,k] * exp(-g[i,k])) - // = (k * exp(g))^T @ (k * exp(-g)) - // The KDA formulation factorizes into a matmul, avoiding the need for a 5D decay mask! + // For KDA: A[j,i] = sum_k(k_beta[j,k] * exp(g[j,k] - g[i,k]) * k[i,k]) + // KDA uses decay mask with S_k packed into batch to compute exp(g[j,k] - g[i,k]) per-key ggml_tensor * k_decay; ggml_tensor * decay_mask = nullptr; - ggml_tensor * g_exp_pos = nullptr; // For KDA: exp(g_cumsum) - ggml_tensor * g_exp_neg = nullptr; // For KDA: exp(-g_cumsum) + ggml_tensor * g_exp_pos = nullptr; if (is_kda) { - // KDA: Exact computation using factorization - // k_pos_beta = k_beta * exp(g_cumsum), k_neg = k * exp(-g_cumsum) - // A = k_pos_beta^T @ k_neg (via mul_mat) + // KDA: Use decay mask with S_k in leading dimension for efficient mul_mat reduction + // A[j,i] = sum_k(k_beta[j,k] * exp(g[j,k] - g[i,k]) * k[i,k]) + // By putting S_k in dim 0, mul_mat implicitly sums over it + + const int64_t CHB = n_chunks * H_k * n_seqs; + + // g_cumsum_t is [chunk_size, S_k, n_chunks, H_k * n_seqs] + // Reshape to [chunk_size, S_k, CHB] then build decay mask + ggml_tensor * gcs = ggml_reshape_3d(ctx0, g_cumsum_t, chunk_size, S_k, CHB); + ggml_tensor * gcs_i = ggml_reshape_4d(ctx0, gcs, chunk_size, 1, S_k, CHB); + ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, gcs, 1, chunk_size, S_k, CHB); + + // Build decay mask: [chunk_size, chunk_size, S_k, CHB] + ggml_tensor * gcs_j_bc = ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, S_k, CHB); + decay_mask = ggml_sub(ctx0, gcs_j_bc, gcs_i); + + cb(decay_mask, "decay_mask_kda", il); + + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); + decay_mask = ggml_exp(ctx0, decay_mask); + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); + + // Permute to [S_k, chunk_size_j, chunk_size_i, CHB] for mul_mat reduction over S_k + decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB); + + // Reshape k and k_beta for broadcasting with decay_mask + // k_i: indexed at position i (dim 2 of decay_mask) + // k_beta_j: indexed at position j (dim 1 of decay_mask) + ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB); + ggml_tensor * k_beta_j = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, 1, CHB); + + // decay_k_beta_j[s,j,i,b] = decay[s,j,i,b] * k_beta[s,j,b] + ggml_tensor * decay_k_beta_j = ggml_mul(ctx0, decay_mask, k_beta_j); + + // mul_mat sums over S_k: result[j,1,i,CHB] = sum_s decay_k_beta_j[s,j,i,b] * k_i[s,1,i,b] + k_decay = ggml_mul_mat(ctx0, decay_k_beta_j, k_i); + k_decay = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, k_decay, chunk_size, chunk_size, n_chunks, H_k * n_seqs))); + + // g_exp_pos is still needed for later (kbeta_gexp, etc.) g_exp_pos = ggml_exp(ctx0, g_cumsum); - g_exp_neg = ggml_exp(ctx0, ggml_neg(ctx0, g_cumsum)); - ggml_tensor * k_pos_beta = ggml_mul(ctx0, k_beta, g_exp_pos); - ggml_tensor * k_neg = ggml_mul(ctx0, k, g_exp_neg); - k_decay = ggml_mul_mat(ctx0, k_pos_beta, k_neg); } else { // GDA: Use decay mask approach (g broadcasts over K dimension) // g_cumsum [chunk_size, 1, n_chunks, H_v * n_seqs] @@ -245,19 +274,41 @@ std::pair llm_graph_context_delta::build_delta_net // Attention scores q @ k^T with decay // For GDA: attn_kq[j,i] = sum_k(q[j,k] * exp(g[j] - g[i]) * k[i,k]) // For KDA: attn_kq[j,i] = sum_k(q[j,k] * exp(g[j,k] - g[i,k]) * k[i,k]) - // = (q * exp(g))^T @ (k * exp(-g)) ggml_tensor * attn_kq; if (is_kda) { - // KDA: Use factorization - ggml_tensor * q_exp_pos = ggml_mul(ctx0, q, g_exp_pos); - ggml_tensor * k_exp_neg = ggml_mul(ctx0, k, g_exp_neg); - attn_kq = ggml_mul_mat(ctx0, q_exp_pos, k_exp_neg); + // KDA: Same approach as k_decay - use decay_mask with S_k in leading dim + const int64_t CHB = n_chunks * H_k * n_seqs; + + // Rebuild decay mask (same structure as k_decay) + ggml_tensor * gcs = ggml_reshape_3d(ctx0, g_cumsum_t, chunk_size, S_k, CHB); + ggml_tensor * gcs_i = ggml_reshape_4d(ctx0, gcs, chunk_size, 1, S_k, CHB); + ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, gcs, 1, chunk_size, S_k, CHB); + ggml_tensor * gcs_j_bc = ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, S_k, CHB); + ggml_tensor * decay_mask_kq = ggml_sub(ctx0, gcs_j_bc, gcs_i); + + decay_mask_kq = ggml_mul(ctx0, decay_mask_kq, diag_mask); + decay_mask_kq = ggml_exp(ctx0, decay_mask_kq); + decay_mask_kq = ggml_mul(ctx0, decay_mask_kq, diag_mask); + + // Permute to [S_k, chunk_size_j, chunk_size_i, CHB] + decay_mask_kq = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask_kq, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB); + + // q_j: indexed at position j, k_i: indexed at position i + ggml_tensor * q_j = ggml_reshape_4d(ctx0, q, S_k, chunk_size, 1, CHB); + ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB); + + // decay_q_j[s,j,i,b] = decay[s,j,i,b] * q[s,j,b] + ggml_tensor * decay_q_j = ggml_mul(ctx0, decay_mask_kq, q_j); + + // mul_mat sums over S_k + attn_kq = ggml_mul_mat(ctx0, decay_q_j, k_i); + attn_kq = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, attn_kq, chunk_size, chunk_size, n_chunks, H_k * n_seqs))); } else { // GDA: Use decay mask attn_kq = ggml_mul_mat(ctx0, k, q); attn_kq = ggml_mul(ctx0, attn_kq, decay_mask); + attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); } - attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); cb(attn_kq, "attn_kq", il); // Compute g_last and g_diff for state updates