From 08b1ed86334622c2698e5ae0bc23c3e92700dfb5 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Tue, 13 Jan 2026 02:00:38 +0100 Subject: [PATCH] Adapt autoregressive version from @ymcki --- src/models/delta.cpp | 51 ++++++++++++++++++++++---------------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/src/models/delta.cpp b/src/models/delta.cpp index 79618cf22e..d3c5cce3e1 100644 --- a/src/models/delta.cpp +++ b/src/models/delta.cpp @@ -386,6 +386,9 @@ std::pair llm_graph_context_delta::build_delta_net /** * Unified autoregressive Delta Net implementation (single token processing). * + * This implementation uses matrix multiplication instead of elementwise operations + summation, + * which is more efficient and mathematically equivalent. See inline comments for equivalences. + * * Input tensor format matches qwen3next conventions: * @param q Query tensor [S_k, H_k, 1, n_seqs] * @param k Key tensor [S_k, H_k, 1, n_seqs] @@ -487,35 +490,33 @@ std::pair llm_graph_context_delta::build_delta_net state = ggml_mul(ctx0, state, g_t); } - // kv_mem = sum_k(state * k) = (state * k.unsqueeze(-1)).sum(dim=-2) - // k shape after unsqueeze: [1, S_k, H_v, n_seqs] - ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_k, H_v, n_seqs); - ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); - // Sum over K dimension (ne[1]): transpose, sum_rows, transpose back - kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)))); + // Equivalence to previous version: + // Previous: kv_mem = sum_k(state * k) using elementwise mult + sum_rows + // Current: k_state = state_t @ k_t using matrix multiplication + // These are equivalent because: sum_k(A * B) = A @ B when dimensions align + ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state)); + ggml_tensor * k_t = ggml_reshape_4d(ctx0, k, S_k, 1, H_k, n_seqs); + ggml_tensor * k_state = ggml_mul_mat(ctx0, state_t, k_t); - // v_t with singleton dimension: [S_v, 1, H_v, n_seqs] + // v_diff = v - k_state (equivalent to v - kv_mem in previous version) ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); + ggml_tensor * v_diff = ggml_sub(ctx0, v_t, k_state); + ggml_tensor * k_beta = ggml_mul(ctx0, k_t, beta_t); - // delta = (v - kv_mem) * beta - ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); - ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); - - // State update: state = state + k.unsqueeze(-1) * delta - // k_t_unsqueezed: [1, S_k, H_v, n_seqs], delta: [S_v, 1, H_v, n_seqs] - // Broadcasting: [S_v, S_k, H_v, n_seqs] - ggml_tensor * k_t_delta = ggml_mul(ctx0, - ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_k, H_v, n_seqs), - delta); - state = ggml_add(ctx0, state, k_t_delta); - - // Output: sum_k(state * q) = (state * q.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_k, H_v, n_seqs); - ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); - // Sum over K dimension - ggml_tensor * core_attn_out = ggml_transpose(ctx0, - ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q)))); + // Equivalence to previous version: + // Previous: state += k.unsqueeze(-1) * delta where delta = (v - kv_mem) * beta + // Current: state += v_diff^T @ k_beta^T using matrix multiplication + // These are equivalent because: outer_product(k, v_diff * beta) = v_diff^T @ k^T + state = ggml_add(ctx0, state, ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_diff)), ggml_cont(ctx0, ggml_transpose(ctx0, k_beta)))); + // Equivalence to previous version: + // Previous: core_attn_out = sum_k(state * q) using elementwise mult + sum_rows + // Current: core_attn_out = state_t @ q using matrix multiplication + // These are equivalent because: sum_k(A * B) = A @ B when dimensions align + q = ggml_reshape_4d(ctx0, q, S_k, 1, H_k, n_seqs); + state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state)); + ggml_tensor * core_attn_out = ggml_mul_mat(ctx0, state_t, q); + // core_attn_out should be [S_v, 1, H_v, n_seqs] after this cb(core_attn_out, "output_tokens", il); cb(state, "new_state", il);