From da33f8e6e305d5e0040d81d3536f0a5bf547bff0 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Wed, 11 Mar 2026 10:10:16 +0100 Subject: [PATCH] Improve comments --- ggml/src/ggml-cuda/gated_delta_net.cu | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 8d4008432c..e2f156cb19 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -25,6 +25,7 @@ __global__ void gated_delta_net_cuda(const float * q, float scale) { const int64_t h_idx = blockIdx.x; const int64_t sequence = blockIdx.y; + // each warp owns one column, using warp-level primitives to reduce across rows const int lane = threadIdx.x; const int col = blockIdx.z * blockDim.y + threadIdx.y; @@ -70,7 +71,7 @@ __global__ void gated_delta_net_cuda(const float * q, const int i = r * WARP_SIZE + lane; kv_shard += s_shard[r] * k_t[i]; } - float kv_col = warp_reduce_sum(kv_shard); // reduce within warp + float kv_col = warp_reduce_sum(kv_shard); // delta[col] = (v[col] - g * kv[col]) * beta float delta_col = (v_t[col] - g_val * kv_col) * beta_val; @@ -85,7 +86,7 @@ __global__ void gated_delta_net_cuda(const float * q, attn_partial += s_shard[r] * q_t[i]; } - float attn_col = warp_reduce_sum(attn_partial); // reduce within warp + float attn_col = warp_reduce_sum(attn_partial); if (lane == 0) { attn_data[col] = attn_col * scale; @@ -99,7 +100,7 @@ __global__ void gated_delta_net_cuda(const float * q, kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i]; } - float kv_col = warp_reduce_sum(kv_shard); // reduce within warp + float kv_col = warp_reduce_sum(kv_shard); // delta[col] = (v[col] - kv[col]) * beta float delta_col = (v_t[col] - kv_col) * beta_val; @@ -114,7 +115,7 @@ __global__ void gated_delta_net_cuda(const float * q, attn_partial += s_shard[r] * q_t[i]; } - float attn_col = warp_reduce_sum(attn_partial); // reduce within warp + float attn_col = warp_reduce_sum(attn_partial); if (lane == 0) { attn_data[col] = attn_col * scale;