From edd4d9bca5dcb9be1ce66e618157d8787a0ea16f Mon Sep 17 00:00:00 2001 From: mkoker <132301062+mkoker@users.noreply.github.com> Date: Tue, 7 Apr 2026 07:41:29 -0400 Subject: [PATCH] vulkan: add FA dequant for q4_1, q5_0, q5_1, iq4_nl (#21029) Add dequantize4() implementations for Q4_1, Q5_0, Q5_1, and IQ4_NL in the flash attention base shader. Register them in the shader generator, pipeline creation, and enable in the scalar/coopmat1 FA support check. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 24 ++-- .../vulkan-shaders/flash_attn_base.glsl | 104 +++++++++++++++++- .../vulkan-shaders/vulkan-shaders-gen.cpp | 4 +- 3 files changed, 119 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 15ed5b2a79..19e7fbdaae 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3447,11 +3447,19 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, ) } else { CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32) } #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->coopmat1_fa_support) { @@ -3459,6 +3467,10 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT1, _cm1) } #endif #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) @@ -15331,11 +15343,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_F32: case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: - // supported in scalar and coopmat2 paths - break; case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: + case GGML_TYPE_IQ4_NL: + // supported in scalar and coopmat2 paths + break; // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently //case GGML_TYPE_Q2_K: //case GGML_TYPE_Q3_K: @@ -15350,12 +15363,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm //case GGML_TYPE_IQ3_XXS: //case GGML_TYPE_IQ3_S: //case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ4_NL: - // currently supported only in coopmat2 path - if (!coopmat2) { - return false; - } - break; + default: return false; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 172d38f034..b30dee8687 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -110,6 +110,97 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { #if defined(DATA_A_Q4_0) #define BLOCK_BYTE_SIZE 18 +#elif defined(DATA_A_Q4_1) +#define BLOCK_BYTE_SIZE 20 +#endif + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + if (binding_idx == BINDING_IDX_K) { + uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); +#ifdef DATA_A_Q4_1 + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m); +#else + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); +#endif + } else { + uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); +#ifdef DATA_A_Q4_1 + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m); +#else + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); +#endif + } +} +#endif + +#if defined(DATA_A_Q5_0) +#define BLOCK_BYTE_SIZE 22 +#elif defined(DATA_A_Q5_1) +#define BLOCK_BYTE_SIZE 24 +#endif + +#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + if (binding_idx == BINDING_IDX_K) { + uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + +#ifdef DATA_A_Q5_1 + uint qh = k_packed.k_data_packed16[a_offset + ib].qh; +#else + uint qh = uint(k_packed.k_data_packed16[a_offset + ib].qh[0]) | (uint(k_packed.k_data_packed16[a_offset + ib].qh[1]) << 16); +#endif + FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f); + + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); +#ifdef DATA_A_Q5_1 + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m); +#else + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); +#endif + } else { + uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + +#ifdef DATA_A_Q5_1 + uint qh = v_packed.v_data_packed16[a_offset + ib].qh; +#else + uint qh = uint(v_packed.v_data_packed16[a_offset + ib].qh[0]) | (uint(v_packed.v_data_packed16[a_offset + ib].qh[1]) << 16); +#endif + FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f); + + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); +#ifdef DATA_A_Q5_1 + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m); +#else + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); +#endif + } +} +#endif + + +#if defined(DATA_A_IQ4_NL) +#define BLOCK_BYTE_SIZE 18 FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { if (binding_idx == BINDING_IDX_K) { @@ -119,7 +210,11 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { vui_lo >>= shift; vui_hi >>= shift; - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f)); + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4( + kvalues_iq4nl[vui_lo & 0xF], + kvalues_iq4nl[(vui_lo >> 8) & 0xF], + kvalues_iq4nl[vui_hi & 0xF], + kvalues_iq4nl[(vui_hi >> 8) & 0xF]); } else { uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); @@ -127,11 +222,14 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { vui_lo >>= shift; vui_hi >>= shift; - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f)); + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4( + kvalues_iq4nl[vui_lo & 0xF], + kvalues_iq4nl[(vui_lo >> 8) & 0xF], + kvalues_iq4nl[vui_hi & 0xF], + kvalues_iq4nl[(vui_hi >> 8) & 0xF]); } } #endif - #if defined(DATA_A_Q8_0) #define BLOCK_BYTE_SIZE 34 FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 8186dba36f..bf04f4822e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -655,7 +655,7 @@ void process_shaders() { if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { + } else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); @@ -666,7 +666,7 @@ void process_shaders() { if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { + } else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);