diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 8d4008432c..e2f156cb19 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -25,6 +25,7 @@ __global__ void gated_delta_net_cuda(const float * q, float scale) { const int64_t h_idx = blockIdx.x; const int64_t sequence = blockIdx.y; + // each warp owns one column, using warp-level primitives to reduce across rows const int lane = threadIdx.x; const int col = blockIdx.z * blockDim.y + threadIdx.y; @@ -70,7 +71,7 @@ __global__ void gated_delta_net_cuda(const float * q, const int i = r * WARP_SIZE + lane; kv_shard += s_shard[r] * k_t[i]; } - float kv_col = warp_reduce_sum(kv_shard); // reduce within warp + float kv_col = warp_reduce_sum(kv_shard); // delta[col] = (v[col] - g * kv[col]) * beta float delta_col = (v_t[col] - g_val * kv_col) * beta_val; @@ -85,7 +86,7 @@ __global__ void gated_delta_net_cuda(const float * q, attn_partial += s_shard[r] * q_t[i]; } - float attn_col = warp_reduce_sum(attn_partial); // reduce within warp + float attn_col = warp_reduce_sum(attn_partial); if (lane == 0) { attn_data[col] = attn_col * scale; @@ -99,7 +100,7 @@ __global__ void gated_delta_net_cuda(const float * q, kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i]; } - float kv_col = warp_reduce_sum(kv_shard); // reduce within warp + float kv_col = warp_reduce_sum(kv_shard); // delta[col] = (v[col] - kv[col]) * beta float delta_col = (v_t[col] - kv_col) * beta_val; @@ -114,7 +115,7 @@ __global__ void gated_delta_net_cuda(const float * q, attn_partial += s_shard[r] * q_t[i]; } - float attn_col = warp_reduce_sum(attn_partial); // reduce within warp + float attn_col = warp_reduce_sum(attn_partial); if (lane == 0) { attn_data[col] = attn_col * scale;