From f6b533d898ce84bae8d9fa8dfc6697ac087800bf Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Wed, 28 Jan 2026 18:52:45 +0100 Subject: [PATCH] Vulkan Flash Attention Coopmat1 Refactor (#19075) * vulkan: use coopmat for flash attention p*v matrix multiplication * fix P loading issue * fix barrier position * remove reduction that is no longer needed * move max thread reduction into loop * remove osh padding * add bounds checks and padding * remove unused code * fix shmem sizes, loop duration and accesses * don't overwrite Qf, add new shared psh buffer instead * add missing bounds checks * use subgroup reductions * optimize * move bounds check, reduce barriers * support other Bc values and other subgroup sizes * remove D_split * replace Of register array with shared memory Ofsh array * parallelize HSV across the rowgroups * go back to Of in registers, not shmem * vectorize sfsh * don't store entire K tile in shmem * fixes * load large k tiles to shmem on Nvidia * adapt shared memory host check function to shader changes * remove Bc 32 case * remove unused variable * fix missing mask reduction tmspsh barrier * fix mask bounds check * fix rowmax f16 under/overflow to inf * fix flash_attn_cm2 BLOCK_SIZE preprocessor directives --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 64 ++- .../vulkan-shaders/flash_attn_base.glsl | 6 + .../vulkan-shaders/flash_attn_cm1.comp | 438 +++++++++++------- .../vulkan-shaders/flash_attn_cm2.comp | 6 +- 4 files changed, 329 insertions(+), 185 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 514f290d09..3852867c29 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3162,17 +3162,31 @@ static void ggml_vk_load_shaders(vk_device& device) { // For scalar, use 128 (arbitrary) // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs. const uint32_t D = (hsk|hsv); - uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1) - ? scalar_flash_attention_workgroup_size - : ((small_rows && (D % 32) == 0) ? 256 : 128); auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache); + uint32_t wg_size; + switch (path) { + case FA_COOPMAT2: + wg_size = ((small_rows && (D % 32) == 0) ? 256 : 128); + break; + case FA_COOPMAT1: + wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc + break; + default: + wg_size = scalar_flash_attention_workgroup_size; + break; + } + // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. const uint32_t D_lsb = D ^ (D & (D-1)); uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); - return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split}; + // Nvidia prefers shared memory use to load large tiles of K + // AMD prefers loading K directly from global memory + const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA ? 1 : 0; + + return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem}; }; #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ @@ -3187,15 +3201,15 @@ static void ggml_vk_load_shaders(vk_device& device) { if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } \ } \ } \ @@ -8344,41 +8358,49 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; - VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); + VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); return supported; } -static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) { +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) { // Needs to be kept up to date on shader changes GGML_UNUSED(hsv); - const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = coopmat1_flash_attention_num_large_rows; - const uint32_t Bc = scalar_flash_attention_Bc; + const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, false, false); + const uint32_t Br = rows_cols[0]; + const uint32_t Bc = rows_cols[1]; + + const uint32_t MatBr = 16, MatBc = 16; + + const uint32_t row_split = Bc / MatBc; const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16); const uint32_t acctype = f32acc ? 4 : 2; const uint32_t f16vec4 = 8; - const uint32_t tmpsh = wg_size * sizeof(float); - const uint32_t tmpshv4 = wg_size * 4 * acctype; + const uint32_t tmpsh = (Bc / MatBc) * sizeof(float); const uint32_t qstride = hsk_pad / 4 + 2; const uint32_t Qf = Br * qstride * f16vec4; + const uint32_t psh_stride = Br / 4 + 2; + const uint32_t Psh = Bc * psh_stride * f16vec4; + const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br; const uint32_t sfsh = Bc * sfshstride * acctype; - const uint32_t kshstride = hsk_pad / 4 + 2; - const uint32_t ksh = Bc * kshstride * f16vec4; + const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA; + const uint32_t kshstride = (k_load_shmem ? hsk_pad : MatBr) / 4 + 2; + const uint32_t vsh_stride = MatBc / 4 * row_split; + const uint32_t ksh = ((kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)) * f16vec4; - const uint32_t slope = Br * sizeof(float); + const uint32_t slope = Br * acctype; - const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope; + const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; - VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported); + VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported); return supported; } @@ -8442,7 +8464,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc); - const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32); + const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32, k->type); if (!coopmat_shape_supported || !coopmat_shmem_supported) { path = FA_SCALAR; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 29b5c7c3a4..23a4d2c005 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -8,6 +8,8 @@ layout (constant_id = 3) const uint32_t HSK = 32; layout (constant_id = 4) const uint32_t HSV = 32; layout (constant_id = 5) const uint32_t Clamp = 0; layout (constant_id = 6) const uint32_t D_split = 16; +layout (constant_id = 7) const uint32_t SubGroupSize = 32; +layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0; // Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths const uint32_t HSK_pad = (HSK + 15) & ~15; @@ -74,6 +76,10 @@ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16 layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; #endif +#ifndef BLOCK_SIZE +#define BLOCK_SIZE 1 +#endif + #if defined(DATA_A_F32) #undef BLOCK_SIZE #define BLOCK_SIZE 4 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 0eb50fe58f..83d52d19d6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -7,6 +7,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable #extension GL_KHR_shader_subgroup_vote : enable #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_cooperative_matrix : enable @@ -14,12 +15,13 @@ #include "types.glsl" #include "flash_attn_base.glsl" -const uint32_t HSK_per_thread = HSK / D_split; -const uint32_t HSV_per_thread = HSV / D_split; +// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd +const uint32_t MatBr = 16; +const uint32_t MatBc = 16; -const uint32_t row_split = 4; +const uint32_t row_split = Bc / MatBc; const uint32_t rows_per_thread = Br / row_split; -const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split; +const uint32_t cols_per_iter = gl_WorkGroupSize.x / row_split; const uint32_t cols_per_thread = Bc / cols_per_iter; @@ -40,24 +42,24 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } -// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd -const uint32_t MatBr = 16; -const uint32_t MatBc = 16; - -shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; -shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x]; +shared float tmpsh[row_split]; const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4 shared f16vec4 Qf[Br * qstride]; +const uint psh_stride = Br / 4 + 2; +shared f16vec4 Psh[Bc * psh_stride]; + // Avoid padding for hsk==256 to make it fit in 48KB shmem. -const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br; -shared ACC_TYPE sfsh[Bc * sfshstride]; +const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4; +shared ACC_TYPEV4 sfsh[Bc * sfshstride]; -const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4 -shared f16vec4 ksh[Bc * kshstride]; +const uint32_t kshstride = (K_LOAD_SHMEM != 0 ? HSK_pad : MatBr) / 4 + 2; // in units of f16vec4 +const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups +const uint vsh_stride = v_cols; +shared f16vec4 ksh[(kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)]; -shared float slope[Br]; +shared ACC_TYPE slope[Br]; void main() { #ifdef NEEDS_INIT_IQ_SHMEM @@ -69,9 +71,9 @@ void main() { const uint32_t tid = gl_LocalInvocationIndex; const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; + const uint32_t d_per_thread = (HSV/4 + threads_per_rowgroup - 1) / threads_per_rowgroup; const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; - const uint32_t d_tid = gl_LocalInvocationIndex % D_split; - const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split; + const uint32_t col_tid = gl_LocalInvocationIndex % threads_per_rowgroup; #define tile_row(r) (row_tid * rows_per_thread + (r)) @@ -102,9 +104,9 @@ void main() { } barrier(); - ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4]; - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + ACC_TYPEV4 Of[rows_per_thread][d_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + [[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) { Of[r][d] = ACC_TYPEV4(0.0); } } @@ -125,13 +127,11 @@ void main() { uint r = tid; slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); } - barrier(); } else { if (tid < Br) { uint r = tid; - slope[r] = 1.0; + slope[r] = ACC_TYPE(1.0); } - barrier(); } #if BLOCK_SIZE > 1 @@ -149,19 +149,45 @@ void main() { [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { - float mask_cache[Bc * Br / WorkGroupSize]; + f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize]; if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; float max_mask = NEG_FLT_MAX_OVER_2; - [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { - uint32_t c = (idx + tid) % Bc; - uint32_t r = (idx + tid) / Bc; - if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { - if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { - float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) / (Br / 4); + uint32_t r = (idx + tid) % (Br / 4); + if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) { + if ((!KV_bounds_check || j * Bc + c < KV)) { + f16vec4 m; + if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]); + max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3])); + } else if (i * Br + r * 4 + 2 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)], + 0.0); + max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])); + } else if (i * Br + r * 4 + 1 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], + 0.0, + 0.0); + max_mask = max(max(max_mask, float(m[0])), float(m[1])); + } else if (i * Br + r * 4 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + 0.0, + 0.0, + 0.0); + max_mask = max(max_mask, float(m[0])); + } else { + m = f16vec4(0.0); + } mask_cache[idx / WorkGroupSize] = m; - max_mask = max(max_mask, m); } } } @@ -180,26 +206,28 @@ void main() { } } - [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { - uint32_t d = (idx + tid) % (HSK / 4); - uint32_t c = (idx + tid) / (HSK / 4); - if (c < Bc && d < HSK / 4) { - f16vec4 K_Tf = f16vec4(0); - if (!KV_bounds_check || j * Bc + c < KV) { + if (K_LOAD_SHMEM != 0) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t c = (idx + tid) / (HSK / 4); + if (c < Bc && d < HSK / 4) { + f16vec4 K_Tf = f16vec4(0); + if (!KV_bounds_check || j * Bc + c < KV) { #if BLOCK_SIZE > 1 - uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); #else - K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); + K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); #endif - } + } - ksh[c * kshstride + d] = K_Tf; + ksh[c * kshstride + d] = K_Tf; + } } + barrier(); } - barrier(); // K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 @@ -208,11 +236,55 @@ void main() { coopmat KMat; coopmat QMat; - for (uint32_t d = 0; d < HSK_pad / 16; ++d) { - coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); + [[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) { + if (K_LOAD_SHMEM == 0) { +#if BLOCK_SIZE == 1 + if (KV_bounds_check || d * 16 + 16 > HSK) { +#endif + barrier(); + [[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) { + uint32_t col_vec = (idx + tid) % (MatBr / 4); + uint32_t row = (idx + tid) / (MatBr / 4); + if (idx + tid < Bc * MatBr / 4) { + f16vec4 K_Tf = f16vec4(0); + if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); +#else + K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); +#endif + } - uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; - coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + ksh[row * kshstride + col_vec] = K_Tf; + } + } + barrier(); +#if BLOCK_SIZE == 1 + } +#endif + +#if BLOCK_SIZE == 1 + if (KV_bounds_check || d * 16 + 16 > HSK) +#endif + { + uint coord = (gl_SubgroupID * MatBc) * kshstride; + coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + } +#if BLOCK_SIZE == 1 + else { + const uint coord = k_offset / 4 + (j * Bc + gl_SubgroupID * MatBc) * k_stride / 4 + d * 16 / 4; + coopMatLoad(KMat, data_kv4, coord, k_stride / 4, gl_CooperativeMatrixLayoutRowMajor); + } +#endif + } else { + uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; + coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + } + + coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); SfMat = coopMatMulAdd(KMat, QMat, SfMat); } @@ -222,26 +294,26 @@ void main() { barrier(); if (p.logit_softcap != 0.0f) { - [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { - uint32_t c = (idx + tid) / Br; - uint32_t r = (idx + tid) % Br; - if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { - sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r])); + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) / (Br / 4); + uint32_t r = (idx + tid) % (Br / 4); + if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) { + sfsh[c * sfshstride + r] = ACC_TYPEV4(p.logit_softcap * tanh(sfsh[c * sfshstride + r])); } } barrier(); } if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { - bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - - [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { - uint32_t c = (idx + tid) % Bc; - uint32_t r = (idx + tid) / Bc; - if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { - if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { - float f = mask_cache[idx / WorkGroupSize]; - sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * f); + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) / (Br / 4); + uint32_t r = (idx + tid) % (Br / 4); + if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) { + if (!KV_bounds_check || j * Bc + c < KV) { + // Mask nem1 bounds check is handled when loading masks + ACC_TYPEV4 masks = ACC_TYPEV4(mask_cache[idx / WorkGroupSize]); + ACC_TYPEV4 slopes = ACC_TYPEV4(slope[r * 4], slope[r * 4 + 1], slope[r * 4 + 2], slope[r * 4 + 3]); + sfsh[c * sfshstride + r] += slopes * masks; } } } @@ -250,51 +322,145 @@ void main() { float eMf[rows_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint r_vec = tile_row(r) / 4; + const uint r_comp = tile_row(r) % 4; + float rowmaxf = NEG_FLT_MAX_OVER_2; [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { continue; } - rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride])); + rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp])); } float Moldf = Mf[r]; + // Compute max across the row + rowmaxf = subgroupMax(rowmaxf); + // M = max(rowmax, Mold) // P = e^(S - M) // eM = e^(Mold - M) Mf[r] = max(rowmaxf, Moldf); eMf[r] = exp(Moldf - Mf[r]); - } - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d]; - } - } - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Lf[r] = eMf[r]*Lf[r]; } - [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { - continue; - } - float Pf[rows_per_thread]; + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]); - Lf[r] += Pf[r]; + Of[r][d_local] = ACC_TYPE(eMf[r]) * Of[r][d_local]; } - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); -#else - vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); + } + + // Calculate and store Pf in Psh + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + const uint col = c * cols_per_iter + col_tid; + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; r += 4) { + const uint row = tile_row(r); + if (KV_bounds_check && j * Bc + col >= KV) { + Psh[col * psh_stride + row / 4] = f16vec4(0.0f); + } else { + const vec4 mfvec = vec4(Mf[r], Mf[r + 1], Mf[r + 2], Mf[r + 3]); + const f16vec4 Pf = f16vec4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec)); + [[unroll]] for (uint32_t vec_idx = 0; vec_idx < 4; ++vec_idx) { + Lf[r + vec_idx] += Pf[vec_idx]; + } + Psh[col * psh_stride + row / 4] = Pf; + } + } + } + + const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up + + // Each subgroup handles HSV/4 columns + [[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) { + const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16; + + SfMat = coopmat(0); + + // Preload V tiles for [Bc, 16 * num subgroups] + const uint v_rows = Bc; + const uint v_total = v_rows * v_cols; + const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x; + +#if BLOCK_SIZE == 1 + // For f16, only preload if not aligned + if (KV_bounds_check) { #endif - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf); + [[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) { + const uint idx = i * gl_WorkGroupSize.x + tid; + const uint row = idx / v_cols; + const uint col = idx % v_cols; + + const uint v_row = j * Bc + row; + const uint v_col = hsv_tile * MatBc * row_split + col * 4; + + const uint coord = v_row * v_stride * BLOCK_SIZE + v_col; + const uint ib = coord / BLOCK_SIZE; + const uint iqs = coord % BLOCK_SIZE; + + if (!KV_bounds_check || (v_row < KV && v_col < HSV)) { +#if BLOCK_SIZE > 1 + ksh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); +#else + ksh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; +#endif + } else { + ksh[row * vsh_stride + col] = f16vec4(0.0f); + } + } +#if BLOCK_SIZE == 1 + } +#endif + + barrier(); + + [[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) { + coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor); + +#if BLOCK_SIZE == 1 + if (!KV_bounds_check) { + // F16 values can be loaded directly from global memory + const uint v_tile_row = j * Bc + bc_chunk * MatBc; + const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4; + coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor); + } else +#endif + { + const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4); + coopMatLoad(QMat, ksh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor); + } + + SfMat = coopMatMulAdd(KMat, QMat, SfMat); + } + + // Store SfMat to sfsh and load into Of + const uint osh_stride = row_split * MatBc / 4; + const uint o_offset = gl_SubgroupID * MatBc / 4; + coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor); + + barrier(); + + const uint hsv_per_tile = row_split * MatBc; + const uint hsv_base = hsv_tile * hsv_per_tile; + const uint d_values_per_tile = hsv_per_tile / 4; + + const uint d_start = hsv_tile * d_values_per_tile; + const uint d_end = min(d_start + d_values_per_tile, HSV / 4); + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + + [[unroll]] for (uint32_t d_local = 0; d_local < d_per_thread; ++d_local) { + const uint d = d_local * threads_per_rowgroup + col_tid; + const uint hsv_col = 4 * d; + + if (hsv_col >= hsv_base && hsv_col < hsv_base + hsv_per_tile && hsv_col < HSV) { + const uint local_hsv = (hsv_col - hsv_base) / 4; + Of[r][d_local] += ACC_TYPEV4(sfsh[row * osh_stride + local_hsv]); + } } } } @@ -302,69 +468,8 @@ void main() { barrier(); } - // prevent race on tmpsh - barrier(); - - // reduce across threads - - float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - FLOAT_TYPE M = Mf[r]; - tmpsh[tid] = M; - // Compute max across the row - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { - M = max(M, tmpsh[tid ^ s]); - barrier(); - tmpsh[tid] = M; - barrier(); - } - rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; - barrier(); - } - - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Moldf[r] = Mf[r]; - - // M = max(rowmax, Mold) - // eM = e^(Mold - M) - Mf[r] = max(rowmaxf[r], Moldf[r]); - eMf[r] = exp(Moldf[r] - Mf[r]); - - Lf[r] = eMf[r]*Lf[r]; - } - - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - FLOAT_TYPE L = Lf[r]; - tmpsh[tid] = L; - // Compute sum across the row - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { - L += tmpsh[tid ^ s]; - barrier(); - tmpsh[tid] = L; - barrier(); - } - Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; - barrier(); - } - - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - - Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d]; - tmpshv4[tid] = Of[r][d]; - - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { - Of[r][d] += tmpshv4[tid ^ s]; - barrier(); - tmpshv4[tid] = Of[r][d]; - barrier(); - } - Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup]; - barrier(); - } + Lf[r] = subgroupAdd(Lf[r]); } // If there is split_k, then the split_k resolve shader does the final @@ -375,9 +480,12 @@ void main() { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d = d0 + col_tid; + if (d >= HSV/4) break; + const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N); } } } @@ -404,8 +512,9 @@ void main() { if (sink > Mf[r]) { ms = exp(Mf[r] - sink); - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - Of[r][d] *= ACC_TYPE(ms); + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d_local = d0 / threads_per_rowgroup; + Of[r][d_local] *= ACC_TYPE(ms); } } else { vs = exp(sink - Mf[r]); @@ -420,11 +529,12 @@ void main() { Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]); } - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] *= ACC_TYPE(Lfrcp[r]); + Of[r][d_local] *= ACC_TYPE(Lfrcp[r]); #if defined(ACC_TYPE_MAX) - Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX); + Of[r][d_local] = clamp(Of[r][d_local], -ACC_TYPE_MAX, ACC_TYPE_MAX); #endif } } @@ -434,9 +544,12 @@ void main() { if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d = d0 + col_tid; + if (d >= HSV / 4) break; + const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N); } } } @@ -444,9 +557,12 @@ void main() { } else { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (i * Br + tile_row(r) < N) { - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d = d0 + col_tid; + if (d >= HSV / 4) break; + const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4 * d + comp] = D_TYPE(Of[r][d_local][comp]); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index d49a8da65f..54f1b0b622 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -55,7 +55,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele return max(elem0, elem1); } -#if defined(BLOCK_SIZE) +#if BLOCK_SIZE > 1 #define DECODEFUNC , DEQUANTFUNC #else #define DECODEFUNC @@ -85,7 +85,7 @@ void main() { tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); -#if defined(BLOCK_SIZE) +#if BLOCK_SIZE > 1 tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE); tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); #endif @@ -98,7 +98,7 @@ void main() { if (Clamp != gl_CooperativeMatrixClampModeConstantNV) { q_stride &= ~7; -#if !defined(BLOCK_SIZE) +#if BLOCK_SIZE == 1 k_stride &= ~7; v_stride &= ~7; #endif