From e1b40fa53af32d1796a38153b9f4ebac44d9d59d Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 5 Mar 2026 11:35:45 +0100 Subject: [PATCH] fix slang issues --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 6 +++ .../vulkan-shaders/flash_attn.slang | 39 +++++++++---------- .../vulkan-shaders/flash_attn_loader.slang | 6 +-- 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3c81805b84..650bce3f3c 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -8840,6 +8840,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); +#ifdef GGML_VULKAN_ENABLE_SLANG + if (tuning_params.path != FA_SCALAR) { +#endif // For F32, the shader treats it as a block of size 4 (for vec4 loads) if (k->type == GGML_TYPE_F32) { k_stride /= 4; @@ -8847,6 +8850,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx if (v->type == GGML_TYPE_F32) { v_stride /= 4; } +#ifdef GGML_VULKAN_ENABLE_SLANG + } +#endif const uint32_t alignment = tuning_params.block_cols; bool aligned = (KV % alignment) == 0 && diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.slang b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.slang index b62204263c..cb3590fdf0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.slang +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.slang @@ -1,18 +1,18 @@ import types; import flash_attn_loader; -[vk::specialization_constant] const uint WorkGroupSize = 128; -[vk::specialization_constant] const uint Br = 1; -[vk::specialization_constant] const uint Bc = 32; -[vk::specialization_constant] const uint HSK = 32; -[vk::specialization_constant] const uint HSV = 32; -[vk::specialization_constant] const uint Clamp = 0; -[vk::specialization_constant] const uint D_split = 16; -[vk::specialization_constant] const uint row_split = 1; -[vk::specialization_constant] const uint SubGroupSize = 32; -[vk::specialization_constant] const uint SHMEM_STAGING = 0; -[vk::specialization_constant] const uint Flags = 0; -[vk::specialization_constant] const uint LIMIT_OCCUPANCY_SHMEM = 0; +[vk::constant_id( 0)] const uint WorkGroupSize = 128; +[vk::constant_id( 1)] const uint Br = 1; +[vk::constant_id( 2)] const uint Bc = 32; +[vk::constant_id( 3)] const uint HSK = 32; +[vk::constant_id( 4)] const uint HSV = 32; +[vk::constant_id( 5)] const uint Clamp = 0; +[vk::constant_id( 6)] const uint D_split = 16; +[vk::constant_id( 7)] const uint row_split = 1; +[vk::constant_id( 8)] const uint SubGroupSize = 32; +[vk::constant_id( 9)] const uint SHMEM_STAGING = 0; +[vk::constant_id(10)] const uint Flags = 0; +[vk::constant_id(11)] const uint LIMIT_OCCUPANCY_SHMEM = 0; static const bool USE_MASK_OPT = (Flags & 1) != 0; static const bool MASK_ENABLE = (Flags & 2) != 0; @@ -131,7 +131,7 @@ T perElemOpStoreCol0(const uint r, const uint32_t { if (r < N && c == 0) { uint offset = iq2 + r; - data_o[o_offset + offset] = (elem as D_TYPE).value; + data_o[o_offset + offset] = floatCast(elem); } return elem; } @@ -240,10 +240,9 @@ typealias VLoader = ScalarKVLoader; // Store the output when doing grouped query attention. // Rows index by Q's dimension 2, and the first N rows are valid. -void gqaStore(const in uint32_t r, const in uint32_t c, const in vector elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) -{ +void gqaStore(const in uint32_t r, const in uint32_t c, const in vector elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) { uint32_t offset = (iq2 + r) * HSV / 4 + c; - data_ov4[o_offset + offset] = (elems as vector).value; + data_ov4[o_offset + offset] = vector(elems); } [shader("compute")] @@ -255,7 +254,7 @@ void main( const Indices idcs = init_indices(wgid); const uint subgroup_invocation_id = WaveGetLaneIndex(); - const uint subgroup_id = tid / SubGroupSize; + const uint subgroup_id = tid / WaveGetLaneCount(); const uint threads_per_rowgroup = WorkGroupSize / row_split; const uint row_tid = tid / threads_per_rowgroup; @@ -386,7 +385,7 @@ void main( tmpsh[subgroup_id] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; } GroupMemoryBarrierWithGroupSync(); - [unroll] for (uint s = 0; s < num_subgroups; ++s) { + [unroll] for (uint s = 0; s < WaveGetNumWaves(); ++s) { max_mask = max(max_mask, tmpsh[s]); } if (max_mask <= NEG_FLT_MAX_OVER_2) { @@ -752,8 +751,8 @@ void main( [unroll] for (uint d = 0; d < HSV_per_thread / 4; ++d) { [unroll] for (uint r = 0; r < rows_per_thread; ++r) { Of[r][d] *= FLOAT(Lfrcp[r]); -#if defined(FLOAT_MAX) - Of[r][d] = clamp(Of[r][d], -FLOAT_MAX, FLOAT_MAX); +#if defined(FLOAT_TYPE_MAX) + Of[r][d] = clamp(Of[r][d], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX); #endif } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_loader.slang b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_loader.slang index c57a910c07..8b89b3174b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_loader.slang +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_loader.slang @@ -21,7 +21,7 @@ public struct ScalarKVLoader : IKVLoader { } public vector load(uint element_idx, uint head_dim4_idx) { - return (buf[offset + element_idx * stride4 + head_dim4_idx] as vector).value; + return vector(buf[offset + element_idx * stride4 + head_dim4_idx]); } } @@ -43,8 +43,8 @@ public struct Q8_0KVLoader : IKVLoader { const uint ib = coord / QUANT_K_Q8_0; const uint iqs = (coord % QUANT_K_Q8_0); - const vector v0 = (unpack_s8s32(int32_t(buf[offset + ib].qs[iqs / 2])).xy as vector).value; // vec4 used due to #12147 - const vector v1 = (unpack_s8s32(int32_t(buf[offset + ib].qs[iqs / 2 + 1])).xy as vector).value; + const vector v0 = vector(unpack_s8s32(int32_t(buf[offset + ib].qs[iqs / 2])).xy); // vec4 used due to #12147 + const vector v1 = vector(unpack_s8s32(int32_t(buf[offset + ib].qs[iqs / 2 + 1])).xy); return FLOAT(buf[offset + ib].d) * vector(v0.x, v0.y, v1.x, v1.y); }