move kv shmem staging to function

This commit is contained in:
Ruben Ortlam 2026-03-09 15:02:25 +01:00
parent 2c623bfaea
commit 0349025db8
1 changed files with 19 additions and 28 deletions

View File

@ -262,6 +262,23 @@ void gqaStore<T: __BuiltinFloatingPointType>(const in uint32_t r, const in uint3
data_ov4[o_offset + offset] = vector<D_TYPE, 4>(elems);
}
void stageKVToShmem<L: IKVLoader>(L loader, uint j, uint HS, uint tid, uint KV) {
GroupMemoryBarrierWithGroupSync();
[unroll] for (uint idx = 0; idx < Bc * HS / 4; idx += WorkGroupSize) {
uint d = (idx + tid) % (HS / 4);
uint c = (idx + tid) / (HS / 4);
if (idx + WorkGroupSize <= Bc * HSK / 4 || c < Bc) {
vector<FLOAT, 4> K_Tf = vector<FLOAT, 4>(0);
if (!KV_bounds_check || j * Bc + c < KV) {
K_Tf = loader.load(j * Bc + c, d);
}
kvsh[c * kvsh_stride + d] = K_Tf;
}
}
GroupMemoryBarrierWithGroupSync();
}
[shader("compute")]
[numthreads(WorkGroupSize, 1, 1)]
void main(
@ -419,20 +436,7 @@ void main(
}
if (SHMEM_STAGING != 0) {
GroupMemoryBarrierWithGroupSync();
[unroll] for (uint idx = 0; idx < Bc * HSK / 4; idx += WorkGroupSize) {
uint d = (idx + tid) % (HSK / 4);
uint c = (idx + tid) / (HSK / 4);
if (idx + WorkGroupSize <= Bc * HSK / 4 || c < Bc) {
vector<FLOAT, 4> K_Tf = vector<FLOAT, 4>(0);
if (!KV_bounds_check || j * Bc + c < idcs.KV) {
K_Tf = kloader.load(j * Bc + c, d);
}
kvsh[c * kvsh_stride + d] = K_Tf;
}
}
GroupMemoryBarrierWithGroupSync();
stageKVToShmem(kloader, j, HSK, tid, idcs.KV);
}
// More d iterations means Q register caching becomes relevant
@ -533,20 +537,7 @@ void main(
}
if (SHMEM_STAGING != 0) {
GroupMemoryBarrierWithGroupSync();
[unroll] for (uint idx = 0; idx < Bc * HSV / 4; idx += WorkGroupSize) {
uint d = (idx + tid) % (HSV / 4);
uint c = (idx + tid) / (HSV / 4);
if (idx + WorkGroupSize <= Bc * HSV / 4 || c < Bc) {
vector<FLOAT, 4> V_Tf = vector<FLOAT, 4>(0);
if (!KV_bounds_check || j * Bc + c < idcs.KV) {
V_Tf = vloader.load(j * Bc + c, d);
}
kvsh[c * kvsh_stride + d] = V_Tf;
}
}
GroupMemoryBarrierWithGroupSync();
stageKVToShmem(vloader, j, HSV, tid, idcs.KV);
}
[unroll] for (uint c = 0; c < cols_per_thread; ++c) {