Adapt autoregressive version from @ymcki
This commit is contained in:
parent
f98f285620
commit
08b1ed8633
|
|
@ -386,6 +386,9 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_graph_context_delta::build_delta_net
|
|||
/**
|
||||
* Unified autoregressive Delta Net implementation (single token processing).
|
||||
*
|
||||
* This implementation uses matrix multiplication instead of elementwise operations + summation,
|
||||
* which is more efficient and mathematically equivalent. See inline comments for equivalences.
|
||||
*
|
||||
* Input tensor format matches qwen3next conventions:
|
||||
* @param q Query tensor [S_k, H_k, 1, n_seqs]
|
||||
* @param k Key tensor [S_k, H_k, 1, n_seqs]
|
||||
|
|
@ -487,35 +490,33 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_graph_context_delta::build_delta_net
|
|||
state = ggml_mul(ctx0, state, g_t);
|
||||
}
|
||||
|
||||
// kv_mem = sum_k(state * k) = (state * k.unsqueeze(-1)).sum(dim=-2)
|
||||
// k shape after unsqueeze: [1, S_k, H_v, n_seqs]
|
||||
ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_k, H_v, n_seqs);
|
||||
ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed);
|
||||
// Sum over K dimension (ne[1]): transpose, sum_rows, transpose back
|
||||
kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem))));
|
||||
// Equivalence to previous version:
|
||||
// Previous: kv_mem = sum_k(state * k) using elementwise mult + sum_rows
|
||||
// Current: k_state = state_t @ k_t using matrix multiplication
|
||||
// These are equivalent because: sum_k(A * B) = A @ B when dimensions align
|
||||
ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
|
||||
ggml_tensor * k_t = ggml_reshape_4d(ctx0, k, S_k, 1, H_k, n_seqs);
|
||||
ggml_tensor * k_state = ggml_mul_mat(ctx0, state_t, k_t);
|
||||
|
||||
// v_t with singleton dimension: [S_v, 1, H_v, n_seqs]
|
||||
// v_diff = v - k_state (equivalent to v - kv_mem in previous version)
|
||||
ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
|
||||
ggml_tensor * v_diff = ggml_sub(ctx0, v_t, k_state);
|
||||
ggml_tensor * k_beta = ggml_mul(ctx0, k_t, beta_t);
|
||||
|
||||
// delta = (v - kv_mem) * beta
|
||||
ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem);
|
||||
ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t);
|
||||
|
||||
// State update: state = state + k.unsqueeze(-1) * delta
|
||||
// k_t_unsqueezed: [1, S_k, H_v, n_seqs], delta: [S_v, 1, H_v, n_seqs]
|
||||
// Broadcasting: [S_v, S_k, H_v, n_seqs]
|
||||
ggml_tensor * k_t_delta = ggml_mul(ctx0,
|
||||
ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_k, H_v, n_seqs),
|
||||
delta);
|
||||
state = ggml_add(ctx0, state, k_t_delta);
|
||||
|
||||
// Output: sum_k(state * q) = (state * q.unsqueeze(-1)).sum(dim=-2)
|
||||
ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_k, H_v, n_seqs);
|
||||
ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed);
|
||||
// Sum over K dimension
|
||||
ggml_tensor * core_attn_out = ggml_transpose(ctx0,
|
||||
ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));
|
||||
// Equivalence to previous version:
|
||||
// Previous: state += k.unsqueeze(-1) * delta where delta = (v - kv_mem) * beta
|
||||
// Current: state += v_diff^T @ k_beta^T using matrix multiplication
|
||||
// These are equivalent because: outer_product(k, v_diff * beta) = v_diff^T @ k^T
|
||||
state = ggml_add(ctx0, state, ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_diff)), ggml_cont(ctx0, ggml_transpose(ctx0, k_beta))));
|
||||
|
||||
// Equivalence to previous version:
|
||||
// Previous: core_attn_out = sum_k(state * q) using elementwise mult + sum_rows
|
||||
// Current: core_attn_out = state_t @ q using matrix multiplication
|
||||
// These are equivalent because: sum_k(A * B) = A @ B when dimensions align
|
||||
q = ggml_reshape_4d(ctx0, q, S_k, 1, H_k, n_seqs);
|
||||
state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
|
||||
ggml_tensor * core_attn_out = ggml_mul_mat(ctx0, state_t, q);
|
||||
// core_attn_out should be [S_v, 1, H_v, n_seqs] after this
|
||||
cb(core_attn_out, "output_tokens", il);
|
||||
cb(state, "new_state", il);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue