Rename variables

This commit is contained in:
Oliver Simons 2026-03-11 15:49:49 +01:00
parent e26d75b083
commit 1623bbc906
1 changed files with 8 additions and 8 deletions

View File

@ -43,10 +43,10 @@ __global__ void gated_delta_net_cuda(const float * q,
constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v;
static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
constexpr int ROWS_PER_LANE = (S_v + warp_size - 1) / warp_size;
float s_shard[ROWS_PER_LANE];
constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
float s_shard[rows_per_lane];
#pragma unroll
for (int r = 0; r < ROWS_PER_LANE; r++) {
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = curr_state[i * S_v + col];
}
@ -68,7 +68,7 @@ __global__ void gated_delta_net_cuda(const float * q,
// kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
float kv_shard = 0.0f;
#pragma unroll
for (int r = 0; r < ROWS_PER_LANE; r++) {
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
kv_shard += s_shard[r] * k_t[i];
}
@ -81,7 +81,7 @@ __global__ void gated_delta_net_cuda(const float * q,
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
float attn_partial = 0.0f;
#pragma unroll
for (int r = 0; r < ROWS_PER_LANE; r++) {
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col;
attn_partial += s_shard[r] * q_t[i];
@ -96,7 +96,7 @@ __global__ void gated_delta_net_cuda(const float * q,
// kv[col] = sum_i g[i] * S[i][col] * k[i]
float kv_shard = 0.0f;
#pragma unroll
for (int r = 0; r < ROWS_PER_LANE; r++) {
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i];
}
@ -110,7 +110,7 @@ __global__ void gated_delta_net_cuda(const float * q,
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
float attn_partial = 0.0f;
#pragma unroll
for (int r = 0; r < ROWS_PER_LANE; r++) {
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col;
attn_partial += s_shard[r] * q_t[i];
@ -128,7 +128,7 @@ __global__ void gated_delta_net_cuda(const float * q,
// Write state back to global memory
#pragma unroll
for (int r = 0; r < ROWS_PER_LANE; r++) {
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
state[i * S_v + col] = s_shard[r];
}