diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b0932794d4..264651839b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -56,6 +56,7 @@ add_library(llama models/deci.cpp models/deepseek.cpp models/deepseek2.cpp + models/delta.cpp models/dots1.cpp models/dream.cpp models/ernie4-5-moe.cpp diff --git a/src/models/delta.cpp b/src/models/delta.cpp new file mode 100644 index 0000000000..887cfc7f89 --- /dev/null +++ b/src/models/delta.cpp @@ -0,0 +1,584 @@ +#include "models.h" +#include "ggml.h" +#include +#include +#include + +llm_graph_context_delta::llm_graph_context_delta(const llm_graph_params & params) : llm_graph_context_mamba(params) {} + +/** + * Unified Delta Net implementation supporting both GDA and KDA modes. + * + * GDA (Gated Delta Attention): g has shape [H, T, B] in GGML (PyTorch: [B, T, H]) + * - Per-head gating, broadcasts over K dimension + * + * KDA (Key-wise Delta Attention): g has shape [K, H, T, B] in GGML (PyTorch: [B, T, H, K]) + * - Per-key gating + * + * The mode is auto-detected based on g's dimensionality. + * + * Tensor dimension convention: + * GGML: ne[0] is innermost (fastest varying), ne[3] is outermost + * PyTorch: dim 0 is outermost, dim -1 is innermost + * So GGML [A, B, C, D] corresponds to PyTorch [D, C, B, A] + */ + +// Helper to get a slice along dimension 2 (n_chunks dimension) +static ggml_tensor * get_slice_2d(ggml_context * ctx, ggml_tensor * t, int64_t chunk) { + return ggml_view_4d(ctx, t, + t->ne[0], t->ne[1], 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], + chunk * t->nb[2]); +} + +/** + * Unified chunked Delta Net implementation. + * + * Input tensor format matches qwen3next conventions: + * @param q Query tensor [S_k, H_k, n_tokens, n_seqs] + * @param k Key tensor [S_k, H_k, n_tokens, n_seqs] + * @param v Value tensor [S_v, H_v, n_tokens, n_seqs] + * @param g Gate tensor: + * GDA: [H_v, n_tokens, n_seqs] + * KDA: [S_k, H_v, n_tokens, n_seqs] + * @param beta Beta tensor [H_v, 1, n_tokens, n_seqs] + * @param state State tensor [S_v, S_v * H_v, 1, n_seqs] + * @param causal_mask Lower triangular mask [chunk_size, chunk_size] + * @param identity Identity matrix [chunk_size, chunk_size] + * @param diag_mask Diagonal mask [chunk_size, chunk_size] + * @param il Layer index (for debugging callbacks) + * @param chunk_size Chunk size for chunked processing + * @param eps_norm Epsilon for L2 normalization + * + * @return Pair of (output_tokens, new_state) + */ +std::pair llm_graph_context_delta::build_delta_net_unified_chunking( + ggml_context * ctx0, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state_reshaped, + ggml_tensor * causal_mask, + ggml_tensor * identity, + ggml_tensor * diag_mask, + int il, + int64_t chunk_size, + float eps_norm) { + + // Input format: [S, H, n_tokens, n_seqs] (matching qwen3next convention) + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + // Detect KDA vs GDA based on g's shape + // GDA: g has shape [H_v, n_tokens, n_seqs] + // KDA: g has shape [S_k, H_v, n_tokens, n_seqs] (4D with ne[0]=S_k) + const bool is_kda = (g->ne[0] == S_k && g->ne[1] == H_v); + + // Validate tensor shapes + GGML_ASSERT(v->ne[2] == n_tokens); + GGML_ASSERT(k->ne[2] == n_tokens); + GGML_ASSERT(state_reshaped->ne[0] == S_v && state_reshaped->ne[1] == S_v && state_reshaped->ne[2] == H_v && state_reshaped->ne[3] == n_seqs); + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); + GGML_ASSERT(H_k == H_v); + + if (is_kda) { + // KDA: g shape [S_k, H_v, n_tokens, n_seqs] + GGML_ASSERT(g->ne[0] == S_k && g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs); + } else { + // GDA: g shape [H_v, n_tokens, n_seqs] + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + } + + // L2 normalize q and k + q = ggml_l2_norm(ctx0, q, eps_norm); + k = ggml_l2_norm(ctx0, k, eps_norm); + + const float scale = 1.0f / sqrtf((float)S_v); + q = ggml_scale(ctx0, q, scale); + + beta = ggml_sigmoid(ctx0, beta); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(beta, "beta_in", il); + cb(g, "g_in", il); + + // Permute tensors to working format [S, n_tokens, H, n_seqs] + // Input: [S, H, n_tokens, n_seqs] -> permute(0, 2, 1, 3) -> [S, n_tokens, H, n_seqs] + q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs); + k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs); + v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + + // Permute g based on mode + if (is_kda) { + // KDA: g [S_k, H_v, n_tokens, n_seqs] -> [S_k, n_tokens, H_k, n_seqs] + g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs); + } else { + // GDA: g [H_v, n_tokens, n_seqs] -> [n_tokens, 1, H_k, n_seqs] + g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); + } + + beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); + + cb(q, "q_perm", il); + cb(k, "k_perm", il); + cb(v, "v_perm", il); + cb(beta, "beta_perm", il); + cb(g, "g_perm", il); + cb(state_reshaped, "state_in", il); + + // Padding for chunk processing + const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; + const int64_t n_chunks = (n_tokens + pad) / chunk_size; + + q = ggml_pad(ctx0, q, 0, pad, 0, 0); + k = ggml_pad(ctx0, k, 0, pad, 0, 0); + v = ggml_pad(ctx0, v, 0, pad, 0, 0); + beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); + + if (is_kda) { + // KDA: g shape [S_k, n_tokens, H_k, n_seqs] -> pad along dim 1 + g = ggml_pad(ctx0, g, 0, pad, 0, 0); + } else { + // GDA: g shape [n_tokens, 1, H_k, n_seqs] -> pad along dim 0 + g = ggml_pad(ctx0, g, pad, 0, 0, 0); + } + + cb(q, "q_pad", il); + cb(k, "k_pad", il); + cb(v, "v_pad", il); + cb(beta, "beta_pad", il); + cb(g, "g_pad", il); + + ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); + ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); + + cb(v_beta, "v_beta", il); + cb(k_beta, "k_beta", il); + + // Reshape to chunks + q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); + k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); + k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); + v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); + v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); + beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); + + // Reshape g for chunks + ggml_tensor * g_cumsum; + if (is_kda) { + // KDA: g [S_k, n_tokens+pad, H_k, n_seqs] -> [S_k, chunk_size, n_chunks, H_k * n_seqs] + g = ggml_reshape_4d(ctx0, g, S_k, chunk_size, n_chunks, H_k * n_seqs); + // Cumsum along chunk_size dimension (ne[1]) + // GGML cumsum operates on ne[0], so we need to transpose, cumsum, transpose back + g = ggml_cont(ctx0, ggml_transpose(ctx0, g)); // [chunk_size, S_k, n_chunks, H_k * n_seqs] + g_cumsum = ggml_cumsum(ctx0, g); + g_cumsum = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); // [S_k, chunk_size, n_chunks, H_k * n_seqs] + } else { + // GDA: g [n_tokens+pad, 1, H_k, n_seqs] -> [chunk_size, 1, n_chunks, H_k * n_seqs] + g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); + g_cumsum = ggml_cumsum(ctx0, g); + } + + cb(g_cumsum, "g_cumsum", il); + + // Build attention matrix A for the WY representation solve + // For GDA: A[j,i] = sum_k(k[j,k] * exp(g[j] - g[i]) * k[i,k]) = (k @ k^T) * exp(g[j] - g[i]) + // For KDA: A[j,i] = sum_k(k[j,k] * exp(g[j,k] - g[i,k]) * k[i,k]) + // = sum_k(k[j,k] * exp(g[j,k]) * k[i,k] * exp(-g[i,k])) + // = (k * exp(g))^T @ (k * exp(-g)) + // The KDA formulation factorizes into a matmul, avoiding the need for a 5D decay mask! + + ggml_tensor * k_decay; + ggml_tensor * decay_mask = nullptr; + ggml_tensor * g_exp_pos = nullptr; // For KDA: exp(g_cumsum) + ggml_tensor * g_exp_neg = nullptr; // For KDA: exp(-g_cumsum) + + if (is_kda) { + // KDA: Exact computation using factorization + // k_pos_beta = k_beta * exp(g_cumsum), k_neg = k * exp(-g_cumsum) + // A = k_pos_beta^T @ k_neg (via mul_mat) + g_exp_pos = ggml_exp(ctx0, g_cumsum); + g_exp_neg = ggml_exp(ctx0, ggml_neg(ctx0, g_cumsum)); + ggml_tensor * k_pos_beta = ggml_mul(ctx0, k_beta, g_exp_pos); + ggml_tensor * k_neg = ggml_mul(ctx0, k, g_exp_neg); + k_decay = ggml_mul_mat(ctx0, k_pos_beta, k_neg); + } else { + // GDA: Use decay mask approach (g broadcasts over K dimension) + // g_cumsum [chunk_size, 1, n_chunks, H_v * n_seqs] + ggml_tensor * gcs_i = g_cumsum; + ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); + ggml_tensor * gcs_j_broadcast = ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); + decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); + + cb(decay_mask, "decay_mask", il); + + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); + decay_mask = ggml_exp(ctx0, decay_mask); + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); + + ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); + k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); + } + + ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); + + cb(attn, "attn_pre_solve", il); + + // Solve triangular system: (I + L) @ X = I, where L is strictly lower triangular + ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); + ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); + ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); + attn = ggml_mul(ctx0, lin_solve, causal_mask); + attn = ggml_add(ctx0, attn, identity); + + cb(attn, "attn_solved", il); + + // Compute u = A @ v and w = A @ (g.exp() * k) + v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); + + ggml_tensor * gexp; + if (is_kda) { + // KDA: Reuse g_exp_pos computed earlier + gexp = g_exp_pos; + } else { + // GDA: g_cumsum [chunk_size, 1, n_chunks, H_k * n_seqs] + ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); + gexp = ggml_exp(ctx0, g_cumsum_t); + } + + ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); + cb(kbeta_gexp, "kbeta_gexp", il); + + ggml_tensor * k_cumdecay = ggml_cont(ctx0, ggml_transpose(ctx0, + ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); + cb(k_cumdecay, "k_cumdecay", il); + + // Attention scores q @ k^T with decay + // For GDA: attn_kq[j,i] = sum_k(q[j,k] * exp(g[j] - g[i]) * k[i,k]) + // For KDA: attn_kq[j,i] = sum_k(q[j,k] * exp(g[j,k] - g[i,k]) * k[i,k]) + // = (q * exp(g))^T @ (k * exp(-g)) + ggml_tensor * attn_kq; + if (is_kda) { + // KDA: Use factorization + ggml_tensor * q_exp_pos = ggml_mul(ctx0, q, g_exp_pos); + ggml_tensor * k_exp_neg = ggml_mul(ctx0, k, g_exp_neg); + attn_kq = ggml_mul_mat(ctx0, q_exp_pos, k_exp_neg); + } else { + // GDA: Use decay mask + attn_kq = ggml_mul_mat(ctx0, k, q); + attn_kq = ggml_mul(ctx0, attn_kq, decay_mask); + } + attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); + cb(attn_kq, "attn_kq", il); + + // Compute g_last and g_diff for state updates + ggml_tensor * g_last; + ggml_tensor * g_diff_exp; + ggml_tensor * g_last_exp; + + if (is_kda) { + // KDA: g_cumsum [S_k, chunk_size, n_chunks, H_k * n_seqs] + // Get last element along chunk_size dimension (ne[1]) + g_last = ggml_view_4d(ctx0, g_cumsum, + g_cumsum->ne[0], 1, g_cumsum->ne[2], g_cumsum->ne[3], + g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3], + (g_cumsum->ne[1] - 1) * g_cumsum->nb[1]); + g_last = ggml_cont(ctx0, g_last); + g_last_exp = ggml_exp(ctx0, g_last); + + // g_diff = g_last - g_cumsum + ggml_tensor * g_last_broadcast = ggml_repeat_4d(ctx0, g_last, + g_cumsum->ne[0], g_cumsum->ne[1], g_cumsum->ne[2], g_cumsum->ne[3]); + ggml_tensor * g_diff = ggml_sub(ctx0, g_last_broadcast, g_cumsum); + g_diff_exp = ggml_exp(ctx0, g_diff); + } else { + // GDA: g_cumsum [chunk_size, 1, n_chunks, H_k * n_seqs] + g_last = ggml_view_4d(ctx0, g_cumsum, + 1, 1, g_cumsum->ne[2], g_cumsum->ne[3], + g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3], + (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum)); + g_last = ggml_cont(ctx0, g_last); + g_last_exp = ggml_exp(ctx0, g_last); + + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last)); + g_diff_exp = ggml_exp(ctx0, g_diff); + } + + cb(g_last, "g_last", il); + cb(g_last_exp, "g_last_exp", il); + + ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp); + cb(key_gdiff, "key_gdiff", il); + + // Process chunks + ggml_tensor * new_state = state_reshaped; + ggml_tensor * core_attn_out = nullptr; + + for (int64_t chunk = 0; chunk < n_chunks; chunk++) { + ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); + ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); + ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); + ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk); + ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); + + cb(attn_chunk, "attn_chunk", il); + + ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), + S_v, S_v, 1, H_v * n_seqs); + + // v_prime = k_cumdecay @ state + ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); + cb(v_prime, "v_prime_chunk", il); + + // v_new = v - v_prime + ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); + ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); + cb(v_new, "v_new_chunk", il); + + // attn_inter = (q * g.exp()) @ state + ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); + cb(attn_inter, "attn_inter_chunk", il); + + // output = attn_inter + attn @ v_new + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk); + cb(v_attn, "v_attn_chunk", il); + + ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); + cb(core_attn_out_chunk, "core_attn_out_chunk", il); + + core_attn_out = core_attn_out == nullptr + ? core_attn_out_chunk + : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2); + + // State update: state = state * g_last_exp + key_gdiff^T @ v_new + ggml_tensor * k_gdiff = ggml_cont(ctx0, get_slice_2d(ctx0, key_gdiff, chunk)); + ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, k_gdiff))); + + ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk)); + + if (is_kda) { + // KDA: g_last_exp [S_k, 1, n_chunks, H_k * n_seqs] + // State: [S_v, S_v, H_v, n_seqs] + // Need to reshape g_last_exp to broadcast correctly over V dimension only + gexp_last_chunk = ggml_reshape_4d(ctx0, gexp_last_chunk, + 1, gexp_last_chunk->ne[0], H_v, n_seqs); // [1, S_k, H_v, n_seqs] + // Transpose to [S_k, 1, H_v, n_seqs] then broadcast + gexp_last_chunk = ggml_cont(ctx0, ggml_permute(ctx0, gexp_last_chunk, 1, 0, 2, 3)); + } else { + // GDA: g_last_exp [1, 1, n_chunks, H_k * n_seqs] + // Broadcasts over both K and V dimensions + gexp_last_chunk = ggml_reshape_4d(ctx0, gexp_last_chunk, + gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs); + } + + new_state = ggml_add(ctx0, + ggml_mul(ctx0, new_state, gexp_last_chunk), + ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); + } + + // Truncate padding and permute back + ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, + S_v, n_tokens, H_v, n_seqs, + ggml_row_size(core_attn_out->type, S_v), + ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), + ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); + output_tokens = ggml_cont(ctx0, output_tokens); + + cb(output_tokens, "output_tokens", il); + + output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); + output_tokens = ggml_cont(ctx0, output_tokens); + + return {output_tokens, new_state}; +} + + +/** + * Unified autoregressive Delta Net implementation (single token processing). + * + * Input tensor format matches qwen3next conventions: + * @param q Query tensor [S_k, H_k, 1, n_seqs] + * @param k Key tensor [S_k, H_k, 1, n_seqs] + * @param v Value tensor [S_v, H_v, 1, n_seqs] + * @param g Gate tensor: + * GDA: [H_v, 1, n_seqs] + * KDA: [S_k, H_v, 1, n_seqs] + * @param beta Beta tensor [H_v, 1, 1, n_seqs] + * @param state State tensor [S_v, S_v * H_v, 1, n_seqs] + * @param il Layer index (for debugging callbacks) + * @param eps_norm Epsilon for L2 normalization + * + * @return Pair of (output_tokens, new_state) + */ +std::pair llm_graph_context_delta::build_delta_net_unified_autoregressive( + ggml_context * ctx0, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + int il, + float eps_norm) { + + // Input format: [S, H, n_tokens, n_seqs] (matching qwen3next convention) + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(n_tokens == 1); // Autoregressive mode is for single token + + // Detect KDA vs GDA based on g's shape + // GDA: g has shape [H_v, 1, n_seqs] or [H_v, n_tokens, n_seqs] + // KDA: g has shape [S_k, H_v, 1, n_seqs] or [S_k, H_v, n_tokens, n_seqs] + const bool is_kda = (g->ne[0] == S_k && g->ne[1] == H_v); + + // Validate shapes + GGML_ASSERT(v->ne[2] == n_tokens); + GGML_ASSERT(k->ne[2] == n_tokens); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs); + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); + GGML_ASSERT(H_k == H_v); + + if (is_kda) { + GGML_ASSERT(g->ne[0] == S_k && g->ne[1] == H_v); + } else { + GGML_ASSERT(g->ne[0] == H_v); + } + + // L2 normalize q and k + q = ggml_l2_norm(ctx0, q, eps_norm); + k = ggml_l2_norm(ctx0, k, eps_norm); + + const float scale = 1.0f / sqrtf((float)S_v); + q = ggml_scale(ctx0, q, scale); + beta = ggml_sigmoid(ctx0, beta); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(beta, "beta_in", il); + cb(g, "g_in", il); + + // Reshape g and beta for broadcasting + ggml_tensor * g_t; + ggml_tensor * beta_t; + + if (is_kda) { + // KDA: g [S_k, H_v, 1, n_seqs] -> [S_k, 1, H_k, n_seqs] + // For state multiplication, need [1, S_k, H_v, n_seqs] to broadcast over V only + g_t = ggml_reshape_4d(ctx0, g, S_k, 1, H_k, n_seqs); + } else { + // GDA: g [H_v, 1, n_seqs] -> [1, 1, H_k, n_seqs] + // For state multiplication, broadcasts over both K and V + g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); + } + + beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); + + // Apply exponential to g_t + g_t = ggml_exp(ctx0, g_t); + + // State decay: state = state * exp(g) + if (is_kda) { + // KDA: g_t [S_k, 1, H_k, n_seqs], state [S_v, S_v, H_v, n_seqs] + // Need to broadcast g_t over V dimension (ne[0] of state) + // Permute g_t to [1, S_k, H_k, n_seqs] for correct broadcasting + ggml_tensor * g_broadcast = ggml_cont(ctx0, ggml_permute(ctx0, g_t, 1, 0, 2, 3)); + state = ggml_mul(ctx0, state, g_broadcast); + } else { + // GDA: g_t [1, 1, H_k, n_seqs] broadcasts over both dimensions + state = ggml_mul(ctx0, state, g_t); + } + + // kv_mem = sum_k(state * k) = (state * k.unsqueeze(-1)).sum(dim=-2) + // k shape after unsqueeze: [1, S_k, H_v, n_seqs] + ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_k, H_v, n_seqs); + ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); + // Sum over K dimension (ne[1]): transpose, sum_rows, transpose back + kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)))); + + // v_t with singleton dimension: [S_v, 1, H_v, n_seqs] + ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); + + // delta = (v - kv_mem) * beta + ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); + ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); + + // State update: state = state + k.unsqueeze(-1) * delta + // k_t_unsqueezed: [1, S_k, H_v, n_seqs], delta: [S_v, 1, H_v, n_seqs] + // Broadcasting: [S_v, S_k, H_v, n_seqs] + ggml_tensor * k_t_delta = ggml_mul(ctx0, + ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_k, H_v, n_seqs), + delta); + state = ggml_add(ctx0, state, k_t_delta); + + // Output: sum_k(state * q) = (state * q.unsqueeze(-1)).sum(dim=-2) + ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_k, H_v, n_seqs); + ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); + // Sum over K dimension + ggml_tensor * core_attn_out = ggml_transpose(ctx0, + ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q)))); + + cb(core_attn_out, "output_tokens", il); + cb(state, "new_state", il); + + return {core_attn_out, state}; +} + + +/** + * Main entry point that dispatches to chunked or autoregressive based on n_tokens. + * + * Input tensor format matches qwen3next conventions: + * @param q Query tensor [S_k, H_k, n_tokens, n_seqs] + * @param k Key tensor [S_k, H_k, n_tokens, n_seqs] + * @param v Value tensor [S_v, H_v, n_tokens, n_seqs] + * @param g Gate tensor (GDA: [H_v, n_tokens, n_seqs], KDA: [S_k, H_v, n_tokens, n_seqs]) + * @param beta Beta tensor [H_v, 1, n_tokens, n_seqs] + * @param state State tensor [S_v, S_v * H_v, 1, n_seqs] + */ +std::pair llm_graph_context_delta::build_delta_net_unified( + ggml_context * ctx0, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + ggml_tensor * diag_mask, + int il, + int64_t chunk_size, + float eps_norm) { + + // Input format: [S, H, n_tokens, n_seqs] (matching qwen3next convention) + const int64_t n_tokens = q->ne[2]; + + if (n_tokens == 1) { + return build_delta_net_unified_autoregressive( + ctx0, q, k, v, g, beta, state, il, eps_norm); + } + return build_delta_net_unified_chunking( + ctx0, q, k, v, g, beta, state, causal_mask, identity, diag_mask, + il, chunk_size, eps_norm); +} diff --git a/src/models/models.h b/src/models/models.h index 6c40f48042..7b558d9c12 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -17,6 +17,53 @@ struct llm_graph_context_mamba : public llm_graph_context { }; +struct llm_graph_context_delta : public llm_graph_context_mamba { + llm_graph_context_delta(const llm_graph_params & params); + + virtual ~llm_graph_context_delta() = default; + + std::pair build_delta_net_unified_chunking( + ggml_context * ctx0, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + ggml_tensor * diag_mask, + int il, + int64_t chunk_size, + float eps_norm); + + std::pair build_delta_net_unified_autoregressive( + ggml_context * ctx0, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + int il, + float eps_norm); + + std::pair build_delta_net_unified( + ggml_context * ctx0, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + ggml_tensor * diag_mask, + int il, + int64_t chunk_size, + float eps_norm); +}; + // Base class for RWKV-related models struct llm_build_rwkv6_base : public llm_graph_context { const llama_model & model; @@ -445,7 +492,7 @@ struct llm_build_qwen3vl : public llm_graph_context { struct llm_build_qwen3vlmoe : public llm_graph_context { llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_qwen3next : public llm_graph_context_mamba { +struct llm_build_qwen3next : public llm_graph_context_delta { llm_build_qwen3next(const llama_model & model, const llm_graph_params & params); private: ggml_tensor * build_layer_attn( diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index 57b6659baf..25b2a42e86 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -4,7 +4,7 @@ #define CHUNK_SIZE 64 llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params), model(model) { + llm_graph_context_delta(params), model(model) { ggml_tensor * cur; ggml_tensor * inpL; @@ -746,7 +746,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); - state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs); + state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); cb(state, "state_predelta", il); // if head keys and value keys are different, repeat to force tensors into matching shapes @@ -775,13 +775,10 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens - std::pair attn_out; // pair of (output, new_state) - if (n_seq_tokens == 1) { - attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); - } else { - attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); - } + std::pair attn_out = build_delta_net_unified(ctx0, q_conv, k_conv, v_conv, + gate, beta, state, causal_mask, identity, diag_mask, + il, CHUNK_SIZE, hparams.f_norm_rms_eps); + ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; cb(output, "attn_output", il);