Revert "move kv shmem staging to function"

This reverts commit 0349025db8.
This commit is contained in:
Ruben Ortlam 2026-03-13 08:18:33 +01:00
parent 0349025db8
commit e880cb2e0d
1 changed files with 28 additions and 19 deletions

View File

@ -262,23 +262,6 @@ void gqaStore<T: __BuiltinFloatingPointType>(const in uint32_t r, const in uint3
data_ov4[o_offset + offset] = vector<D_TYPE, 4>(elems);
}
void stageKVToShmem<L: IKVLoader>(L loader, uint j, uint HS, uint tid, uint KV) {
GroupMemoryBarrierWithGroupSync();
[unroll] for (uint idx = 0; idx < Bc * HS / 4; idx += WorkGroupSize) {
uint d = (idx + tid) % (HS / 4);
uint c = (idx + tid) / (HS / 4);
if (idx + WorkGroupSize <= Bc * HSK / 4 || c < Bc) {
vector<FLOAT, 4> K_Tf = vector<FLOAT, 4>(0);
if (!KV_bounds_check || j * Bc + c < KV) {
K_Tf = loader.load(j * Bc + c, d);
}
kvsh[c * kvsh_stride + d] = K_Tf;
}
}
GroupMemoryBarrierWithGroupSync();
}
[shader("compute")]
[numthreads(WorkGroupSize, 1, 1)]
void main(
@ -436,7 +419,20 @@ void main(
}
if (SHMEM_STAGING != 0) {
stageKVToShmem(kloader, j, HSK, tid, idcs.KV);
GroupMemoryBarrierWithGroupSync();
[unroll] for (uint idx = 0; idx < Bc * HSK / 4; idx += WorkGroupSize) {
uint d = (idx + tid) % (HSK / 4);
uint c = (idx + tid) / (HSK / 4);
if (idx + WorkGroupSize <= Bc * HSK / 4 || c < Bc) {
vector<FLOAT, 4> K_Tf = vector<FLOAT, 4>(0);
if (!KV_bounds_check || j * Bc + c < idcs.KV) {
K_Tf = kloader.load(j * Bc + c, d);
}
kvsh[c * kvsh_stride + d] = K_Tf;
}
}
GroupMemoryBarrierWithGroupSync();
}
// More d iterations means Q register caching becomes relevant
@ -537,7 +533,20 @@ void main(
}
if (SHMEM_STAGING != 0) {
stageKVToShmem(vloader, j, HSV, tid, idcs.KV);
GroupMemoryBarrierWithGroupSync();
[unroll] for (uint idx = 0; idx < Bc * HSV / 4; idx += WorkGroupSize) {
uint d = (idx + tid) % (HSV / 4);
uint c = (idx + tid) / (HSV / 4);
if (idx + WorkGroupSize <= Bc * HSV / 4 || c < Bc) {
vector<FLOAT, 4> V_Tf = vector<FLOAT, 4>(0);
if (!KV_bounds_check || j * Bc + c < idcs.KV) {
V_Tf = vloader.load(j * Bc + c, d);
}
kvsh[c * kvsh_stride + d] = V_Tf;
}
}
GroupMemoryBarrierWithGroupSync();
}
[unroll] for (uint c = 0; c < cols_per_thread; ++c) {