diff --git a/common/arg.cpp b/common/arg.cpp index 538d2a4b0a..07410574c9 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -387,6 +387,8 @@ const std::vector kv_cache_types = { GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, + GGML_TYPE_TBQ3_0, + GGML_TYPE_TBQ4_0, }; static ggml_type kv_cache_type_from_str(const std::string & s) { diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 669f66b650..ba3e8cc5ac 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -428,7 +428,9 @@ extern "C" { // GGML_TYPE_IQ4_NL_8_8 = 38, GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block) GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale) - GGML_TYPE_COUNT = 41, + GGML_TYPE_TBQ3_0 = 41, // TurboQuant 3-bit + GGML_TYPE_TBQ4_0 = 42, // TurboQuant 4-bit + GGML_TYPE_COUNT = 43, }; // precision @@ -465,6 +467,8 @@ extern "C" { GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors GGML_FTYPE_MOSTLY_NVFP4 = 26, // except 1d tensors + GGML_FTYPE_MOSTLY_TBQ3_0 = 27, // except 1d tensors + GGML_FTYPE_MOSTLY_TBQ4_0 = 28, // except 1d tensors }; // available tensor operations: diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 78853304d9..46bdfa0d93 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -205,6 +205,9 @@ add_library(ggml-base ggml-threading.h ggml-quants.c ggml-quants.h + ggml-turboq.c + ggml-turboq.h + ggml-turboq-tables.h gguf.cpp) set_target_properties(ggml-base PROPERTIES diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 92cf739e7a..f03a1c3a62 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -266,6 +266,22 @@ typedef struct { } block_tq2_0; static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0 block size/padding"); +// TurboQuant blocks + +// 3.0625 bpw +typedef struct { + uint8_t qs[QK_K * 3 / 8]; + ggml_half d; +} block_tbq3_0; +static_assert(sizeof(block_tbq3_0) == sizeof(ggml_half) + QK_K * 3 / 8, "wrong tbq3_0 block size/padding"); + +// 4.0625 bpw +typedef struct { + uint8_t qs[QK_K / 2]; + ggml_half d; +} block_tbq4_0; +static_assert(sizeof(block_tbq4_0) == sizeof(ggml_half) + QK_K / 2, "wrong tbq4_0 block size/padding"); + // // Super-block quantization structures // diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 41da829315..74b886b054 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -18,6 +18,8 @@ #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K +#define ggml_vec_dot_tbq3_0_q8_K_generic ggml_vec_dot_tbq3_0_q8_K +#define ggml_vec_dot_tbq4_0_q8_K_generic ggml_vec_dot_tbq4_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K #define ggml_vec_dot_q3_K_q8_K_generic ggml_vec_dot_q3_K_q8_K #define ggml_vec_dot_q4_K_q8_K_generic ggml_vec_dot_q4_K_q8_K @@ -70,6 +72,9 @@ #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) +// quants.c +#define ggml_vec_dot_tbq3_0_q8_K_generic ggml_vec_dot_tbq3_0_q8_K +#define ggml_vec_dot_tbq4_0_q8_K_generic ggml_vec_dot_tbq4_0_q8_K // repack.cpp #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 @@ -82,6 +87,8 @@ #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_tbq3_0_q8_K_generic ggml_vec_dot_tbq3_0_q8_K +#define ggml_vec_dot_tbq4_0_q8_K_generic ggml_vec_dot_tbq4_0_q8_K // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 @@ -114,6 +121,8 @@ #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K +#define ggml_vec_dot_tbq3_0_q8_K_generic ggml_vec_dot_tbq3_0_q8_K +#define ggml_vec_dot_tbq4_0_q8_K_generic ggml_vec_dot_tbq4_0_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 @@ -157,6 +166,8 @@ #define quantize_row_q8_K_generic quantize_row_q8_K #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K +#define ggml_vec_dot_tbq3_0_q8_K_generic ggml_vec_dot_tbq3_0_q8_K +#define ggml_vec_dot_tbq4_0_q8_K_generic ggml_vec_dot_tbq4_0_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 @@ -199,6 +210,8 @@ #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__riscv) // quants.c +#define ggml_vec_dot_tbq3_0_q8_K_generic ggml_vec_dot_tbq3_0_q8_K +#define ggml_vec_dot_tbq4_0_q8_K_generic ggml_vec_dot_tbq4_0_q8_K #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 @@ -242,6 +255,8 @@ #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K +#define ggml_vec_dot_tbq3_0_q8_K_generic ggml_vec_dot_tbq3_0_q8_K +#define ggml_vec_dot_tbq4_0_q8_K_generic ggml_vec_dot_tbq4_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K #define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K #define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K @@ -292,6 +307,8 @@ #define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K +#define ggml_vec_dot_tbq3_0_q8_K_generic ggml_vec_dot_tbq3_0_q8_K +#define ggml_vec_dot_tbq4_0_q8_K_generic ggml_vec_dot_tbq4_0_q8_K #define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K #define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K #define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 7486acc2b5..a3d3e6d733 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -390,6 +390,18 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_TBQ3_0] = { + .from_float = quantize_row_tbq3_0, + .vec_dot = ggml_vec_dot_tbq3_0_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, + [GGML_TYPE_TBQ4_0] = { + .from_float = quantize_row_tbq4_0, + .vec_dot = ggml_vec_dot_tbq4_0_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, [GGML_TYPE_I32] = { .from_float = (ggml_from_float_t) ggml_cpu_fp32_to_i32, }, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 765ce07f06..17d0056aa9 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -11,6 +11,8 @@ #include #include #include +#include +#include // ggml_compute_forward_dup @@ -472,6 +474,33 @@ static void ggml_compute_forward_dup_bytes( } } +template +static inline void ggml_dup_from_float_row(const float * src, dst_t * dst, int64_t n) { + for (int64_t i = 0; i < n; ++i) { + dst[i] = (dst_t) src[i]; + } +} + +template<> +inline void ggml_dup_from_float_row(const float * src, float * dst, int64_t n) { + ggml_vec_cpy_f32(n, dst, src); +} + +template<> +inline void ggml_dup_from_float_row(const float * src, ggml_fp16_t * dst, int64_t n) { + for (int64_t i = 0; i < n; ++i) { + dst[i] = GGML_CPU_FP32_TO_FP16(src[i]); + } +} + +template<> +inline void ggml_dup_from_float_row(const float * src, ggml_bf16_t * dst, int64_t n) { + for (int64_t i = 0; i < n; ++i) { + dst[i] = GGML_FP32_TO_BF16(src[i]); + } +} + +template static void ggml_compute_forward_dup_from_q( const ggml_compute_params * params, ggml_tensor * dst) { @@ -517,9 +546,19 @@ static void ggml_compute_forward_dup_from_q( const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13; - dequantize_row_q( - (const void *) ((char *) src0->data + x_offset), - (float *) ((char *) dst->data + dst_offset), qk); + if constexpr (std::is_same_v) { + dequantize_row_q( + (const void *) ((char *) src0->data + x_offset), + (float *) ((char *) dst->data + dst_offset), qk); + } else { + std::vector tmp(qk); + + dequantize_row_q( + (const void *) ((char *) src0->data + x_offset), + tmp.data(), qk); + + ggml_dup_from_float_row(tmp.data(), (dst_t *) ((char *) dst->data + dst_offset), qk); + } } } @@ -564,9 +603,19 @@ void ggml_compute_forward_dup( } break; default: { - if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) { - ggml_compute_forward_dup_from_q(params, dst); - break; + if (ggml_is_quantized(src0->type)) { + if (dst->type == GGML_TYPE_F32) { + ggml_compute_forward_dup_from_q(params, dst); + break; + } + if (dst->type == GGML_TYPE_F16) { + ggml_compute_forward_dup_from_q(params, dst); + break; + } + if (dst->type == GGML_TYPE_BF16) { + ggml_compute_forward_dup_from_q(params, dst); + break; + } } GGML_ABORT("fatal error"); } @@ -678,6 +727,8 @@ void ggml_compute_forward_add( case GGML_TYPE_Q6_K: case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_TBQ3_0: + case GGML_TYPE_TBQ4_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -1128,6 +1179,8 @@ void ggml_compute_forward_add1( case GGML_TYPE_Q6_K: case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_TBQ3_0: + case GGML_TYPE_TBQ4_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -1257,6 +1310,8 @@ void ggml_compute_forward_acc( case GGML_TYPE_Q6_K: case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_TBQ3_0: + case GGML_TYPE_TBQ4_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -4345,6 +4400,8 @@ void ggml_compute_forward_out_prod( case GGML_TYPE_Q6_K: case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_TBQ3_0: + case GGML_TYPE_TBQ4_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -4621,6 +4678,8 @@ void ggml_compute_forward_set( case GGML_TYPE_Q6_K: case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_TBQ3_0: + case GGML_TYPE_TBQ4_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -4844,6 +4903,8 @@ void ggml_compute_forward_get_rows( case GGML_TYPE_Q6_K: case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_TBQ3_0: + case GGML_TYPE_TBQ4_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -5569,6 +5630,8 @@ void ggml_compute_forward_clamp( case GGML_TYPE_Q6_K: case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_TBQ3_0: + case GGML_TYPE_TBQ4_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index 7ebbb9c6f1..f5b0687122 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -108,6 +108,18 @@ void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, quantize_row_tq2_0_ref(x, y, k); } +void quantize_row_tbq3_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK_K == 0); + block_tbq3_0 * GGML_RESTRICT y = vy; + quantize_row_tbq3_0_ref(x, y, k); +} + +void quantize_row_tbq4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK_K == 0); + block_tbq4_0 * GGML_RESTRICT y = vy; + quantize_row_tbq4_0_ref(x, y, k); +} + //===================================== Q8_K ============================================== void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { @@ -456,6 +468,83 @@ void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, *s = sumf; } +// TurboQuant vec_dot falls back to dequantize-then-dot on CPU. + +#if defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L && !defined(__STDC_NO_THREADS__) +#define TURBOQ_VD_TL _Thread_local +#elif defined(__GNUC__) || defined(__clang__) +#define TURBOQ_VD_TL __thread +#elif defined(_MSC_VER) +#define TURBOQ_VD_TL __declspec(thread) +#else +#define TURBOQ_VD_TL +#endif + +static TURBOQ_VD_TL float * tbq_vd_buf = NULL; +static TURBOQ_VD_TL int64_t tbq_vd_buf_size = 0; + +static float * tbq_vd_get_scratch(int64_t n) { + if (n > tbq_vd_buf_size) { + free(tbq_vd_buf); + tbq_vd_buf = (float *)malloc(n * sizeof(float)); + tbq_vd_buf_size = n; + } + return tbq_vd_buf; +} + +void ggml_vec_dot_tbq3_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + float * tmp = tbq_vd_get_scratch(n); + dequantize_row_tbq3_0((const block_tbq3_0 *)vx, tmp, n); + + const block_q8_K * GGML_RESTRICT y = vy; + const int nb = n / QK_K; + + float sumf = 0.0f; + int64_t idx = 0; + for (int i = 0; i < nb; i++) { + const float d = y[i].d; + for (int j = 0; j < QK_K; j++) { + sumf += tmp[idx] * (d * y[i].qs[j]); + idx++; + } + } + + *s = sumf; +} + +void ggml_vec_dot_tbq4_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + float * tmp = tbq_vd_get_scratch(n); + dequantize_row_tbq4_0((const block_tbq4_0 *)vx, tmp, n); + + const block_q8_K * GGML_RESTRICT y = vy; + const int nb = n / QK_K; + + float sumf = 0.0f; + int64_t idx = 0; + for (int i = 0; i < nb; i++) { + const float d = y[i].d; + for (int j = 0; j < QK_K; j++) { + sumf += tmp[idx] * (d * y[i].qs[j]); + idx++; + } + } + + *s = sumf; +} + + void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index 3584aaa43e..c447fb4e4f 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -32,6 +32,9 @@ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_tbq3_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_tbq4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); + void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -54,6 +57,9 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_tbq3_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_tbq4_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 48695a61ea..ddc6aae2e9 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -5399,6 +5399,15 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); } break; + case GGML_TYPE_TBQ3_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_tbq3_0, data, nb); + } break; + case GGML_TYPE_TBQ4_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_tbq4_0, data, nb); + } break; + case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 00604f75c0..6719168c52 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -34,6 +34,9 @@ GGML_API void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_API void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_tq2_0_ref(const float * GGML_RESTRICT x, block_tq2_0 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_tbq3_0_ref(const float * GGML_RESTRICT x, block_tbq3_0 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_tbq4_0_ref(const float * GGML_RESTRICT x, block_tbq4_0 * GGML_RESTRICT y, int64_t k); + GGML_API void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k); @@ -61,6 +64,9 @@ GGML_API void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GG GGML_API void dequantize_row_tq1_0(const block_tq1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_tq2_0(const block_tq2_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_tbq3_0(const block_tbq3_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_tbq4_0(const block_tbq4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + GGML_API void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -85,6 +91,9 @@ GGML_API size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RE GGML_API size_t quantize_tq1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_tq2_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_tbq3_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_tbq4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + GGML_API size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/ggml/src/ggml-turboq-tables.h b/ggml/src/ggml-turboq-tables.h new file mode 100644 index 0000000000..5524e039f1 --- /dev/null +++ b/ggml/src/ggml-turboq-tables.h @@ -0,0 +1,35 @@ +#pragma once + +// Lloyd-Max codebooks for the TurboQuant CPU path. + +static const float turboq_codebook_2bit[4] = { + -1.5104f, -0.4528f, 0.4528f, 1.5104f, +}; + +static const float turboq_codebook_3bit[8] = { + -2.1520f, -1.3440f, -0.7560f, -0.2451f, + 0.2451f, 0.7560f, 1.3440f, 2.1520f, +}; + +static const float turboq_codebook_4bit[16] = { + -2.7326f, -2.0690f, -1.6180f, -1.2562f, + -0.9424f, -0.6568f, -0.3881f, -0.1284f, + 0.1284f, 0.3881f, 0.6568f, 0.9424f, + 1.2562f, 1.6180f, 2.0690f, 2.7326f, +}; + +static const float turboq_boundaries_2bit[3] = { + -0.9816f, 0.0000f, 0.9816f, +}; + +static const float turboq_boundaries_3bit[7] = { + -1.7480f, -1.0500f, -0.5006f, 0.0000f, + 0.5006f, 1.0500f, 1.7480f, +}; + +static const float turboq_boundaries_4bit[15] = { + -2.4008f, -1.8435f, -1.4371f, -1.0993f, + -0.7996f, -0.5225f, -0.2583f, 0.0000f, + 0.2583f, 0.5225f, 0.7996f, 1.0993f, + 1.4371f, 1.8435f, 2.4008f, +}; diff --git a/ggml/src/ggml-turboq.c b/ggml/src/ggml-turboq.c new file mode 100644 index 0000000000..58d260a214 --- /dev/null +++ b/ggml/src/ggml-turboq.c @@ -0,0 +1,682 @@ +// TurboQuant reference helpers for the CPU path. + +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" + +#include "ggml-turboq.h" +#include "ggml-turboq-tables.h" +#include "ggml-quants.h" +#include "ggml-impl.h" +#include "ggml.h" + +#include +#include +#include +#include + +#if defined(__AVX2__) +#include +#endif + +#if defined(__GNUC__) || defined(__clang__) +#define TURBOQ_TLS __thread +#elif defined(_MSC_VER) +#define TURBOQ_TLS __declspec(thread) +#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L && !defined(__STDC_NO_THREADS__) +#define TURBOQ_TLS _Thread_local +#else +#define TURBOQ_TLS +#endif + +static inline uint64_t splitmix64_next(uint64_t * state) { + uint64_t z = (*state += 0x9e3779b97f4a7c15ULL); + z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9ULL; + z = (z ^ (z >> 27)) * 0x94d049bb133111ebULL; + return z ^ (z >> 31); +} + +static void turboq_generate_gaussian(float * out, int64_t n, uint64_t seed) { + uint64_t state = seed; + int64_t i = 0; + for (; i + 1 < n; i += 2) { + // Generate two uniform (0,1) variates + double u1 = ((double)(splitmix64_next(&state) >> 11) + 0.5) / (double)(1ULL << 53); + double u2 = ((double)(splitmix64_next(&state) >> 11) + 0.5) / (double)(1ULL << 53); + double r = sqrt(-2.0 * log(u1)); + double th = 2.0 * 3.14159265358979323846 * u2; + out[i] = (float)(r * cos(th)); + out[i + 1] = (float)(r * sin(th)); + } + if (i < n) { + double u1 = ((double)(splitmix64_next(&state) >> 11) + 0.5) / (double)(1ULL << 53); + double u2 = ((double)(splitmix64_next(&state) >> 11) + 0.5) / (double)(1ULL << 53); + double r = sqrt(-2.0 * log(u1)); + double th = 2.0 * 3.14159265358979323846 * u2; + out[i] = (float)(r * cos(th)); + } +} + +// --------------------------------------------------------------------------- +// Householder QR decomposition (in-place, no LAPACK dependency) +// +// Input: A[d*d] stored column-major (A[i + j*d] = A_{i,j}) +// Output: Q[d*d] column-major orthogonal matrix, with Haar sign correction +// +// Uses Householder reflections: Q = H_1 * H_2 * ... * H_d where +// H_k = I - 2 * v_k * v_k^T / (v_k^T * v_k) +// --------------------------------------------------------------------------- + +// Compute Q from Householder QR of column-major matrix A[d×d]. +// A is modified in-place (becomes R on upper triangle, v below diagonal). +// Q is written to Q_out[d×d] column-major. +// Applies Haar sign correction: Q[:,j] *= sign(R[j,j]) so that Q is +// uniformly distributed on O(d) (Haar measure). +static void turboq_householder_qr(float * A, float * Q_out, int64_t d) { + float * tau = (float *)malloc(d * sizeof(float)); + // Store sign(R[k,k]) = -sign(alpha_k) for Haar correction + float * r_sign = (float *)malloc(d * sizeof(float)); + + for (int64_t k = 0; k < d; k++) { + // Compute norm of A[k:d, k] + float norm_sq = 0.0f; + for (int64_t i = k; i < d; i++) { + float val = A[i + k * d]; + norm_sq += val * val; + } + float norm = sqrtf(norm_sq); + + // Choose sign to avoid cancellation + float alpha = A[k + k * d]; + float sign_alpha = (alpha >= 0.0f) ? 1.0f : -1.0f; + float u1 = alpha + sign_alpha * norm; + + // R[k,k] = -sign(alpha) * norm, so sign(R[k,k]) = -sign(alpha) + r_sign[k] = -sign_alpha; + + // Compute tau = 2 / (v^T v) + float vtv = u1 * u1 + (norm_sq - alpha * alpha); + if (vtv < 1e-30f) { + tau[k] = 0.0f; + continue; + } + tau[k] = 2.0f / vtv; + + // Store v in A[k:d, k] + A[k + k * d] = u1; + + // Apply H_k to remaining columns A[k:d, k+1:d] + for (int64_t j = k + 1; j < d; j++) { + float dot = 0.0f; + dot += u1 * A[k + j * d]; + for (int64_t i = k + 1; i < d; i++) { + dot += A[i + k * d] * A[i + j * d]; + } + dot *= tau[k]; + A[k + j * d] -= dot * u1; + for (int64_t i = k + 1; i < d; i++) { + A[i + j * d] -= dot * A[i + k * d]; + } + } + } + + // Build Q by back-accumulation: Q = H_1 * H_2 * ... * H_{d-1} + memset(Q_out, 0, d * d * sizeof(float)); + for (int64_t i = 0; i < d; i++) { + Q_out[i + i * d] = 1.0f; + } + + for (int64_t k = d - 1; k >= 0; k--) { + if (tau[k] == 0.0f) continue; + float u1 = A[k + k * d]; + for (int64_t j = 0; j < d; j++) { + float dot = 0.0f; + dot += u1 * Q_out[k + j * d]; + for (int64_t i = k + 1; i < d; i++) { + dot += A[i + k * d] * Q_out[i + j * d]; + } + dot *= tau[k]; + Q_out[k + j * d] -= dot * u1; + for (int64_t i = k + 1; i < d; i++) { + Q_out[i + j * d] -= dot * A[i + k * d]; + } + } + } + + // Haar sign correction: Q[:,j] *= sign(R[j,j]) + // This ensures Q is uniformly distributed on O(d), not just SO(d). + // Reference: Mezzadri (2007), "How to Generate Random Matrices from the Classical Compact Groups" + for (int64_t j = 0; j < d; j++) { + if (r_sign[j] < 0.0f) { + for (int64_t i = 0; i < d; i++) { + Q_out[i + j * d] = -Q_out[i + j * d]; + } + } + } + + free(tau); + free(r_sign); +} + +// --------------------------------------------------------------------------- +// Rotation matrix cache +// +// For a given (dimension, seed) pair, generate and cache the d×d orthogonal Q. +// The cache is thread-local to avoid locks. In practice, all rows of a weight +// matrix share the same dimension, so the cache hit rate is ~100%. +// --------------------------------------------------------------------------- + +static TURBOQ_TLS float * tl_Q = NULL; +static TURBOQ_TLS float * tl_Q_row = NULL; +static TURBOQ_TLS int64_t tl_Q_dim = 0; +static TURBOQ_TLS uint64_t tl_Q_seed = 0; + +static const float * turboq_get_rotation(int64_t d, uint64_t seed) { + if (tl_Q != NULL && tl_Q_dim == d && tl_Q_seed == seed) { + return tl_Q; + } + // Regenerate + free(tl_Q); + free(tl_Q_row); + tl_Q = (float *)malloc(d * d * sizeof(float)); + tl_Q_row = (float *)malloc(d * d * sizeof(float)); + tl_Q_dim = d; + tl_Q_seed = seed; + + // Generate d×d Gaussian random matrix (column-major) + float * A = (float *)malloc(d * d * sizeof(float)); + turboq_generate_gaussian(A, d * d, seed); + + // Compute QR, store Q in tl_Q + turboq_householder_qr(A, tl_Q, d); + + for (int64_t i = 0; i < d; ++i) { + for (int64_t j = 0; j < d; ++j) { + tl_Q_row[i * d + j] = tl_Q[i + j * d]; + } + } + + free(A); + return tl_Q; +} + +static const float * turboq_get_rotation_row(int64_t d, uint64_t seed) { + turboq_get_rotation(d, seed); + return tl_Q_row; +} + +// --------------------------------------------------------------------------- +// Projection matrix cache (for Q_prod QJL stage) +// +// S is a d×d random Gaussian matrix (NOT orthogonalized), used for QJL: +// qjl_signs = sign(S · residual) +// dequant: sqrt(pi/2)/d · gamma · S^T · signs +// Uses a different seed stream from the rotation matrix Q. +// --------------------------------------------------------------------------- + +static TURBOQ_TLS float * tl_S = NULL; +static TURBOQ_TLS float * tl_S_row = NULL; +static TURBOQ_TLS int64_t tl_S_dim = 0; +static TURBOQ_TLS uint64_t tl_S_seed = 0; + +static const float * turboq_get_projection(int64_t d, uint64_t seed) { + // Use a different seed stream for S vs Q + uint64_t s_seed = seed ^ 0x1234567890abcdefULL; + if (tl_S != NULL && tl_S_dim == d && tl_S_seed == s_seed) { + return tl_S; + } + free(tl_S); + free(tl_S_row); + tl_S = (float *)malloc(d * d * sizeof(float)); + tl_S_row = (float *)malloc(d * d * sizeof(float)); + tl_S_dim = d; + tl_S_seed = s_seed; + + // Generate d×d Gaussian random matrix (column-major), no QR + turboq_generate_gaussian(tl_S, d * d, s_seed); + + for (int64_t i = 0; i < d; ++i) { + for (int64_t j = 0; j < d; ++j) { + tl_S_row[i * d + j] = tl_S[i + j * d]; + } + } + + return tl_S; +} + +static const float * turboq_get_projection_row(int64_t d, uint64_t seed) { + turboq_get_projection(d, seed); + return tl_S_row; +} + +// --------------------------------------------------------------------------- +// Dense matrix-vector multiply: y = M * x (M is d×d column-major) +// --------------------------------------------------------------------------- + +static void matvec(float * y, const float * M, const float * x, int64_t d) { + for (int64_t i = 0; i < d; i++) { + float sum = 0.0f; + for (int64_t j = 0; j < d; j++) { + sum += M[i + j * d] * x[j]; // M[i,j] = M[i + j*d] (column-major) + } + y[i] = sum; + } +} + +#if defined(__AVX2__) +static inline float turboq_hsum_avx(__m256 v) { + __m128 lo = _mm256_castps256_ps128(v); + __m128 hi = _mm256_extractf128_ps(v, 1); + __m128 sum = _mm_add_ps(lo, hi); + sum = _mm_hadd_ps(sum, sum); + sum = _mm_hadd_ps(sum, sum); + return _mm_cvtss_f32(sum); +} +#endif + +static void matvec_row(float * y, const float * M, const float * x, int64_t d) { + for (int64_t i = 0; i < d; ++i) { + const float * row = M + i * d; + float sum = 0.0f; + int64_t j = 0; +#if defined(__AVX2__) + __m256 acc = _mm256_setzero_ps(); + for (; j + 7 < d; j += 8) { + const __m256 mv = _mm256_loadu_ps(row + j); + const __m256 xv = _mm256_loadu_ps(x + j); +#if defined(__FMA__) + acc = _mm256_fmadd_ps(mv, xv, acc); +#else + acc = _mm256_add_ps(acc, _mm256_mul_ps(mv, xv)); +#endif + } + sum += turboq_hsum_avx(acc); +#endif + for (; j < d; ++j) { + sum += row[j] * x[j]; + } + y[i] = sum; + } +} + +// --------------------------------------------------------------------------- +// Dense matrix-transpose-vector multiply: y = M^T * x (M is d×d column-major) +// --------------------------------------------------------------------------- + +static void matvec_t(float * y, const float * M, const float * x, int64_t d) { + for (int64_t j = 0; j < d; j++) { + const float * col = M + j * d; + float sum = 0.0f; + int64_t i = 0; +#if defined(__AVX2__) + __m256 acc = _mm256_setzero_ps(); + for (; i + 7 < d; i += 8) { + const __m256 mv = _mm256_loadu_ps(col + i); + const __m256 xv = _mm256_loadu_ps(x + i); +#if defined(__FMA__) + acc = _mm256_fmadd_ps(mv, xv, acc); +#else + acc = _mm256_add_ps(acc, _mm256_mul_ps(mv, xv)); +#endif + } + sum += turboq_hsum_avx(acc); +#endif + for (; i < d; ++i) { + sum += col[i] * x[i]; // M^T[j,i] = M[i,j] = M[i + j*d] + } + y[j] = sum; + } +} + +// --------------------------------------------------------------------------- +// Public API (kept for compatibility, now wraps dense rotation) +// --------------------------------------------------------------------------- + +// The rotation matrix is a global parameter (same for all vectors), per the paper. +// This seed is used to deterministically generate both Q and S matrices. +uint64_t turboq_seed_from_row(int64_t row_idx) { + (void)row_idx; + return 0x517cc1b727220a95ULL; +} + +// Forward rotation: y = Q · x (paper Algorithm 1, line 5: y <- Pi . x) +void turboq_rotate_forward(float * y, const float * x, int64_t d, uint64_t seed) { + const float * Q = turboq_get_rotation_row(d, seed); + matvec_row(y, Q, x, d); +} + +// Inverse rotation: x = Q^T · y (paper Algorithm 1, line 10: x_tilde <- Pi^T . y_tilde) +void turboq_rotate_inverse(float * x, const float * y, int64_t d, uint64_t seed) { + const float * Q = turboq_get_rotation(d, seed); + matvec_t(x, Q, y, d); +} + +// --------------------------------------------------------------------------- +// Scratch buffer (thread-local, for temporary vectors) +// --------------------------------------------------------------------------- + +static TURBOQ_TLS float * tl_buf = NULL; +static TURBOQ_TLS int64_t tl_buf_size = 0; + +static float * turboq_get_scratch(int64_t n) { + if (n > tl_buf_size) { + free(tl_buf); + tl_buf = (float *)malloc(n * sizeof(float)); + tl_buf_size = n; + } + return tl_buf; +} + +// Second scratch buffer (needed when two temp vectors are required simultaneously, +// e.g. rotated-domain values + original-domain result in dequant) +static TURBOQ_TLS float * tl_buf2 = NULL; +static TURBOQ_TLS int64_t tl_buf2_size = 0; + +static float * turboq_get_scratch2(int64_t n) { + if (n > tl_buf2_size) { + free(tl_buf2); + tl_buf2 = (float *)malloc(n * sizeof(float)); + tl_buf2_size = n; + } + return tl_buf2; +} + +// Third scratch buffer (needed by Q_prod dequant which requires three simultaneous vectors: +// mse_rot, signs_f, and mse_unit) +static TURBOQ_TLS float * tl_buf3 = NULL; +static TURBOQ_TLS int64_t tl_buf3_size = 0; + +static float * turboq_get_scratch3(int64_t n) { + if (n > tl_buf3_size) { + free(tl_buf3); + tl_buf3 = (float *)malloc(n * sizeof(float)); + tl_buf3_size = n; + } + return tl_buf3; +} + +#define TURBOQ_KV_DIM 128 + +static inline float turboq_block_scale_up(void) { + return sqrtf((float) QK_K); +} + +static inline float turboq_block_scale_down(void) { + return 1.0f / turboq_block_scale_up(); +} + +static void turboq_rotate_block_forward(float * y, const float * x, uint64_t seed) { + const float * Q = turboq_get_rotation_row(TURBOQ_KV_DIM, seed); + + for (int64_t i = 0; i < QK_K; i += TURBOQ_KV_DIM) { + matvec_row(y + i, Q, x + i, TURBOQ_KV_DIM); + } +} + +static void turboq_rotate_block_inverse(float * x, const float * y, uint64_t seed) { + const float * Q = turboq_get_rotation(TURBOQ_KV_DIM, seed); + + for (int64_t i = 0; i < QK_K; i += TURBOQ_KV_DIM) { + matvec_t(x + i, Q, y + i, TURBOQ_KV_DIM); + } +} + +static void turboq_project_block(float * y, const float * x, uint64_t seed) { + const float * S = turboq_get_projection_row(TURBOQ_KV_DIM, seed); + + for (int64_t i = 0; i < QK_K; i += TURBOQ_KV_DIM) { + matvec_row(y + i, S, x + i, TURBOQ_KV_DIM); + } +} + +static void turboq_project_block_inverse(float * x, const float * y, uint64_t seed) { + const float * S = turboq_get_projection(TURBOQ_KV_DIM, seed); + + for (int64_t i = 0; i < QK_K; i += TURBOQ_KV_DIM) { + matvec_t(x + i, S, y + i, TURBOQ_KV_DIM); + } +} + +static void turboq_rotate_qk_forward(float * y, const float * x, uint64_t seed) { + const float * Q = turboq_get_rotation_row(QK_K, seed); + matvec_row(y, Q, x, QK_K); +} + +static void turboq_rotate_qk_inverse(float * x, const float * y, uint64_t seed) { + const float * Q = turboq_get_rotation(QK_K, seed); + matvec_t(x, Q, y, QK_K); +} + +static void turboq_project_qk(float * y, const float * x, uint64_t seed) { + const float * S = turboq_get_projection_row(QK_K, seed); + matvec_row(y, S, x, QK_K); +} + +static void turboq_project_qk_inverse(float * x, const float * y, uint64_t seed) { + const float * S = turboq_get_projection(QK_K, seed); + matvec_t(x, S, y, QK_K); +} + +// --------------------------------------------------------------------------- +// Scalar codebook quantization +// --------------------------------------------------------------------------- + +static inline uint8_t quantize_scalar(float val, const float * boundaries, int n_boundaries) { + for (int i = 0; i < n_boundaries; i++) { + if (val < boundaries[i]) { + return (uint8_t)i; + } + } + return (uint8_t)n_boundaries; +} + +static inline uint8_t quantize_scalar_3bit(float val) { + return quantize_scalar(val, turboq_boundaries_3bit, 7); +} + +static inline uint8_t quantize_scalar_2bit(float val) { + return quantize_scalar(val, turboq_boundaries_2bit, 3); +} + +static inline uint8_t quantize_scalar_4bit(float val) { + return quantize_scalar(val, turboq_boundaries_4bit, 15); +} + +// --------------------------------------------------------------------------- +// 3-bit packing/unpacking +// --------------------------------------------------------------------------- + +static void pack_3bit(uint8_t * dst, const uint8_t * indices, int64_t n) { + int64_t full_groups = n / 8; + for (int64_t g = 0; g < full_groups; g++) { + const uint8_t * idx = indices + g * 8; + uint32_t bits = 0; + for (int j = 0; j < 8; j++) { + bits |= ((uint32_t)(idx[j] & 0x7)) << (j * 3); + } + dst[g * 3 + 0] = (uint8_t)(bits & 0xFF); + dst[g * 3 + 1] = (uint8_t)((bits >> 8) & 0xFF); + dst[g * 3 + 2] = (uint8_t)((bits >> 16) & 0xFF); + } +} + +static void unpack_3bit(uint8_t * indices, const uint8_t * src, int64_t n) { + int64_t full_groups = n / 8; + for (int64_t g = 0; g < full_groups; g++) { + uint32_t bits = (uint32_t)src[g * 3 + 0] + | ((uint32_t)src[g * 3 + 1] << 8) + | ((uint32_t)src[g * 3 + 2] << 16); + for (int j = 0; j < 8; j++) { + indices[g * 8 + j] = (uint8_t)((bits >> (j * 3)) & 0x7); + } + } +} + +// --------------------------------------------------------------------------- +// TBQ3_0: TurboQuant 3-bit +// --------------------------------------------------------------------------- + +void quantize_row_tbq3_0_ref(const float * GGML_RESTRICT x, block_tbq3_0 * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + float * unit = turboq_get_scratch(QK_K); + float * rotated = turboq_get_scratch2(QK_K); + const uint64_t seed = turboq_seed_from_row(0); + const float scale_up = turboq_block_scale_up(); + uint8_t indices[QK_K]; + + for (int64_t b = 0; b < nb; b++) { + const float * xb = x + b * QK_K; + + float norm_sq = 0.0f; + for (int64_t j = 0; j < QK_K; ++j) { + norm_sq += xb[j] * xb[j]; + } + + float norm = sqrtf(norm_sq); + if (norm < 1e-10f) { + norm = 1e-10f; + } + + for (int64_t j = 0; j < QK_K; ++j) { + unit[j] = xb[j] / norm; + } + + turboq_rotate_block_forward(rotated, unit, seed); + + for (int64_t j = 0; j < QK_K; j++) { + float val = rotated[j] * scale_up; + indices[j] = quantize_scalar_3bit(val); + } + pack_3bit(y[b].qs, indices, QK_K); + y[b].d = GGML_FP32_TO_FP16(norm); + } +} + +void dequantize_row_tbq3_0(const block_tbq3_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + float * rotated = turboq_get_scratch(QK_K); + float * unit_approx = turboq_get_scratch2(QK_K); + const uint64_t seed = turboq_seed_from_row(0); + const float scale_down = turboq_block_scale_down(); + uint8_t indices[QK_K]; + + for (int64_t b = 0; b < nb; b++) { + const float norm = GGML_FP16_TO_FP32(x[b].d); + + unpack_3bit(indices, x[b].qs, QK_K); + for (int64_t j = 0; j < QK_K; j++) { + rotated[j] = turboq_codebook_3bit[indices[j]] * scale_down; + } + + turboq_rotate_block_inverse(unit_approx, rotated, seed); + + for (int64_t j = 0; j < QK_K; ++j) { + y[b * QK_K + j] = unit_approx[j] * norm; + } + } +} + +size_t quantize_tbq3_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + (void)imatrix; + assert(n_per_row % QK_K == 0); + + const int64_t nb_per_row = n_per_row / QK_K; + const size_t row_size = nb_per_row * sizeof(block_tbq3_0); + + for (int64_t row = 0; row < nrows; row++) { + const float * row_src = src + row * n_per_row; + block_tbq3_0 * row_dst = (block_tbq3_0 *)((char *)dst + row * row_size); + quantize_row_tbq3_0_ref(row_src, row_dst, n_per_row); + } + return nrows * row_size; +} + +// --------------------------------------------------------------------------- +// TBQ4_0: TurboQuant 4-bit +// --------------------------------------------------------------------------- + +void quantize_row_tbq4_0_ref(const float * GGML_RESTRICT x, block_tbq4_0 * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + float * unit = turboq_get_scratch(QK_K); + float * rotated = turboq_get_scratch2(QK_K); + const uint64_t seed = turboq_seed_from_row(0); + const float scale_up = turboq_block_scale_up(); + + for (int64_t b = 0; b < nb; b++) { + const float * xb = x + b * QK_K; + + float norm_sq = 0.0f; + for (int64_t j = 0; j < QK_K; ++j) { + norm_sq += xb[j] * xb[j]; + } + + float norm = sqrtf(norm_sq); + if (norm < 1e-10f) { + norm = 1e-10f; + } + + for (int64_t j = 0; j < QK_K; ++j) { + unit[j] = xb[j] / norm; + } + + turboq_rotate_block_forward(rotated, unit, seed); + + memset(y[b].qs, 0, sizeof(y[b].qs)); + for (int64_t j = 0; j < QK_K; j++) { + float val = rotated[j] * scale_up; + uint8_t idx = quantize_scalar_4bit(val); + if (j % 2 == 0) { + y[b].qs[j / 2] = idx; + } else { + y[b].qs[j / 2] |= (idx << 4); + } + } + y[b].d = GGML_FP32_TO_FP16(norm); + } +} + +void dequantize_row_tbq4_0(const block_tbq4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + float * rotated = turboq_get_scratch(QK_K); + float * unit_approx = turboq_get_scratch2(QK_K); + const uint64_t seed = turboq_seed_from_row(0); + const float scale_down = turboq_block_scale_down(); + + for (int64_t b = 0; b < nb; b++) { + const float norm = GGML_FP16_TO_FP32(x[b].d); + + for (int64_t j = 0; j < QK_K; j++) { + uint8_t idx; + if (j % 2 == 0) { + idx = x[b].qs[j / 2] & 0x0F; + } else { + idx = (x[b].qs[j / 2] >> 4) & 0x0F; + } + rotated[j] = turboq_codebook_4bit[idx] * scale_down; + } + + turboq_rotate_block_inverse(unit_approx, rotated, seed); + + for (int64_t j = 0; j < QK_K; ++j) { + y[b * QK_K + j] = unit_approx[j] * norm; + } + } +} + +size_t quantize_tbq4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + (void)imatrix; + assert(n_per_row % QK_K == 0); + + const int64_t nb_per_row = n_per_row / QK_K; + const size_t row_size = nb_per_row * sizeof(block_tbq4_0); + + for (int64_t row = 0; row < nrows; row++) { + const float * row_src = src + row * n_per_row; + block_tbq4_0 * row_dst = (block_tbq4_0 *)((char *)dst + row * row_size); + quantize_row_tbq4_0_ref(row_src, row_dst, n_per_row); + } + return nrows * row_size; +} diff --git a/ggml/src/ggml-turboq.h b/ggml/src/ggml-turboq.h new file mode 100644 index 0000000000..e620e875e1 --- /dev/null +++ b/ggml/src/ggml-turboq.h @@ -0,0 +1,21 @@ +#pragma once + +// TurboQuant helpers used by the CPU quantizers. + +#include "ggml.h" + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +void turboq_rotate_forward(float * y, const float * x, int64_t d, uint64_t seed); + +void turboq_rotate_inverse(float * x, const float * y, int64_t d, uint64_t seed); + +uint64_t turboq_seed_from_row(int64_t row_idx); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e9b6720c0a..6d895068c5 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -904,6 +904,22 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .type_size = 0, .is_quantized = false, }, + [GGML_TYPE_TBQ3_0] = { + .type_name = "tbq3_0", + .blck_size = QK_K, + .type_size = sizeof(block_tbq3_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_tbq3_0, + .from_float_ref = (ggml_from_float_t) quantize_row_tbq3_0_ref, + }, + [GGML_TYPE_TBQ4_0] = { + .type_name = "tbq4_0", + .blck_size = QK_K, + .type_size = sizeof(block_tbq4_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_tbq4_0, + .from_float_ref = (ggml_from_float_t) quantize_row_tbq4_0_ref, + }, }; const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) { @@ -1389,6 +1405,8 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; case GGML_FTYPE_MOSTLY_MXFP4: wtype = GGML_TYPE_MXFP4; break; case GGML_FTYPE_MOSTLY_NVFP4: wtype = GGML_TYPE_NVFP4; break; + case GGML_FTYPE_MOSTLY_TBQ3_0: wtype = GGML_TYPE_TBQ3_0; break; + case GGML_FTYPE_MOSTLY_TBQ4_0: wtype = GGML_TYPE_TBQ4_0; break; case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break; case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break; case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break; @@ -7666,6 +7684,8 @@ size_t ggml_quantize_chunk( case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_TQ1_0: result = quantize_tq1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_TQ2_0: result = quantize_tq2_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_TBQ3_0: result = quantize_tbq3_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_TBQ4_0: result = quantize_tbq4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/include/llama.h b/include/llama.h index a940f9d648..9452fcb708 100644 --- a/include/llama.h +++ b/include/llama.h @@ -154,6 +154,8 @@ extern "C" { LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38, // except 1d tensors LLAMA_FTYPE_MOSTLY_NVFP4 = 39, // except 1d tensors + LLAMA_FTYPE_MOSTLY_TBQ3_0 = 40, // except 1d tensors + LLAMA_FTYPE_MOSTLY_TBQ4_0 = 41, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a808e3e454..bd98cdb43d 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2944,10 +2944,15 @@ llama_context * llama_init_from_model( if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) { const uint32_t blck_size = ggml_blck_size(params.type_k); + const bool is_tbq_k = params.type_k == GGML_TYPE_TBQ3_0 || params.type_k == GGML_TYPE_TBQ4_0; + for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { - if (model->hparams.n_embd_head_k(il) % blck_size != 0) { - LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n", - __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k(il)); + const uint32_t n_embd_k = is_tbq_k ? model->hparams.n_embd_k_gqa(il) : model->hparams.n_embd_head_k(il); + + if (n_embd_k % blck_size != 0) { + LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide %s=%u\n", + __func__, ggml_type_name(params.type_k), blck_size, + is_tbq_k ? "n_embd_k_gqa" : "n_embd_head_k", n_embd_k); return nullptr; } } @@ -2955,10 +2960,15 @@ llama_context * llama_init_from_model( if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) { const uint32_t blck_size = ggml_blck_size(params.type_v); + const bool is_tbq_v = params.type_v == GGML_TYPE_TBQ3_0 || params.type_v == GGML_TYPE_TBQ4_0; + for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { - if (model->hparams.n_embd_head_v(il) % blck_size != 0) { - LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_v=%u\n", - __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v(il)); + const uint32_t n_embd_v = is_tbq_v ? model->hparams.n_embd_v_gqa(il) : model->hparams.n_embd_head_v(il); + + if (n_embd_v % blck_size != 0) { + LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide %s=%u\n", + __func__, ggml_type_name(params.type_v), blck_size, + is_tbq_v ? "n_embd_v_gqa" : "n_embd_head_v", n_embd_v); return nullptr; } } diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index c2833b75ce..d21a13ac62 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1790,19 +1790,50 @@ ggml_tensor * llm_graph_context::build_attn_mha( float kq_scale, int il) const { const bool v_trans = v->nb[1] > v->nb[2]; + const bool k_is_tbq = k->type == GGML_TYPE_TBQ3_0 || k->type == GGML_TYPE_TBQ4_0; + const bool v_is_tbq = v->type == GGML_TYPE_TBQ3_0 || v->type == GGML_TYPE_TBQ4_0; + const bool use_flash_attn = cparams.flash_attn && kq_b == nullptr; + const enum ggml_type tbq_attn_type = use_flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; // split the batch into streams if needed - const auto n_stream = k->ne[3]; + const auto n_stream = k_is_tbq ? k->ne[2] : (v_is_tbq ? v->ne[2] : k->ne[3]); q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0); + if (k_is_tbq) { + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_embd_k_gqa = k->ne[0]; + + GGML_ASSERT(n_head_kv > 0); + GGML_ASSERT(n_embd_k_gqa % n_head_kv == 0); + + k = ggml_cast(ctx0, k, tbq_attn_type); + cb(k, use_flash_attn ? "k_tbq_f16" : "k_tbq_f32", il); + + k = ggml_reshape_4d(ctx0, k, n_embd_k_gqa / n_head_kv, n_head_kv, k->ne[1], k->ne[2]); + cb(k, "k_tbq_reshaped", il); + } + + if (v_is_tbq) { + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_embd_v_gqa = v->ne[0]; + + GGML_ASSERT(n_head_kv > 0); + GGML_ASSERT(n_embd_v_gqa % n_head_kv == 0); + + v = ggml_cast(ctx0, v, tbq_attn_type); + cb(v, use_flash_attn ? "v_tbq_f16" : "v_tbq_f32", il); + + v = ggml_reshape_4d(ctx0, v, n_embd_v_gqa / n_head_kv, n_head_kv, v->ne[1], v->ne[2]); + cb(v, "v_tbq_reshaped", il); + } + q = ggml_permute(ctx0, q, 0, 2, 1, 3); k = ggml_permute(ctx0, k, 0, 2, 1, 3); v = ggml_permute(ctx0, v, 0, 2, 1, 3); ggml_tensor * cur; - const bool use_flash_attn = cparams.flash_attn && kq_b == nullptr; if (use_flash_attn) { GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet"); diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 5f57ba9e1d..fe6517d505 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1032,6 +1032,14 @@ ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_k const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; + if (k->type == GGML_TYPE_TBQ3_0 || k->type == GGML_TYPE_TBQ4_0) { + return ggml_view_3d(ctx, k, + n_embd_k_gqa, n_kv, ns, + ggml_row_size(k->type, n_embd_k_gqa), + ggml_row_size(k->type, n_embd_k_gqa*kv_size), + ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0); + } + return ggml_view_4d(ctx, k, hparams.n_embd_head_k(il), hparams.n_head_kv(il), n_kv, ns, ggml_row_size(k->type, hparams.n_embd_head_k(il)), @@ -1053,6 +1061,14 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; + if (v->type == GGML_TYPE_TBQ3_0 || v->type == GGML_TYPE_TBQ4_0) { + return ggml_view_3d(ctx, v, + n_embd_v_gqa, n_kv, ns, + ggml_row_size(v->type, n_embd_v_gqa), + ggml_row_size(v->type, n_embd_v_gqa*kv_size), + ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0); + } + if (!v_trans) { // note: v->nb[1] <= v->nb[2] return ggml_view_4d(ctx, v, diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 67e1056c53..b2023cb271 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -382,7 +382,9 @@ static ggml_type tensor_type_fallback(quantize_state_impl & qs, const ggml_tenso case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_TQ1_0: - case GGML_TYPE_TQ2_0: return_type = GGML_TYPE_Q4_0; break; + case GGML_TYPE_TQ2_0: + case GGML_TYPE_TBQ3_0: + case GGML_TYPE_TBQ4_0: return_type = GGML_TYPE_Q4_0; break; case GGML_TYPE_Q4_K: return_type = GGML_TYPE_Q5_0; break; case GGML_TYPE_Q5_K: return_type = GGML_TYPE_Q5_1; break; case GGML_TYPE_Q6_K: return_type = GGML_TYPE_Q8_0; break; @@ -482,6 +484,9 @@ static ggml_type llama_tensor_get_type_impl(quantize_state_impl & qs, ggml_type else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0) { new_type = GGML_TYPE_Q4_K; } + else if (ftype == LLAMA_FTYPE_MOSTLY_TBQ3_0 || ftype == LLAMA_FTYPE_MOSTLY_TBQ4_0) { + new_type = GGML_TYPE_Q4_K; + } } } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) { @@ -815,6 +820,8 @@ static ggml_type llama_ftype_get_default_type(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q6_K: return GGML_TYPE_Q6_K; case LLAMA_FTYPE_MOSTLY_TQ1_0: return GGML_TYPE_TQ1_0; case LLAMA_FTYPE_MOSTLY_TQ2_0: return GGML_TYPE_TQ2_0; + case LLAMA_FTYPE_MOSTLY_TBQ3_0: return GGML_TYPE_TBQ3_0; + case LLAMA_FTYPE_MOSTLY_TBQ4_0: return GGML_TYPE_TBQ4_0; case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return GGML_TYPE_IQ2_XXS; case LLAMA_FTYPE_MOSTLY_IQ2_XS: return GGML_TYPE_IQ2_XS; case LLAMA_FTYPE_MOSTLY_IQ2_S: return GGML_TYPE_IQ2_XS; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 781c621d93..325ded79d3 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7318,6 +7318,11 @@ static const ggml_type other_types[] = { GGML_TYPE_BF16, }; +static const ggml_type turboq_types[] = { + GGML_TYPE_TBQ3_0, + GGML_TYPE_TBQ4_0, +}; + #ifdef _MSC_VER // Workaround long compile time with msvc #pragma optimize("", off) @@ -7388,6 +7393,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_get_rows(GGML_TYPE_I32, 256, 5, 4, b, 1, v)); } } + for (ggml_type type : turboq_types) { + for (bool v : {false, true}) { + test_cases.emplace_back(new test_get_rows(type, 256, 5, 4, 1, 1, v)); + } + } test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_F32, 1, 8, 2, 1, false)); for (ggml_type type : all_types) { @@ -7398,6 +7408,11 @@ static std::vector> make_test_cases_eval() { for (bool v : {false, true}) { test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_I32, 256, 5, 4, 1, v)); } + for (ggml_type type : turboq_types) { + for (bool v : {false, true}) { + test_cases.emplace_back(new test_get_rows_back(type, 256, 5, 4, 1, v)); + } + } test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, GGML_TYPE_I64, { 1, 8, 1, 3 }, { 1, 1 }, 2, false)); test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, GGML_TYPE_I32, { 1, 8, 1, 3 }, { 1, 1 }, 2, false)); @@ -7417,6 +7432,12 @@ static std::vector> make_test_cases_eval() { } } } + for (ggml_type type : turboq_types) { + for (bool v : {false, true}) { + test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 5, 1, 3 }, { 1, 1 }, 1, v)); + test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 11, 1, 1 }, { 2, 3 }, 7, v)); + } + } for (int mode : { GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_VISION }) { for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) { @@ -7783,11 +7804,31 @@ static std::vector> make_test_cases_eval() { } } for (ggml_type type_src : all_types) { - for (ggml_type type_dst : {GGML_TYPE_F32}) { + if (!ggml_is_quantized(type_src)) { + continue; + } + test_cases.emplace_back(new test_cpy(type_src, GGML_TYPE_F32, {256, 4, 4, 4})); + test_cases.emplace_back(new test_cpy(type_src, GGML_TYPE_F32, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows + } + for (ggml_type type : turboq_types) { + test_cases.emplace_back(new test_cpy(type, GGML_TYPE_F32, {256, 4, 4, 4})); + test_cases.emplace_back(new test_cpy(type, GGML_TYPE_F32, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows + } + for (ggml_type type_src : all_types) { + if (!ggml_is_quantized(type_src)) { + continue; + } + for (ggml_type type_dst : {GGML_TYPE_F16, GGML_TYPE_BF16}) { test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4})); test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows } } + for (ggml_type type : turboq_types) { + for (ggml_type type_dst : {GGML_TYPE_F16, GGML_TYPE_BF16}) { + test_cases.emplace_back(new test_cpy(type, type_dst, {256, 4, 4, 4})); + test_cases.emplace_back(new test_cpy(type, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows + } + } for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) { for (ggml_type type_dst : {GGML_TYPE_F16, GGML_TYPE_F32}) { test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {1, 0, 2, 3})); // cpy not-contiguous @@ -7807,6 +7848,10 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0})); test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0})); + for (ggml_type type : turboq_types) { + test_cases.emplace_back(new test_cpy(type, type, {256, 4, 4, 4})); + test_cases.emplace_back(new test_cpy(type, type, {256, 2, 3, 4}, {0, 2, 1, 3})); + } for (ggml_type type_dst : { GGML_TYPE_F32, GGML_TYPE_I32, GGML_TYPE_F16, GGML_TYPE_BF16 }) { for (bool use_view_slice : { true, false }) { @@ -8890,6 +8935,9 @@ static std::vector> make_test_cases_perf() { } } } + for (ggml_type type : turboq_types) { + test_cases.emplace_back(new test_flash_attn_ext(128, 128, 8, {1, 1}, 512, 1, true, false, 0, 0, GGML_PREC_F32, type)); + } for (int col : {8192, 16384, 32768, 65536, 131072, 262144, 524288}) { for (int rows : {1, 4, 16}){ diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index a8fb192623..cc50457bc1 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -2,9 +2,13 @@ #include "ggml.h" #include "ggml-cpu.h" +#include "../ggml/src/ggml-quants.h" +#include "../ggml/src/ggml-turboq.h" +#include "../ggml/src/ggml-turboq-tables.h" #undef NDEBUG #include +#include #include #include #include @@ -20,11 +24,13 @@ constexpr float MAX_QUANTIZATION_TOTAL_ERROR_TERNARY = 0.01f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS = 0.0050f; +constexpr float MAX_QUANTIZATION_TOTAL_ERROR_TBQ4 = 0.0025f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR_FP4 = 0.0030f; constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f; constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f; constexpr float MAX_DOT_PRODUCT_ERROR_FP4 = 0.03f; constexpr float MAX_DOT_PRODUCT_ERROR_TERNARY = 0.15f; +constexpr float MAX_DOT_PRODUCT_ERROR_TBQ3 = 0.05f; static const char* RESULT_STR[] = {"ok", "FAILED"}; @@ -100,6 +106,41 @@ static float dot_product_error(const ggml_type_traits * qfns, const ggml_type_tr return fabsf(result - dot_ref) / test_size; } +static bool test_turboq_vec_dot_dispatch() { + for (ggml_type type : { GGML_TYPE_TBQ3_0, GGML_TYPE_TBQ4_0 }) { + const auto * qfns_cpu = ggml_get_type_traits_cpu(type); + if (qfns_cpu->vec_dot == nullptr || qfns_cpu->vec_dot_type != GGML_TYPE_Q8_K) { + return false; + } + } + + return true; +} + +static bool test_tbq3_codebook() { + static const float expected[8] = { + -2.1520f, -1.3440f, -0.7560f, -0.2451f, + 0.2451f, 0.7560f, 1.3440f, 2.1520f, + }; + + for (int i = 0; i < 8; ++i) { + if (fabsf(turboq_codebook_3bit[i] - expected[i]) > 1e-4f) { + return false; + } + } + + return true; +} + +static bool test_tbq3_norm_scaling() { + std::vector x(QK_K, 1.0f); + block_tbq3_0 block = {}; + + quantize_row_tbq3_0_ref(x.data(), &block, QK_K); + + return fabsf(ggml_fp16_to_fp32(block.d) - 16.0f) < 1e-3f; +} + int main(int argc, char * argv[]) { bool verbose = false; const size_t test_size = 32 * 128; @@ -127,6 +168,24 @@ int main(int argc, char * argv[]) { int num_failed = 0; bool failed = false; + failed = !test_turboq_vec_dot_dispatch(); + num_failed += failed; + if (failed || verbose) { + printf("%5s vec_dot dispatch: %s\n", "tbq*", RESULT_STR[failed]); + } + + failed = !test_tbq3_codebook(); + num_failed += failed; + if (failed || verbose) { + printf("%5s codebook values: %s\n", "tbq3", RESULT_STR[failed]); + } + + failed = !test_tbq3_norm_scaling(); + num_failed += failed; + if (failed || verbose) { + printf("%5s norm scaling: %s\n", "tbq3", RESULT_STR[failed]); + } + for (int i = 0; i < GGML_TYPE_COUNT; i++) { ggml_type type = (ggml_type) i; const auto * qfns = ggml_get_type_traits(type); @@ -152,6 +211,8 @@ int main(int argc, char * argv[]) { type == GGML_TYPE_Q3_K ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS : type == GGML_TYPE_IQ3_S ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS : type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS : + type == GGML_TYPE_TBQ3_0 ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS : + type == GGML_TYPE_TBQ4_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TBQ4 : type == GGML_TYPE_NVFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_FP4 : MAX_QUANTIZATION_TOTAL_ERROR; failed = !(total_error < max_quantization_error); num_failed += failed; @@ -172,6 +233,8 @@ int main(int argc, char * argv[]) { ? MAX_DOT_PRODUCT_ERROR_LOWBIT : type == GGML_TYPE_TQ1_0 || type == GGML_TYPE_TQ2_0 ? MAX_DOT_PRODUCT_ERROR_TERNARY + : type == GGML_TYPE_TBQ3_0 + ? MAX_DOT_PRODUCT_ERROR_TBQ3 : type == GGML_TYPE_NVFP4 ? MAX_DOT_PRODUCT_ERROR_FP4 : MAX_DOT_PRODUCT_ERROR; diff --git a/tools/cli/README.md b/tools/cli/README.md index 840976a884..e336a909fc 100644 --- a/tools/cli/README.md +++ b/tools/cli/README.md @@ -52,8 +52,8 @@ | `-kvo, --kv-offload, -nkvo, --no-kv-offload` | whether to enable KV cache offloading (default: enabled)
(env: LLAMA_ARG_KV_OFFLOAD) | | `--repack, -nr, --no-repack` | whether to enable weight repacking (default: enabled)
(env: LLAMA_ARG_REPACK) | | `--no-host` | bypass host buffer allowing extra buffers to be used
(env: LLAMA_ARG_NO_HOST) | -| `-ctk, --cache-type-k TYPE` | KV cache data type for K
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K) | -| `-ctv, --cache-type-v TYPE` | KV cache data type for V
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V) | +| `-ctk, --cache-type-k TYPE` | KV cache data type for K
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1, tbq3_0, tbq4_0
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K) | +| `-ctv, --cache-type-v TYPE` | KV cache data type for V
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1, tbq3_0, tbq4_0
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V) | | `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)
(env: LLAMA_ARG_DEFRAG_THOLD) | | `-np, --parallel N` | number of parallel sequences to decode (default: 1)
(env: LLAMA_ARG_N_PARALLEL) | | `--mlock` | force system to keep model in RAM rather than swapping or compressing
(env: LLAMA_ARG_MLOCK) | @@ -97,8 +97,8 @@ | `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:
- 0: generic output
- 1: error
- 2: warning
- 3: info
- 4: debug
(default: 3)

