Separate cm2 TQ1/TQ2 dequantization from the generic tq helpers

Sponsored-by: Tether Inc.
Signed-off-by: Italo Nicola <italo.nicola@collabora.com>
This commit is contained in:
Italo Nicola 2026-02-20 13:42:22 -03:00
parent 47817aad47
commit db52ab40fc
2 changed files with 32 additions and 38 deletions

View File

@ -671,18 +671,19 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufTQ2
block_tq2_0 block;
};
#define TQ2_CM2 1
#include "tq_utils.comp"
#undef TQ2_CM2
float16_t dequantFuncTQ2_0(const in decodeBufTQ2_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const int val = tq2_dequantize(bl, idx);
const uint upper = idx / 128;
return d * float16_t(val);
const uint byte = (upper * 32) + (idx % 32);
const uint shift = ((idx % 128) / 32) * 2;
const int c = (int(bl.block.qs[byte]) >> shift) & 3;
return d * float16_t(c - 1);
}
#endif
@ -691,18 +692,34 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufTQ1
block_tq1_0 block;
};
#define TQ1_CM2 1
#include "tq_utils.comp"
#undef TQ1_CM2
float16_t dequantFuncTQ1_0(const in decodeBufTQ1_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const int val = tq1_dequantize(bl, idx);
const uint pow3[6] = uint[6](1, 3, 9, 27, 81, 243);
return d * float16_t(val);
uint val;
if (idx < 160) {
const uint trit = idx / 32;
const uint byte = idx % 32;
const uint q = uint(bl.block.qs[byte]);
val = (((q * pow3[trit]) & 255) * 3) / 256;
} else if (idx < 240) {
const uint relative_idx = idx - 160;
const uint trit = relative_idx / 16;
const uint byte = relative_idx % 16;
const uint q = uint(bl.block.qs[32 + byte]);
val = (((q * pow3[trit]) & 255) * 3) / 256;
} else {
const uint relative_idx = idx - 240;
const uint trit = relative_idx / 4;
const uint byte = relative_idx % 4;
const uint q = uint(bl.block.qh[byte]);
val = (((q * pow3[trit]) & 255) * 3) / 256;
}
return d * float16_t(val - 1);
}
#endif

View File

@ -1,64 +1,41 @@
#ifndef TQ_UTILS_COMP
#define TQ_UTILS_COMP
#ifndef TQ_UTILS_IMPL
#define TQ_UTILS_IMPL
#if defined(DATA_A_TQ2_0)
#if defined(TQ2_CM2)
int tq2_dequantize(const in decodeBufTQ2_0 bl, uint iqs) {
#else
int tq2_dequantize(uint ib, uint iqs) {
#endif
const uint upper = iqs / 128;
const uint byte = (upper * 32) + (iqs % 32);
const uint shift = ((iqs % 128) / 32) * 2;
#if defined(TQ2_CM2)
const int c = (int(bl.block.qs[byte]) >> shift) & 3;
#else
const int c = (int(data_a[ib].qs[byte]) >> shift) & 3;
#endif
return c - 1;
}
#endif
#if defined(DATA_A_TQ1_0)
#if defined(TQ1_CM2)
int tq1_dequantize(const in decodeBufTQ1_0 bl, uint iqs) {
#else
int tq1_dequantize(uint ib, uint iqs) {
#endif
const uint pow3[6] = uint[6](1, 3, 9, 27, 81, 243);
if (iqs < 160) {
const uint trit = iqs / 32;
const uint byte = iqs % 32;
#if defined(TQ1_CM2)
const uint q = uint(bl.block.qs[byte]);
#else
const uint q = uint(data_a[ib].qs[byte]);
#endif
const uint val = (((q * pow3[trit]) & 255) * 3) / 256;
return int(val) - 1;
} else if (iqs < 240) {
const uint relative_idx = iqs - 160;
const uint trit = relative_idx / 16;
const uint byte = relative_idx % 16;
#if defined(TQ1_CM2)
const uint q = uint(bl.block.qs[32 + byte]);
#else
const uint q = uint(data_a[ib].qs[32 + byte]);
#endif
const uint val = (((q * pow3[trit]) & 255) * 3) / 256;
return int(val) - 1;
} else {
const uint relative_idx = iqs - 240;
const uint trit = relative_idx / 4;
const uint byte = relative_idx % 4;
#if defined(TQ1_CM2)
const uint q = uint(bl.block.qh[byte]);
#else
const uint q = uint(data_a[ib].qh[byte]);
#endif
const uint val = (((q * pow3[trit]) & 255) * 3) / 256;
return int(val) - 1;
}