From db52ab40fc093ec5018a6f4d35e5f4d16a6df8b8 Mon Sep 17 00:00:00 2001 From: Italo Nicola Date: Fri, 20 Feb 2026 13:42:22 -0300 Subject: [PATCH] Separate cm2 TQ1/TQ2 dequantization from the generic tq helpers Sponsored-by: Tether Inc. Signed-off-by: Italo Nicola --- .../vulkan-shaders/dequant_funcs_cm2.glsl | 41 +++++++++++++------ .../{tq_utils.comp => tq_utils.glsl} | 29 ++----------- 2 files changed, 32 insertions(+), 38 deletions(-) rename ggml/src/ggml-vulkan/vulkan-shaders/{tq_utils.comp => tq_utils.glsl} (68%) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index dc99b26e0d..5a76aaad41 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -671,18 +671,19 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufTQ2 block_tq2_0 block; }; -#define TQ2_CM2 1 -#include "tq_utils.comp" -#undef TQ2_CM2 - float16_t dequantFuncTQ2_0(const in decodeBufTQ2_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const float16_t d = bl.block.d; const uint idx = coordInBlock[1]; - const int val = tq2_dequantize(bl, idx); + const uint upper = idx / 128; - return d * float16_t(val); + const uint byte = (upper * 32) + (idx % 32); + const uint shift = ((idx % 128) / 32) * 2; + + const int c = (int(bl.block.qs[byte]) >> shift) & 3; + + return d * float16_t(c - 1); } #endif @@ -691,18 +692,34 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufTQ1 block_tq1_0 block; }; -#define TQ1_CM2 1 -#include "tq_utils.comp" -#undef TQ1_CM2 - float16_t dequantFuncTQ1_0(const in decodeBufTQ1_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const float16_t d = bl.block.d; const uint idx = coordInBlock[1]; - const int val = tq1_dequantize(bl, idx); + const uint pow3[6] = uint[6](1, 3, 9, 27, 81, 243); - return d * float16_t(val); + uint val; + if (idx < 160) { + const uint trit = idx / 32; + const uint byte = idx % 32; + const uint q = uint(bl.block.qs[byte]); + val = (((q * pow3[trit]) & 255) * 3) / 256; + } else if (idx < 240) { + const uint relative_idx = idx - 160; + const uint trit = relative_idx / 16; + const uint byte = relative_idx % 16; + const uint q = uint(bl.block.qs[32 + byte]); + val = (((q * pow3[trit]) & 255) * 3) / 256; + } else { + const uint relative_idx = idx - 240; + const uint trit = relative_idx / 4; + const uint byte = relative_idx % 4; + const uint q = uint(bl.block.qh[byte]); + val = (((q * pow3[trit]) & 255) * 3) / 256; + } + + return d * float16_t(val - 1); } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/tq_utils.comp b/ggml/src/ggml-vulkan/vulkan-shaders/tq_utils.glsl similarity index 68% rename from ggml/src/ggml-vulkan/vulkan-shaders/tq_utils.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/tq_utils.glsl index c4abde19a5..53e79bb75b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/tq_utils.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/tq_utils.glsl @@ -1,64 +1,41 @@ -#ifndef TQ_UTILS_COMP -#define TQ_UTILS_COMP +#ifndef TQ_UTILS_IMPL +#define TQ_UTILS_IMPL #if defined(DATA_A_TQ2_0) -#if defined(TQ2_CM2) -int tq2_dequantize(const in decodeBufTQ2_0 bl, uint iqs) { -#else int tq2_dequantize(uint ib, uint iqs) { -#endif const uint upper = iqs / 128; const uint byte = (upper * 32) + (iqs % 32); const uint shift = ((iqs % 128) / 32) * 2; -#if defined(TQ2_CM2) - const int c = (int(bl.block.qs[byte]) >> shift) & 3; -#else const int c = (int(data_a[ib].qs[byte]) >> shift) & 3; -#endif + return c - 1; } #endif #if defined(DATA_A_TQ1_0) -#if defined(TQ1_CM2) -int tq1_dequantize(const in decodeBufTQ1_0 bl, uint iqs) { -#else int tq1_dequantize(uint ib, uint iqs) { -#endif const uint pow3[6] = uint[6](1, 3, 9, 27, 81, 243); if (iqs < 160) { const uint trit = iqs / 32; const uint byte = iqs % 32; -#if defined(TQ1_CM2) - const uint q = uint(bl.block.qs[byte]); -#else const uint q = uint(data_a[ib].qs[byte]); -#endif const uint val = (((q * pow3[trit]) & 255) * 3) / 256; return int(val) - 1; } else if (iqs < 240) { const uint relative_idx = iqs - 160; const uint trit = relative_idx / 16; const uint byte = relative_idx % 16; -#if defined(TQ1_CM2) - const uint q = uint(bl.block.qs[32 + byte]); -#else const uint q = uint(data_a[ib].qs[32 + byte]); -#endif const uint val = (((q * pow3[trit]) & 255) * 3) / 256; return int(val) - 1; } else { const uint relative_idx = iqs - 240; const uint trit = relative_idx / 4; const uint byte = relative_idx % 4; -#if defined(TQ1_CM2) - const uint q = uint(bl.block.qh[byte]); -#else const uint q = uint(data_a[ib].qh[byte]); -#endif const uint val = (((q * pow3[trit]) & 255) * 3) / 256; return int(val) - 1; }