From 5bb05ed21c63cee908f982299a9727dc28ab1b55 Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Sat, 21 Mar 2026 13:37:09 -0400 Subject: [PATCH] Comment consistency pass and cleanup. --- common/arg.cpp | 2 - ggml/src/ggml-common.h | 135 +++++++++++----------------- ggml/src/ggml-cpu/arch-fallback.h | 1 - ggml/src/ggml-cpu/arch/arm/quants.c | 24 ++--- ggml/src/ggml-cpu/arch/x86/quants.c | 21 ++--- ggml/src/ggml-cpu/ops.cpp | 68 +++++--------- ggml/src/ggml-cpu/quants.c | 7 +- ggml/src/ggml-impl.h | 28 ++---- ggml/src/ggml-quants.c | 19 +--- ggml/src/ggml-quants.h | 70 ++------------- src/llama-kv-cache.cpp | 5 +- tests/test-backend-ops.cpp | 27 ++---- 12 files changed, 113 insertions(+), 294 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 6f61ad07ff..5f221b7263 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -404,8 +404,6 @@ const std::vector kv_cache_types = { }; static ggml_type kv_cache_type_from_str(const std::string & s) { - // Short aliases: "mxfp4" → E2M1, "mxfp6" → E2M3, "mxfp8" → E4M3. - // Full names (mxfp4_e2m1, mxfp8_e4m3, mxfp6_e2m3, etc.) match via ggml_type_name() below. if (s == "mxfp4") { return GGML_TYPE_MXFP4_E2M1; } diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index cc9a4a0aca..7308c3749b 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -71,7 +71,6 @@ typedef sycl::half2 ggml_half2; #define GGML_COMMON_DECL #endif -// Pure numeric constants needed by both DECL and IMPL sections. #define MXFP_HADAMARD_32_NORM 0.17677669529663689f // 1/sqrt(32) #if defined(GGML_COMMON_DECL) @@ -199,11 +198,8 @@ typedef struct { } block_q4_1; static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding"); -// MXFP E8M0 shared exponent constants (OCP MX v1.0 §5.3). -// EMAX_OFFSET: ceil(log2(max_finite)) for each element type — used to center the E8M0 scale. -// MSE_RANGE: search radius around round(log2(amax)). Tests 2*range+1 candidate exponents, -// picking the one that minimizes total round-trip quantization error per block. -// Inspired by "Four Over Six" (arXiv:2512.02010); generalized to all MX types. +// E8M0 shared exponent constants (OCP MX v1.0 SS5.3). +// EMAX_OFFSET = ceil(log2(max_finite)), MSE_RANGE = search radius for optimal scale. #define MXFP_E8M0_MSE_RANGE 2 #define MXFP4_E2M1_EMAX_OFFSET 2 // ceil(log2(6.0)) #define MXFP6_E2M3_EMAX_OFFSET 3 // ceil(log2(7.5)) @@ -211,11 +207,7 @@ static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4 #define MXFP8_E4M3_EMAX_OFFSET 8 // ceil(log2(448)) #define MXFP8_E5M2_EMAX_OFFSET 16 // ceil(log2(57344)) -// MXFP type properties — single source of truth for all backends. -// Bits per element, quantized bytes per block, and Hadamard rotation flag. -// USE_HADAMARD: 1 for types with >= 3-bit mantissa (E2M1, E4M3, E2M3). -// 0 for 2-bit mantissa types (E5M2, E3M2) where Hadamard provides -// no quality benefit and hurts models with D_head ≤ 64. +// MXFP type properties -- shared across all backends. #define MXFP_BITS_PER_ELEM_E2M1 4 #define MXFP_BITS_PER_ELEM_E4M3 8 #define MXFP_BITS_PER_ELEM_E5M2 8 @@ -239,52 +231,48 @@ static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4 // EXP_MASK = (1< float. GGML_TABLE_BEGIN(float, kvalues_mxfp6_e2m3, 64) 0.0f, 0.125f, 0.25f, 0.375f, 0.5f, 0.625f, 0.75f, 0.875f, 1.0f, 1.125f, 1.25f, 1.375f, 1.5f, 1.625f, 1.75f, 1.875f, @@ -1316,8 +1296,7 @@ GGML_TABLE_BEGIN(float, kvalues_mxfp6_e2m3, 64) -4.0f, -4.5f, -5.0f, -5.5f, -6.0f, -6.5f, -7.0f, -7.5f, GGML_TABLE_END() -// FP6 E3M2 dequantization LUT: 6-bit value → float (64 entries). -// Generated from ggml_mxfp_fp6_e3m2_to_float(). No NaN/Inf — all bit patterns are valid. +// FP6 E3M2 dequantization LUT: 6-bit value -> float. No NaN/Inf. GGML_TABLE_BEGIN(float, kvalues_mxfp6_e3m2, 64) 0.0f, 0.0625f, 0.125f, 0.1875f, 0.25f, 0.3125f, 0.375f, 0.4375f, 0.5f, 0.625f, 0.75f, 0.875f, 1.0f, 1.25f, 1.5f, 1.75f, @@ -1329,8 +1308,7 @@ GGML_TABLE_BEGIN(float, kvalues_mxfp6_e3m2, 64) -8.0f, -10.0f, -12.0f, -14.0f, -16.0f, -20.0f, -24.0f, -28.0f, GGML_TABLE_END() -// FP8 E4M3 dequantization LUT: byte → float (256 entries). -// Generated from ggml_mxfp_fp8_e4m3_to_float(). Entry 127 = 448 (max finite), 255 = NaN. +// FP8 E4M3 dequantization LUT: byte -> float. Entry 127 = 448 (max finite), 255 = NaN. GGML_TABLE_BEGIN(float, kvalues_mxfp8_e4m3, 256) 0.0f, 0.001953125f, 0.00390625f, 0.005859375f, 0.0078125f, 0.009765625f, 0.01171875f, 0.013671875f, 0.015625f, 0.017578125f, 0.01953125f, 0.021484375f, 0.0234375f, 0.025390625f, 0.02734375f, 0.029296875f, @@ -1366,8 +1344,7 @@ GGML_TABLE_BEGIN(float, kvalues_mxfp8_e4m3, 256) -256.0f, -288.0f, -320.0f, -352.0f, -384.0f, -416.0f, -448.0f, NAN, GGML_TABLE_END() -// FP8 E5M2 dequantization LUT: byte → float (256 entries). -// Generated from ggml_mxfp_fp8_e5m2_to_float(). Entries 124-127 = {Inf, NaN, NaN, NaN}. +// FP8 E5M2 dequantization LUT: byte -> float. Entries 124-127 = {Inf, NaN, NaN, NaN}. GGML_TABLE_BEGIN(float, kvalues_mxfp8_e5m2, 256) 0.0f, 1.525879e-05f, 3.051758e-05f, 4.577637e-05f, 6.103516e-05f, 7.629395e-05f, 9.155273e-05f, 1.068115e-04f, 1.220703e-04f, 1.525879e-04f, 1.831055e-04f, 2.136230e-04f, 2.441406e-04f, 3.051758e-04f, 3.662109e-04f, 4.272461e-04f, @@ -1403,16 +1380,10 @@ GGML_TABLE_BEGIN(float, kvalues_mxfp8_e5m2, 256) -32768.0f, -40960.0f, -49152.0f, -57344.0f, -INFINITY, NAN, NAN, NAN, GGML_TABLE_END() -// ------------------------------------------------------------------------------------------------------------------ -// Canonical MXFP element converters — portable IEEE-754 bit manipulation. -// Single source of truth for CPU, CUDA, HIP, MUSA, SYCL. Metal/Vulkan keep MSL/GLSL copies. -// ------------------------------------------------------------------------------------------------------------------ +// MXFP element converters -- portable IEEE-754 bit manipulation. #if defined(GGML_MXFP_FUNC) -// --- FP4 E2M1: [S(1) | E(2) | M(1)] — max normal = 6.0 --- -// Canonical converters using true E2M1 values {0, 0.5, 1, 1.5, 2, 3, 4, 6}. -// The int8 kvalues_mxfp4 LUT stores doubled values {0,1,2,3,4,6,8,12} for -// CPU/CUDA nibble-indexed integer arithmetic — that doubling is an implementation detail. +// FP4 E2M1: [S(1) | E(2) | M(1)], max normal = 6.0 GGML_MXFP_FUNC float ggml_mxfp_fp4_e2m1_to_float(uint8_t v) { const float sign = (v & 0x8) ? -1.0f : 1.0f; @@ -1427,8 +1398,6 @@ GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp4_e2m1(float x) { if (x < 0) { sign = 0x8; x = -x; } if (x == 0) return sign; if (x >= 6.0f) return sign | 0x7; // max finite - // Decision boundaries (midpoints of adjacent canonical values): - // {0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0} if (x < 0.25f) return sign | 0x0; // 0 else if (x < 0.75f) return sign | 0x1; // 0.5 else if (x < 1.25f) return sign | 0x2; // 1.0 @@ -1439,7 +1408,7 @@ GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp4_e2m1(float x) { else return sign | 0x7; // 6.0 } -// --- FP6 E2M3: [S(1) | E(2) | M(3)] — max normal = 7.5 --- +// FP6 E2M3: [S(1) | E(2) | M(3)], max normal = 7.5 GGML_MXFP_FUNC float ggml_mxfp_fp6_e2m3_to_float(uint8_t v) { const float sign = (v & 0x20) ? -1.0f : 1.0f; @@ -1474,7 +1443,7 @@ GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp6_e2m3(float x) { return sign | (uint8_t)(((f32_exp + 1) << 3) | mant); } -// --- FP6 E3M2: [S(1) | E(3) | M(2)] — max normal = 28.0, no NaN/Inf --- +// FP6 E3M2: [S(1) | E(3) | M(2)], max normal = 28.0, no NaN/Inf GGML_MXFP_FUNC float ggml_mxfp_fp6_e3m2_to_float(uint8_t v) { const float sign = (v & 0x20) ? -1.0f : 1.0f; @@ -1512,7 +1481,7 @@ GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp6_e3m2(float x) { return sign | (uint8_t)((biased_exp << 2) | mant); } -// --- FP8 E4M3: [S(1) | E(4) | M(3)] — bias=7, max finite=448 --- +// FP8 E4M3: [S(1) | E(4) | M(3)], bias=7, max finite=448 GGML_MXFP_FUNC float ggml_mxfp_fp8_e4m3_to_float(uint8_t v) { uint32_t sign = ((uint32_t)(v & 0x80)) << 24; @@ -1571,7 +1540,7 @@ GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp8_e4m3(float x) { return sign | (uint8_t)((e4m3_exp << 3) | mant3); } -// --- FP8 E5M2: [S(1) | E(5) | M(2)] — bias=15, max finite=57344 --- +// FP8 E5M2: [S(1) | E(5) | M(2)], bias=15, max finite=57344 GGML_MXFP_FUNC float ggml_mxfp_fp8_e5m2_to_float(uint8_t v) { uint32_t sign = ((uint32_t)(v & 0x80)) << 24; @@ -1629,7 +1598,7 @@ GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp8_e5m2(float x) { return sign | (uint8_t)((e5m2_exp << 2) | mant2); } -// --- FP6 packing/unpacking --- +// FP6 packing/unpacking // Pack 4 six-bit values into 3 bytes GGML_MXFP_FUNC void ggml_mxfp_pack_fp6x4(const uint8_t v[4], uint8_t out[3]) { @@ -1674,8 +1643,6 @@ GGML_MXFP_FUNC int ggml_mxfp_e8m0_base_estimate(float amax, int emax_offset) { } // Block-32 Walsh-Hadamard Transform, normalized by 1/sqrt(32). -// Spreads outlier energy across all elements sharing an E8M0 exponent, -// improving quantization quality (see QuaRot arXiv:2404.00456). GGML_MXFP_FUNC void ggml_mxfp_hadamard_32_inplace(GGML_MXFP_THREAD float * vals) { GGML_MXFP_UNROLL for (int stride = 1; stride < 32; stride *= 2) { diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 3f01c0b1c7..f622658918 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -2,7 +2,6 @@ #pragma once // Rename `_generic` functions if no native implementation is available. -// This effectively selects the generic implementation. #if defined(GGML_CPU_GENERIC) // quants.c diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index 6564131f2b..9ad9f29ae5 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -4134,21 +4134,14 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -// ── MXFP FP8/FP6 NEON helpers ────────────────────────────────────────────── -// Shared IEEE-754 bit reconstruction and FP6 unpacking used by vec_dot, -// dequantize_row, and SoA dequant functions. -// -// NEON requires vshlq_n_u32 to have a compile-time literal constant, so we use -// two separate helpers for FP8 (sign at bit 7, shift 24) and FP6 (sign at bit 5, -// shift 26) rather than a single parameterized function. +// MXFP FP8/FP6 NEON helpers +// Separate FP8/FP6 functions because NEON vshlq_n_u32 requires compile-time constants. #if defined(__ARM_NEON) -// Use shared mxfp_dequant_traits_t from ggml-common.h. #define mxfp_neon_traits_t mxfp_dequant_traits_t -// Dequantize 4 raw FP8 values (uint32x4_t) → 4 IEEE-754 floats. -// Sign bit at position 7, sign shift = 24. +// Dequantize 4 FP8 values to floats. static inline float32x4_t mxfp8_dequant_neon( const uint32x4_t v_raw, const uint32x4_t v_exp_mask, const uint32x4_t v_mant_mask, @@ -4172,8 +4165,7 @@ static inline float32x4_t mxfp8_dequant_neon( return vbslq_f32(is_sub, sub_val, normal); } -// Dequantize 4 raw FP6 values (uint32x4_t) → 4 IEEE-754 floats. -// Sign bit at position 5, sign shift = 26. +// Dequantize 4 FP6 values to floats. static inline float32x4_t mxfp6_dequant_neon( const uint32x4_t v_raw, const uint32x4_t v_exp_mask, const uint32x4_t v_mant_mask, @@ -4229,7 +4221,7 @@ static inline void widen_s8x8_to_f32x4x2(const int8_t * src, *hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16))); } -// ── MXFP FP8/FP6 vec_dot ────────────────────────────────────────────────── +// MXFP FP8/FP6 vec_dot static void ggml_vec_dot_mxfp8_q8_0_neon( int n, float * GGML_RESTRICT s, @@ -4319,7 +4311,7 @@ static void ggml_vec_dot_mxfp6_q8_0_neon( *s = vaddvq_f32(vaddq_f32(acc0, acc1)); } -// ── MXFP FP8/FP6 dequantize_row (AoS) ───────────────────────────────────── +// MXFP FP8/FP6 dequantize_row (AoS) static void dequantize_row_mxfp8_neon( const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k, @@ -4381,7 +4373,7 @@ static void dequantize_row_mxfp6_neon( } } -// ── MXFP SoA dequant (flash attention) ───────────────────────────────────── +// MXFP SoA dequant (flash attention) static void dequantize_row_mxfp8_soa_neon( const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, @@ -4492,7 +4484,7 @@ static void dequantize_row_mxfp4_soa_neon( #endif // __ARM_NEON -// ── Public dispatch functions ────────────────────────────────────────────── +// Public dispatch functions void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index b00b1467d3..21b3fb4605 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -3819,18 +3819,13 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -// ── MXFP FP8/FP6 AVX2 helpers ────────────────────────────────────────────── -// Shared IEEE-754 bit reconstruction and FP6 unpacking used by vec_dot, -// dequantize_row, and SoA dequant functions. +// MXFP FP8/FP6 AVX2 helpers #if defined(__AVX2__) -// Use shared mxfp_dequant_traits_t from ggml-common.h. -// Aliases for readability within this file. #define mxfp_avx2_traits_t mxfp_dequant_traits_t -// Dequantize 8 raw MXFP values (widened to int32) → 8 IEEE-754 floats. -// Handles both normal and subnormal paths. Works for any FP6/FP8 format. +// Dequantize 8 FP8/FP6 values to floats. static inline __m256 mxfp_dequant_avx2( const __m256i v_raw, const __m256i v_exp_mask, const __m256i v_mant_mask, @@ -3872,9 +3867,9 @@ static inline __m256i unpack_fp6x8_avx2(const uint8_t * qs, int j) { return _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)unpacked)); } -// ── MXFP FP8/FP6 vec_dot ────────────────────────────────────────────────── +// MXFP FP8/FP6 vec_dot -// Unified FP8 × Q8_0 dot product (works for E4M3 and E5M2). +// FP8 x Q8_0 dot product (E4M3/E5M2). static void ggml_vec_dot_mxfp8_q8_0_avx2( int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, @@ -3915,7 +3910,7 @@ static void ggml_vec_dot_mxfp8_q8_0_avx2( *s = hsum_float_8(acc); } -// Unified FP6 × Q8_0 dot product (works for E2M3 and E3M2). +// FP6 x Q8_0 dot product (E2M3/E3M2). static void ggml_vec_dot_mxfp6_q8_0_avx2( int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, @@ -3955,7 +3950,7 @@ static void ggml_vec_dot_mxfp6_q8_0_avx2( *s = hsum_float_8(acc); } -// ── MXFP FP8/FP6 dequantize_row (AoS) ───────────────────────────────────── +// MXFP FP8/FP6 dequantize_row (AoS) static void dequantize_row_mxfp8_avx2( const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k, @@ -4016,7 +4011,7 @@ static void dequantize_row_mxfp6_avx2( } } -// ── MXFP SoA dequant (flash attention) ───────────────────────────────────── +// MXFP SoA dequant (flash attention) static void dequantize_row_mxfp8_soa_avx2( const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, @@ -4116,7 +4111,7 @@ static void dequantize_row_mxfp4_soa_avx2( #endif // __AVX2__ -// ── Public dispatch functions ────────────────────────────────────────────── +// Public dispatch functions void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 8ac3cd0912..a8f55efbed 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4909,8 +4909,7 @@ void ggml_compute_forward_get_rows( //} } -// NEON-optimized Hadamard for ARM platforms; scalar fallback uses ggml_hadamard_32_inplace -// from ggml-quants.c (the reference implementation). +// NEON-optimized Hadamard; scalar fallback below #if defined(__ARM_NEON) static void hadamard_32_inplace(float vals[32]) { float32x4_t v0 = vld1q_f32(vals + 0); @@ -4978,13 +4977,11 @@ static void hadamard_32_inplace(float vals[32]) { vst1q_f32(vals + 28, vmulq_f32(v7, norm)); } #else -// Scalar fallback: delegate to reference implementation in ggml-quants.c static void hadamard_32_inplace(float vals[32]) { ggml_hadamard_32_inplace(vals); } #endif -// Apply Hadamard rotation to each 32-element block in a float buffer. static void ggml_apply_hadamard_blocks(float * data, int64_t n) { GGML_ASSERT(n % 32 == 0); for (int64_t i = 0; i < n; i += 32) { @@ -5031,8 +5028,6 @@ static void ggml_compute_forward_set_rows_f32( const int32_t apply_hadamard = ((const int32_t *)dst->op_params)[0]; - // For MXFP types, use SoA quantize (canonical FA layout). - // For non-MXFP types, use the standard AoS from_float. typedef void (*quantize_soa_fn)(const float *, void *, int64_t); quantize_soa_fn mxfp_soa_quantize = nullptr; ggml_from_float_t from_float = nullptr; @@ -5046,8 +5041,6 @@ static void ggml_compute_forward_set_rows_f32( break; } - // Pre-allocate Hadamard temp buffer once outside the hot loop (nc is constant). - // nc == n_embd_k_gqa which is bounded by model architecture (typically <= 8192). std::vector had_tmp; if (apply_hadamard) { had_tmp.resize(nc); @@ -8275,8 +8268,7 @@ void ggml_compute_forward_top_k( typedef void (*mxfp_soa_quantize_fn)(const float *, void *, int64_t); typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t); -// Shared MXFP dispatch parameters for flash attention. -// Populated once and used by both the one_chunk and tiled paths. +// MXFP dispatch parameters for flash attention. struct mxfp_fa_params { mxfp_soa_quantize_fn q_quantize; mxfp_soa_dequantize_fn k_dequantize; @@ -8287,9 +8279,6 @@ struct mxfp_fa_params { int64_t v_soa_elems; bool apply_hadamard; // Per-head SoA addressing (avoids dequanting all heads in multihead mode). - // qs_per_block: bytes of quantized data per 32-element block. - // head_qs_bytes: total qs bytes for one head (blocks_per_head * qs_per_block). - // head_e8m0_offset: byte offset from row start to e8m0 region. int k_qs_per_block; int v_qs_per_block; int k_head_qs_bytes; @@ -8455,9 +8444,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( ggml_vec_dot_t kq_vec_dot = nullptr; ggml_to_float_t v_to_float = nullptr; - if (is_mxfp_k) { - kq_vec_dot = nullptr; - } else { + if (!is_mxfp_k) { ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type; q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float; kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot; @@ -8472,20 +8459,15 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( int ith = params->ith; - // Dequant buffers for MXFP SoA — stack-allocated, no heap allocation in the hot path. - // In multihead mode, only dequant one head (DK or DV elements) instead of all heads. - // DK/DV are bounded by 1024 (asserted below for MXFP). + if (is_mxfp_k) { GGML_ASSERT(DK <= 1024); } + if (is_mxfp_v) { GGML_ASSERT(DV <= 1024); } + float k_dequant_buf[1024]; float v_dequant_buf[1024]; - // Per-head SoA temp buffer: holds [qs | e8m0] for one head in multihead mode. - // For DK=1024 with MXFP8: 32 blocks * 32 qs + 32 e8m0 = 1056 bytes. - // Stack-allocated since sizes are bounded by DK/DV <= 1024. - // Max: 1024/32 * 32(qs) + 1024/32 = 1056 bytes (MXFP8). - char k_head_soa[1088]; // 1056 rounded up for alignment + char k_head_soa[1088]; // max: DK=1024 MXFP8 -> 1056 bytes, rounded up char v_head_soa[1088]; - // Thread-local work buffers (constant across ir loop) float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); float * V32 = (VKQ32 + 1*DV); ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); @@ -8521,31 +8503,24 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( const int iv3 = iq3 / rv3; const int iv2 = iq2 / rv2; - // Precompute loop-invariant base pointer offsets for K and V. - // Only ic varies in the inner loop; head/batch offsets are constant. const size_t k_base_offset = ik2*nbk2 + ik3*nbk3; const size_t v_base_offset = iv2*nbv2 + iv3*nbv3; const char * k_base = (const char *) k->data + k_base_offset; const char * v_base = (const char *) v->data + v_base_offset; - // For multihead MXFP: precompute per-head SoA byte offsets (constant per query row). - // head_qs_start: byte offset to this head's qs blocks within the SoA row. - // head_e8m0_start: byte offset to this head's e8m0 scales within the SoA row. + // Per-head SoA byte offsets const int k_head_qs_start = mxfp.k_multihead ? ik2 * mxfp.k_head_qs_bytes : 0; const int k_head_e8m0_start = mxfp.k_multihead ? (int)mxfp.k_head_e8m0_offset + ik2 * mxfp.k_blocks_per_head : 0; const int v_head_qs_start = mxfp.v_multihead ? iv2 * mxfp.v_head_qs_bytes : 0; const int v_head_e8m0_start = mxfp.v_multihead ? (int)mxfp.v_head_e8m0_offset + iv2 * mxfp.v_blocks_per_head : 0; - // Multihead MXFP row base (without head offset) — only ic varies. const char * k_row_base = mxfp.k_multihead ? ((const char *) k->data + ik3*nbk3) : nullptr; const char * v_row_base = mxfp.v_multihead ? ((const char *) v->data + iv3*nbv3) : nullptr; const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); float Q_f32[1024]; if (is_mxfp_k) { - // Q preprocessing: Hadamard → SoA quantize → SoA dequant (round-trip). - // Captures the same quantization loss as K, matching GPU MMA semantics. - GGML_ASSERT(DK <= 1024); + // Q preprocessing: Hadamard + SoA round-trip captures same quantization loss as K. if (mxfp.apply_hadamard) { float q_tmp[1024]; memcpy(q_tmp, pq, DK * sizeof(float)); @@ -8557,7 +8532,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( mxfp.k_dequantize(Q_q, Q_f32, DK); } else { if (mxfp.apply_hadamard) { - GGML_ASSERT(DK <= 1024); float q_tmp[1024]; memcpy(q_tmp, pq, DK * sizeof(float)); ggml_apply_hadamard_blocks(q_tmp, DK); @@ -8581,8 +8555,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( if (is_mxfp_k) { if (mxfp.k_multihead) { - // Multihead: extract this head's SoA blocks into temp buffer, dequant only DK elements. - // Copy qs blocks then e8m0 scales for this head into contiguous [qs|e8m0] layout. + // Extract this head's SoA blocks const char * row = k_row_base + ic*nbk1; memcpy(k_head_soa, row + k_head_qs_start, mxfp.k_head_qs_bytes); memcpy(k_head_soa + mxfp.k_head_qs_bytes, row + k_head_e8m0_start, mxfp.k_blocks_per_head); @@ -8610,10 +8583,12 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( if (v_is_f16) { if (s > M) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f M = s; ms = expf(Mold - M); ggml_vec_scale_f16(DV, VKQ16, ms); } else { + // no new maximum, ms == 1.0f, vs != 1.0f vs = expf(s - M); } @@ -8621,10 +8596,12 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) (v_base + ic*nbv1), vs); } else { if (s > M) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f M = s; ms = expf(Mold - M); ggml_vec_scale_f32(DV, VKQ32, ms); } else { + // no new maximum, ms == 1.0f, vs != 1.0f vs = expf(s - M); } @@ -8643,6 +8620,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( v_to_float(v_base + ic*nbv1, V32, DV); ggml_vec_mad_f32(DV, VKQ32, V32, vs); } else { + // V is F32 ggml_vec_mad_f32(DV, VKQ32, (const float *) (v_base + ic*nbv1), vs); } } @@ -8782,12 +8760,12 @@ static void ggml_compute_forward_flash_attn_ext_tiled( static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q; static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV; - // MXFP dequant scratch buffers — allocated once per thread, reused across all tiles. - // DK/DV bounded by 1024, so per-head dequant fits in stack buffers. + if (is_mxfp_k) { GGML_ASSERT(DK <= 1024); } + if (is_mxfp_v) { GGML_ASSERT(DV <= 1024); } + float k_dequant_buf[1024]; float v_dequant_buf[1024]; - // Per-head SoA temp buffers for multihead extraction (same as one_chunk path). char k_head_soa[1088]; char v_head_soa[1088]; @@ -8853,10 +8831,7 @@ static void ggml_compute_forward_flash_attn_ext_tiled( if (mxfp.apply_hadamard) { ggml_apply_hadamard_blocks(Q_f32 + tq * DK, DK); } - // SoA round-trip: quantize Q to SoA, then dequant back to float - const size_t q_soa_bytes = ggml_row_size(k->type, DK); - GGML_ASSERT(q_soa_bytes <= 2048); - uint8_t q_mxfp_buf[2048]; // max: DK=1024 * 33/32 = 1056 bytes (MXFP8) + uint8_t q_mxfp_buf[1088]; // max: DK=1024 MXFP8 -> 1056 bytes mxfp.q_quantize(Q_f32 + tq * DK, q_mxfp_buf, DK); mxfp.k_dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK); } @@ -9234,10 +9209,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t dr = (nr + nchunk - 1) / nchunk; static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q; - // Tiled GEMM path: dequant K/V to float, then simd_gemm. - // Only for types that natively dequant to float (f32, f16, MXFP). - // Standard quant types (q8_0, q4_0) must use the scalar one_chunk path - // to preserve vec_dot semantics and produce identical results to master. + // Tiled path: f32, f16, and MXFP only (quant types use one_chunk) bool use_tiled = !use_ref && (q->type == GGML_TYPE_F32 && kv_is_f32_f16_or_mxfp && diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index 7303638c81..477bd07304 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -264,9 +264,7 @@ void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, *s = sumf; } -// Generic MXFP-to-Q8_0 dot product. Dequants one MX block (32 elements) -// to float via the existing public dequantize_row functions, then dots -// against Q8_0 int8 values. Reference implementation — not SIMD-optimized. +// Generic MXFP x Q8_0 dot product (scalar, not SIMD-optimized) static void ggml_vec_dot_mxfp_q8_0_impl( int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, size_t block_size, @@ -311,8 +309,7 @@ void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, (ggml_to_float_t)dequantize_row_mxfp6); } -// Generic (scalar) dequant wrappers — delegates to ggml-quants.c reference implementations. -// On x86/ARM, arch-specific SIMD versions override these via the fallback.h mapping. +// Generic dequant wrappers — arch-specific SIMD versions override via fallback.h. void dequantize_row_mxfp8_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { dequantize_row_mxfp8(x, y, k); } diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 90f020b0e8..f9358b0432 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -431,15 +431,11 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) // E8M0 shared exponent to float: returns 2^(x - 127). -// Canonical source: ggml_mxfp_e8m0_to_fp32() in ggml-common.h. -// Kept here because ggml-impl.h cannot depend on ggml-common.h IMPL section. -// NaN (x == 0xFF) is not handled — callers guarantee valid exponents. static inline float ggml_e8m0_to_fp32(uint8_t x) { uint32_t bits; if (x == 0) { bits = 0x00400000; // denorm: 0.5 * 2^(-126) = 2^(-127) } else { - // normalized: exponent x placed in bits [30:23], mantissa = 0 → 2^(x-127) bits = (uint32_t) x << 23; } float result; @@ -447,9 +443,7 @@ static inline float ggml_e8m0_to_fp32(uint8_t x) { return result; } -// E8M0 to float/2: returns 2^(x - 128). Equal to ggml_e8m0_to_fp32(x) / 2. -// Useful with MXFP4 because the E2M1 kvalues table stores 2 * float_value. -// Canonical source: ggml_mxfp_e8m0_to_fp32_half() in ggml-common.h. +// E8M0 to float/2: returns 2^(x - 128). static inline float ggml_e8m0_to_fp32_half(uint8_t x) { uint32_t bits; if (x < 2) { @@ -467,8 +461,7 @@ static inline float ggml_e8m0_to_fp32_half(uint8_t x) { #define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x) #define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x) -// UE4M3 (unsigned E4M3): 4 exponent bits (bias 7), 3 mantissa bits, no sign. -// Range: [0, 448], with 0x7F = NaN treated as zero. +// UE4M3 (unsigned E4M3): 4 exponent bits (bias 7), 3 mantissa bits. // Returns value * 0.5 to match kvalues_mxfp4 convention (kvalues = 2 * E2M1_float). static inline float ggml_ue4m3_to_fp32(uint8_t x) { if (x == 0 || x == 0x7F) { @@ -487,20 +480,19 @@ static inline float ggml_ue4m3_to_fp32(uint8_t x) { return raw * 0.5f; } -// Float32 to UE4M3 with round-to-nearest-even. -// Clamps to [0, 448]. Max representable: 0x7E = (1 + 7/8) * 2^8 = 448. +// Float32 to UE4M3 with round-to-nearest. static inline uint8_t ggml_fp32_to_ue4m3(float x) { if (!(x > 0.0f)) { - return 0; // negative, zero, NaN → 0 + return 0; } if (x > 448.0f) { - x = 448.0f; // clamp to max representable + x = 448.0f; } uint32_t bits; memcpy(&bits, &x, 4); int fp32_exp = ((bits >> 23) & 0xFF) - 127; - int fp32_man = (bits >> 20) & 0x7; // top 3 mantissa bits - int ue4m3_exp = fp32_exp + 7; // rebias: FP32 bias 127 → UE4M3 bias 7 + int fp32_man = (bits >> 20) & 0x7; + int ue4m3_exp = fp32_exp + 7; if (ue4m3_exp <= 0) { // subnormal: value = man * 2^(-9), so man = round(x * 512) int man = (int) (x * 512.0f + 0.5f); @@ -513,17 +505,15 @@ static inline uint8_t ggml_fp32_to_ue4m3(float x) { return (uint8_t) man; } if (ue4m3_exp >= 15) { - return 0x7E; // max normal + return 0x7E; } - // round-to-nearest using bit 19 (first bit below UE4M3 mantissa) int round_bit = (bits >> 19) & 1; int ue4m3_man = fp32_man + round_bit; if (ue4m3_man > 7) { - // mantissa overflow → carry into exponent ue4m3_man = 0; ue4m3_exp++; if (ue4m3_exp >= 15) { - return 0x7E; // max normal + return 0x7E; } } return (uint8_t) ((ue4m3_exp << 3) | ue4m3_man); diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index e88435061d..1a99711401 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -259,16 +259,12 @@ void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_REST // ====================== MXFP element conversions (wrappers around ggml-common.h) -// FP8 E4M3: 1 sign, 4 exp (bias 7), 3 mantissa. Max finite: 448, NaN: 0x7F (saturated to max). float fp8_e4m3_to_float(uint8_t v) { return ggml_mxfp_fp8_e4m3_to_float(v); } uint8_t float_to_fp8_e4m3_rn(float x) { return ggml_mxfp_float_to_fp8_e4m3(x); } // ====================== MXFP quantization infrastructure // -// MSE-optimal E8M0 shared exponent: tests ±R candidates around round(log2(amax)) -// and picks whichever minimizes total round-trip quantization error per block. -// Improves on OCP MX v1.0 §5.3 floor(log2(amax)) by 0.05-0.2 PPL. -// Ref: OCP MX v1.0 spec, Four Over Six (arXiv:2512.02010) +// MSE-optimal E8M0: tests candidates around round(log2(amax)), picks lowest quantization error. typedef struct { int emax_offset; // type-specific offset to max representable exponent @@ -279,8 +275,7 @@ typedef struct { static inline int best_index_mxfp4(float x, float e); -// MXFP4 MSE error using decision boundary quantization with half-scale -// (kvalues_mxfp4 are doubled E2M1 values, so scale is halved to compensate) +// MSE error for MXFP4 (kvalues are doubled, so scale is halved) static float mse_error_mxfp4(float val, float inv_scale, float scale) { const float d = scale * 0.5f; const float inv_d = (d > 0.0f) ? 1.0f / d : 0.0f; @@ -301,7 +296,6 @@ static float mse_error_mxfp4(float val, float inv_scale, float scale) { static const mxfp_elem_traits_t mxfp4_traits = { MXFP4_E2M1_EMAX_OFFSET, NULL, NULL, mse_error_mxfp4 }; -// Find MSE-optimal E8M0 exponent by testing ±R candidates around round(log2(amax)) static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_elem_traits_t * traits) { float amax = 0.0f; for (int j = 0; j < qk; j++) { @@ -336,7 +330,6 @@ static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_ return (uint8_t)best_e; } -// Decision boundary quantization for kvalues_mxfp4 {0,1,2,3,4,6,8,12} static inline int best_index_mxfp4(float x, float e) { const float inv_e = (e > 0.0f) ? 1.0f / e : 0.0f; const float normalized = fabsf(x) * inv_e; @@ -566,11 +559,6 @@ void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_REST } // ====================== Hadamard rotation -// -// 32-element Walsh-Hadamard transform applied before MX quantization to spread -// outlier energy across the shared-exponent group. Orthogonal (H^T·H = I), so -// H(K)·H(Q) = K·Q — attention scores are preserved when both K and Q are rotated. -// Prior art: QuIP# (Tseng et al. 2024), BRQ (Huang et al. 2024) void ggml_hadamard_32_inplace(float vals[32]) { ggml_mxfp_hadamard_32_inplace(vals); @@ -586,7 +574,6 @@ uint8_t float_to_fp8_e5m2_rn(float x) { return ggml_mxfp_float_to_fp8_e5m2(x); } void pack_fp6x4(const uint8_t v[4], uint8_t out[3]) { ggml_mxfp_pack_fp6x4(v, out); } void unpack_fp6x4(const uint8_t in[3], uint8_t v[4]) { ggml_mxfp_unpack_fp6x4(in, v); } -// round-trip quantization error for MSE-optimal exponent search static float mse_error_fp8_e4m3(float val, float inv_scale, float scale) { const float recon = fp8_e4m3_to_float(float_to_fp8_e4m3_rn(val * inv_scale)) * scale; const float err = val - recon; @@ -685,8 +672,6 @@ void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_REST } // ====================== SoA (Struct-of-Arrays) quantize/dequantize for flash attention -// -// Layout per row: [qs_block0|qs_block1|...|qs_blockN][e8m0_0|e8m0_1|...|e8m0_N] void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) { assert(k % QK_MXFP4 == 0); diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index c4d2ae86d3..7d848edffe 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -55,8 +55,7 @@ GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_API void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -// SoA (Struct-of-Arrays) quantize/dequantize — canonical reference for flash attention. -// Layout: [qs contiguous][e8m0 contiguous] per row. Same total bytes as AoS. +// SoA quantize/dequantize for flash attention GGML_API void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); GGML_API void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); @@ -114,83 +113,26 @@ GGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_REST GGML_API size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -// -// MXFP element-level conversion functions (reference implementations) -// -// These implement the OCP Microscaling (MX) format element types as defined in: -// OCP Microscaling Formats (MX) Specification v1.0, Sep 2023 -// https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf -// -// Each MX block contains 32 elements sharing a single E8M0 exponent. The element -// types define the per-element mantissa format within each block. -// -// All converters use IEEE-754 bit manipulation for exact results (no floating-point -// rounding in the conversion itself). Quantization functions use round-to-nearest-even -// (RNE) per the MX specification. -// -// GPU backends (CUDA, Metal, Vulkan) provide their own optimized versions — these -// functions serve as the canonical reference and are used by the CPU backend. -// - -// FP8 E4M3: 1 sign, 4 exponent (bias 7), 3 mantissa bits -// Range: ±[2^-9, 448], NaN: exp=15 mant=7 (only NaN encoding) -// Ref: OCP MX v1.0 §4.2 +// MXFP element converters GGML_API float fp8_e4m3_to_float(uint8_t v); GGML_API uint8_t float_to_fp8_e4m3_rn(float x); -// FP8 E5M2: 1 sign, 5 exponent (bias 15), 2 mantissa bits -// Range: ±[2^-16, 57344], NaN/Inf: exp=31 (standard IEEE-like) -// Ref: OCP MX v1.0 §4.2 GGML_API float fp8_e5m2_to_float(uint8_t v); GGML_API uint8_t float_to_fp8_e5m2_rn(float x); -// FP6 E2M3: 1 sign, 2 exponent (bias 1), 3 mantissa bits -// Range: ±[2^-3, 7.5], stored as low 6 bits of a byte (00xxxxxx) -// MX format: NO NaN/Inf — all bit patterns are valid numbers -// Ref: OCP MX v1.0 §4.2 +// no NaN/Inf in FP6 — all bit patterns are valid numbers GGML_API float fp6_e2m3_to_float(uint8_t v); GGML_API uint8_t float_to_fp6_e2m3_rn(float x); -// FP6 E3M2: 1 sign, 3 exponent (bias 3), 2 mantissa bits -// Range: ±[2^-4, 28.0], stored as low 6 bits of a byte (00xxxxxx) -// MX format: NO NaN/Inf — exp=7 is a valid normal value (unlike IEEE-754) -// CRITICAL: subnormal scale is 2^(1-bias-m) = 2^(-4) = 1/16, NOT 1/4 -// Ref: OCP MX v1.0 §4.2 +// no NaN/Inf — exp=7 is a valid normal value (unlike IEEE-754) GGML_API float fp6_e3m2_to_float(uint8_t v); GGML_API uint8_t float_to_fp6_e3m2_rn(float x); -// FP6 tight packing: pack/unpack 4 six-bit values into/from 3 bytes -// Layout: v[0]=bits[5:0], v[1]=bits[11:6], v[2]=bits[17:12], v[3]=bits[23:18] -// Saves 25% memory vs byte-padded storage (24B vs 32B per MX block) +// Pack/unpack 4 six-bit values into 3 bytes GGML_API void pack_fp6x4(const uint8_t v[4], uint8_t out[3]); GGML_API void unpack_fp6x4(const uint8_t in[3], uint8_t v[4]); -// -// Hadamard rotation (reference scalar implementation) -// -// 32-element Walsh-Hadamard transform applied to MX blocks before quantization. -// Distributes outlier energy uniformly across the block, dramatically improving -// quantization quality for types with shared exponents. -// -// Mathematical property: H^T·H = I (orthogonal), so H(K)·H(Q) = K·Q. -// Flash attention applies matching rotation to Q, preserving attention scores exactly. -// -// Implementation: 5 butterfly stages (log2(32) = 5) + normalization by 1/sqrt(32). -// This is the standard "fast Walsh-Hadamard transform" with O(n log n) operations. -// -// Applied in set_rows (K cache quantization) and flash_attn (Q quantization). -// Skipped for MLA models (DK != DV) where V is a view of K — rotation would corrupt V. -// -// Empirical impact (PPL degradation WITHOUT rotation, Qwen3-Coder-30B): -// MXFP8 E4M3: +0.22, MXFP8 E5M2: +1.38, MXFP6 E2M3: +3.34, MXFP6 E3M2: +4.60 -// -// Prior art: QuIP# (Tseng et al. 2024), BRQ (Huang et al. 2024) use Hadamard for -// weight quantization. Our contribution applies it to KV cache quantization at the -// MX block boundary, where block-32 is optimal because it matches the shared exponent -// group size exactly. -// -// GPU backends provide optimized versions (CUDA warp shuffles, Metal SIMD groups). -// +// Block-32 Walsh-Hadamard transform, normalized by 1/sqrt(32) GGML_API void ggml_hadamard_32_inplace(float vals[32]); GGML_API void iq2xs_init_impl(enum ggml_type type); diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 0e501697b0..a890510edf 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -51,7 +51,6 @@ llama_kv_cache::llama_kv_cache( auto it = ctx_map.find(buft); if (it == ctx_map.end()) { ggml_init_params params = { - // 2 base tensors (K+V) + 2*n_stream view tensors /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, @@ -140,10 +139,10 @@ llama_kv_cache::llama_kv_cache( uint32_t n_embd_k_alloc = n_embd_k_gqa; const bool is_mxfp_k = ggml_is_type_mxfp(type_k); if (is_mxfp_k) { - const int qk = (int)ggml_blck_size(type_k); // 32 for all MXFP types + const int qk = (int)ggml_blck_size(type_k); GGML_ASSERT(n_embd_k_gqa % qk == 0 && "MXFP K cache requires n_embd_k_gqa divisible by block size"); const int blocks = (int)n_embd_k_gqa / qk; - const int blocks_aligned = (blocks + 15) & ~15; // align to 16 + const int blocks_aligned = (blocks + 15) & ~15; n_embd_k_alloc = (uint32_t)(blocks_aligned * qk); } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d123211505..0e48e9e354 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -150,8 +150,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m } } -// MXFP SoA quantize/dequantize (from ggml-quants.h, which is internal to ggml -// and not in the test include path). Signatures must match ggml-quants.h exactly. +// MXFP SoA functions (internal to ggml, not in test include path) typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t); extern "C" { void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); @@ -162,10 +161,7 @@ extern "C" { void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); } -// Initialize an MXFP tensor with SoA (Struct-of-Arrays) layout. -// soa_bytes: byte width of one SoA region. Default 0 = ne[0] elements (one ggml row). -// For FA K/V tensors, pass nb[1] so that when heads are physically contiguous -// within one KV-position stride, the SoA region spans all heads (matching FA's read pattern). +// Initialize an MXFP tensor with SoA layout (soa_bytes = region width, 0 = one row). static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f, size_t soa_bytes = 0) { GGML_ASSERT(ggml_is_type_mxfp(tensor->type)); @@ -189,9 +185,6 @@ static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float std::uniform_real_distribution dist(min, max); std::vector region_f32(soa_elems); - // Iterate over logical SoA regions using tensor strides. - // Each SoA region is soa_bytes wide at the innermost stride level. - // Outer dimensions (those with stride > soa_bytes) are iterated explicitly. const size_t nb1 = tensor->nb[1]; const size_t nb2 = tensor->nb[2]; const size_t nb3 = tensor->nb[3]; @@ -199,18 +192,13 @@ static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float const int64_t ne2 = tensor->ne[2]; const int64_t ne3 = tensor->ne[3]; - // Determine iteration: if soa_bytes == nb1, iterate over (ne1 * ne2 * ne3) regions. - // If soa_bytes < nb1 (per-head), iterate over (ne1 * ne2 * ne3) regions with stride nb1. - // We use strides to compute offsets, handling views and permutations correctly. const int64_t heads_per_region = (int64_t)(soa_bytes / head_row_sz); GGML_ASSERT(soa_bytes % head_row_sz == 0 && "soa_bytes must be a multiple of head_row_sz"); - // For multi-head regions, we step by nb1 (KV-position stride) between regions. - // For per-head, we step through all dimensions. std::vector buf(ggml_nbytes(tensor), 0); if (heads_per_region > 1) { - // Multi-head SoA: iterate over (kv_positions * batches), each region = nb1 bytes + // Multi-head SoA: for (int64_t i3 = 0; i3 < ne3; i3++) { const int64_t n_groups = ne2 / heads_per_region; for (int64_t ig = 0; ig < n_groups; ig++) { @@ -222,7 +210,7 @@ static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float } } } else { - // Per-head SoA: one SoA region per ggml row + // Per-head SoA: for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i2 = 0; i2 < ne2; i2++) { for (int64_t i1 = 0; i1 < ne1; i1++) { @@ -328,7 +316,6 @@ static std::vector tensor_to_float(const ggml_tensor * t) { bool quantized = ggml_is_quantized(t->type); const bool is_mxfp = ggml_is_type_mxfp(t->type); - // SoA dequant for MXFP readback mxfp_soa_dequantize_fn mxfp_dequant_soa = nullptr; if (is_mxfp) { switch (t->type) { @@ -4051,7 +4038,6 @@ struct test_mul_mat_id : public test_case { return 5e-4; } - // Same Blackwell FP4 tolerance as test_mul_mat above. double max_nmse_err(ggml_backend_t backend) override { // for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance if (type_a == GGML_TYPE_MXFP4_E2M1 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) { @@ -6401,9 +6387,7 @@ struct test_flash_attn_ext : public test_case { } else if (strcmp(t->name, "m") == 0) { init_tensor_kq_mask(t); } else if ((strcmp(t->name, "k") == 0 || strcmp(t->name, "v") == 0) && ggml_is_type_mxfp(t->type)) { - // MXFP K/V tensors use SoA layout. Pass nb[1] (KV-position stride) as the - // SoA region width — when heads are physically contiguous within that stride, - // the FA kernel dequants the full multi-head region as one SoA block. + // MXFP K/V use SoA layout; nb[1] spans all heads in one KV-position stride init_tensor_mxfp_soa(t, -1.0f, 1.0f, t->nb[1]); } else { init_tensor_uniform(t); @@ -8778,7 +8762,6 @@ static std::vector> make_test_cases_eval() { for (ggml_type type_V : {GGML_TYPE_MXFP4_E2M1}) { if (type_K == type_V) continue; for (int nb : {1, 3, 32}) { - // hsk hsv nh nr23 kv nb mask sinks bias softcap prec type_K permute type_V test_cases.emplace_back(new test_flash_attn_ext( 128, 128, 4, {1, 1}, 512, nb, true, false, 0.0f, 0.0f, GGML_PREC_F32, type_K, {0, 1, 2, 3}, type_V)); }