(env: LLAMA_LOG_VERBOSITY) | | `--log-prefix` | Enable prefix in log messages
(env: LLAMA_LOG_PREFIX) | | `--log-timestamps` | Enable timestamps in log messages
(env: LLAMA_LOG_TIMESTAMPS) | -| `-ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K_DRAFT) | -| `-ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V_DRAFT) | +| `-ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1, tbq3_0, tbq4_0
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K_DRAFT) | +| `-ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1, tbq3_0, tbq4_0
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V_DRAFT) | ### Sampling params diff --git a/tools/completion/README.md b/tools/completion/README.md index 25884ed92d..9539fb4878 100644 --- a/tools/completion/README.md +++ b/tools/completion/README.md @@ -135,8 +135,8 @@ llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1 | `-kvo, --kv-offload, -nkvo, --no-kv-offload` | whether to enable KV cache offloading (default: enabled)
(env: LLAMA_ARG_KV_OFFLOAD) | | `--repack, -nr, --no-repack` | whether to enable weight repacking (default: enabled)
(env: LLAMA_ARG_REPACK) | | `--no-host` | bypass host buffer allowing extra buffers to be used
(env: LLAMA_ARG_NO_HOST) | -| `-ctk, --cache-type-k TYPE` | KV cache data type for K
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K) | -| `-ctv, --cache-type-v TYPE` | KV cache data type for V
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V) | +| `-ctk, --cache-type-k TYPE` | KV cache data type for K
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1, tbq3_0, tbq4_0
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K) | +| `-ctv, --cache-type-v TYPE` | KV cache data type for V
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1, tbq3_0, tbq4_0
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V) | | `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)
(env: LLAMA_ARG_DEFRAG_THOLD) | | `-np, --parallel N` | number of parallel sequences to decode (default: 1)
(env: LLAMA_ARG_N_PARALLEL) | | `--mlock` | force system to keep model in RAM rather than swapping or compressing
(env: LLAMA_ARG_MLOCK) | @@ -180,8 +180,8 @@ llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1 | `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:
- 0: generic output
- 1: error
- 2: warning
- 3: info
- 4: debug
(default: 3)

