Revert "move kv shmem staging to function"
This reverts commit 0349025db8.
This commit is contained in:
parent
0349025db8
commit
e880cb2e0d
|
|
@ -262,23 +262,6 @@ void gqaStore<T: __BuiltinFloatingPointType>(const in uint32_t r, const in uint3
|
||||||
data_ov4[o_offset + offset] = vector<D_TYPE, 4>(elems);
|
data_ov4[o_offset + offset] = vector<D_TYPE, 4>(elems);
|
||||||
}
|
}
|
||||||
|
|
||||||
void stageKVToShmem<L: IKVLoader>(L loader, uint j, uint HS, uint tid, uint KV) {
|
|
||||||
GroupMemoryBarrierWithGroupSync();
|
|
||||||
[unroll] for (uint idx = 0; idx < Bc * HS / 4; idx += WorkGroupSize) {
|
|
||||||
uint d = (idx + tid) % (HS / 4);
|
|
||||||
uint c = (idx + tid) / (HS / 4);
|
|
||||||
if (idx + WorkGroupSize <= Bc * HSK / 4 || c < Bc) {
|
|
||||||
vector<FLOAT, 4> K_Tf = vector<FLOAT, 4>(0);
|
|
||||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
|
||||||
K_Tf = loader.load(j * Bc + c, d);
|
|
||||||
}
|
|
||||||
|
|
||||||
kvsh[c * kvsh_stride + d] = K_Tf;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
GroupMemoryBarrierWithGroupSync();
|
|
||||||
}
|
|
||||||
|
|
||||||
[shader("compute")]
|
[shader("compute")]
|
||||||
[numthreads(WorkGroupSize, 1, 1)]
|
[numthreads(WorkGroupSize, 1, 1)]
|
||||||
void main(
|
void main(
|
||||||
|
|
@ -436,7 +419,20 @@ void main(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (SHMEM_STAGING != 0) {
|
if (SHMEM_STAGING != 0) {
|
||||||
stageKVToShmem(kloader, j, HSK, tid, idcs.KV);
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
[unroll] for (uint idx = 0; idx < Bc * HSK / 4; idx += WorkGroupSize) {
|
||||||
|
uint d = (idx + tid) % (HSK / 4);
|
||||||
|
uint c = (idx + tid) / (HSK / 4);
|
||||||
|
if (idx + WorkGroupSize <= Bc * HSK / 4 || c < Bc) {
|
||||||
|
vector<FLOAT, 4> K_Tf = vector<FLOAT, 4>(0);
|
||||||
|
if (!KV_bounds_check || j * Bc + c < idcs.KV) {
|
||||||
|
K_Tf = kloader.load(j * Bc + c, d);
|
||||||
|
}
|
||||||
|
|
||||||
|
kvsh[c * kvsh_stride + d] = K_Tf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
}
|
}
|
||||||
|
|
||||||
// More d iterations means Q register caching becomes relevant
|
// More d iterations means Q register caching becomes relevant
|
||||||
|
|
@ -537,7 +533,20 @@ void main(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (SHMEM_STAGING != 0) {
|
if (SHMEM_STAGING != 0) {
|
||||||
stageKVToShmem(vloader, j, HSV, tid, idcs.KV);
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
[unroll] for (uint idx = 0; idx < Bc * HSV / 4; idx += WorkGroupSize) {
|
||||||
|
uint d = (idx + tid) % (HSV / 4);
|
||||||
|
uint c = (idx + tid) / (HSV / 4);
|
||||||
|
if (idx + WorkGroupSize <= Bc * HSV / 4 || c < Bc) {
|
||||||
|
vector<FLOAT, 4> V_Tf = vector<FLOAT, 4>(0);
|
||||||
|
if (!KV_bounds_check || j * Bc + c < idcs.KV) {
|
||||||
|
V_Tf = vloader.load(j * Bc + c, d);
|
||||||
|
}
|
||||||
|
|
||||||
|
kvsh[c * kvsh_stride + d] = V_Tf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
}
|
}
|
||||||
|
|
||||||
[unroll] for (uint c = 0; c < cols_per_thread; ++c) {
|
[unroll] for (uint c = 0; c < cols_per_thread; ++c) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue