Improve comments

This commit is contained in:
Oliver Simons 2026-03-11 10:10:16 +01:00
parent f62852a33a
commit da33f8e6e3
1 changed files with 5 additions and 4 deletions

View File

@ -25,6 +25,7 @@ __global__ void gated_delta_net_cuda(const float * q,
float scale) {
const int64_t h_idx = blockIdx.x;
const int64_t sequence = blockIdx.y;
// each warp owns one column, using warp-level primitives to reduce across rows
const int lane = threadIdx.x;
const int col = blockIdx.z * blockDim.y + threadIdx.y;
@ -70,7 +71,7 @@ __global__ void gated_delta_net_cuda(const float * q,
const int i = r * WARP_SIZE + lane;
kv_shard += s_shard[r] * k_t[i];
}
float kv_col = warp_reduce_sum(kv_shard); // reduce within warp
float kv_col = warp_reduce_sum(kv_shard);
// delta[col] = (v[col] - g * kv[col]) * beta
float delta_col = (v_t[col] - g_val * kv_col) * beta_val;
@ -85,7 +86,7 @@ __global__ void gated_delta_net_cuda(const float * q,
attn_partial += s_shard[r] * q_t[i];
}
float attn_col = warp_reduce_sum(attn_partial); // reduce within warp
float attn_col = warp_reduce_sum(attn_partial);
if (lane == 0) {
attn_data[col] = attn_col * scale;
@ -99,7 +100,7 @@ __global__ void gated_delta_net_cuda(const float * q,
kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i];
}
float kv_col = warp_reduce_sum(kv_shard); // reduce within warp
float kv_col = warp_reduce_sum(kv_shard);
// delta[col] = (v[col] - kv[col]) * beta
float delta_col = (v_t[col] - kv_col) * beta_val;
@ -114,7 +115,7 @@ __global__ void gated_delta_net_cuda(const float * q,
attn_partial += s_shard[r] * q_t[i];
}
float attn_col = warp_reduce_sum(attn_partial); // reduce within warp
float attn_col = warp_reduce_sum(attn_partial);
if (lane == 0) {
attn_data[col] = attn_col * scale;