(env: LLAMA_LOG_VERBOSITY) | | `--log-prefix` | Enable prefix in log messages
(env: LLAMA_LOG_PREFIX) | | `--log-timestamps` | Enable timestamps in log messages
(env: LLAMA_LOG_TIMESTAMPS) | -| `-ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K_DRAFT) | -| `-ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V_DRAFT) | +| `-ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1, tbq3_0, tbq4_0
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K_DRAFT) | +| `-ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1, tbq3_0, tbq4_0
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V_DRAFT) | ### Sampling params diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 0a23f69853..560d7061a9 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -483,6 +483,12 @@ static ggml_type ggml_type_from_name(const std::string & s) { if (s == "iq4_nl") { return GGML_TYPE_IQ4_NL; } + if (s == "tbq3_0") { + return GGML_TYPE_TBQ3_0; + } + if (s == "tbq4_0") { + return GGML_TYPE_TBQ4_0; + } return GGML_TYPE_COUNT; } diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index b727c9dd39..faeff14eb9 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -42,6 +42,8 @@ static const std::vector QUANT_OPTIONS = { { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, { "TQ1_0", LLAMA_FTYPE_MOSTLY_TQ1_0, " 1.69 bpw ternarization", }, { "TQ2_0", LLAMA_FTYPE_MOSTLY_TQ2_0, " 2.06 bpw ternarization", }, + { "TBQ3_0", LLAMA_FTYPE_MOSTLY_TBQ3_0, " 3.06 bpw TurboQuant", }, + { "TBQ4_0", LLAMA_FTYPE_MOSTLY_TBQ4_0, " 4.06 bpw TurboQuant", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.96G, +3.5199 ppl @ Llama-3-8B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.96G, +3.1836 ppl @ Llama-3-8B", }, { "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", }, diff --git a/tools/server/README.md b/tools/server/README.md index 1bd8201689..f374f97f8e 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -69,8 +69,8 @@ For the full list of features, please refer to [server's changelog](https://gith | `-kvo, --kv-offload, -nkvo, --no-kv-offload` | whether to enable KV cache offloading (default: enabled)
(env: LLAMA_ARG_KV_OFFLOAD) | | `--repack, -nr, --no-repack` | whether to enable weight repacking (default: enabled)
(env: LLAMA_ARG_REPACK) | | `--no-host` | bypass host buffer allowing extra buffers to be used
(env: LLAMA_ARG_NO_HOST) | -| `-ctk, --cache-type-k TYPE` | KV cache data type for K
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K) | -| `-ctv, --cache-type-v TYPE` | KV cache data type for V
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V) | +| `-ctk, --cache-type-k TYPE` | KV cache data type for K
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1, tbq3_0, tbq4_0
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K) | +| `-ctv, --cache-type-v TYPE` | KV cache data type for V
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1, tbq3_0, tbq4_0
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V) | | `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)
(env: LLAMA_ARG_DEFRAG_THOLD) | | `--mlock` | force system to keep model in RAM rather than swapping or compressing
(env: LLAMA_ARG_MLOCK) | | `--mmap, --no-mmap` | whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)
(env: LLAMA_ARG_MMAP) | @@ -113,8 +113,8 @@ For the full list of features, please refer to [server's changelog](https://gith | `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:
- 0: generic output
- 1: error
- 2: warning
- 3: info
- 4: debug
(default: 3)

(env: LLAMA_LOG_VERBOSITY) | | `--log-prefix` | Enable prefix in log messages
(env: LLAMA_LOG_PREFIX) | | `--log-timestamps` | Enable timestamps in log messages
(env: LLAMA_LOG_TIMESTAMPS) | -| `-ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K_DRAFT) | -| `-ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V_DRAFT) | +| `-ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1, tbq3_0, tbq4_0
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K_DRAFT) | +| `-ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1, tbq3_0, tbq4_0
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V_DRAFT) | ### Sampling params