diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index da17405b6e..5270c3f317 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -67,7 +67,7 @@ shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1]; #ifdef MMQ -shared block_a_cache kblocksh[Bc * qf_stride]; +shared block_a_cache kblocksh[SHMEM_STAGING != 0 ? Bc * qf_stride : 1]; #endif shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1]; @@ -147,12 +147,19 @@ void main() { Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals)); +#ifdef DATA_A_Q8_0 + if (buf_iqs == 0) { + // sum is only needed for q4_0 + Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0); + } +#else // DATA_A_Q4_0 const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w; const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8); if (buf_iqs == 0) { Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd); } +#endif #endif } barrier(); @@ -431,8 +438,8 @@ void main() { } } else { [[unroll]] for (uint32_t d = 0; d < 4; d++) { - uint vui = (uint(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1]) << 16) | - uint(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0]); + uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0], + k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1])); k_quants[d ] = int32_t( vui & 0x0F0F0F0F); k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);