diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 5270c3f317..f9edc92051 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -407,9 +407,11 @@ void main() { #if defined(DATA_A_Q4_0) if (d_per_step < 8) { [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { - uint sub = (d_block + d) % 4; + uint pos = d_tid * (HSK_per_thread / 4) + d_block + d; + uint sub = pos % 4; + uint shift = ((pos % 8) >= 4) ? 4 : 0; + uint vui = kblocksh[buf_ib].qs[sub]; - uint shift = ((d_block + d) >= 4) ? 4 : 0; k_quants[d] = int32_t((vui >> shift) & 0x0F0F0F0F); } } else { @@ -422,7 +424,7 @@ void main() { } #elif defined(DATA_A_Q8_0) [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { - k_quants[d] = kblocksh[buf_ib].qs[d_block % 8 + d]; + k_quants[d] = kblocksh[buf_ib].qs[(d_tid * (HSK_per_thread / 4) + d_block) % 8 + d]; } #endif } else {