Improve comments
This commit is contained in:
parent
f62852a33a
commit
da33f8e6e3
|
|
@ -25,6 +25,7 @@ __global__ void gated_delta_net_cuda(const float * q,
|
|||
float scale) {
|
||||
const int64_t h_idx = blockIdx.x;
|
||||
const int64_t sequence = blockIdx.y;
|
||||
// each warp owns one column, using warp-level primitives to reduce across rows
|
||||
const int lane = threadIdx.x;
|
||||
const int col = blockIdx.z * blockDim.y + threadIdx.y;
|
||||
|
||||
|
|
@ -70,7 +71,7 @@ __global__ void gated_delta_net_cuda(const float * q,
|
|||
const int i = r * WARP_SIZE + lane;
|
||||
kv_shard += s_shard[r] * k_t[i];
|
||||
}
|
||||
float kv_col = warp_reduce_sum(kv_shard); // reduce within warp
|
||||
float kv_col = warp_reduce_sum(kv_shard);
|
||||
|
||||
// delta[col] = (v[col] - g * kv[col]) * beta
|
||||
float delta_col = (v_t[col] - g_val * kv_col) * beta_val;
|
||||
|
|
@ -85,7 +86,7 @@ __global__ void gated_delta_net_cuda(const float * q,
|
|||
attn_partial += s_shard[r] * q_t[i];
|
||||
}
|
||||
|
||||
float attn_col = warp_reduce_sum(attn_partial); // reduce within warp
|
||||
float attn_col = warp_reduce_sum(attn_partial);
|
||||
|
||||
if (lane == 0) {
|
||||
attn_data[col] = attn_col * scale;
|
||||
|
|
@ -99,7 +100,7 @@ __global__ void gated_delta_net_cuda(const float * q,
|
|||
kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i];
|
||||
}
|
||||
|
||||
float kv_col = warp_reduce_sum(kv_shard); // reduce within warp
|
||||
float kv_col = warp_reduce_sum(kv_shard);
|
||||
|
||||
// delta[col] = (v[col] - kv[col]) * beta
|
||||
float delta_col = (v_t[col] - kv_col) * beta_val;
|
||||
|
|
@ -114,7 +115,7 @@ __global__ void gated_delta_net_cuda(const float * q,
|
|||
attn_partial += s_shard[r] * q_t[i];
|
||||
}
|
||||
|
||||
float attn_col = warp_reduce_sum(attn_partial); // reduce within warp
|
||||
float attn_col = warp_reduce_sum(attn_partial);
|
||||
|
||||
if (lane == 0) {
|
||||
attn_data[col] = attn_col * scale;
|
||||
|
|
|
|||
Loading…
Reference in New Issue