From 2352caa10279a3d6b7568339d9b5c8914a61294e Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Tue, 10 Mar 2026 15:03:15 +0100 Subject: [PATCH] Shard columns across warps This reduces register pressure (avoids spill for S_v = 128) and gives the warp-scheduler more CTAs to schedule (thus hiding data-access latencies). --- ggml/src/ggml-cuda/gated_delta_net.cu | 70 +++++++++++++++++---------- 1 file changed, 45 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 801d828cea..4125c25fa5 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -26,7 +26,8 @@ __global__ void gated_delta_net_cuda(const float * q, float scale) { const int64_t h_idx = blockIdx.x; const int64_t sequence = blockIdx.y; - const int col = threadIdx.x; // each thread owns one column + const int lane = threadIdx.x; + const int col = blockIdx.z * blockDim.y + threadIdx.y; const int64_t iq1 = fastmodulo_s64(h_idx, neqk1_magic); const int64_t iq3 = fastdiv_s64(sequence, rq3_magic); @@ -40,11 +41,13 @@ __global__ void gated_delta_net_cuda(const float * q, curr_state += state_offset; attn_data += (sequence * n_tokens * H + h_idx) * S_v; - // Load state column into registers - float s[S_v]; + static_assert(S_v % WARP_SIZE == 0, "S_v must be a multiple of WARP_SIZE"); + constexpr int ROWS_PER_LANE = S_v / WARP_SIZE; + float s_shard[ROWS_PER_LANE]; #pragma unroll - for (int i = 0; i < S_v; i++) { - s[i] = curr_state[i * S_v + col]; + for (int r = 0; r < ROWS_PER_LANE; r++) { + const int i = r * WARP_SIZE + lane; + s_shard[r] = curr_state[i * S_v + col]; } for (int t = 0; t < n_tokens; t++) { @@ -62,46 +65,61 @@ __global__ void gated_delta_net_cuda(const float * q, const float g_val = expf(*g_t); // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i] - float kv_col = 0.0f; + float kv_shard = 0.0f; #pragma unroll - for (int i = 0; i < S_v; i++) { - kv_col += s[i] * k_t[i]; + for (int r = 0; r < ROWS_PER_LANE; r++) { + const int i = r * WARP_SIZE + lane; + kv_shard += s_shard[r] * k_t[i]; } + float kv_col = warp_reduce_sum(kv_shard); // reduce within warp // delta[col] = (v[col] - g * kv[col]) * beta float delta_col = (v_t[col] - g_val * kv_col) * beta_val; // fused: S[i][col] = g * S[i][col] + k[i] * delta[col] // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i] - float attn_col = 0.0f; + float attn_partial = 0.0f; #pragma unroll - for (int i = 0; i < S_v; i++) { - s[i] = g_val * s[i] + k_t[i] * delta_col; - attn_col += s[i] * q_t[i]; + for (int r = 0; r < ROWS_PER_LANE; r++) { + const int i = r * WARP_SIZE + lane; + s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col; + attn_partial += s_shard[r] * q_t[i]; } - attn_data[col] = attn_col * scale; + float attn_col = warp_reduce_sum(attn_partial); // reduce within warp + + if (lane == 0) { + attn_data[col] = attn_col * scale; + } } else { // kv[col] = sum_i g[i] * S[i][col] * k[i] - float kv_col = 0.0f; + float kv_shard = 0.0f; #pragma unroll - for (int i = 0; i < S_v; i++) { - kv_col += expf(g_t[i]) * s[i] * k_t[i]; + for (int r = 0; r < ROWS_PER_LANE; r++) { + const int i = r * WARP_SIZE + lane; + kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i]; } + float kv_col = warp_reduce_sum(kv_shard); // reduce within warp + // delta[col] = (v[col] - kv[col]) * beta float delta_col = (v_t[col] - kv_col) * beta_val; // fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col] // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i] - float attn_col = 0.0f; + float attn_partial = 0.0f; #pragma unroll - for (int i = 0; i < S_v; i++) { - s[i] = expf(g_t[i]) * s[i] + k_t[i] * delta_col; - attn_col += s[i] * q_t[i]; + for (int r = 0; r < ROWS_PER_LANE; r++) { + const int i = r * WARP_SIZE + lane; + s_shard[r] = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col; + attn_partial += s_shard[r] * q_t[i]; } - attn_data[col] = attn_col * scale; + float attn_col = warp_reduce_sum(attn_partial); // reduce within warp + + if (lane == 0) { + attn_data[col] = attn_col * scale; + } } attn_data += S_v * H; @@ -109,8 +127,9 @@ __global__ void gated_delta_net_cuda(const float * q, // Write state back to global memory #pragma unroll - for (int i = 0; i < S_v; i++) { - state[i * S_v + col] = s[i]; + for (int r = 0; r < ROWS_PER_LANE; r++) { + const int i = r * WARP_SIZE + lane; + state[i * S_v + col] = s_shard[r]; } } @@ -126,8 +145,9 @@ static void launch_gated_delta_net( int64_t neqk1, int64_t rq3, float scale, cudaStream_t stream) { - dim3 grid_dims(H, n_seqs, 1); - dim3 block_dims(S_v, 1, 1); + const int num_warps = 4; + dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps); + dim3 block_dims(WARP_SIZE, num_warps, 1); const fastdiv_consts_s64 neqk1_magic = init_fastdiv_s64(neqk1); const fastdiv_consts_s64 rq3_magic = init_fastdiv_s64(rq3);