From cb999704fb5c5dc1763f2e692887229c809e14d7 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Fri, 26 Dec 2025 17:12:11 +0000 Subject: [PATCH] vulkan: small dequantization improvements (#18380) * iq4_xs * quants --- .../vulkan-shaders/dequant_funcs.glsl | 8 +---- .../vulkan-shaders/mul_mm_funcs.glsl | 35 ++++++++++--------- .../src/ggml-vulkan/vulkan-shaders/types.glsl | 28 +++++++++++---- 3 files changed, 42 insertions(+), 29 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 70ee542d96..376944f1e2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -401,13 +401,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3; const uint qshift = (iqs & 16) >> 2; - u8vec4 qs = u8vec4( - data_a[a_offset + ib].qs[iq + 0], - data_a[a_offset + ib].qs[iq + 1], - data_a[a_offset + ib].qs[iq + 2], - data_a[a_offset + ib].qs[iq + 3] - ); - qs = (qs >> qshift) & uint8_t(0xF); + const u8vec4 qs = unpack8((data_a_packed32[a_offset + ib].qs[iq/4] >> qshift) & 0x0F0F0F0F); const float dl = float(int(sl | (sh << 4)) - 32); return dl * vec4( diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 58ede04400..1a3531761a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -159,14 +159,16 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint is = iqs / 8; // 0..15 const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3 const uint qsshift = halfsplit * 2; // 0,2,4,6 - const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF) | (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4)); const float dl = float(data_a[ib].d) * float(us - 32); - buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)), - dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); + const vec2 qs = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> qsshift) & 0x0303).xy); + const vec2 hm = vec2(unpack8(((uint(data_a_packed16[ib].hmask[hmi / 2]) >> (4 * n + halfsplit)) & 0x0101 ^ 0x0101) << 2).xy); + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * (qs.x - hm.x), + dl * (qs.y - hm.y)); #elif defined(DATA_A_Q4_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -198,8 +200,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const float d = loadd.x * sc; const float m = -loadd.y * mbyte; - buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m), - fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); + const vec2 q = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F).xy); + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m), + fma(d, q.y, m)); #elif defined(DATA_A_Q5_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -213,8 +217,6 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 const uint qhi = (iqs % 16) * 2; // 0,2,4..30 - const uint8_t hm = uint8_t(1 << (iqs / 16)); - const vec2 loadd = vec2(data_a[ib].dm); const uint scidx0 = (is < 4) ? is : (is + 4); @@ -234,8 +236,12 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const float d = loadd.x * sc; const float m = -loadd.y * mbyte; - buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m), - fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); + const uint qs = (uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F; + const uint qh = ((uint(data_a_packed16[ib].qh[qhi / 2]) >> (iqs / 16)) & 0x0101) << 4; + const vec2 q = vec2(unpack8(qs | qh).xy); + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m), + fma(d, q.y, m)); #elif defined(DATA_A_Q6_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -394,11 +400,9 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const float d = float(data_a[ib].d); const uint qs = data_a[ib].qs[iqs]; - const uint signs = pack32(u8vec4( - data_a[ib].qs[is+0], - data_a[ib].qs[is+1], - data_a[ib].qs[is+2], - data_a[ib].qs[is+3] + const uint signs = pack32(u16vec2( + data_a_packed16[ib].qs[is/2], + data_a_packed16[ib].qs[is/2+1] )); const float db = d * 0.5 * (0.5 + (signs >> 28)); const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); @@ -443,8 +447,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3; const uint qshift = (idx & 8) >> 1; - u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]); - qs = (qs >> qshift) & uint8_t(0xF); + u8vec2 qs = unpack8((uint(data_a_packed16[ib].qs[iq/2]) >> qshift) & 0x0F0F).xy; const float d = float(data_a[ib].d); const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 02578c77c4..402a2a8397 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -172,16 +172,12 @@ struct block_q8_0 float16_t d; int8_t qs[32]; }; + struct block_q8_0_packed16 { float16_t d; int16_t qs[32/2]; }; -struct block_q8_0_packed32 -{ - float16_t d; - int32_t qs[32/4]; -}; #if defined(DATA_A_Q8_0) #define QUANT_K QUANT_K_Q8_0 @@ -189,7 +185,6 @@ struct block_q8_0_packed32 #define QUANT_AUXF 1 #define A_TYPE block_q8_0 #define A_TYPE_PACKED16 block_q8_0_packed16 -#define A_TYPE_PACKED32 block_q8_0_packed32 #define DATA_A_QUANT_LEGACY #endif @@ -201,11 +196,13 @@ struct block_q8_1 f16vec2 ds; int8_t qs[32]; }; + struct block_q8_1_packed16 { f16vec2 ds; int16_t qs[16]; }; + struct block_q8_1_packed32 { f16vec2 ds; @@ -218,6 +215,7 @@ struct block_q8_1_x4 f16vec2 ds[4]; int32_t qs[32]; }; + struct block_q8_1_x4_packed128 { f16vec2 ds[4]; @@ -1346,10 +1344,28 @@ struct block_iq4_xs uint8_t qs[QUANT_K_IQ4_XS/2]; }; +struct block_iq4_xs_packed16 +{ + float16_t d; + uint16_t scales_h; + uint16_t scales_l[QUANT_K_IQ4_XS/128]; + uint16_t qs[QUANT_K_IQ4_XS/4]; +}; + +struct block_iq4_xs_packed32 +{ + float16_t d; + uint16_t scales_h; + uint32_t scales_l; + uint32_t qs[QUANT_K_IQ4_XS/8]; +}; + #if defined(DATA_A_IQ4_XS) #define QUANT_K QUANT_K_IQ4_XS #define QUANT_R QUANT_R_IQ4_XS #define A_TYPE block_iq4_xs +#define A_TYPE_PACKED16 block_iq4_xs_packed16 +#define A_TYPE_PACKED32 block_iq4_xs_packed32 #endif #define QUANT_K_IQ4_NL 32