stage K loads through shmem

This commit is contained in:
Ruben Ortlam 2026-02-08 10:08:20 +01:00
parent 50a420e044
commit b7b67f8742
2 changed files with 42 additions and 10 deletions

View File

@ -3218,7 +3218,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
// Nvidia prefers shared memory use to load large tiles of K.
// Switch to loading from global memory when it would use too much shared memory.
// AMD prefers loading K directly from global memory
const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0;
const uint32_t k_load_shmem = 1; // device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0;
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags};
};

View File

@ -50,8 +50,12 @@ shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size];
const uint32_t masksh_stride = Br + 1;
shared FLOAT_TYPE masksh[Bc * masksh_stride];
const uint qfstride = HSK / 4 + 1;
shared FLOAT_TYPEV4 Qf[Br * qfstride];
const uint32_t qf_stride = HSK / 4 + 1;
shared FLOAT_TYPEV4 Qf[Br * qf_stride];
const uint32_t D = HSK > HSV ? HSK : HSV;
const uint32_t kvsh_stride = D / 4 + 1;
shared FLOAT_TYPEV4 kvsh[K_LOAD_SHMEM != 0 ? Bc * kvsh_stride : 1];
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
@ -75,7 +79,7 @@ void main() {
uint32_t r = (idx + tid) / (HSK / 4);
if (r < Br && d < HSK / 4 &&
i * Br + r < N) {
Qf[r * qfstride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
}
}
barrier();
@ -184,10 +188,33 @@ void main() {
}
}
if (K_LOAD_SHMEM != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
if (c < Bc && d < HSK / 4) {
FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = FLOAT_TYPEV4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
#else
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
#endif
}
kvsh[c * kvsh_stride + d] = K_Tf;
}
}
barrier();
}
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
FLOAT_TYPEV4 Q_cache[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Q_cache[r] = Qf[tile_row(r) * qfstride + d * D_split + d_tid];
Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid];
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
@ -195,14 +222,19 @@ void main() {
continue;
}
FLOAT_TYPEV4 K_Tf;
if (K_LOAD_SHMEM != 0) {
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
} else {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = FLOAT_TYPEV4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
#else
FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
#endif
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf));
}