move kv shmem staging to function
This commit is contained in:
parent
2c623bfaea
commit
0349025db8
|
|
@ -262,6 +262,23 @@ void gqaStore<T: __BuiltinFloatingPointType>(const in uint32_t r, const in uint3
|
|||
data_ov4[o_offset + offset] = vector<D_TYPE, 4>(elems);
|
||||
}
|
||||
|
||||
void stageKVToShmem<L: IKVLoader>(L loader, uint j, uint HS, uint tid, uint KV) {
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
[unroll] for (uint idx = 0; idx < Bc * HS / 4; idx += WorkGroupSize) {
|
||||
uint d = (idx + tid) % (HS / 4);
|
||||
uint c = (idx + tid) / (HS / 4);
|
||||
if (idx + WorkGroupSize <= Bc * HSK / 4 || c < Bc) {
|
||||
vector<FLOAT, 4> K_Tf = vector<FLOAT, 4>(0);
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
K_Tf = loader.load(j * Bc + c, d);
|
||||
}
|
||||
|
||||
kvsh[c * kvsh_stride + d] = K_Tf;
|
||||
}
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
}
|
||||
|
||||
[shader("compute")]
|
||||
[numthreads(WorkGroupSize, 1, 1)]
|
||||
void main(
|
||||
|
|
@ -419,20 +436,7 @@ void main(
|
|||
}
|
||||
|
||||
if (SHMEM_STAGING != 0) {
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
[unroll] for (uint idx = 0; idx < Bc * HSK / 4; idx += WorkGroupSize) {
|
||||
uint d = (idx + tid) % (HSK / 4);
|
||||
uint c = (idx + tid) / (HSK / 4);
|
||||
if (idx + WorkGroupSize <= Bc * HSK / 4 || c < Bc) {
|
||||
vector<FLOAT, 4> K_Tf = vector<FLOAT, 4>(0);
|
||||
if (!KV_bounds_check || j * Bc + c < idcs.KV) {
|
||||
K_Tf = kloader.load(j * Bc + c, d);
|
||||
}
|
||||
|
||||
kvsh[c * kvsh_stride + d] = K_Tf;
|
||||
}
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
stageKVToShmem(kloader, j, HSK, tid, idcs.KV);
|
||||
}
|
||||
|
||||
// More d iterations means Q register caching becomes relevant
|
||||
|
|
@ -533,20 +537,7 @@ void main(
|
|||
}
|
||||
|
||||
if (SHMEM_STAGING != 0) {
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
[unroll] for (uint idx = 0; idx < Bc * HSV / 4; idx += WorkGroupSize) {
|
||||
uint d = (idx + tid) % (HSV / 4);
|
||||
uint c = (idx + tid) / (HSV / 4);
|
||||
if (idx + WorkGroupSize <= Bc * HSV / 4 || c < Bc) {
|
||||
vector<FLOAT, 4> V_Tf = vector<FLOAT, 4>(0);
|
||||
if (!KV_bounds_check || j * Bc + c < idcs.KV) {
|
||||
V_Tf = vloader.load(j * Bc + c, d);
|
||||
}
|
||||
|
||||
kvsh[c * kvsh_stride + d] = V_Tf;
|
||||
}
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
stageKVToShmem(vloader, j, HSV, tid, idcs.KV);
|
||||
}
|
||||
|
||||
[unroll] for (uint c = 0; c < cols_per_thread; ++c) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue