From 82afe402cba340796a629bb21a494cb7cea238c4 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Fri, 20 Mar 2026 14:32:40 +0100 Subject: [PATCH] fix SHMEM_STAGING indexing --- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 5270c3f317..f9edc92051 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -407,9 +407,11 @@ void main() { #if defined(DATA_A_Q4_0) if (d_per_step < 8) { [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { - uint sub = (d_block + d) % 4; + uint pos = d_tid * (HSK_per_thread / 4) + d_block + d; + uint sub = pos % 4; + uint shift = ((pos % 8) >= 4) ? 4 : 0; + uint vui = kblocksh[buf_ib].qs[sub]; - uint shift = ((d_block + d) >= 4) ? 4 : 0; k_quants[d] = int32_t((vui >> shift) & 0x0F0F0F0F); } } else { @@ -422,7 +424,7 @@ void main() { } #elif defined(DATA_A_Q8_0) [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { - k_quants[d] = kblocksh[buf_ib].qs[d_block % 8 + d]; + k_quants[d] = kblocksh[buf_ib].qs[(d_tid * (HSK_per_thread / 4) + d_block) % 8 + d]; } #endif } else {