diff --git a/src/models/delta.cpp b/src/models/delta.cpp index 887cfc7f89..79618cf22e 100644 --- a/src/models/delta.cpp +++ b/src/models/delta.cpp @@ -118,16 +118,7 @@ std::pair llm_graph_context_delta::build_delta_net q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs); k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs); v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - - // Permute g based on mode - if (is_kda) { - // KDA: g [S_k, H_v, n_tokens, n_seqs] -> [S_k, n_tokens, H_k, n_seqs] - g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs); - } else { - // GDA: g [H_v, n_tokens, n_seqs] -> [n_tokens, 1, H_k, n_seqs] - g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); - } - + g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 0, 2, 1, 3), is_kda ? S_k : 1, n_tokens, H_k, n_seqs); beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); cb(q, "q_perm", il); @@ -145,14 +136,8 @@ std::pair llm_graph_context_delta::build_delta_net k = ggml_pad(ctx0, k, 0, pad, 0, 0); v = ggml_pad(ctx0, v, 0, pad, 0, 0); beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); + g = ggml_pad(ctx0, g, 0, pad, 0, 0); - if (is_kda) { - // KDA: g shape [S_k, n_tokens, H_k, n_seqs] -> pad along dim 1 - g = ggml_pad(ctx0, g, 0, pad, 0, 0); - } else { - // GDA: g shape [n_tokens, 1, H_k, n_seqs] -> pad along dim 0 - g = ggml_pad(ctx0, g, pad, 0, 0, 0); - } cb(q, "q_pad", il); cb(k, "k_pad", il); @@ -176,18 +161,20 @@ std::pair llm_graph_context_delta::build_delta_net // Reshape g for chunks ggml_tensor * g_cumsum; + ggml_tensor * g_cumsum_t; if (is_kda) { // KDA: g [S_k, n_tokens+pad, H_k, n_seqs] -> [S_k, chunk_size, n_chunks, H_k * n_seqs] g = ggml_reshape_4d(ctx0, g, S_k, chunk_size, n_chunks, H_k * n_seqs); // Cumsum along chunk_size dimension (ne[1]) // GGML cumsum operates on ne[0], so we need to transpose, cumsum, transpose back g = ggml_cont(ctx0, ggml_transpose(ctx0, g)); // [chunk_size, S_k, n_chunks, H_k * n_seqs] - g_cumsum = ggml_cumsum(ctx0, g); - g_cumsum = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); // [S_k, chunk_size, n_chunks, H_k * n_seqs] + g_cumsum_t = ggml_cumsum(ctx0, g); + g_cumsum = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum_t)); // [S_k, chunk_size, n_chunks, H_k * n_seqs] } else { // GDA: g [n_tokens+pad, 1, H_k, n_seqs] -> [chunk_size, 1, n_chunks, H_k * n_seqs] g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); g_cumsum = ggml_cumsum(ctx0, g); + g_cumsum_t = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_k * n_seqs); } cb(g_cumsum, "g_cumsum", il); @@ -217,7 +204,8 @@ std::pair llm_graph_context_delta::build_delta_net // GDA: Use decay mask approach (g broadcasts over K dimension) // g_cumsum [chunk_size, 1, n_chunks, H_v * n_seqs] ggml_tensor * gcs_i = g_cumsum; - ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); + ggml_tensor * gcs_j = g_cumsum_t; + g_exp_pos = ggml_exp(ctx0, g_cumsum_t); ggml_tensor * gcs_j_broadcast = ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); @@ -247,17 +235,7 @@ std::pair llm_graph_context_delta::build_delta_net // Compute u = A @ v and w = A @ (g.exp() * k) v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); - ggml_tensor * gexp; - if (is_kda) { - // KDA: Reuse g_exp_pos computed earlier - gexp = g_exp_pos; - } else { - // GDA: g_cumsum [chunk_size, 1, n_chunks, H_k * n_seqs] - ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); - gexp = ggml_exp(ctx0, g_cumsum_t); - } - - ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); + ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, g_exp_pos); cb(kbeta_gexp, "kbeta_gexp", il); ggml_tensor * k_cumdecay = ggml_cont(ctx0, ggml_transpose(ctx0, @@ -330,7 +308,7 @@ std::pair llm_graph_context_delta::build_delta_net ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk); - ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); + ggml_tensor * gexp_chunk = get_slice_2d(ctx0, g_exp_pos, chunk); cb(attn_chunk, "attn_chunk", il);