Shard columns across warps

This reduces register pressure (avoids spill for S_v = 128) and gives
the warp-scheduler more CTAs to schedule (thus hiding data-access
latencies).
This commit is contained in:
Oliver Simons 2026-03-10 15:03:15 +01:00
parent 55d2e3a361
commit 2352caa102
1 changed files with 45 additions and 25 deletions

View File

@ -26,7 +26,8 @@ __global__ void gated_delta_net_cuda(const float * q,
float scale) {
const int64_t h_idx = blockIdx.x;
const int64_t sequence = blockIdx.y;
const int col = threadIdx.x; // each thread owns one column
const int lane = threadIdx.x;
const int col = blockIdx.z * blockDim.y + threadIdx.y;
const int64_t iq1 = fastmodulo_s64(h_idx, neqk1_magic);
const int64_t iq3 = fastdiv_s64(sequence, rq3_magic);
@ -40,11 +41,13 @@ __global__ void gated_delta_net_cuda(const float * q,
curr_state += state_offset;
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
// Load state column into registers
float s[S_v];
static_assert(S_v % WARP_SIZE == 0, "S_v must be a multiple of WARP_SIZE");
constexpr int ROWS_PER_LANE = S_v / WARP_SIZE;
float s_shard[ROWS_PER_LANE];
#pragma unroll
for (int i = 0; i < S_v; i++) {
s[i] = curr_state[i * S_v + col];
for (int r = 0; r < ROWS_PER_LANE; r++) {
const int i = r * WARP_SIZE + lane;
s_shard[r] = curr_state[i * S_v + col];
}
for (int t = 0; t < n_tokens; t++) {
@ -62,46 +65,61 @@ __global__ void gated_delta_net_cuda(const float * q,
const float g_val = expf(*g_t);
// kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
float kv_col = 0.0f;
float kv_shard = 0.0f;
#pragma unroll
for (int i = 0; i < S_v; i++) {
kv_col += s[i] * k_t[i];
for (int r = 0; r < ROWS_PER_LANE; r++) {
const int i = r * WARP_SIZE + lane;
kv_shard += s_shard[r] * k_t[i];
}
float kv_col = warp_reduce_sum(kv_shard); // reduce within warp
// delta[col] = (v[col] - g * kv[col]) * beta
float delta_col = (v_t[col] - g_val * kv_col) * beta_val;
// fused: S[i][col] = g * S[i][col] + k[i] * delta[col]
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
float attn_col = 0.0f;
float attn_partial = 0.0f;
#pragma unroll
for (int i = 0; i < S_v; i++) {
s[i] = g_val * s[i] + k_t[i] * delta_col;
attn_col += s[i] * q_t[i];
for (int r = 0; r < ROWS_PER_LANE; r++) {
const int i = r * WARP_SIZE + lane;
s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col;
attn_partial += s_shard[r] * q_t[i];
}
attn_data[col] = attn_col * scale;
float attn_col = warp_reduce_sum(attn_partial); // reduce within warp
if (lane == 0) {
attn_data[col] = attn_col * scale;
}
} else {
// kv[col] = sum_i g[i] * S[i][col] * k[i]
float kv_col = 0.0f;
float kv_shard = 0.0f;
#pragma unroll
for (int i = 0; i < S_v; i++) {
kv_col += expf(g_t[i]) * s[i] * k_t[i];
for (int r = 0; r < ROWS_PER_LANE; r++) {
const int i = r * WARP_SIZE + lane;
kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i];
}
float kv_col = warp_reduce_sum(kv_shard); // reduce within warp
// delta[col] = (v[col] - kv[col]) * beta
float delta_col = (v_t[col] - kv_col) * beta_val;
// fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col]
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
float attn_col = 0.0f;
float attn_partial = 0.0f;
#pragma unroll
for (int i = 0; i < S_v; i++) {
s[i] = expf(g_t[i]) * s[i] + k_t[i] * delta_col;
attn_col += s[i] * q_t[i];
for (int r = 0; r < ROWS_PER_LANE; r++) {
const int i = r * WARP_SIZE + lane;
s_shard[r] = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col;
attn_partial += s_shard[r] * q_t[i];
}
attn_data[col] = attn_col * scale;
float attn_col = warp_reduce_sum(attn_partial); // reduce within warp
if (lane == 0) {
attn_data[col] = attn_col * scale;
}
}
attn_data += S_v * H;
@ -109,8 +127,9 @@ __global__ void gated_delta_net_cuda(const float * q,
// Write state back to global memory
#pragma unroll
for (int i = 0; i < S_v; i++) {
state[i * S_v + col] = s[i];
for (int r = 0; r < ROWS_PER_LANE; r++) {
const int i = r * WARP_SIZE + lane;
state[i * S_v + col] = s_shard[r];
}
}
@ -126,8 +145,9 @@ static void launch_gated_delta_net(
int64_t neqk1, int64_t rq3,
float scale, cudaStream_t stream) {
dim3 grid_dims(H, n_seqs, 1);
dim3 block_dims(S_v, 1, 1);
const int num_warps = 4;
dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps);
dim3 block_dims(WARP_SIZE, num_warps, 1);
const fastdiv_consts_s64 neqk1_magic = init_fastdiv_s64(neqk1);
const fastdiv_consts_s64 rq3_magic = init_fastdiv_s64(rq3);