Adapt autoregressive version from @ymcki

This commit is contained in:
Piotr Wilkin 2026-01-13 02:00:38 +01:00
parent f98f285620
commit 08b1ed8633
1 changed files with 26 additions and 25 deletions

View File

@ -386,6 +386,9 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_graph_context_delta::build_delta_net
/**
* Unified autoregressive Delta Net implementation (single token processing).
*
* This implementation uses matrix multiplication instead of elementwise operations + summation,
* which is more efficient and mathematically equivalent. See inline comments for equivalences.
*
* Input tensor format matches qwen3next conventions:
* @param q Query tensor [S_k, H_k, 1, n_seqs]
* @param k Key tensor [S_k, H_k, 1, n_seqs]
@ -487,35 +490,33 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_graph_context_delta::build_delta_net
state = ggml_mul(ctx0, state, g_t);
}
// kv_mem = sum_k(state * k) = (state * k.unsqueeze(-1)).sum(dim=-2)
// k shape after unsqueeze: [1, S_k, H_v, n_seqs]
ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_k, H_v, n_seqs);
ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed);
// Sum over K dimension (ne[1]): transpose, sum_rows, transpose back
kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem))));
// Equivalence to previous version:
// Previous: kv_mem = sum_k(state * k) using elementwise mult + sum_rows
// Current: k_state = state_t @ k_t using matrix multiplication
// These are equivalent because: sum_k(A * B) = A @ B when dimensions align
ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
ggml_tensor * k_t = ggml_reshape_4d(ctx0, k, S_k, 1, H_k, n_seqs);
ggml_tensor * k_state = ggml_mul_mat(ctx0, state_t, k_t);
// v_t with singleton dimension: [S_v, 1, H_v, n_seqs]
// v_diff = v - k_state (equivalent to v - kv_mem in previous version)
ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
ggml_tensor * v_diff = ggml_sub(ctx0, v_t, k_state);
ggml_tensor * k_beta = ggml_mul(ctx0, k_t, beta_t);
// delta = (v - kv_mem) * beta
ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem);
ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t);
// State update: state = state + k.unsqueeze(-1) * delta
// k_t_unsqueezed: [1, S_k, H_v, n_seqs], delta: [S_v, 1, H_v, n_seqs]
// Broadcasting: [S_v, S_k, H_v, n_seqs]
ggml_tensor * k_t_delta = ggml_mul(ctx0,
ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_k, H_v, n_seqs),
delta);
state = ggml_add(ctx0, state, k_t_delta);
// Output: sum_k(state * q) = (state * q.unsqueeze(-1)).sum(dim=-2)
ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_k, H_v, n_seqs);
ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed);
// Sum over K dimension
ggml_tensor * core_attn_out = ggml_transpose(ctx0,
ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));
// Equivalence to previous version:
// Previous: state += k.unsqueeze(-1) * delta where delta = (v - kv_mem) * beta
// Current: state += v_diff^T @ k_beta^T using matrix multiplication
// These are equivalent because: outer_product(k, v_diff * beta) = v_diff^T @ k^T
state = ggml_add(ctx0, state, ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_diff)), ggml_cont(ctx0, ggml_transpose(ctx0, k_beta))));
// Equivalence to previous version:
// Previous: core_attn_out = sum_k(state * q) using elementwise mult + sum_rows
// Current: core_attn_out = state_t @ q using matrix multiplication
// These are equivalent because: sum_k(A * B) = A @ B when dimensions align
q = ggml_reshape_4d(ctx0, q, S_k, 1, H_k, n_seqs);
state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
ggml_tensor * core_attn_out = ggml_mul_mat(ctx0, state_t, q);
// core_attn_out should be [S_v, 1, H_v, n_seqs] after this
cb(core_attn_out, "output_tokens", il);
cb(state, "new_state", il);