diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 6370fa8fa0..16746e75d9 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3222,12 +3222,12 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t D_lsb = D ^ (D & (D-1)); uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); - // Nvidia prefers shared memory use to load large tiles of K. + // Nvidia prefers shared memory use to load large tiles of K/V. // Switch to loading from global memory when it would use too much shared memory. // AMD prefers loading K directly from global memory - const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256 ? 1 : 0; + const uint32_t shmem_staging = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256 ? 1 : 0; - return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags}; + return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, shmem_staging, flags}; }; #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 6d85212d44..7324d770a0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -55,7 +55,7 @@ shared FLOAT_TYPEV4 Qf[Br * qf_stride]; const uint32_t D = HSK > HSV ? HSK : HSV; const uint32_t kvsh_stride = D / 4 + 1; -shared FLOAT_TYPEV4 kvsh[K_LOAD_SHMEM != 0 ? Bc * kvsh_stride : 1]; +shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1]; void main() { #ifdef NEEDS_INIT_IQ_SHMEM @@ -188,12 +188,12 @@ void main() { } } - if (K_LOAD_SHMEM != 0) { + if (SHMEM_STAGING != 0) { barrier(); [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); uint32_t c = (idx + tid) / (HSK / 4); - if (c < Bc && d < HSK / 4) { + if (c < Bc) { FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0); if (!KV_bounds_check || j * Bc + c < KV) { #if BLOCK_SIZE > 1 @@ -224,7 +224,7 @@ void main() { } FLOAT_TYPEV4 K_Tf; - if (K_LOAD_SHMEM != 0) { + if (SHMEM_STAGING != 0) { K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; } else { #if BLOCK_SIZE > 1 @@ -294,7 +294,7 @@ void main() { } } - if (K_LOAD_SHMEM != 0) { + if (SHMEM_STAGING != 0) { barrier(); [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSV / 4); @@ -331,7 +331,7 @@ void main() { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { FLOAT_TYPEV4 Vf; - if (K_LOAD_SHMEM != 0) { + if (SHMEM_STAGING != 0) { Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; } else { #if BLOCK_SIZE > 1 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 4142c1e6ea..0a077f6876 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -9,7 +9,7 @@ layout (constant_id = 4) const uint32_t HSV = 32; layout (constant_id = 5) const uint32_t Clamp = 0; layout (constant_id = 6) const uint32_t D_split = 16; layout (constant_id = 7) const uint32_t SubGroupSize = 32; -layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0; +layout (constant_id = 8) const uint32_t SHMEM_STAGING = 0; layout (constant_id = 9) const uint32_t Flags = 0; const bool USE_MASK_OPT = (Flags & 1) != 0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 68bef90e48..4776c5e0e2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -54,10 +54,11 @@ shared f16vec4 Psh[Bc * psh_stride]; const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4; shared ACC_TYPEV4 sfsh[Bc * sfshstride]; -const uint32_t kshstride = (K_LOAD_SHMEM != 0 ? HSK_pad : MatBr) / 4 + 2; // in units of f16vec4 +const uint32_t D_pad = HSK_pad > HSV_pad ? HSK_pad : HSV_pad; +const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; // in units of f16vec4 const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups const uint vsh_stride = v_cols; -shared f16vec4 ksh[(kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)]; +shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)]; shared ACC_TYPE slope[Br]; @@ -78,15 +79,15 @@ void main() { #define tile_row(r) (row_tid * rows_per_thread + (r)) // Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK). - if ((HSK % 16) != 0) { + if ((HSK % 16) != 0 || (HSV % 16) != 0) { [[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) { if (i + tid < Br * qstride) { Qf[i + tid] = f16vec4(0); } } - [[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) { - if (i + tid < Bc * kshstride) { - ksh[i + tid] = f16vec4(0); + [[unroll]] for (uint i = 0; i < Bc * kvsh_stride; i += gl_WorkGroupSize.x) { + if (i + tid < Bc * kvsh_stride) { + kvsh[i + tid] = f16vec4(0); } } barrier(); @@ -231,13 +232,13 @@ void main() { } } - if (K_LOAD_SHMEM != 0) { - [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { - uint32_t d = (idx + tid) % (HSK / 4); - uint32_t c = (idx + tid) / (HSK / 4); - if (c < Bc && d < HSK / 4) { + if (SHMEM_STAGING != 0) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK_pad / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK_pad / 4); + uint32_t c = (idx + tid) / (HSK_pad / 4); + if (c < Bc) { f16vec4 K_Tf = f16vec4(0); - if (!KV_bounds_check || j * Bc + c < KV) { + if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) { #if BLOCK_SIZE > 1 uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; uint ib = coord / BLOCK_SIZE; @@ -248,7 +249,7 @@ void main() { #endif } - ksh[c * kshstride + d] = K_Tf; + kvsh[c * kvsh_stride + d] = K_Tf; } } barrier(); @@ -262,7 +263,7 @@ void main() { coopmat QMat; [[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) { - if (K_LOAD_SHMEM == 0) { + if (SHMEM_STAGING == 0) { #if BLOCK_SIZE == 1 if (KV_bounds_check || d * 16 + 16 > HSK) { #endif @@ -283,7 +284,7 @@ void main() { #endif } - ksh[row * kshstride + col_vec] = K_Tf; + kvsh[row * kvsh_stride + col_vec] = K_Tf; } } barrier(); @@ -295,8 +296,8 @@ void main() { if (KV_bounds_check || d * 16 + 16 > HSK) #endif { - uint coord = (gl_SubgroupID * MatBc) * kshstride; - coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + uint coord = (gl_SubgroupID * MatBc) * kvsh_stride; + coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); } #if BLOCK_SIZE == 1 else { @@ -305,8 +306,8 @@ void main() { } #endif } else { - uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; - coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4; + coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); } coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); @@ -397,6 +398,29 @@ void main() { } } + if (SHMEM_STAGING != 0) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV_pad / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSV_pad / 4); + uint32_t c = (idx + tid) / (HSV_pad / 4); + if (c < Bc) { + f16vec4 V_Tf = f16vec4(0); + if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + V_Tf = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); +#else + V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); +#endif + } + + kvsh[c * kvsh_stride + d] = V_Tf; + } + } + } + barrier(); + const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up // Each subgroup handles HSV/4 columns @@ -410,6 +434,7 @@ void main() { const uint v_total = v_rows * v_cols; const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x; + if (SHMEM_STAGING == 0) { #if BLOCK_SIZE == 1 // For f16, only preload if not aligned if (KV_bounds_check) { @@ -428,43 +453,52 @@ void main() { if (!KV_bounds_check || (v_row < KV && v_col < HSV)) { #if BLOCK_SIZE > 1 - ksh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); + kvsh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); #else - ksh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; + kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; #endif } else { - ksh[row * vsh_stride + col] = f16vec4(0.0f); + kvsh[row * vsh_stride + col] = f16vec4(0.0f); } } + #if BLOCK_SIZE == 1 } #endif - + } barrier(); - [[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) { - coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor); - -#if BLOCK_SIZE == 1 - if (!KV_bounds_check) { - // F16 values can be loaded directly from global memory - const uint v_tile_row = j * Bc + bc_chunk * MatBc; - const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4; - coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor); - } else -#endif - { - const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4); - coopMatLoad(QMat, ksh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor); - } - - SfMat = coopMatMulAdd(KMat, QMat, SfMat); - } - - // Store SfMat to sfsh and load into Of const uint osh_stride = row_split * MatBc / 4; const uint o_offset = gl_SubgroupID * MatBc / 4; - coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor); + + if (hsv_offset < HSV_pad) { + [[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) { + coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor); + + if (SHMEM_STAGING == 0) { +#if BLOCK_SIZE == 1 + if (!KV_bounds_check) { + // F16 values can be loaded directly from global memory + const uint v_tile_row = j * Bc + bc_chunk * MatBc; + const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4; + coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor); + } else +#endif + { + const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4); + coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor); + } + } else { + const uint v_tile_offset = bc_chunk * MatBc * kvsh_stride + (hsv_tile * row_split + gl_SubgroupID) * (MatBc / 4); + coopMatLoad(QMat, kvsh, v_tile_offset, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); + } + + SfMat = coopMatMulAdd(KMat, QMat, SfMat); + } + + // Store SfMat to sfsh and load into Of + coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor); + } barrier();