llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/tq_utils.comp

68 lines
1.8 KiB
Plaintext

#ifndef TQ_UTILS_COMP
#define TQ_UTILS_COMP
#if defined(DATA_A_TQ2_0)
#if defined(TQ2_CM2)
int tq2_dequantize(const in decodeBufTQ2_0 bl, uint iqs) {
#else
int tq2_dequantize(uint ib, uint iqs) {
#endif
const uint upper = iqs / 128;
const uint byte = (upper * 32) + (iqs % 32);
const uint shift = ((iqs % 128) / 32) * 2;
#if defined(TQ2_CM2)
const int c = (int(bl.block.qs[byte]) >> shift) & 3;
#else
const int c = (int(data_a[ib].qs[byte]) >> shift) & 3;
#endif
return c - 1;
}
#endif
#if defined(DATA_A_TQ1_0)
#if defined(TQ1_CM2)
int tq1_dequantize(const in decodeBufTQ1_0 bl, uint iqs) {
#else
int tq1_dequantize(uint ib, uint iqs) {
#endif
const uint pow3[6] = uint[6](1, 3, 9, 27, 81, 243);
if (iqs < 160) {
const uint trit = iqs / 32;
const uint byte = iqs % 32;
#if defined(TQ1_CM2)
const uint q = uint(bl.block.qs[byte]);
#else
const uint q = uint(data_a[ib].qs[byte]);
#endif
const uint val = (((q * pow3[trit]) & 255) * 3) / 256;
return int(val) - 1;
} else if (iqs < 240) {
const uint relative_idx = iqs - 160;
const uint trit = relative_idx / 16;
const uint byte = relative_idx % 16;
#if defined(TQ1_CM2)
const uint q = uint(bl.block.qs[32 + byte]);
#else
const uint q = uint(data_a[ib].qs[32 + byte]);
#endif
const uint val = (((q * pow3[trit]) & 255) * 3) / 256;
return int(val) - 1;
} else {
const uint relative_idx = iqs - 240;
const uint trit = relative_idx / 4;
const uint byte = relative_idx % 4;
#if defined(TQ1_CM2)
const uint q = uint(bl.block.qh[byte]);
#else
const uint q = uint(data_a[ib].qh[byte]);
#endif
const uint val = (((q * pow3[trit]) & 255) * 3) / 256;
return int(val) - 1;
}
}
#endif
#endif