Shard columns across warps
This reduces register pressure (avoids spill for S_v = 128) and gives the warp-scheduler more CTAs to schedule (thus hiding data-access latencies).
This commit is contained in:
parent
55d2e3a361
commit
2352caa102
|
|
@ -26,7 +26,8 @@ __global__ void gated_delta_net_cuda(const float * q,
|
|||
float scale) {
|
||||
const int64_t h_idx = blockIdx.x;
|
||||
const int64_t sequence = blockIdx.y;
|
||||
const int col = threadIdx.x; // each thread owns one column
|
||||
const int lane = threadIdx.x;
|
||||
const int col = blockIdx.z * blockDim.y + threadIdx.y;
|
||||
|
||||
const int64_t iq1 = fastmodulo_s64(h_idx, neqk1_magic);
|
||||
const int64_t iq3 = fastdiv_s64(sequence, rq3_magic);
|
||||
|
|
@ -40,11 +41,13 @@ __global__ void gated_delta_net_cuda(const float * q,
|
|||
curr_state += state_offset;
|
||||
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
|
||||
|
||||
// Load state column into registers
|
||||
float s[S_v];
|
||||
static_assert(S_v % WARP_SIZE == 0, "S_v must be a multiple of WARP_SIZE");
|
||||
constexpr int ROWS_PER_LANE = S_v / WARP_SIZE;
|
||||
float s_shard[ROWS_PER_LANE];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < S_v; i++) {
|
||||
s[i] = curr_state[i * S_v + col];
|
||||
for (int r = 0; r < ROWS_PER_LANE; r++) {
|
||||
const int i = r * WARP_SIZE + lane;
|
||||
s_shard[r] = curr_state[i * S_v + col];
|
||||
}
|
||||
|
||||
for (int t = 0; t < n_tokens; t++) {
|
||||
|
|
@ -62,46 +65,61 @@ __global__ void gated_delta_net_cuda(const float * q,
|
|||
const float g_val = expf(*g_t);
|
||||
|
||||
// kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
|
||||
float kv_col = 0.0f;
|
||||
float kv_shard = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < S_v; i++) {
|
||||
kv_col += s[i] * k_t[i];
|
||||
for (int r = 0; r < ROWS_PER_LANE; r++) {
|
||||
const int i = r * WARP_SIZE + lane;
|
||||
kv_shard += s_shard[r] * k_t[i];
|
||||
}
|
||||
float kv_col = warp_reduce_sum(kv_shard); // reduce within warp
|
||||
|
||||
// delta[col] = (v[col] - g * kv[col]) * beta
|
||||
float delta_col = (v_t[col] - g_val * kv_col) * beta_val;
|
||||
|
||||
// fused: S[i][col] = g * S[i][col] + k[i] * delta[col]
|
||||
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
|
||||
float attn_col = 0.0f;
|
||||
float attn_partial = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < S_v; i++) {
|
||||
s[i] = g_val * s[i] + k_t[i] * delta_col;
|
||||
attn_col += s[i] * q_t[i];
|
||||
for (int r = 0; r < ROWS_PER_LANE; r++) {
|
||||
const int i = r * WARP_SIZE + lane;
|
||||
s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col;
|
||||
attn_partial += s_shard[r] * q_t[i];
|
||||
}
|
||||
|
||||
attn_data[col] = attn_col * scale;
|
||||
float attn_col = warp_reduce_sum(attn_partial); // reduce within warp
|
||||
|
||||
if (lane == 0) {
|
||||
attn_data[col] = attn_col * scale;
|
||||
}
|
||||
} else {
|
||||
// kv[col] = sum_i g[i] * S[i][col] * k[i]
|
||||
float kv_col = 0.0f;
|
||||
float kv_shard = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < S_v; i++) {
|
||||
kv_col += expf(g_t[i]) * s[i] * k_t[i];
|
||||
for (int r = 0; r < ROWS_PER_LANE; r++) {
|
||||
const int i = r * WARP_SIZE + lane;
|
||||
kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i];
|
||||
}
|
||||
|
||||
float kv_col = warp_reduce_sum(kv_shard); // reduce within warp
|
||||
|
||||
// delta[col] = (v[col] - kv[col]) * beta
|
||||
float delta_col = (v_t[col] - kv_col) * beta_val;
|
||||
|
||||
// fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col]
|
||||
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
|
||||
float attn_col = 0.0f;
|
||||
float attn_partial = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < S_v; i++) {
|
||||
s[i] = expf(g_t[i]) * s[i] + k_t[i] * delta_col;
|
||||
attn_col += s[i] * q_t[i];
|
||||
for (int r = 0; r < ROWS_PER_LANE; r++) {
|
||||
const int i = r * WARP_SIZE + lane;
|
||||
s_shard[r] = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col;
|
||||
attn_partial += s_shard[r] * q_t[i];
|
||||
}
|
||||
|
||||
attn_data[col] = attn_col * scale;
|
||||
float attn_col = warp_reduce_sum(attn_partial); // reduce within warp
|
||||
|
||||
if (lane == 0) {
|
||||
attn_data[col] = attn_col * scale;
|
||||
}
|
||||
}
|
||||
|
||||
attn_data += S_v * H;
|
||||
|
|
@ -109,8 +127,9 @@ __global__ void gated_delta_net_cuda(const float * q,
|
|||
|
||||
// Write state back to global memory
|
||||
#pragma unroll
|
||||
for (int i = 0; i < S_v; i++) {
|
||||
state[i * S_v + col] = s[i];
|
||||
for (int r = 0; r < ROWS_PER_LANE; r++) {
|
||||
const int i = r * WARP_SIZE + lane;
|
||||
state[i * S_v + col] = s_shard[r];
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -126,8 +145,9 @@ static void launch_gated_delta_net(
|
|||
int64_t neqk1, int64_t rq3,
|
||||
float scale, cudaStream_t stream) {
|
||||
|
||||
dim3 grid_dims(H, n_seqs, 1);
|
||||
dim3 block_dims(S_v, 1, 1);
|
||||
const int num_warps = 4;
|
||||
dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps);
|
||||
dim3 block_dims(WARP_SIZE, num_warps, 1);
|
||||
|
||||
const fastdiv_consts_s64 neqk1_magic = init_fastdiv_s64(neqk1);
|
||||
const fastdiv_consts_s64 rq3_magic = init_fastdiv_s64(rq3);
|
||||
|
|
|
|||
Loading…
Reference in New Issue