also stage V through shmem when this is done for K

This commit is contained in:
Ruben Ortlam 2026-02-08 12:42:53 +01:00
parent 8fbd3575e0
commit b626e3296d
4 changed files with 88 additions and 54 deletions

View File

@ -3222,12 +3222,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t D_lsb = D ^ (D & (D-1));
uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
// Nvidia prefers shared memory use to load large tiles of K.
// Nvidia prefers shared memory use to load large tiles of K/V.
// Switch to loading from global memory when it would use too much shared memory.
// AMD prefers loading K directly from global memory
const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256 ? 1 : 0;
const uint32_t shmem_staging = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256 ? 1 : 0;
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags};
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, shmem_staging, flags};
};
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \

View File

@ -55,7 +55,7 @@ shared FLOAT_TYPEV4 Qf[Br * qf_stride];
const uint32_t D = HSK > HSV ? HSK : HSV;
const uint32_t kvsh_stride = D / 4 + 1;
shared FLOAT_TYPEV4 kvsh[K_LOAD_SHMEM != 0 ? Bc * kvsh_stride : 1];
shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1];
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
@ -188,12 +188,12 @@ void main() {
}
}
if (K_LOAD_SHMEM != 0) {
if (SHMEM_STAGING != 0) {
barrier();
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
if (c < Bc && d < HSK / 4) {
if (c < Bc) {
FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
#if BLOCK_SIZE > 1
@ -224,7 +224,7 @@ void main() {
}
FLOAT_TYPEV4 K_Tf;
if (K_LOAD_SHMEM != 0) {
if (SHMEM_STAGING != 0) {
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
} else {
#if BLOCK_SIZE > 1
@ -294,7 +294,7 @@ void main() {
}
}
if (K_LOAD_SHMEM != 0) {
if (SHMEM_STAGING != 0) {
barrier();
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSV / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSV / 4);
@ -331,7 +331,7 @@ void main() {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
FLOAT_TYPEV4 Vf;
if (K_LOAD_SHMEM != 0) {
if (SHMEM_STAGING != 0) {
Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
} else {
#if BLOCK_SIZE > 1

View File

@ -9,7 +9,7 @@ layout (constant_id = 4) const uint32_t HSV = 32;
layout (constant_id = 5) const uint32_t Clamp = 0;
layout (constant_id = 6) const uint32_t D_split = 16;
layout (constant_id = 7) const uint32_t SubGroupSize = 32;
layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0;
layout (constant_id = 8) const uint32_t SHMEM_STAGING = 0;
layout (constant_id = 9) const uint32_t Flags = 0;
const bool USE_MASK_OPT = (Flags & 1) != 0;

View File

@ -54,10 +54,11 @@ shared f16vec4 Psh[Bc * psh_stride];
const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4;
shared ACC_TYPEV4 sfsh[Bc * sfshstride];
const uint32_t kshstride = (K_LOAD_SHMEM != 0 ? HSK_pad : MatBr) / 4 + 2; // in units of f16vec4
const uint32_t D_pad = HSK_pad > HSV_pad ? HSK_pad : HSV_pad;
const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; // in units of f16vec4
const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups
const uint vsh_stride = v_cols;
shared f16vec4 ksh[(kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)];
shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)];
shared ACC_TYPE slope[Br];
@ -78,15 +79,15 @@ void main() {
#define tile_row(r) (row_tid * rows_per_thread + (r))
// Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK).
if ((HSK % 16) != 0) {
if ((HSK % 16) != 0 || (HSV % 16) != 0) {
[[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) {
if (i + tid < Br * qstride) {
Qf[i + tid] = f16vec4(0);
}
}
[[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) {
if (i + tid < Bc * kshstride) {
ksh[i + tid] = f16vec4(0);
[[unroll]] for (uint i = 0; i < Bc * kvsh_stride; i += gl_WorkGroupSize.x) {
if (i + tid < Bc * kvsh_stride) {
kvsh[i + tid] = f16vec4(0);
}
}
barrier();
@ -231,13 +232,13 @@ void main() {
}
}
if (K_LOAD_SHMEM != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
if (c < Bc && d < HSK / 4) {
if (SHMEM_STAGING != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK_pad / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK_pad / 4);
uint32_t c = (idx + tid) / (HSK_pad / 4);
if (c < Bc) {
f16vec4 K_Tf = f16vec4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
@ -248,7 +249,7 @@ void main() {
#endif
}
ksh[c * kshstride + d] = K_Tf;
kvsh[c * kvsh_stride + d] = K_Tf;
}
}
barrier();
@ -262,7 +263,7 @@ void main() {
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
[[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
if (K_LOAD_SHMEM == 0) {
if (SHMEM_STAGING == 0) {
#if BLOCK_SIZE == 1
if (KV_bounds_check || d * 16 + 16 > HSK) {
#endif
@ -283,7 +284,7 @@ void main() {
#endif
}
ksh[row * kshstride + col_vec] = K_Tf;
kvsh[row * kvsh_stride + col_vec] = K_Tf;
}
}
barrier();
@ -295,8 +296,8 @@ void main() {
if (KV_bounds_check || d * 16 + 16 > HSK)
#endif
{
uint coord = (gl_SubgroupID * MatBc) * kshstride;
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
uint coord = (gl_SubgroupID * MatBc) * kvsh_stride;
coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
#if BLOCK_SIZE == 1
else {
@ -305,8 +306,8 @@ void main() {
}
#endif
} else {
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4;
coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
@ -397,6 +398,29 @@ void main() {
}
}
if (SHMEM_STAGING != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSV_pad / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSV_pad / 4);
uint32_t c = (idx + tid) / (HSV_pad / 4);
if (c < Bc) {
f16vec4 V_Tf = f16vec4(0);
if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
V_Tf = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V));
#else
V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
#endif
}
kvsh[c * kvsh_stride + d] = V_Tf;
}
}
}
barrier();
const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up
// Each subgroup handles HSV/4 columns
@ -410,6 +434,7 @@ void main() {
const uint v_total = v_rows * v_cols;
const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x;
if (SHMEM_STAGING == 0) {
#if BLOCK_SIZE == 1
// For f16, only preload if not aligned
if (KV_bounds_check) {
@ -428,43 +453,52 @@ void main() {
if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
#if BLOCK_SIZE > 1
ksh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V));
kvsh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V));
#else
ksh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
#endif
} else {
ksh[row * vsh_stride + col] = f16vec4(0.0f);
kvsh[row * vsh_stride + col] = f16vec4(0.0f);
}
}
#if BLOCK_SIZE == 1
}
#endif
}
barrier();
[[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) {
coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
#if BLOCK_SIZE == 1
if (!KV_bounds_check) {
// F16 values can be loaded directly from global memory
const uint v_tile_row = j * Bc + bc_chunk * MatBc;
const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
} else
#endif
{
const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
coopMatLoad(QMat, ksh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
SfMat = coopMatMulAdd(KMat, QMat, SfMat);
}
// Store SfMat to sfsh and load into Of
const uint osh_stride = row_split * MatBc / 4;
const uint o_offset = gl_SubgroupID * MatBc / 4;
coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor);
if (hsv_offset < HSV_pad) {
[[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) {
coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
if (SHMEM_STAGING == 0) {
#if BLOCK_SIZE == 1
if (!KV_bounds_check) {
// F16 values can be loaded directly from global memory
const uint v_tile_row = j * Bc + bc_chunk * MatBc;
const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
} else
#endif
{
const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
} else {
const uint v_tile_offset = bc_chunk * MatBc * kvsh_stride + (hsv_tile * row_split + gl_SubgroupID) * (MatBc / 4);
coopMatLoad(QMat, kvsh, v_tile_offset, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
SfMat = coopMatMulAdd(KMat, QMat, SfMat);
}
// Store SfMat to sfsh and load into Of
coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
barrier();