diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 977aff62d8..1bee3e187c 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2858,11 +2858,10 @@ struct vk_fa_tuning_params { } }; -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc); +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type); static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc); static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { - GGML_UNUSED(kv_type); vk_fa_tuning_params result{}; result.path = FA_SCALAR; @@ -2914,7 +2913,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0; - if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) { + if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, kv_type)) { result.block_rows /= 2; } @@ -3445,21 +3444,47 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->fp16) { CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, ) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, ) + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product && device->subgroup_clustered) { + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _int8) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _int8) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _int8) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _int8) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _int8) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _int8) + } else +#endif + { + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, ) + } } else { CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32) + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product && device->subgroup_clustered) { + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32_int8) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32_int8) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32_int8) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32_int8) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32_int8) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32_int8) + } else +#endif + { + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32) + } } #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->coopmat1_fa_support) { @@ -8780,7 +8805,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) { GGML_UNUSED(f32acc); // Needs to be kept up to date on shader changes const uint32_t wg_size = params.workgroup_size; @@ -8789,21 +8814,51 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + const bool mmq = device->integer_dot_product && device->subgroup_clustered && + (kv_type == GGML_TYPE_Q4_0 || kv_type == GGML_TYPE_Q4_1 || + kv_type == GGML_TYPE_Q5_0 || kv_type == GGML_TYPE_Q5_1 || + kv_type == GGML_TYPE_Q8_0 || kv_type == GGML_TYPE_IQ4_NL); + // tmpsh is overestimated slightly const uint32_t tmpsh = wg_size * sizeof(float); const uint32_t tmpshv4 = wg_size * 4 * float_type_size; const uint32_t masksh = Bc * (Br + 1) * float_type_size; - const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size; + uint32_t Qf, kvsh, kblocksh_size; + if (mmq) { + // block_b_cache: int32_t qs[8] + FLOAT_TYPEV2 ds + const uint32_t block_b_size = 8 * sizeof(int32_t) + 2 * float_type_size; + Qf = Br * (hsk / 32) * block_b_size; - const uint32_t D = std::max(hsk, hsv); - const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size; + // kvsh uses D = HSV (K goes through kblocksh instead) + kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size; - const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh; + // block_a_cache size depends on quant type + uint32_t block_a_size; + switch (kv_type) { + case GGML_TYPE_Q4_0: block_a_size = 4 * sizeof(uint32_t) + float_type_size; break; + case GGML_TYPE_Q4_1: block_a_size = 4 * sizeof(uint32_t) + 2 * float_type_size; break; + case GGML_TYPE_Q5_0: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + float_type_size; break; + case GGML_TYPE_Q5_1: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + 2 * float_type_size; break; + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: block_a_size = 8 * sizeof(int32_t) + float_type_size; break; + default: block_a_size = 0; break; + } + kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size; + } else { + Qf = Br * (hsk / 4 + 1) * 4 * float_type_size; + + const uint32_t D = std::max(hsk, hsv); + kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size; + + kblocksh_size = 0; + } + + const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh + kblocksh_size; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; - VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); + VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", mmq=" << mmq << ", total_size=" << total_size << ", supported=" << supported); return supported; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 11b7dce857..6e6bdabc92 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -10,6 +10,13 @@ #extension GL_EXT_shader_subgroup_extended_types_float16 : require #endif +#ifdef MMQ +#extension GL_EXT_integer_dot_product : require +#extension GL_KHR_shader_subgroup_clustered : require + +#include "mul_mmq_shmem_types.glsl" +#endif + #extension GL_KHR_shader_subgroup_shuffle : enable #extension GL_KHR_shader_subgroup_vote : enable @@ -41,15 +48,34 @@ shared FLOAT_TYPEV4 tmpshv4[tmpsh_size]; const uint32_t masksh_stride = Br + 1; shared FLOAT_TYPE masksh[Bc * masksh_stride]; +#ifndef MMQ const uint32_t qf_stride = HSK / 4 + 1; shared FLOAT_TYPEV4 Qf[Br * qf_stride]; +#else +const uint32_t qf_stride = HSK / 32; +shared block_b_cache Qf[Br * qf_stride]; +#endif + +#ifndef MMQ const uint32_t D = HSK > HSV ? HSK : HSV; +#else +const uint32_t D = HSV; +#endif const uint32_t kvsh_stride = D / 4 + 1; shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1]; +#ifdef MMQ + +shared block_a_cache kblocksh[SHMEM_STAGING != 0 ? Bc * qf_stride : 1]; +#endif + shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1]; +#ifdef MMQ +#include "flash_attn_mmq_funcs.glsl" +#endif + void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); @@ -82,10 +108,39 @@ void main() { [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); uint32_t r = (idx + tid) / (HSK / 4); - if (r < Br && d < HSK / 4 && - i * Br + r < N) { + const bool is_in_bounds = r < Br && d < HSK / 4 && i * Br + r < N; +#ifndef MMQ + if (is_in_bounds) { Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); } +#else + const uint buf_ib = r * qf_stride + d / 8; + const uint buf_iqs = d % 8; + + FLOAT_TYPEV4 vals = is_in_bounds ? FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale) : FLOAT_TYPEV4(0.0f); + const FLOAT_TYPEV4 abs_vals = abs(vals); + + const FLOAT_TYPE thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w)); + const FLOAT_TYPE amax = subgroupClusteredMax(thread_max, 8); + const FLOAT_TYPE qd = amax / FLOAT_TYPE(127.0); + const FLOAT_TYPE qd_inv = qd != FLOAT_TYPE(0.0) ? FLOAT_TYPE(1.0) / qd : FLOAT_TYPE(0.0); + vals = round(vals * qd_inv); + + Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals)); + +#if defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL) + if (buf_iqs == 0) { + Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0); + } +#else // Q4_0, Q4_1, Q5_0, Q5_1 + const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w; + const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8); + + if (buf_iqs == 0) { + Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd); + } +#endif +#endif } barrier(); @@ -195,6 +250,7 @@ void main() { if (SHMEM_STAGING != 0) { barrier(); +#ifndef MMQ [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); uint32_t c = (idx + tid) / (HSK / 4); @@ -214,9 +270,29 @@ void main() { kvsh[c * kvsh_stride + d] = K_Tf; } } +#else // MMQ + const uint ints_per_block = 8 / QUANT_R_MMQ; + const uint quant_iters = Bc * HSK / 32 * ints_per_block; + [[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) { + const uint32_t iqs = (idx + tid) % ints_per_block; + const uint32_t ib = (idx + tid) / ints_per_block; + const uint32_t c = ib / (HSK / 32); + const uint32_t block = ib % (HSK / 32); + if (idx + gl_WorkGroupSize.x <= quant_iters || c < Bc) { + const uint buf_ib = c * qf_stride + block; + if (!KV_bounds_check || j * Bc + c < KV) { + const uint global_ib = (j * Bc + c) * k_stride + block; + k_block_to_shmem(buf_ib, global_ib, iqs, k_offset); + } else { + k_block_to_shmem_zero(buf_ib, iqs); + } + } + } +#endif // MMQ barrier(); } +#ifndef MMQ // More d iterations means Q register caching becomes relevant // Few iterations means the additional registers needed are worse than the speed-up from caching if (HSK_per_thread / 4 > 4) { @@ -275,6 +351,110 @@ void main() { } } } +#else // MMQ + const uint hsk4 = HSK_per_thread / 4; + const uint d_per_step = (hsk4 % 8 == 0) ? 8 : + (hsk4 % 4 == 0) ? 4 : + (hsk4 % 2 == 0) ? 2 : 1; + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + + [[unroll]] for (uint32_t d_block = 0; d_block < HSK_per_thread / 4; d_block += d_per_step) { + int32_t k_quants[d_per_step]; + ACC_TYPEV2 k_dm; + + if (SHMEM_STAGING != 0) { + const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8; + const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx; +#if QUANT_AUXF == 1 + k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm, 0.0); +#else + k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm); +#endif + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) + if (d_per_step == 8) { + [[unroll]] for (uint32_t d = 0; d < 4; d++) { + uint vui = kblocksh[buf_ib].qs[d]; + k_quants[d ] = int32_t( vui & 0x0F0F0F0F); + k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F); +#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) + uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF; + uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF; + k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); + k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); +#endif + } + } else +#endif + { + [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { + k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d); + } + } + } else { + const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block); + const uint ib = coord / BLOCK_SIZE; + const uint iqs = (coord % BLOCK_SIZE); + +#if QUANT_AUXF == 1 + k_dm = ACC_TYPEV2(get_k_d(ib, k_offset), 0.0); +#else + k_dm = ACC_TYPEV2(get_k_dm(ib, k_offset)); +#endif +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) + if (d_per_step == 8) { +#if defined(DATA_A_Q5_0) + uint qh = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qh[0], + k_packed.k_data_packed16[k_offset + ib].qh[1])); +#elif defined(DATA_A_Q5_1) + uint qh = k_packed.k_data_packed16[k_offset + ib].qh; +#endif + [[unroll]] for (uint32_t d = 0; d < 4; d++) { +#if defined(A_TYPE_PACKED32) + uint vui = k_packed32.k_data_packed32[k_offset + ib].qs[d]; +#else + uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0], + k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1])); +#endif + k_quants[d ] = int32_t( vui & 0x0F0F0F0F); + k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F); +#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) + uint qh_lo = (qh >> (d * 4)) & 0xF; + uint qh_hi = (qh >> (d * 4 + 16)) & 0xF; + k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); + k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); +#endif + } + } else +#endif + { + [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { + k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset); + } + } + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint qib = tile_row(r) * qf_stride + (d_tid * (HSK_per_thread / 4) + d_block) / 8; + const uint qiqs = (d_tid * (HSK_per_thread / 4) + d_block) % 8; + + int32_t acc = 0; + [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { + acc += dotPacked4x8EXT(Qf[qib].qs[qiqs + d], k_quants[d]); + } + + Sf[r][c] += ACC_TYPE(acc) * ACC_TYPE(Qf[qib].ds.x) * k_dm.x; + if ((d_tid * (HSK_per_thread / 4) + d_block) % 8 == 0) { + Sf[r][c] += k_dot_correction(qib, k_dm); + } + } + } + } +#endif // MMQ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { // Compute sum across the D_split diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index b30dee8687..6f34924691 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -89,6 +89,11 @@ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16 layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; #endif +#if defined(A_TYPE_PACKED32) +layout (binding = 1) readonly buffer K_PACKED32 {A_TYPE_PACKED32 k_data_packed32[];} k_packed32; +layout (binding = 2) readonly buffer V_PACKED32 {A_TYPE_PACKED32 v_data_packed32[];} v_packed32; +#endif + #ifndef BLOCK_SIZE #define BLOCK_SIZE 1 #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl new file mode 100644 index 0000000000..e14e62d546 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl @@ -0,0 +1,149 @@ +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) +int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { +#ifdef DATA_A_Q4_0 + uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], + k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); +#else + uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4]; +#endif + + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + + return int32_t(vui & 0x0F0F0F0F); +} +#endif + +#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) +int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { +#ifdef DATA_A_Q5_0 + uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], + k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); + uint qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qh[0], + k_packed.k_data_packed16[a_offset + ib].qh[1])); +#else + uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4]; + uint qh = k_packed.k_data_packed16[a_offset + ib].qh; +#endif + + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + + uint qh_bits = (qh >> iqs) & 0xF; + return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u); +} +#endif + +#if defined(DATA_A_Q8_0) +int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { + return pack32(i16vec2(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2], k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])); +} +#endif + +#if defined(DATA_A_IQ4_NL) +int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { + uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], + k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + + u8vec4 idx = unpack8(vui & 0x0F0F0F0F); + return pack32(i8vec4(kvalues_iq4nl_const[idx.x], + kvalues_iq4nl_const[idx.y], + kvalues_iq4nl_const[idx.z], + kvalues_iq4nl_const[idx.w])); +} +#endif + +#if QUANT_AUXF == 1 +FLOAT_TYPE get_k_d(uint ib, uint a_offset) { + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d); +} +#else +FLOAT_TYPEV2 get_k_dm(uint ib, uint a_offset) { + return FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + ib].dm); +} +#endif + +void k_block_to_shmem(const uint buf_ib, const uint global_ib, const uint iqs, const uint a_offset) { +#if defined(DATA_A_Q4_0) + kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], + k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); +#elif defined(DATA_A_Q4_1) + kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs]; +#elif defined(DATA_A_Q5_0) + kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], + k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); + if (iqs == 0) { + kblocksh[buf_ib].qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qh[0], + k_packed.k_data_packed16[a_offset + global_ib].qh[1])); + } +#elif defined(DATA_A_Q5_1) + kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs]; + if (iqs == 0) { + kblocksh[buf_ib].qh = k_packed.k_data_packed16[a_offset + global_ib].qh; + } +#elif defined(DATA_A_Q8_0) + kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], + k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); +#elif defined(DATA_A_IQ4_NL) + const uint qs = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], + k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); + const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F); + const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F); + kblocksh[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_iq4nl_const[i_a0.x], kvalues_iq4nl_const[i_a0.y], + kvalues_iq4nl_const[i_a0.z], kvalues_iq4nl_const[i_a0.w])); + kblocksh[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_iq4nl_const[i_a1.x], kvalues_iq4nl_const[i_a1.y], + kvalues_iq4nl_const[i_a1.z], kvalues_iq4nl_const[i_a1.w])); +#endif + + if (iqs == 0) { +#if QUANT_AUXF == 1 + kblocksh[buf_ib].dm = FLOAT_TYPE(k_packed.k_data_packed16[a_offset + global_ib].d); +#else + kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + global_ib].dm); +#endif + } +} + +int32_t get_k_qs_shmem(const uint buf_ib, const uint pos) { +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) + uint sub = pos % 4; + uint shift = ((pos % 8) >= 4) ? 4 : 0; + return int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F); +#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) + uint sub = pos % 4; + uint shift = ((pos % 8) >= 4) ? 4 : 0; + int32_t result = int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F); + uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4)) & 0xF; + return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u); +#elif defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL) + return kblocksh[buf_ib].qs[pos]; +#endif +} + +ACC_TYPE k_dot_correction(const uint qib, const ACC_TYPEV2 k_dm) { +#if defined(DATA_A_Q4_0) + return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; +#elif defined(DATA_A_Q5_0) + return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; +#elif defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) + return ACC_TYPE(Qf[qib].ds.y) * k_dm.y; +#else + return ACC_TYPE(0.0); +#endif +} + +void k_block_to_shmem_zero(const uint buf_ib, const uint iqs) { + kblocksh[buf_ib].qs[iqs] = 0; +#if defined(DATA_A_IQ4_NL) + kblocksh[buf_ib].qs[iqs + 4] = 0; +#endif + if (iqs == 0) { +#if QUANT_AUXF == 1 + kblocksh[buf_ib].dm = FLOAT_TYPE(0.0f); +#else + kblocksh[buf_ib].dm = FLOAT_TYPEV2(0.0f); +#endif + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl index c700f6e3f2..10552d013a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl @@ -32,6 +32,12 @@ struct block_a_cache { int32_t qs[32/4]; FLOAT_TYPE dm; }; +#elif defined(DATA_A_IQ4_NL) +#define QUANT_R_MMQ 2 +struct block_a_cache { + int32_t qs[8]; + FLOAT_TYPE dm; +}; #elif defined(DATA_A_MXFP4) #define QUANT_R_MMQ 2 struct block_a_cache { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 4239070af5..1fb592fb84 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -1692,6 +1692,7 @@ struct block_iq4_nl_packed16 #if defined(DATA_A_IQ4_NL) #define QUANT_K QUANT_K_IQ4_NL #define QUANT_R QUANT_R_IQ4_NL +#define QUANT_AUXF 1 #define A_TYPE block_iq4_nl #define A_TYPE_PACKED16 block_iq4_nl_packed16 #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 77a55ea812..607eef7d0d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -406,8 +406,8 @@ std::map merge_maps(const std::map> compiles; -void string_to_spv(std::string name, const std::string& source, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { - name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); +void string_to_spv(std::string name, const std::string& source, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false, const std::string& suffix = "") { + name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")) + suffix; std::string out_path = join_paths(output_dir, name + ".spv"); if (input_filepath == "") { @@ -625,15 +625,16 @@ void process_shaders() { for (const bool& fp16 : {false, true}) { std::map base_dict; if (fp16) { - base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}}; + base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV2", "f16vec2"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}}; } else { - base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}}; + base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"FLOAT_TYPEV4", "vec4"}}; } // flash attention for (const bool& f16acc : {false, true}) { std::map fa_base_dict = base_dict; fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float"; + fa_base_dict["ACC_TYPEV2"] = fp16 && f16acc ? "f16vec2" : "vec2"; fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4"; if (fp16 && f16acc) { fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)"; @@ -672,6 +673,12 @@ void process_shaders() { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc); +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (tname != "f32") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }, {"MMQ", "1"}}), fp16, false, false, f16acc, "_int8"); + } +#endif } } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f41558902c..e54e6b2bb4 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8580,7 +8580,7 @@ static std::vector> make_test_cases_eval() { for (int nb : { 1, 3, 32, 75, }) { for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) { if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue; - for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { + for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0, GGML_TYPE_IQ4_NL}) { if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72) continue; test_cases.emplace_back(new test_flash_attn_ext( hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));