From d8c9f9c7f60a3cf8e81a12df688d9f1d32c9dba5 Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Sun, 15 Mar 2026 15:29:13 -0400 Subject: [PATCH 01/13] ggml: MXFP flash attention with SoA layout (CPU scalar reference) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add MXFP KV cache quantization for flash attention using Struct-of-Arrays (SoA) memory layout exclusively. Three MX types: MXFP4 (E2M1), MXFP8 (E4M3), MXFP6 (E2M3), implementing the OCP Microscaling v1.0 spec. SoA layout stores [qs contiguous][e8m0 contiguous] per row, enabling aligned memory access patterns for GPU backends. All functions in the flash attention pipeline — set_rows quantization, Q preprocessing, K/V dequantization — use SoA end-to-end. The existing AoS block layout remains for MUL_MAT weight quantization (untouched). Q preprocessing applies Walsh-Hadamard rotation (block-32) before quantize/dequant round-trip, distributing outlier energy across the shared exponent group. This is essential for perplexity: MXFP8: +0.22 PPL without rotation MXFP6: +3.34 PPL without rotation Hadamard is skipped for MLA models (DK != DV) where V is a view of K. Shared infrastructure in ggml-common.h: - Block structures (block_mxfp8: 33B, block_mxfp6: 25B per 32 elements) - E8M0 MSE-optimal scale search with ±1 range - Canonical element converters (FP8 E4M3/E5M2, FP6 E2M3/E3M2) - FP6 tight packing (4 six-bit values in 3 bytes, 25% savings) - IEEE-754 bit reconstruction constants for SIMD backends - SoA layout macros, portable bit cast, type property queries CPU implementation: - Scalar reference + ARM NEON + x86 AVX2 optimized paths - Both FA paths supported: one_chunk (scalar) and tiled (SIMD GEMM) - Split-KV path extended for single-query decode - Generic vec_dot via dequant-to-float for MUL_MAT compatibility - Arch fallbacks for loongarch, powerpc, riscv, s390, wasm KV cache integration: - set_rows writes SoA with optional Hadamard (op_params[0] flag) - K cache block-aligned to 16 for CUDA cp.async compatibility - CLI: --cache-type-k/v with short aliases (mxfp4, mxfp6, mxfp8) Tests: - Flash attention: all 3 types at D=64/128, mixed K/V (mxfp8+mxfp4) - SET_ROWS: Hadamard rotation for all types - SoA-aware test initialization and comparison for MXFP tensors - Quantize functions coverage for all types Rename GGML_TYPE_MXFP4 → GGML_TYPE_MXFP4_E2M1 across all backends (CPU, OpenCL, SYCL) for consistency with the MX type family naming. --- common/arg.cpp | 14 + ggml/include/ggml-cpu.h | 1 + ggml/include/ggml.h | 13 +- ggml/src/ggml-common.h | 553 +++++++++++++++++++++- ggml/src/ggml-cpu/arch-fallback.h | 13 + ggml/src/ggml-cpu/arch/arm/quants.c | 538 +++++++++++++++++++++ ggml/src/ggml-cpu/arch/loongarch/quants.c | 11 + ggml/src/ggml-cpu/arch/powerpc/quants.c | 7 + ggml/src/ggml-cpu/arch/riscv/quants.c | 8 + ggml/src/ggml-cpu/arch/s390/quants.c | 7 + ggml/src/ggml-cpu/arch/wasm/quants.c | 11 + ggml/src/ggml-cpu/arch/x86/quants.c | 498 +++++++++++++++++++ ggml/src/ggml-cpu/ggml-cpu.c | 16 +- ggml/src/ggml-cpu/ops.cpp | 366 ++++++++++++-- ggml/src/ggml-cpu/quants.c | 72 +++ ggml/src/ggml-cpu/quants.h | 22 + ggml/src/ggml-cpu/repack.cpp | 6 +- ggml/src/ggml-impl.h | 47 +- ggml/src/ggml-opencl/ggml-opencl.cpp | 16 +- ggml/src/ggml-quants.c | 519 +++++++++++++++++++- ggml/src/ggml-quants.h | 94 ++++ ggml/src/ggml-sycl/convert.cpp | 4 +- ggml/src/ggml-sycl/mmvq.cpp | 2 +- ggml/src/ggml.c | 52 +- src/llama-kv-cache.cpp | 57 ++- src/llama-quant.cpp | 4 +- tests/test-backend-ops.cpp | 192 +++++++- tools/llama-bench/llama-bench.cpp | 10 +- 28 files changed, 3002 insertions(+), 151 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 10aa1b5e4f..0e5191b2f4 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -398,9 +398,23 @@ const std::vector kv_cache_types = { GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, + GGML_TYPE_MXFP4_E2M1, + GGML_TYPE_MXFP8_E4M3, + GGML_TYPE_MXFP6_E2M3, }; static ggml_type kv_cache_type_from_str(const std::string & s) { + // Short aliases: "mxfp4" → E2M1, "mxfp6" → E2M3, "mxfp8" → E4M3. + // Full names (mxfp4_e2m1, mxfp8_e4m3, mxfp6_e2m3, etc.) match via ggml_type_name() below. + if (s == "mxfp4") { + return GGML_TYPE_MXFP4_E2M1; + } + if (s == "mxfp6") { + return GGML_TYPE_MXFP6_E2M3; + } + if (s == "mxfp8") { + return GGML_TYPE_MXFP8_E4M3; + } for (const auto & type : kv_cache_types) { if (ggml_type_name(type) == s) { return type; diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h index e3e067c916..2e13dd58ba 100644 --- a/ggml/include/ggml-cpu.h +++ b/ggml/include/ggml-cpu.h @@ -115,6 +115,7 @@ extern "C" { struct ggml_type_traits_cpu { ggml_from_float_t from_float; + ggml_to_float_t to_float; // SIMD-optimized dequant (NULL = use global to_float) ggml_vec_dot_t vec_dot; enum ggml_type vec_dot_type; int64_t nrows; // number of rows to process simultaneously diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 25f9601e9b..420ea38126 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -426,9 +426,11 @@ extern "C" { // GGML_TYPE_IQ4_NL_4_4 = 36, // GGML_TYPE_IQ4_NL_4_8 = 37, // GGML_TYPE_IQ4_NL_8_8 = 38, - GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block) - GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale) - GGML_TYPE_COUNT = 41, + GGML_TYPE_MXFP4_E2M1 = 39, // MX FP4 E2M1 + GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale) + GGML_TYPE_MXFP8_E4M3 = 41, // MX FP8 E4M3 + GGML_TYPE_MXFP6_E2M3 = 42, // MX FP6 E2M3 + GGML_TYPE_COUNT = 43, }; // precision @@ -463,7 +465,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors - GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors + GGML_FTYPE_MOSTLY_MXFP4_E2M1 = 25, // except 1d tensors GGML_FTYPE_MOSTLY_NVFP4 = 26, // except 1d tensors }; @@ -744,6 +746,9 @@ extern "C" { GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor); GGML_API bool ggml_is_quantized(enum ggml_type type); + GGML_API bool ggml_is_type_mxfp(enum ggml_type type); + GGML_API bool ggml_mxfp_use_hadamard(enum ggml_type type); + GGML_API int ggml_mxfp_qs_per_block(enum ggml_type type); // quantized bytes per 32-element block (SoA qs region) // TODO: temporary until model loading of ggml examples is refactored GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype); diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 92cf739e7a..9945fef137 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -71,6 +71,9 @@ typedef sycl::half2 ggml_half2; #define GGML_COMMON_DECL #endif +// Pure numeric constants needed by both DECL and IMPL sections. +#define MXFP_HADAMARD_32_NORM 0.17677669529663689f // 1/sqrt(32) + #if defined(GGML_COMMON_DECL) #ifndef __cplusplus @@ -105,6 +108,12 @@ typedef sycl::half2 ggml_half2; #define QI_NVFP4 (QK_NVFP4 / (4 * QR_NVFP4)) #define QR_NVFP4 2 +#define QI_MXFP8 (QK_MXFP8 / (4 * QR_MXFP8)) +#define QR_MXFP8 1 + +#define QI_MXFP6 (QK_MXFP6 / (4 * QR_MXFP6)) +#define QR_MXFP6 1 + #define QI5_0 (QK5_0 / (4 * QR5_0)) #define QR5_0 2 @@ -190,6 +199,74 @@ typedef struct { } block_q4_1; static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding"); +// MXFP E8M0 shared exponent constants (OCP MX v1.0 §5.3). +// EMAX_OFFSET: ceil(log2(max_finite)) for each element type — used to center the E8M0 scale. +// MSE_RANGE: search radius around round(log2(amax)). Tests 2*range+1 candidate exponents, +// picking the one that minimizes total round-trip quantization error per block. +// Inspired by "Four Over Six" (arXiv:2512.02010); generalized to all MX types. +#define MXFP_E8M0_MSE_RANGE 2 +#define MXFP4_E2M1_EMAX_OFFSET 2 // ceil(log2(6.0)) +#define MXFP6_E2M3_EMAX_OFFSET 3 // ceil(log2(7.5)) +#define MXFP6_E3M2_EMAX_OFFSET 5 // ceil(log2(28.0)) +#define MXFP8_E4M3_EMAX_OFFSET 8 // ceil(log2(448)) +#define MXFP8_E5M2_EMAX_OFFSET 16 // ceil(log2(57344)) + +// MXFP type properties — single source of truth for all backends. +// Bits per element, quantized bytes per block, and Hadamard rotation flag. +// USE_HADAMARD: 1 for types with >= 3-bit mantissa (E2M1, E4M3, E2M3). +// 0 for 2-bit mantissa types (E5M2, E3M2) where Hadamard provides +// no quality benefit and hurts models with D_head ≤ 64. +#define MXFP_BITS_PER_ELEM_E2M1 4 +#define MXFP_BITS_PER_ELEM_E4M3 8 +#define MXFP_BITS_PER_ELEM_E5M2 8 +#define MXFP_BITS_PER_ELEM_E2M3 6 +#define MXFP_BITS_PER_ELEM_E3M2 6 + +#define MXFP_QS_PER_BLOCK_E2M1 16 // 32 * 4 / 8 +#define MXFP_QS_PER_BLOCK_E4M3 32 // 32 * 8 / 8 +#define MXFP_QS_PER_BLOCK_E5M2 32 +#define MXFP_QS_PER_BLOCK_E2M3 24 // 32 * 6 / 8 +#define MXFP_QS_PER_BLOCK_E3M2 24 + +#define MXFP_USE_HADAMARD_E2M1 1 +#define MXFP_USE_HADAMARD_E4M3 1 +#define MXFP_USE_HADAMARD_E5M2 0 +#define MXFP_USE_HADAMARD_E2M3 1 +#define MXFP_USE_HADAMARD_E3M2 0 + +// SIMD dequant constants for IEEE-754 bit reconstruction of FP8/FP6 elements. +// For a format with sign(1), exp(E), mant(M), bias(B): +// EXP_MASK = (1< +#include +#include #define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = { #define GGML_TABLE_END() }; +#define GGML_MXFP_FUNC static inline +static inline uint32_t ggml_mxfp_f32_as_u32_(float f) { uint32_t u; memcpy(&u, &f, sizeof(u)); return u; } +static inline float ggml_mxfp_u32_as_f32_(uint32_t u) { float f; memcpy(&f, &u, sizeof(f)); return f; } +#define GGML_MXFP_F32_AS_U32(f) ggml_mxfp_f32_as_u32_(f) +#define GGML_MXFP_U32_AS_F32(u) ggml_mxfp_u32_as_f32_(u) +#define GGML_MXFP_LDEXPF(x, n) ldexpf(x, n) +#define GGML_MXFP_THREAD +#define GGML_MXFP_UNROLL #define GGML_COMMON_IMPL #elif defined(GGML_COMMON_IMPL_CPP) #include +#include +#include #define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = { #define GGML_TABLE_END() }; +#define GGML_MXFP_FUNC static inline +static inline uint32_t ggml_mxfp_f32_as_u32_(float f) { uint32_t u; memcpy(&u, &f, sizeof(u)); return u; } +static inline float ggml_mxfp_u32_as_f32_(uint32_t u) { float f; memcpy(&f, &u, sizeof(f)); return f; } +#define GGML_MXFP_F32_AS_U32(f) ggml_mxfp_f32_as_u32_(f) +#define GGML_MXFP_U32_AS_F32(u) ggml_mxfp_u32_as_f32_(u) +#define GGML_MXFP_LDEXPF(x, n) ldexpf(x, n) +#define GGML_MXFP_THREAD +#define GGML_MXFP_UNROLL #define GGML_COMMON_IMPL #elif defined(GGML_COMMON_IMPL_METAL) @@ -464,21 +589,44 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_ #define GGML_TABLE_BEGIN(type, name, size) static const constant type name[size] = { #define GGML_TABLE_END() }; +#define GGML_MXFP_FUNC static inline +#define GGML_MXFP_F32_AS_U32(f) as_type(f) +#define GGML_MXFP_U32_AS_F32(u) as_type(u) +#define GGML_MXFP_LDEXPF(x, n) metal::ldexp(x, n) +#define GGML_MXFP_THREAD thread +#define GGML_MXFP_UNROLL _Pragma("unroll") #define GGML_COMMON_IMPL #elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) || defined(GGML_COMMON_IMPL_MUSA) #include +#include #define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = { #define GGML_TABLE_END() }; +#define GGML_MXFP_FUNC static __device__ __forceinline__ +#define GGML_MXFP_F32_AS_U32(f) __float_as_uint(f) +#define GGML_MXFP_U32_AS_F32(u) __uint_as_float(u) +#define GGML_MXFP_LDEXPF(x, n) ldexpf(x, n) +#define GGML_MXFP_THREAD +#define GGML_MXFP_UNROLL _Pragma("unroll") #define GGML_COMMON_IMPL #elif defined(GGML_COMMON_IMPL_SYCL) #include +#include +#include #define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = { #define GGML_TABLE_END() }; +#define GGML_MXFP_FUNC static inline +static inline uint32_t ggml_mxfp_f32_as_u32_(float f) { uint32_t u; memcpy(&u, &f, sizeof(u)); return u; } +static inline float ggml_mxfp_u32_as_f32_(uint32_t u) { float f; memcpy(&f, &u, sizeof(f)); return f; } +#define GGML_MXFP_F32_AS_U32(f) ggml_mxfp_f32_as_u32_(f) +#define GGML_MXFP_U32_AS_F32(u) ggml_mxfp_u32_as_f32_(u) +#define GGML_MXFP_LDEXPF(x, n) ldexpf(x, n) +#define GGML_MXFP_THREAD +#define GGML_MXFP_UNROLL #define GGML_COMMON_IMPL #endif @@ -1100,12 +1248,415 @@ GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16) -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113, GGML_TABLE_END() -// e2m1 values (doubled) +// Canonical E2M1 values (true FP4 magnitudes). // ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf +GGML_TABLE_BEGIN(float, kvalues_mxfp4_float, 16) + 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, + -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f, +GGML_TABLE_END() + +// E2M1 values doubled (implementation detail for CPU/CUDA integer arithmetic). +// Used with GGML_E8M0_TO_FP32_HALF(e) = scale/2 so that int8 × half_scale = true value. +// Canonical values are in kvalues_mxfp4_float above. GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16) 0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12, GGML_TABLE_END() +// FP6 E2M3 dequantization LUT: 6-bit value → float (64 entries). +// Generated from ggml_mxfp_fp6_e2m3_to_float(). Indices 0-31 positive, 32-63 negative. +GGML_TABLE_BEGIN(float, kvalues_mxfp6_e2m3, 64) + 0.0f, 0.125f, 0.25f, 0.375f, 0.5f, 0.625f, 0.75f, 0.875f, + 1.0f, 1.125f, 1.25f, 1.375f, 1.5f, 1.625f, 1.75f, 1.875f, + 2.0f, 2.25f, 2.5f, 2.75f, 3.0f, 3.25f, 3.5f, 3.75f, + 4.0f, 4.5f, 5.0f, 5.5f, 6.0f, 6.5f, 7.0f, 7.5f, + -0.0f, -0.125f, -0.25f, -0.375f, -0.5f, -0.625f, -0.75f, -0.875f, + -1.0f, -1.125f, -1.25f, -1.375f, -1.5f, -1.625f, -1.75f, -1.875f, + -2.0f, -2.25f, -2.5f, -2.75f, -3.0f, -3.25f, -3.5f, -3.75f, + -4.0f, -4.5f, -5.0f, -5.5f, -6.0f, -6.5f, -7.0f, -7.5f, +GGML_TABLE_END() + +// FP6 E3M2 dequantization LUT: 6-bit value → float (64 entries). +// Generated from ggml_mxfp_fp6_e3m2_to_float(). No NaN/Inf — all bit patterns are valid. +GGML_TABLE_BEGIN(float, kvalues_mxfp6_e3m2, 64) + 0.0f, 0.0625f, 0.125f, 0.1875f, 0.25f, 0.3125f, 0.375f, 0.4375f, + 0.5f, 0.625f, 0.75f, 0.875f, 1.0f, 1.25f, 1.5f, 1.75f, + 2.0f, 2.5f, 3.0f, 3.5f, 4.0f, 5.0f, 6.0f, 7.0f, + 8.0f, 10.0f, 12.0f, 14.0f, 16.0f, 20.0f, 24.0f, 28.0f, + -0.0f, -0.0625f, -0.125f,-0.1875f, -0.25f,-0.3125f, -0.375f,-0.4375f, + -0.5f, -0.625f, -0.75f, -0.875f, -1.0f, -1.25f, -1.5f, -1.75f, + -2.0f, -2.5f, -3.0f, -3.5f, -4.0f, -5.0f, -6.0f, -7.0f, + -8.0f, -10.0f, -12.0f, -14.0f, -16.0f, -20.0f, -24.0f, -28.0f, +GGML_TABLE_END() + +// FP8 E4M3 dequantization LUT: byte → float (256 entries). +// Generated from ggml_mxfp_fp8_e4m3_to_float(). Entry 127 = 448 (max finite), 255 = NaN. +GGML_TABLE_BEGIN(float, kvalues_mxfp8_e4m3, 256) + 0.0f, 0.001953125f, 0.00390625f, 0.005859375f, 0.0078125f, 0.009765625f, 0.01171875f, 0.013671875f, + 0.015625f, 0.017578125f, 0.01953125f, 0.021484375f, 0.0234375f, 0.025390625f, 0.02734375f, 0.029296875f, + 0.03125f, 0.03515625f, 0.0390625f, 0.04296875f, 0.046875f, 0.05078125f, 0.0546875f, 0.05859375f, + 0.0625f, 0.0703125f, 0.078125f, 0.0859375f, 0.09375f, 0.1015625f, 0.109375f, 0.1171875f, + 0.125f, 0.140625f, 0.15625f, 0.171875f, 0.1875f, 0.203125f, 0.21875f, 0.234375f, + 0.25f, 0.28125f, 0.3125f, 0.34375f, 0.375f, 0.40625f, 0.4375f, 0.46875f, + 0.5f, 0.5625f, 0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f, 0.9375f, + 1.0f, 1.125f, 1.25f, 1.375f, 1.5f, 1.625f, 1.75f, 1.875f, + 2.0f, 2.25f, 2.5f, 2.75f, 3.0f, 3.25f, 3.5f, 3.75f, + 4.0f, 4.5f, 5.0f, 5.5f, 6.0f, 6.5f, 7.0f, 7.5f, + 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 18.0f, 20.0f, 22.0f, 24.0f, 26.0f, 28.0f, 30.0f, + 32.0f, 36.0f, 40.0f, 44.0f, 48.0f, 52.0f, 56.0f, 60.0f, + 64.0f, 72.0f, 80.0f, 88.0f, 96.0f, 104.0f, 112.0f, 120.0f, + 128.0f, 144.0f, 160.0f, 176.0f, 192.0f, 208.0f, 224.0f, 240.0f, + 256.0f, 288.0f, 320.0f, 352.0f, 384.0f, 416.0f, 448.0f, NAN, + -0.0f,-0.001953125f, -0.00390625f,-0.005859375f, -0.0078125f,-0.009765625f, -0.01171875f,-0.013671875f, + -0.015625f,-0.017578125f, -0.01953125f,-0.021484375f, -0.0234375f,-0.025390625f, -0.02734375f,-0.029296875f, + -0.03125f, -0.03515625f, -0.0390625f, -0.04296875f, -0.046875f, -0.05078125f, -0.0546875f, -0.05859375f, + -0.0625f, -0.0703125f, -0.078125f, -0.0859375f, -0.09375f, -0.1015625f, -0.109375f, -0.1171875f, + -0.125f, -0.140625f, -0.15625f, -0.171875f, -0.1875f, -0.203125f, -0.21875f, -0.234375f, + -0.25f, -0.28125f, -0.3125f, -0.34375f, -0.375f, -0.40625f, -0.4375f, -0.46875f, + -0.5f, -0.5625f, -0.625f, -0.6875f, -0.75f, -0.8125f, -0.875f, -0.9375f, + -1.0f, -1.125f, -1.25f, -1.375f, -1.5f, -1.625f, -1.75f, -1.875f, + -2.0f, -2.25f, -2.5f, -2.75f, -3.0f, -3.25f, -3.5f, -3.75f, + -4.0f, -4.5f, -5.0f, -5.5f, -6.0f, -6.5f, -7.0f, -7.5f, + -8.0f, -9.0f, -10.0f, -11.0f, -12.0f, -13.0f, -14.0f, -15.0f, + -16.0f, -18.0f, -20.0f, -22.0f, -24.0f, -26.0f, -28.0f, -30.0f, + -32.0f, -36.0f, -40.0f, -44.0f, -48.0f, -52.0f, -56.0f, -60.0f, + -64.0f, -72.0f, -80.0f, -88.0f, -96.0f, -104.0f, -112.0f, -120.0f, + -128.0f, -144.0f, -160.0f, -176.0f, -192.0f, -208.0f, -224.0f, -240.0f, + -256.0f, -288.0f, -320.0f, -352.0f, -384.0f, -416.0f, -448.0f, NAN, +GGML_TABLE_END() + +// FP8 E5M2 dequantization LUT: byte → float (256 entries). +// Generated from ggml_mxfp_fp8_e5m2_to_float(). Entries 124-127 = {Inf, NaN, NaN, NaN}. +GGML_TABLE_BEGIN(float, kvalues_mxfp8_e5m2, 256) + 0.0f, 1.525879e-05f, 3.051758e-05f, 4.577637e-05f, 6.103516e-05f, 7.629395e-05f, 9.155273e-05f, 1.068115e-04f, + 1.220703e-04f, 1.525879e-04f, 1.831055e-04f, 2.136230e-04f, 2.441406e-04f, 3.051758e-04f, 3.662109e-04f, 4.272461e-04f, + 4.882812e-04f, 6.103516e-04f, 7.324219e-04f, 8.544922e-04f, 9.765625e-04f, 1.220703e-03f, 1.464844e-03f, 1.708984e-03f, + 1.953125e-03f, 2.441406e-03f, 2.929688e-03f, 3.417969e-03f, 3.906250e-03f, 4.882812e-03f, 5.859375e-03f, 6.835938e-03f, + 7.812500e-03f, 9.765625e-03f, 1.171875e-02f, 1.367188e-02f, 1.562500e-02f, 1.953125e-02f, 2.343750e-02f, 2.734375e-02f, + 3.125000e-02f, 3.906250e-02f, 4.687500e-02f, 5.468750e-02f, 6.250000e-02f, 7.812500e-02f, 9.375000e-02f, 1.093750e-01f, + 0.125f, 0.15625f, 0.1875f, 0.21875f, 0.25f, 0.3125f, 0.375f, 0.4375f, + 0.5f, 0.625f, 0.75f, 0.875f, 1.0f, 1.25f, 1.5f, 1.75f, + 2.0f, 2.5f, 3.0f, 3.5f, 4.0f, 5.0f, 6.0f, 7.0f, + 8.0f, 10.0f, 12.0f, 14.0f, 16.0f, 20.0f, 24.0f, 28.0f, + 32.0f, 40.0f, 48.0f, 56.0f, 64.0f, 80.0f, 96.0f, 112.0f, + 128.0f, 160.0f, 192.0f, 224.0f, 256.0f, 320.0f, 384.0f, 448.0f, + 512.0f, 640.0f, 768.0f, 896.0f, 1024.0f, 1280.0f, 1536.0f, 1792.0f, + 2048.0f, 2560.0f, 3072.0f, 3584.0f, 4096.0f, 5120.0f, 6144.0f, 7168.0f, + 8192.0f, 10240.0f, 12288.0f, 14336.0f, 16384.0f, 20480.0f, 24576.0f, 28672.0f, + 32768.0f, 40960.0f, 49152.0f, 57344.0f, INFINITY, NAN, NAN, NAN, + -0.0f,-1.525879e-05f,-3.051758e-05f,-4.577637e-05f,-6.103516e-05f,-7.629395e-05f,-9.155273e-05f,-1.068115e-04f, + -1.220703e-04f,-1.525879e-04f,-1.831055e-04f,-2.136230e-04f,-2.441406e-04f,-3.051758e-04f,-3.662109e-04f,-4.272461e-04f, + -4.882812e-04f,-6.103516e-04f,-7.324219e-04f,-8.544922e-04f,-9.765625e-04f,-1.220703e-03f,-1.464844e-03f,-1.708984e-03f, + -1.953125e-03f,-2.441406e-03f,-2.929688e-03f,-3.417969e-03f,-3.906250e-03f,-4.882812e-03f,-5.859375e-03f,-6.835938e-03f, + -7.812500e-03f,-9.765625e-03f,-1.171875e-02f,-1.367188e-02f,-1.562500e-02f,-1.953125e-02f,-2.343750e-02f,-2.734375e-02f, + -3.125000e-02f,-3.906250e-02f,-4.687500e-02f,-5.468750e-02f,-6.250000e-02f,-7.812500e-02f,-9.375000e-02f,-1.093750e-01f, + -0.125f, -0.15625f, -0.1875f, -0.21875f, -0.25f, -0.3125f, -0.375f, -0.4375f, + -0.5f, -0.625f, -0.75f, -0.875f, -1.0f, -1.25f, -1.5f, -1.75f, + -2.0f, -2.5f, -3.0f, -3.5f, -4.0f, -5.0f, -6.0f, -7.0f, + -8.0f, -10.0f, -12.0f, -14.0f, -16.0f, -20.0f, -24.0f, -28.0f, + -32.0f, -40.0f, -48.0f, -56.0f, -64.0f, -80.0f, -96.0f, -112.0f, + -128.0f, -160.0f, -192.0f, -224.0f, -256.0f, -320.0f, -384.0f, -448.0f, + -512.0f, -640.0f, -768.0f, -896.0f, -1024.0f, -1280.0f, -1536.0f, -1792.0f, + -2048.0f, -2560.0f, -3072.0f, -3584.0f, -4096.0f, -5120.0f, -6144.0f, -7168.0f, + -8192.0f, -10240.0f, -12288.0f, -14336.0f, -16384.0f, -20480.0f, -24576.0f, -28672.0f, + -32768.0f, -40960.0f, -49152.0f, -57344.0f, -INFINITY, NAN, NAN, NAN, +GGML_TABLE_END() + +// ------------------------------------------------------------------------------------------------------------------ +// Canonical MXFP element converters — portable IEEE-754 bit manipulation. +// Single source of truth for CPU, CUDA, HIP, MUSA, SYCL. Metal/Vulkan keep MSL/GLSL copies. +// ------------------------------------------------------------------------------------------------------------------ +#if defined(GGML_MXFP_FUNC) + +// --- FP4 E2M1: [S(1) | E(2) | M(1)] — max normal = 6.0 --- +// Canonical converters using true E2M1 values {0, 0.5, 1, 1.5, 2, 3, 4, 6}. +// The int8 kvalues_mxfp4 LUT stores doubled values {0,1,2,3,4,6,8,12} for +// CPU/CUDA nibble-indexed integer arithmetic — that doubling is an implementation detail. + +GGML_MXFP_FUNC float ggml_mxfp_fp4_e2m1_to_float(uint8_t v) { + const float sign = (v & 0x8) ? -1.0f : 1.0f; + const int exp = (v >> 1) & 0x3; + const int mant = v & 0x1; + if (exp == 0) return sign * (float)mant * 0.5f; + return sign * (1.0f + mant * 0.5f) * (float)(1 << (exp - 1)); +} + +GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp4_e2m1(float x) { + uint8_t sign = 0; + if (x < 0) { sign = 0x8; x = -x; } + if (x == 0) return sign; + if (x >= 6.0f) return sign | 0x7; // max finite + // Decision boundaries (midpoints of adjacent canonical values): + // {0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0} + if (x < 0.25f) return sign | 0x0; // 0 + else if (x < 0.75f) return sign | 0x1; // 0.5 + else if (x < 1.25f) return sign | 0x2; // 1.0 + else if (x < 1.75f) return sign | 0x3; // 1.5 + else if (x < 2.5f) return sign | 0x4; // 2.0 + else if (x < 3.5f) return sign | 0x5; // 3.0 + else if (x < 5.0f) return sign | 0x6; // 4.0 + else return sign | 0x7; // 6.0 +} + +// --- FP6 E2M3: [S(1) | E(2) | M(3)] — max normal = 7.5 --- + +GGML_MXFP_FUNC float ggml_mxfp_fp6_e2m3_to_float(uint8_t v) { + const float sign = (v & 0x20) ? -1.0f : 1.0f; + const int exp = (v >> 3) & 0x3; + const int mant = v & 0x7; + if (exp == 0) return sign * (float)mant * 0.125f; + return sign * (1.0f + mant * 0.125f) * (float)(1 << (exp - 1)); +} + +GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp6_e2m3(float x) { + uint8_t sign = 0; + if (x < 0) { sign = 0x20; x = -x; } + if (x == 0) return sign; + if (x >= 7.5f) return sign | 0x1F; // max finite + + uint32_t bits = GGML_MXFP_F32_AS_U32(x); + int f32_exp = (int)((bits >> 23) & 0xFF) - 127; + + if (f32_exp < 0) { + // Subnormal in E2M3: mant * 2^(-3) + float scaled = x * 8.0f; + int mant = (int)(scaled + 0.5f); + if (mant > 7) return sign | 0x08; // smallest normal + return sign | (uint8_t)mant; + } + if (f32_exp > 2) f32_exp = 2; + + float mantf = (x / (float)(1 << f32_exp)) - 1.0f; + int mant = (int)(mantf * 8.0f + 0.5f); + if (mant > 7) { mant = 0; f32_exp++; } + if (f32_exp > 2) return sign | 0x1F; + return sign | (uint8_t)(((f32_exp + 1) << 3) | mant); +} + +// --- FP6 E3M2: [S(1) | E(3) | M(2)] — max normal = 28.0, no NaN/Inf --- + +GGML_MXFP_FUNC float ggml_mxfp_fp6_e3m2_to_float(uint8_t v) { + const float sign = (v & 0x20) ? -1.0f : 1.0f; + const int exp = (v >> 2) & 0x7; + const int mant = v & 0x3; + if (exp == 0) return sign * (float)mant * 0.0625f; // 2^(-4) + // MX E3M2 has no NaN/Inf — exp=7 is a valid normal value (max finite = 28.0). + return sign * GGML_MXFP_LDEXPF(1.0f + mant * 0.25f, exp - 3); +} + +GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp6_e3m2(float x) { + uint8_t sign = 0; + if (x < 0) { sign = 0x20; x = -x; } + if (x == 0) return sign; + if (x >= 28.0f) return sign | 0x1F; // max finite + + uint32_t bits = GGML_MXFP_F32_AS_U32(x); + int f32_exp = (int)((bits >> 23) & 0xFF) - 127; + int biased_exp = f32_exp + 3; + + if (biased_exp <= 0) { + // Subnormal in E3M2: mant * 2^(-4) + float scaled = x * 16.0f; + int mant = (int)(scaled + 0.5f); + if (mant > 3) return sign | 0x04; // smallest normal + return sign | (uint8_t)mant; + } + if (biased_exp > 7) return sign | 0x1F; + + float pow2 = (f32_exp >= 0) ? (float)(1 << f32_exp) : 1.0f / (float)(1 << (-f32_exp)); + float mantf = (x / pow2) - 1.0f; + int mant = (int)(mantf * 4.0f + 0.5f); + if (mant > 3) { mant = 0; biased_exp++; } + if (biased_exp > 7) return sign | 0x1F; + return sign | (uint8_t)((biased_exp << 2) | mant); +} + +// --- FP8 E4M3: [S(1) | E(4) | M(3)] — bias=7, max finite=448 --- + +GGML_MXFP_FUNC float ggml_mxfp_fp8_e4m3_to_float(uint8_t v) { + uint32_t sign = ((uint32_t)(v & 0x80)) << 24; + uint32_t exp = (v >> 3) & 0xF; + uint32_t mant = v & 0x7; + + if (exp == 0) { + if (mant == 0) return GGML_MXFP_U32_AS_F32(sign); + // Subnormal: mant * 2^(1-7) * 2^(-3) = mant * 2^(-9) + float val = (float)mant * (1.0f / 512.0f); + uint32_t vb = GGML_MXFP_F32_AS_U32(val); + vb = (vb & 0x7FFFFFFFu) | sign; + return GGML_MXFP_U32_AS_F32(vb); + } + if (exp == 15 && mant == 7) { + return GGML_MXFP_U32_AS_F32(sign | 0x7FC00000u); + } + // Normal: (-1)^S * 2^(E-7) * (1 + M/8) → F32 exp = E-7+127 = E+120 + return GGML_MXFP_U32_AS_F32(sign | ((exp + 120) << 23) | (mant << 20)); +} + +GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp8_e4m3(float x) { + uint32_t bits = GGML_MXFP_F32_AS_U32(x); + uint8_t sign = (bits >> 24) & 0x80; + bits &= 0x7FFFFFFFu; + if (bits == 0) return sign; + + uint32_t f32_exp = (bits >> 23) & 0xFF; + uint32_t f32_mant = bits & 0x7FFFFF; + int e4m3_exp = (int)f32_exp - 120; + + if (e4m3_exp <= 0) { + // Subnormal in E4M3 + int shift = 1 - e4m3_exp; + uint32_t full_mant = (1u << 23) | f32_mant; + int total_shift = 20 + shift; + if (total_shift >= 32) return sign; + uint32_t mant3 = full_mant >> total_shift; + if (total_shift > 0 && total_shift < 32) { + uint32_t round_bit = (full_mant >> (total_shift - 1)) & 1; + uint32_t sticky = (total_shift > 1) ? (full_mant & ((1u << (total_shift - 1)) - 1)) : 0; + if (round_bit && (sticky || (mant3 & 1))) mant3++; + } + if (mant3 > 7) return sign | 0x08; + return sign | (uint8_t)mant3; + } + + uint32_t round_bit = (f32_mant >> 19) & 1; + uint32_t sticky = f32_mant & ((1u << 19) - 1); + uint32_t mant3 = f32_mant >> 20; + if (round_bit && (sticky || (mant3 & 1))) { + mant3++; + if (mant3 > 7) { mant3 = 0; e4m3_exp++; } + } + if (e4m3_exp > 15 || (e4m3_exp == 15 && mant3 >= 7)) return sign | 0x7E; // max finite + return sign | (uint8_t)((e4m3_exp << 3) | mant3); +} + +// --- FP8 E5M2: [S(1) | E(5) | M(2)] — bias=15, max finite=57344 --- + +GGML_MXFP_FUNC float ggml_mxfp_fp8_e5m2_to_float(uint8_t v) { + uint32_t sign = ((uint32_t)(v & 0x80)) << 24; + uint32_t exp = (v >> 2) & 0x1F; + uint32_t mant = v & 0x3; + + if (exp == 0) { + if (mant == 0) return GGML_MXFP_U32_AS_F32(sign); + // Subnormal: mant * 2^(1-15) * 2^(-2) = mant/4 * 2^(-14) + float val = (float)mant * 0.25f * (1.0f / 16384.0f); + uint32_t vb = GGML_MXFP_F32_AS_U32(val); + vb = (vb & 0x7FFFFFFFu) | sign; + return GGML_MXFP_U32_AS_F32(vb); + } + if (exp == 31) { + return GGML_MXFP_U32_AS_F32(sign | 0x7F800000u | (mant ? 0x400000u : 0)); + } + // Normal: F32 exp = E-15+127 = E+112 + return GGML_MXFP_U32_AS_F32(sign | ((exp + 112) << 23) | (mant << 21)); +} + +GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp8_e5m2(float x) { + uint32_t bits = GGML_MXFP_F32_AS_U32(x); + uint8_t sign = (bits >> 24) & 0x80; + bits &= 0x7FFFFFFFu; + if (bits == 0) return sign; + + uint32_t f32_exp = (bits >> 23) & 0xFF; + uint32_t f32_mant = bits & 0x7FFFFF; + int e5m2_exp = (int)f32_exp - 112; + + if (e5m2_exp <= 0) { + int shift = 1 - e5m2_exp; + uint32_t full_mant = (1u << 23) | f32_mant; + int total_shift = 21 + shift; + if (total_shift >= 32) return sign; + uint32_t mant2 = full_mant >> total_shift; + if (total_shift > 0 && total_shift < 32) { + uint32_t round_bit = (full_mant >> (total_shift - 1)) & 1; + uint32_t sticky = (total_shift > 1) ? (full_mant & ((1u << (total_shift - 1)) - 1)) : 0; + if (round_bit && (sticky || (mant2 & 1))) mant2++; + } + if (mant2 > 3) return sign | 0x04; + return sign | (uint8_t)mant2; + } + + uint32_t round_bit = (f32_mant >> 20) & 1; + uint32_t sticky = f32_mant & ((1u << 20) - 1); + uint32_t mant2 = f32_mant >> 21; + if (round_bit && (sticky || (mant2 & 1))) { + mant2++; + if (mant2 > 3) { mant2 = 0; e5m2_exp++; } + } + if (e5m2_exp >= 31) return sign | 0x7B; // max finite + return sign | (uint8_t)((e5m2_exp << 2) | mant2); +} + +// --- FP6 packing/unpacking --- + +// Pack 4 six-bit values into 3 bytes +GGML_MXFP_FUNC void ggml_mxfp_pack_fp6x4(const uint8_t v[4], uint8_t out[3]) { + uint32_t packed = (v[0] & 0x3F) | ((v[1] & 0x3F) << 6) | + ((v[2] & 0x3F) << 12) | ((v[3] & 0x3F) << 18); + out[0] = (uint8_t)(packed); + out[1] = (uint8_t)(packed >> 8); + out[2] = (uint8_t)(packed >> 16); +} + +// Unpack 3 bytes into 4 six-bit values +GGML_MXFP_FUNC void ggml_mxfp_unpack_fp6x4(const uint8_t in[3], uint8_t v[4]) { + uint32_t packed = (uint32_t)in[0] | ((uint32_t)in[1] << 8) | ((uint32_t)in[2] << 16); + v[0] = packed & 0x3F; + v[1] = (packed >> 6) & 0x3F; + v[2] = (packed >> 12) & 0x3F; + v[3] = (packed >> 18) & 0x3F; +} + +// E8M0 shared exponent → float conversion. +// E8M0 encoding: value = 2^(x - 127) for x > 0, 2^(-127) for x == 0. +GGML_MXFP_FUNC float ggml_mxfp_e8m0_to_fp32(uint8_t x) { + uint32_t bits = (x == 0) ? 0x00400000u : ((uint32_t)x << 23); + return GGML_MXFP_U32_AS_F32(bits); +} + +// E8M0 → float/2. Used with MXFP4 since E2M1 values are doubled in kvalues_mxfp4. +GGML_MXFP_FUNC float ggml_mxfp_e8m0_to_fp32_half(uint8_t x) { + uint32_t bits = (x < 2) ? (0x00200000u << x) : ((uint32_t)(x - 1) << 23); + return GGML_MXFP_U32_AS_F32(bits); +} + +// E8M0 base exponent estimate: round(log2(amax)) - emax_offset + 127. +// Uses integer bit extraction — no log2f() SFU dependency. +// Caller must ensure amax > 0 and finite. Returns unclamped e_base. +GGML_MXFP_FUNC int ggml_mxfp_e8m0_base_estimate(float amax, int emax_offset) { + uint32_t amax_bits = GGML_MXFP_F32_AS_U32(amax); + const int floor_log2 = (int)((amax_bits >> 23) & 0xFF) - 127; + // Round: add 1 if mantissa >= sqrt(2)-1 (0x3504F3 in 23-bit IEEE mantissa). + const int round_log2 = floor_log2 + ((amax_bits & 0x7FFFFF) >= 0x3504F3 ? 1 : 0); + return round_log2 - emax_offset + 127; +} + +// Block-32 Walsh-Hadamard Transform, normalized by 1/sqrt(32). +// Spreads outlier energy across all elements sharing an E8M0 exponent, +// improving quantization quality (see QuaRot arXiv:2404.00456). +GGML_MXFP_FUNC void ggml_mxfp_hadamard_32_inplace(GGML_MXFP_THREAD float * vals) { + GGML_MXFP_UNROLL + for (int stride = 1; stride < 32; stride *= 2) { + GGML_MXFP_UNROLL + for (int i = 0; i < 32; i += 2 * stride) { + GGML_MXFP_UNROLL + for (int j = 0; j < stride; ++j) { + const float a = vals[i + j]; + const float b = vals[i + j + stride]; + vals[i + j] = a + b; + vals[i + j + stride] = a - b; + } + } + } + GGML_MXFP_UNROLL + for (int i = 0; i < 32; ++i) { + vals[i] *= MXFP_HADAMARD_32_NORM; + } +} + +#endif // GGML_MXFP_FUNC + #define NGRID_IQ1S 2048 #define IQ1S_DELTA 0.125f #define IQ1M_DELTA 0.125f diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 41da829315..42647e14e1 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -16,6 +16,8 @@ #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0 +#define ggml_vec_dot_mxfp6_e2m3_q8_0_generic ggml_vec_dot_mxfp6_e2m3_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -341,3 +343,14 @@ #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #endif + +// MXFP dequantize has no arch-specific (SIMD) implementations except on arm and x86. +// All other targets use the scalar generic as the public cpu function. +#if !defined(__aarch64__) && !defined(__arm__) && !defined(_M_ARM) && !defined(_M_ARM64) && \ + !defined(__x86_64__) && !defined(__i386__) && !defined(_M_IX86) && !defined(_M_X64) +#define dequantize_row_mxfp8_cpu_generic dequantize_row_mxfp8_cpu +#define dequantize_row_mxfp6_e2m3_cpu_generic dequantize_row_mxfp6_e2m3_cpu +#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu +#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu +#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu +#endif diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index c1856201b3..0f0ba86518 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -4134,3 +4134,541 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } +// NEON-optimized MXFP8 × Q8_0 dot product. +// Dequants FP8 elements to float via IEEE 754 bit construction, then dots against Q8_0. +// Parameters encode the FP8 format: sign_shift, exp_mask, mant_mask, ieee_exp_bias, mant_shift, sub_scale. +#if defined(__ARM_NEON) +static inline void ggml_vec_dot_mxfp8_q8_0_neon( + int n, float * GGML_RESTRICT s, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + // FP8 format parameters: + const uint32_t exp_mask, // 0xF for E4M3, 0x1F for E5M2 + const uint32_t mant_mask, // 0x7 for E4M3, 0x3 for E5M2 + const int exp_shift, // 3 for E4M3, 2 for E5M2 + const uint32_t ieee_exp_off, // 120 for E4M3, 112 for E5M2 + const int mant_shift, // 20 for E4M3, 21 for E5M2 + const float sub_scale) { // 1/512 for E4M3, 1/65536 for E5M2 + assert(n % QK_MXFP8 == 0); + const int nb = n / QK_MXFP8; + const block_mxfp8 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + float32x4_t acc0 = vdupq_n_f32(0.0f); + float32x4_t acc1 = vdupq_n_f32(0.0f); + + const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask); + const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask); + const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off); + const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale); + // Use variable shifts (vshlq_u32) instead of constant shifts (vshlq_n_u32) + // because exp_shift/mant_shift are function parameters, not compile-time constants. + // Clang requires _n_ intrinsics to have literal constant arguments. + const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift); + const int32x4_t v_mant_shift = vdupq_n_s32(mant_shift); + + for (int ib = 0; ib < nb; ++ib) { + const float scale = GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d); + const float32x4_t v_scale = vdupq_n_f32(scale); + + // Process 32 FP8 elements in 8 groups of 4 + for (int j = 0; j < 32; j += 8) { + // Load 8 FP8 bytes, extend to two uint32x4_t + const uint8x8_t raw8 = vld1_u8(x[ib].qs + j); + const uint16x8_t raw16 = vmovl_u8(raw8); + const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16)); + const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16)); + + // Load 8 Q8_0 int8 values, extend to two int32x4_t → float32x4_t + const int8x8_t q8 = vld1_s8(y[ib].qs + j); + const int16x8_t q16 = vmovl_s8(q8); + const float32x4_t qf_lo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(q16))); + const float32x4_t qf_hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16))); + + // Dequant FP8 → float for both groups of 4 + #define DEQUANT_FP8_NEON(v_raw, qf, acc) do { \ + const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x80)); \ + const uint32x4_t exp = vandq_u32( \ + vshlq_u32(v_raw, v_neg_exp_shift), \ + v_exp_mask); \ + const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \ + /* Normal: IEEE bits = (exp + offset) << 23 | mant << mant_shift */ \ + const uint32x4_t ieee = vorrq_u32( \ + vorrq_u32(vshlq_n_u32(sign, 24), \ + vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \ + vshlq_u32(mant, v_mant_shift)); \ + const float32x4_t normal = vreinterpretq_f32_u32(ieee); \ + /* Subnormal: sign * mant * sub_scale */ \ + const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \ + const uint32x4_t sub_bits = vorrq_u32( \ + vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 24)); \ + const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \ + /* Select: subnormal when exp == 0, else normal */ \ + const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \ + const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \ + /* Multiply by scale and Q8 value, accumulate */ \ + (acc) = vfmaq_f32((acc), vmulq_f32(val, v_scale), qf); \ + } while (0) + + DEQUANT_FP8_NEON(v_lo, qf_lo, acc0); + DEQUANT_FP8_NEON(v_hi, qf_hi, acc1); + #undef DEQUANT_FP8_NEON + } + } + + *s = vaddvq_f32(vaddq_f32(acc0, acc1)); +} +#endif + +void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); +#if defined(__ARM_NEON) + // E4M3: sign(1) exp(4) mant(3), bias=7 + ggml_vec_dot_mxfp8_q8_0_neon(n, s, vx, vy, + MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT, + MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE); +#else + ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +// NEON-optimized MXFP6 × Q8_0 dot product. +// Unpacks tight 6-bit packing (4 values per 3 bytes), then dequants to float. +#if defined(__ARM_NEON) +static inline void ggml_vec_dot_mxfp6_q8_0_neon( + int n, float * GGML_RESTRICT s, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + size_t block_size, + // FP6 format parameters: + const uint32_t exp_mask, // 0x3 for E2M3, 0x7 for E3M2 + const uint32_t mant_mask, // 0x7 for E2M3, 0x3 for E3M2 + const int exp_shift, // 3 for E2M3, 2 for E3M2 + const uint32_t ieee_exp_off, // 126 for E2M3, 124 for E3M2 + const int mant_shift, // 20 for E2M3, 21 for E3M2 + const float sub_scale) { // 1/8 for E2M3, 1/16 for E3M2 + assert(n % QK_MXFP6 == 0); + const int nb = n / QK_MXFP6; + const block_q8_0 * GGML_RESTRICT y = vy; + + float32x4_t acc0 = vdupq_n_f32(0.0f); + float32x4_t acc1 = vdupq_n_f32(0.0f); + + const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask); + const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask); + const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off); + const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale); + const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift); + const int32x4_t v_mant_shift = vdupq_n_s32(mant_shift); + + for (int ib = 0; ib < nb; ++ib) { + const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size); + const float scale = GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d); + const float32x4_t v_scale = vdupq_n_f32(scale); + + // Process 32 FP6 elements: 8 groups of 4, each packed in 3 bytes + for (int j = 0; j < 32; j += 8) { + // Unpack two groups of 4 FP6 values (6 bytes → 8 values) + uint8_t unpacked[8]; + // Group 1: 3 bytes → 4 values + { + const uint8_t * p = xb->qs + (j * 3 / 4); + const uint32_t packed = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); + unpacked[0] = (packed >> 0) & 0x3F; + unpacked[1] = (packed >> 6) & 0x3F; + unpacked[2] = (packed >> 12) & 0x3F; + unpacked[3] = (packed >> 18) & 0x3F; + } + // Group 2: next 3 bytes → 4 values + { + const uint8_t * p = xb->qs + ((j + 4) * 3 / 4); + const uint32_t packed = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); + unpacked[4] = (packed >> 0) & 0x3F; + unpacked[5] = (packed >> 6) & 0x3F; + unpacked[6] = (packed >> 12) & 0x3F; + unpacked[7] = (packed >> 18) & 0x3F; + } + + // Extend to uint32x4_t + const uint8x8_t raw8 = vld1_u8(unpacked); + const uint16x8_t raw16 = vmovl_u8(raw8); + const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16)); + const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16)); + + // Load Q8_0 int8 values + const int8x8_t q8 = vld1_s8(y[ib].qs + j); + const int16x8_t q16 = vmovl_s8(q8); + const float32x4_t qf_lo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(q16))); + const float32x4_t qf_hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16))); + + // Dequant FP6 → float (same IEEE construction as FP8, sign bit at position 5) + #define DEQUANT_FP6_NEON(v_raw, qf, acc) do { \ + const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x20)); \ + const uint32x4_t exp = vandq_u32( \ + vshlq_u32(v_raw, v_neg_exp_shift), \ + v_exp_mask); \ + const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \ + const uint32x4_t ieee = vorrq_u32( \ + vorrq_u32(vshlq_n_u32(sign, 26), \ + vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \ + vshlq_u32(mant, v_mant_shift)); \ + const float32x4_t normal = vreinterpretq_f32_u32(ieee); \ + const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \ + const uint32x4_t sub_bits = vorrq_u32( \ + vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 26)); \ + const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \ + const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \ + const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \ + (acc) = vfmaq_f32((acc), vmulq_f32(val, v_scale), qf); \ + } while (0) + + DEQUANT_FP6_NEON(v_lo, qf_lo, acc0); + DEQUANT_FP6_NEON(v_hi, qf_hi, acc1); + #undef DEQUANT_FP6_NEON + } + } + + *s = vaddvq_f32(vaddq_f32(acc0, acc1)); +} +#endif + +void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); +#if defined(__ARM_NEON) + // E2M3: sign(1) exp(2) mant(3), bias=1 + ggml_vec_dot_mxfp6_q8_0_neon(n, s, vx, vy, sizeof(block_mxfp6), + MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, + MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); +#else + ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +// ---- MXFP dequantize_row (to_float) — NEON-optimized ---- + +#if defined(__ARM_NEON) +static inline void dequantize_row_mxfp8_neon( + const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k, + const uint32_t exp_mask, const uint32_t mant_mask, + const int exp_shift, const uint32_t ieee_exp_off, + const int mant_shift, const float sub_scale) { + assert(k % QK_MXFP8 == 0); + const int nb = k / QK_MXFP8; + const block_mxfp8 * GGML_RESTRICT x = vx; + + const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask); + const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask); + const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off); + const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale); + const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift); + const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift); + + for (int ib = 0; ib < nb; ++ib) { + const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32(x[ib].e)); + + for (int j = 0; j < 32; j += 8) { + const uint8x8_t raw8 = vld1_u8(x[ib].qs + j); + const uint16x8_t raw16 = vmovl_u8(raw8); + const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16)); + const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16)); + + #define DEQUANT_FP8_STORE(v_raw, dst) do { \ + const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x80)); \ + const uint32x4_t exp = vandq_u32( \ + vshlq_u32(v_raw, v_neg_exp_shift), \ + v_exp_mask); \ + const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \ + const uint32x4_t ieee = vorrq_u32( \ + vorrq_u32(vshlq_n_u32(sign, 24), \ + vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \ + vshlq_u32(mant, v_mant_shift_v)); \ + const float32x4_t normal = vreinterpretq_f32_u32(ieee); \ + const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \ + const uint32x4_t sub_bits = vorrq_u32( \ + vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 24)); \ + const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \ + const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \ + const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \ + vst1q_f32(dst, vmulq_f32(val, v_scale)); \ + } while (0) + + DEQUANT_FP8_STORE(v_lo, y + ib * QK_MXFP8 + j); + DEQUANT_FP8_STORE(v_hi, y + ib * QK_MXFP8 + j + 4); + #undef DEQUANT_FP8_STORE + } + } +} + +static inline void dequantize_row_mxfp6_neon( + const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k, + size_t block_size, + const uint32_t exp_mask, const uint32_t mant_mask, + const int exp_shift, const uint32_t ieee_exp_off, + const int mant_shift, const float sub_scale) { + assert(k % QK_MXFP6 == 0); + const int nb = k / QK_MXFP6; + + const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask); + const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask); + const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off); + const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale); + const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift); + const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift); + + for (int ib = 0; ib < nb; ++ib) { + const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size); + const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32(xb->e)); + + for (int j = 0; j < 32; j += 4) { + const uint8_t * p = xb->qs + (j * 3 / 4); + const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); + uint8_t unpacked[4]; + unpacked[0] = (pk >> 0) & 0x3F; + unpacked[1] = (pk >> 6) & 0x3F; + unpacked[2] = (pk >> 12) & 0x3F; + unpacked[3] = (pk >> 18) & 0x3F; + + const uint8x8_t raw8 = vcreate_u8( + (uint64_t)unpacked[0] | ((uint64_t)unpacked[1] << 8) | + ((uint64_t)unpacked[2] << 16) | ((uint64_t)unpacked[3] << 24)); + const uint32x4_t v_raw = vmovl_u16(vget_low_u16(vmovl_u8(raw8))); + + const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x20)); + const uint32x4_t exp = vandq_u32( + vshlq_u32(v_raw, v_neg_exp_shift), + v_exp_mask); + const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); + + const uint32x4_t ieee = vorrq_u32( + vorrq_u32(vshlq_n_u32(sign, 26), + vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), + vshlq_u32(mant, v_mant_shift_v)); + const float32x4_t normal = vreinterpretq_f32_u32(ieee); + + const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); + const uint32x4_t sub_bits = vorrq_u32( + vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 26)); + const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); + + const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); + const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); + + vst1q_f32(y + ib * QK_MXFP6 + j, vmulq_f32(val, v_scale)); + } + } +} +#endif + +void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +#if defined(__ARM_NEON) + dequantize_row_mxfp8_neon(x, y, k, + MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT, + MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE); +#else + dequantize_row_mxfp8_cpu_generic(x, y, k); +#endif +} + +void dequantize_row_mxfp6_e2m3_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +#if defined(__ARM_NEON) + dequantize_row_mxfp6_neon(x, y, k, sizeof(block_mxfp6), + MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, + MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); +#else + dequantize_row_mxfp6_e2m3_cpu_generic(x, y, k); +#endif +} + +// ---- MXFP SoA dequantize_row (to_float) — NEON-optimized ---- + +#if defined(__ARM_NEON) +static inline void dequantize_row_mxfp4_soa_neon( + const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_MXFP4 == 0); + const int nb = k / QK_MXFP4; + const char * row = (const char *)src; + const char * qs_base = row; + const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK); + + const int8x16_t values = vld1q_s8(kvalues_mxfp4); + const uint8x16_t m4b = vdupq_n_u8(0x0f); + + for (int i = 0; i < nb; i++) { + const float d = GGML_E8M0_TO_FP32_HALF((uint8_t)e8m0_base[i]); + const float32x4_t v_scale = vdupq_n_f32(d); + const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP4_SOA_QS_PER_BLOCK)); + + const uint8x16_t q4bits = vld1q_u8(qs); + + const int8x16_t lo = ggml_vqtbl1q_s8(values, vandq_u8(q4bits, m4b)); + const int8x16_t hi = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits, 4)); + + float * out_lo = y + i * QK_MXFP4; + float * out_hi = y + i * QK_MXFP4 + QK_MXFP4/2; + + { + const int16x8_t lo16_0 = vmovl_s8(vget_low_s8(lo)); + const int16x8_t lo16_1 = vmovl_s8(vget_high_s8(lo)); + vst1q_f32(out_lo + 0, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(lo16_0))), v_scale)); + vst1q_f32(out_lo + 4, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(lo16_0))), v_scale)); + vst1q_f32(out_lo + 8, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(lo16_1))), v_scale)); + vst1q_f32(out_lo + 12, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(lo16_1))), v_scale)); + } + { + const int16x8_t hi16_0 = vmovl_s8(vget_low_s8(hi)); + const int16x8_t hi16_1 = vmovl_s8(vget_high_s8(hi)); + vst1q_f32(out_hi + 0, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(hi16_0))), v_scale)); + vst1q_f32(out_hi + 4, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(hi16_0))), v_scale)); + vst1q_f32(out_hi + 8, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(hi16_1))), v_scale)); + vst1q_f32(out_hi + 12, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(hi16_1))), v_scale)); + } + } +} + +static inline void dequantize_row_mxfp8_soa_neon( + const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, + const uint32_t exp_mask, const uint32_t mant_mask, + const int exp_shift, const uint32_t ieee_exp_off, + const int mant_shift, const float sub_scale) { + assert(k % QK_MXFP8 == 0); + const int nb = k / QK_MXFP8; + const char * row = (const char *)src; + const char * qs_base = row; + const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK); + + const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask); + const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask); + const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off); + const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale); + const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift); + const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift); + + for (int ib = 0; ib < nb; ++ib) { + const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib])); + const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP8_SOA_QS_PER_BLOCK)); + + for (int j = 0; j < 32; j += 8) { + const uint8x8_t raw8 = vld1_u8(qs + j); + const uint16x8_t raw16 = vmovl_u8(raw8); + const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16)); + const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16)); + + #define DEQUANT_FP8_STORE_SOA(v_raw, dst) do { \ + const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x80)); \ + const uint32x4_t exp = vandq_u32( \ + vshlq_u32(v_raw, v_neg_exp_shift), \ + v_exp_mask); \ + const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \ + const uint32x4_t ieee = vorrq_u32( \ + vorrq_u32(vshlq_n_u32(sign, 24), \ + vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \ + vshlq_u32(mant, v_mant_shift_v)); \ + const float32x4_t normal = vreinterpretq_f32_u32(ieee); \ + const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \ + const uint32x4_t sub_bits = vorrq_u32( \ + vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 24)); \ + const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \ + const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \ + const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \ + vst1q_f32(dst, vmulq_f32(val, v_scale)); \ + } while (0) + + DEQUANT_FP8_STORE_SOA(v_lo, y + ib * QK_MXFP8 + j); + DEQUANT_FP8_STORE_SOA(v_hi, y + ib * QK_MXFP8 + j + 4); + #undef DEQUANT_FP8_STORE_SOA + } + } +} + +static inline void dequantize_row_mxfp6_soa_neon( + const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, + const uint32_t exp_mask, const uint32_t mant_mask, + const int exp_shift, const uint32_t ieee_exp_off, + const int mant_shift, const float sub_scale) { + assert(k % QK_MXFP6 == 0); + const int nb = k / QK_MXFP6; + const char * row = (const char *)src; + const char * qs_base = row; + const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK); + + const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask); + const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask); + const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off); + const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale); + const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift); + const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift); + + for (int ib = 0; ib < nb; ++ib) { + const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib])); + const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP6_SOA_QS_PER_BLOCK)); + + for (int j = 0; j < 32; j += 4) { + const uint8_t * p = qs + (j * 3 / 4); + const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); + uint8_t unpacked[4]; + unpacked[0] = (pk >> 0) & 0x3F; + unpacked[1] = (pk >> 6) & 0x3F; + unpacked[2] = (pk >> 12) & 0x3F; + unpacked[3] = (pk >> 18) & 0x3F; + + const uint8x8_t raw8 = vcreate_u8( + (uint64_t)unpacked[0] | ((uint64_t)unpacked[1] << 8) | + ((uint64_t)unpacked[2] << 16) | ((uint64_t)unpacked[3] << 24)); + const uint32x4_t v_raw = vmovl_u16(vget_low_u16(vmovl_u8(raw8))); + + const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x20)); + const uint32x4_t exp = vandq_u32( + vshlq_u32(v_raw, v_neg_exp_shift), + v_exp_mask); + const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); + + const uint32x4_t ieee = vorrq_u32( + vorrq_u32(vshlq_n_u32(sign, 26), + vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), + vshlq_u32(mant, v_mant_shift_v)); + const float32x4_t normal = vreinterpretq_f32_u32(ieee); + + const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); + const uint32x4_t sub_bits = vorrq_u32( + vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 26)); + const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); + + const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); + const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); + + vst1q_f32(y + ib * QK_MXFP6 + j, vmulq_f32(val, v_scale)); + } + } +} +#endif + +void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +#if defined(__ARM_NEON) + dequantize_row_mxfp4_soa_neon(x, y, k); +#else + dequantize_row_mxfp4_soa_cpu_generic(x, y, k); +#endif +} + +void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +#if defined(__ARM_NEON) + dequantize_row_mxfp8_soa_neon(x, y, k, + MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT, + MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE); +#else + dequantize_row_mxfp8_soa_cpu_generic(x, y, k); +#endif +} + +void dequantize_row_mxfp6_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +#if defined(__ARM_NEON) + dequantize_row_mxfp6_soa_neon(x, y, k, + MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, + MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); +#else + dequantize_row_mxfp6_soa_cpu_generic(x, y, k); +#endif +} + diff --git a/ggml/src/ggml-cpu/arch/loongarch/quants.c b/ggml/src/ggml-cpu/arch/loongarch/quants.c index f531e916b9..a75dac8b15 100644 --- a/ggml/src/ggml-cpu/arch/loongarch/quants.c +++ b/ggml/src/ggml-cpu/arch/loongarch/quants.c @@ -2157,3 +2157,14 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } +void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +} + +void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +} + +void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +} diff --git a/ggml/src/ggml-cpu/arch/powerpc/quants.c b/ggml/src/ggml-cpu/arch/powerpc/quants.c index d3dfd049ea..82ca1f9df9 100644 --- a/ggml/src/ggml-cpu/arch/powerpc/quants.c +++ b/ggml/src/ggml-cpu/arch/powerpc/quants.c @@ -2303,3 +2303,10 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } +void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +} + +void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +} diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index 826055dd9a..dcb97756c6 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -3607,3 +3607,11 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo return ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } + +void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +} + +void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +} diff --git a/ggml/src/ggml-cpu/arch/s390/quants.c b/ggml/src/ggml-cpu/arch/s390/quants.c index 34184ed851..234488f25c 100644 --- a/ggml/src/ggml-cpu/arch/s390/quants.c +++ b/ggml/src/ggml-cpu/arch/s390/quants.c @@ -1464,3 +1464,10 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } +void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +} + +void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +} diff --git a/ggml/src/ggml-cpu/arch/wasm/quants.c b/ggml/src/ggml-cpu/arch/wasm/quants.c index 74a359e6d1..88bc6ad778 100644 --- a/ggml/src/ggml-cpu/arch/wasm/quants.c +++ b/ggml/src/ggml-cpu/arch/wasm/quants.c @@ -1219,3 +1219,14 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } +void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +} + +void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +} + +void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +} diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index 74d699f633..29d5a28759 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -3818,3 +3818,501 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } + +// AVX2-optimized MXFP8 × Q8_0 dot product. +// Dequants FP8 elements to float via IEEE 754 bit construction, then dots against Q8_0. +// Parameters encode the FP8 format: exp_mask, mant_mask, exp_shift, ieee_exp_offset, mant_shift, sub_scale. +#if defined(__AVX2__) +static inline void ggml_vec_dot_mxfp8_q8_0_avx2( + int n, float * GGML_RESTRICT s, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + // FP8 format parameters: + const int exp_mask, // 0xF for E4M3, 0x1F for E5M2 + const int mant_mask, // 0x7 for E4M3, 0x3 for E5M2 + const int exp_shift, // 3 for E4M3, 2 for E5M2 + const int ieee_exp_off, // 120 for E4M3, 112 for E5M2 + const int mant_shift, // 20 for E4M3, 21 for E5M2 + const float sub_scale) { // 1/512 for E4M3, 1/65536 for E5M2 + assert(n % QK_MXFP8 == 0); + const int nb = n / QK_MXFP8; + const block_mxfp8 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask); + const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask); + const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off); + const __m256 v_sub_sc = _mm256_set1_ps(sub_scale); + const __m256i v_zero = _mm256_setzero_si256(); + + __m256 acc = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const __m256 v_scale = _mm256_set1_ps( + GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d)); + + // Process 32 FP8 elements in 4 groups of 8 + // AVX2 _mm256_cvtepu8_epi32 widens 8 bytes → 8 int32s directly + for (int j = 0; j < 32; j += 8) { + // Load 8 FP8 bytes → 8 int32s + const __m128i raw8 = _mm_loadl_epi64((const __m128i *)(x[ib].qs + j)); + const __m256i v_raw = _mm256_cvtepu8_epi32(raw8); + + // Load 8 Q8_0 int8 values → float + const __m128i q8 = _mm_loadl_epi64((const __m128i *)(y[ib].qs + j)); + const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(q8)); + + // Extract sign (bit 7), exponent, mantissa + const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x80)); + const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask); + const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask); + + // Normal path: IEEE bits = (sign << 24) | ((exp + offset) << 23) | (mant << mant_shift) + const __m256i ieee = _mm256_or_si256( + _mm256_or_si256(_mm256_slli_epi32(sign, 24), + _mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)), + _mm256_slli_epi32(mant, mant_shift)); + const __m256 normal = _mm256_castsi256_ps(ieee); + + // Subnormal path: |val| = mant * sub_scale, then apply sign + const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc); + const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256( + _mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 24))); + + // Select: subnormal when exp == 0, else normal + const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero)); + const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub); + + // Accumulate: val * scale * q8_float + acc = _mm256_fmadd_ps(_mm256_mul_ps(val, v_scale), qf, acc); + } + } + + *s = hsum_float_8(acc); +} +#endif + +void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); +#if defined(__AVX2__) + // E4M3: sign(1) exp(4) mant(3), bias=7 + ggml_vec_dot_mxfp8_q8_0_avx2(n, s, vx, vy, + MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT, + MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE); +#else + ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +// AVX2-optimized MXFP6 × Q8_0 dot product. +// Unpacks tight 6-bit packing (4 values per 3 bytes), then dequants to float. +#if defined(__AVX2__) +static inline void ggml_vec_dot_mxfp6_q8_0_avx2( + int n, float * GGML_RESTRICT s, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + size_t block_size, + // FP6 format parameters: + const int exp_mask, // 0x3 for E2M3, 0x7 for E3M2 + const int mant_mask, // 0x7 for E2M3, 0x3 for E3M2 + const int exp_shift, // 3 for E2M3, 2 for E3M2 + const int ieee_exp_off, // 126 for E2M3, 124 for E3M2 + const int mant_shift, // 20 for E2M3, 21 for E3M2 + const float sub_scale) { // 1/8 for E2M3, 1/16 for E3M2 + assert(n % QK_MXFP6 == 0); + const int nb = n / QK_MXFP6; + const block_q8_0 * GGML_RESTRICT y = vy; + + const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask); + const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask); + const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off); + const __m256 v_sub_sc = _mm256_set1_ps(sub_scale); + const __m256i v_zero = _mm256_setzero_si256(); + + __m256 acc = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size); + const __m256 v_scale = _mm256_set1_ps( + GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d)); + + // Process 32 FP6 elements in 4 groups of 8 (each group = 2 × 3-byte packs) + for (int j = 0; j < 32; j += 8) { + // Unpack 8 FP6 values from 6 bytes (two groups of 3 bytes → 4 values each) + uint8_t unpacked[8]; + { + const uint8_t * p = xb->qs + (j * 3 / 4); + const uint32_t pk0 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); + unpacked[0] = (pk0 >> 0) & 0x3F; + unpacked[1] = (pk0 >> 6) & 0x3F; + unpacked[2] = (pk0 >> 12) & 0x3F; + unpacked[3] = (pk0 >> 18) & 0x3F; + } + { + const uint8_t * p = xb->qs + ((j + 4) * 3 / 4); + const uint32_t pk1 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); + unpacked[4] = (pk1 >> 0) & 0x3F; + unpacked[5] = (pk1 >> 6) & 0x3F; + unpacked[6] = (pk1 >> 12) & 0x3F; + unpacked[7] = (pk1 >> 18) & 0x3F; + } + + // Widen 8 bytes → 8 int32s + const __m128i raw8 = _mm_loadl_epi64((const __m128i *)unpacked); + const __m256i v_raw = _mm256_cvtepu8_epi32(raw8); + + // Load 8 Q8_0 int8 values → float + const __m128i q8 = _mm_loadl_epi64((const __m128i *)(y[ib].qs + j)); + const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(q8)); + + // Extract sign (bit 5 for FP6), exponent, mantissa + const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x20)); + const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask); + const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask); + + // Normal: IEEE bits = (sign << 26) | ((exp + offset) << 23) | (mant << mant_shift) + const __m256i ieee = _mm256_or_si256( + _mm256_or_si256(_mm256_slli_epi32(sign, 26), + _mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)), + _mm256_slli_epi32(mant, mant_shift)); + const __m256 normal = _mm256_castsi256_ps(ieee); + + // Subnormal: |val| = mant * sub_scale, apply sign + const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc); + const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256( + _mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 26))); + + // Select: subnormal when exp == 0 + const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero)); + const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub); + + acc = _mm256_fmadd_ps(_mm256_mul_ps(val, v_scale), qf, acc); + } + } + + *s = hsum_float_8(acc); +} +#endif + +void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); +#if defined(__AVX2__) + // E2M3: sign(1) exp(2) mant(3), bias=1 + ggml_vec_dot_mxfp6_q8_0_avx2(n, s, vx, vy, sizeof(block_mxfp6), + MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, + MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); +#else + ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +// ---- MXFP dequantize_row (to_float) — AVX2-optimized ---- +// Extracts the SIMD dequant logic from vec_dot above, writing floats to output buffer +// instead of accumulating a dot product. + +#if defined(__AVX2__) +static inline void dequantize_row_mxfp8_avx2( + const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k, + const int exp_mask, const int mant_mask, const int exp_shift, + const int ieee_exp_off, const int mant_shift, const float sub_scale) { + assert(k % QK_MXFP8 == 0); + const int nb = k / QK_MXFP8; + const block_mxfp8 * GGML_RESTRICT x = vx; + + const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask); + const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask); + const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off); + const __m256 v_sub_sc = _mm256_set1_ps(sub_scale); + const __m256i v_zero = _mm256_setzero_si256(); + + for (int ib = 0; ib < nb; ++ib) { + const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32(x[ib].e)); + + for (int j = 0; j < 32; j += 8) { + const __m128i raw8 = _mm_loadl_epi64((const __m128i *)(x[ib].qs + j)); + const __m256i v_raw = _mm256_cvtepu8_epi32(raw8); + + const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x80)); + const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask); + const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask); + + const __m256i ieee = _mm256_or_si256( + _mm256_or_si256(_mm256_slli_epi32(sign, 24), + _mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)), + _mm256_slli_epi32(mant, mant_shift)); + const __m256 normal = _mm256_castsi256_ps(ieee); + + const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc); + const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256( + _mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 24))); + + const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero)); + const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub); + + _mm256_storeu_ps(y + ib * QK_MXFP8 + j, _mm256_mul_ps(val, v_scale)); + } + } +} + +static inline void dequantize_row_mxfp6_avx2( + const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k, + size_t block_size, + const int exp_mask, const int mant_mask, const int exp_shift, + const int ieee_exp_off, const int mant_shift, const float sub_scale) { + assert(k % QK_MXFP6 == 0); + const int nb = k / QK_MXFP6; + + const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask); + const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask); + const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off); + const __m256 v_sub_sc = _mm256_set1_ps(sub_scale); + const __m256i v_zero = _mm256_setzero_si256(); + + for (int ib = 0; ib < nb; ++ib) { + const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size); + const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32(xb->e)); + + for (int j = 0; j < 32; j += 8) { + // Unpack 8 FP6 values from 6 bytes (two groups of 3 bytes → 4 values each) + uint8_t unpacked[8]; + { + const uint8_t * p = xb->qs + (j * 3 / 4); + const uint32_t pk0 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); + unpacked[0] = (pk0 >> 0) & 0x3F; + unpacked[1] = (pk0 >> 6) & 0x3F; + unpacked[2] = (pk0 >> 12) & 0x3F; + unpacked[3] = (pk0 >> 18) & 0x3F; + } + { + const uint8_t * p = xb->qs + ((j + 4) * 3 / 4); + const uint32_t pk1 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); + unpacked[4] = (pk1 >> 0) & 0x3F; + unpacked[5] = (pk1 >> 6) & 0x3F; + unpacked[6] = (pk1 >> 12) & 0x3F; + unpacked[7] = (pk1 >> 18) & 0x3F; + } + + const __m128i raw8 = _mm_loadl_epi64((const __m128i *)unpacked); + const __m256i v_raw = _mm256_cvtepu8_epi32(raw8); + + const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x20)); + const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask); + const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask); + + const __m256i ieee = _mm256_or_si256( + _mm256_or_si256(_mm256_slli_epi32(sign, 26), + _mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)), + _mm256_slli_epi32(mant, mant_shift)); + const __m256 normal = _mm256_castsi256_ps(ieee); + + const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc); + const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256( + _mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 26))); + + const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero)); + const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub); + + _mm256_storeu_ps(y + ib * QK_MXFP6 + j, _mm256_mul_ps(val, v_scale)); + } + } +} +#endif + +void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +#if defined(__AVX2__) + dequantize_row_mxfp8_avx2(x, y, k, + MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT, + MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE); +#else + dequantize_row_mxfp8_cpu_generic(x, y, k); +#endif +} + +void dequantize_row_mxfp6_e2m3_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +#if defined(__AVX2__) + dequantize_row_mxfp6_avx2(x, y, k, sizeof(block_mxfp6), + MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, + MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); +#else + dequantize_row_mxfp6_e2m3_cpu_generic(x, y, k); +#endif +} + +// SoA dequant for flash attention — contiguous qs region + separate e8m0 region +#if defined(__AVX2__) +static inline void dequantize_row_mxfp4_soa_avx2( + const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_MXFP4 == 0); + const int nb = k / QK_MXFP4; + const char * row = (const char *)src; + const char * qs_base = row; + const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK); + + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4); + const __m128i m4b = _mm_set1_epi8(0x0f); + + for (int i = 0; i < nb; i++) { + const float d = GGML_E8M0_TO_FP32_HALF((uint8_t)e8m0_base[i]); + const __m256 v_scale = _mm256_set1_ps(d); + const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP4_SOA_QS_PER_BLOCK)); + + const __m128i q4bits = _mm_loadu_si128((const __m128i *)qs); + + const __m128i lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits, m4b)); + const __m128i hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4b)); + + // lo nibbles → first 16 floats + const __m256i lo32_0 = _mm256_cvtepi8_epi32(lo); + const __m256i lo32_1 = _mm256_cvtepi8_epi32(_mm_srli_si128(lo, 8)); + _mm256_storeu_ps(y + i * QK_MXFP4 + 0, _mm256_mul_ps(_mm256_cvtepi32_ps(lo32_0), v_scale)); + _mm256_storeu_ps(y + i * QK_MXFP4 + 8, _mm256_mul_ps(_mm256_cvtepi32_ps(lo32_1), v_scale)); + + // hi nibbles → second 16 floats + const __m256i hi32_0 = _mm256_cvtepi8_epi32(hi); + const __m256i hi32_1 = _mm256_cvtepi8_epi32(_mm_srli_si128(hi, 8)); + _mm256_storeu_ps(y + i * QK_MXFP4 + 16, _mm256_mul_ps(_mm256_cvtepi32_ps(hi32_0), v_scale)); + _mm256_storeu_ps(y + i * QK_MXFP4 + 24, _mm256_mul_ps(_mm256_cvtepi32_ps(hi32_1), v_scale)); + } +} + +static inline void dequantize_row_mxfp8_soa_avx2( + const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, + const int exp_mask, const int mant_mask, const int exp_shift, + const int ieee_exp_off, const int mant_shift, const float sub_scale) { + assert(k % QK_MXFP8 == 0); + const int nb = k / QK_MXFP8; + const char * row = (const char *)src; + const char * qs_base = row; + const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK); + + const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask); + const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask); + const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off); + const __m256 v_sub_sc = _mm256_set1_ps(sub_scale); + const __m256i v_zero = _mm256_setzero_si256(); + + for (int ib = 0; ib < nb; ++ib) { + const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib])); + const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP8_SOA_QS_PER_BLOCK)); + + for (int j = 0; j < 32; j += 8) { + const __m128i raw8 = _mm_loadl_epi64((const __m128i *)(qs + j)); + const __m256i v_raw = _mm256_cvtepu8_epi32(raw8); + + const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x80)); + const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask); + const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask); + + const __m256i ieee = _mm256_or_si256( + _mm256_or_si256(_mm256_slli_epi32(sign, 24), + _mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)), + _mm256_slli_epi32(mant, mant_shift)); + const __m256 normal = _mm256_castsi256_ps(ieee); + + const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc); + const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256( + _mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 24))); + + const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero)); + const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub); + + _mm256_storeu_ps(y + ib * QK_MXFP8 + j, _mm256_mul_ps(val, v_scale)); + } + } +} + +static inline void dequantize_row_mxfp6_soa_avx2( + const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, + const int exp_mask, const int mant_mask, const int exp_shift, + const int ieee_exp_off, const int mant_shift, const float sub_scale) { + assert(k % QK_MXFP6 == 0); + const int nb = k / QK_MXFP6; + const char * row = (const char *)src; + const char * qs_base = row; + const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK); + + const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask); + const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask); + const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off); + const __m256 v_sub_sc = _mm256_set1_ps(sub_scale); + const __m256i v_zero = _mm256_setzero_si256(); + + for (int ib = 0; ib < nb; ++ib) { + const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib])); + const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP6_SOA_QS_PER_BLOCK)); + + for (int j = 0; j < 32; j += 8) { + uint8_t unpacked[8]; + { + const uint8_t * p = qs + (j * 3 / 4); + const uint32_t pk0 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); + unpacked[0] = (pk0 >> 0) & 0x3F; + unpacked[1] = (pk0 >> 6) & 0x3F; + unpacked[2] = (pk0 >> 12) & 0x3F; + unpacked[3] = (pk0 >> 18) & 0x3F; + } + { + const uint8_t * p = qs + ((j + 4) * 3 / 4); + const uint32_t pk1 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); + unpacked[4] = (pk1 >> 0) & 0x3F; + unpacked[5] = (pk1 >> 6) & 0x3F; + unpacked[6] = (pk1 >> 12) & 0x3F; + unpacked[7] = (pk1 >> 18) & 0x3F; + } + + const __m128i raw8 = _mm_loadl_epi64((const __m128i *)unpacked); + const __m256i v_raw = _mm256_cvtepu8_epi32(raw8); + + const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x20)); + const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask); + const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask); + + const __m256i ieee = _mm256_or_si256( + _mm256_or_si256(_mm256_slli_epi32(sign, 26), + _mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)), + _mm256_slli_epi32(mant, mant_shift)); + const __m256 normal = _mm256_castsi256_ps(ieee); + + const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc); + const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256( + _mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 26))); + + const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero)); + const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub); + + _mm256_storeu_ps(y + ib * QK_MXFP6 + j, _mm256_mul_ps(val, v_scale)); + } + } +} +#endif + +void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +#if defined(__AVX2__) + dequantize_row_mxfp4_soa_avx2(x, y, k); +#else + dequantize_row_mxfp4_soa_cpu_generic(x, y, k); +#endif +} + +void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +#if defined(__AVX2__) + dequantize_row_mxfp8_soa_avx2(x, y, k, + MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT, + MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE); +#else + dequantize_row_mxfp8_soa_cpu_generic(x, y, k); +#endif +} + +void dequantize_row_mxfp6_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +#if defined(__AVX2__) + dequantize_row_mxfp6_soa_avx2(x, y, k, + MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, + MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); +#else + dequantize_row_mxfp6_soa_cpu_generic(x, y, k); +#endif +} + diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 8b323bd9b0..a87f808c95 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -264,7 +264,7 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_1, .nrows = 1, }, - [GGML_TYPE_MXFP4] = { + [GGML_TYPE_MXFP4_E2M1] = { .from_float = quantize_row_mxfp4, .vec_dot = ggml_vec_dot_mxfp4_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, @@ -276,6 +276,20 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, }, + [GGML_TYPE_MXFP8_E4M3] = { + .from_float = quantize_row_mxfp8, + .to_float = dequantize_row_mxfp8_cpu, + .vec_dot = ggml_vec_dot_mxfp8_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + }, + [GGML_TYPE_MXFP6_E2M3] = { + .from_float = quantize_row_mxfp6_e2m3, + .to_float = dequantize_row_mxfp6_e2m3_cpu, + .vec_dot = ggml_vec_dot_mxfp6_e2m3_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + }, [GGML_TYPE_Q2_K] = { .from_float = quantize_row_q2_K, .vec_dot = ggml_vec_dot_q2_K_q8_K, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 314cc1088a..02cd1abb8d 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -2,6 +2,8 @@ #include "ggml-cpu.h" #include "ggml-impl.h" +#include "ggml-quants.h" +#include "quants.h" #include "binary-ops.h" #include "simd-gemm.h" #include "ggml.h" @@ -11,6 +13,7 @@ #include #include #include +#include // ggml_compute_forward_dup @@ -669,8 +672,10 @@ void ggml_compute_forward_add( case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: - case GGML_TYPE_MXFP4: + case GGML_TYPE_MXFP4_E2M1: case GGML_TYPE_NVFP4: + case GGML_TYPE_MXFP8_E4M3: + case GGML_TYPE_MXFP6_E2M3: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -1119,8 +1124,10 @@ void ggml_compute_forward_add1( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: - case GGML_TYPE_MXFP4: + case GGML_TYPE_MXFP4_E2M1: case GGML_TYPE_NVFP4: + case GGML_TYPE_MXFP8_E4M3: + case GGML_TYPE_MXFP6_E2M3: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -1248,8 +1255,10 @@ void ggml_compute_forward_acc( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: - case GGML_TYPE_MXFP4: + case GGML_TYPE_MXFP4_E2M1: case GGML_TYPE_NVFP4: + case GGML_TYPE_MXFP8_E4M3: + case GGML_TYPE_MXFP6_E2M3: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4336,8 +4345,10 @@ void ggml_compute_forward_out_prod( case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: - case GGML_TYPE_MXFP4: + case GGML_TYPE_MXFP4_E2M1: case GGML_TYPE_NVFP4: + case GGML_TYPE_MXFP8_E4M3: + case GGML_TYPE_MXFP6_E2M3: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4612,8 +4623,10 @@ void ggml_compute_forward_set( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: - case GGML_TYPE_MXFP4: + case GGML_TYPE_MXFP4_E2M1: case GGML_TYPE_NVFP4: + case GGML_TYPE_MXFP8_E4M3: + case GGML_TYPE_MXFP6_E2M3: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4835,8 +4848,10 @@ void ggml_compute_forward_get_rows( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: - case GGML_TYPE_MXFP4: + case GGML_TYPE_MXFP4_E2M1: case GGML_TYPE_NVFP4: + case GGML_TYPE_MXFP8_E4M3: + case GGML_TYPE_MXFP6_E2M3: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4894,6 +4909,96 @@ void ggml_compute_forward_get_rows( //} } +// NEON-optimized Hadamard for ARM platforms; scalar fallback uses ggml_hadamard_32_inplace +// from ggml-quants.c (the reference implementation). +#if defined(__ARM_NEON) +static void hadamard_32_inplace(float vals[32]) { + float32x4_t v0 = vld1q_f32(vals + 0); + float32x4_t v1 = vld1q_f32(vals + 4); + float32x4_t v2 = vld1q_f32(vals + 8); + float32x4_t v3 = vld1q_f32(vals + 12); + float32x4_t v4 = vld1q_f32(vals + 16); + float32x4_t v5 = vld1q_f32(vals + 20); + float32x4_t v6 = vld1q_f32(vals + 24); + float32x4_t v7 = vld1q_f32(vals + 28); + + #define HADAMARD_S1(v) do { \ + float32x2_t lo = vget_low_f32(v); \ + float32x2_t hi = vget_high_f32(v); \ + float32x2x2_t t = vtrn_f32(lo, hi); \ + float32x2_t sum = vadd_f32(t.val[0], t.val[1]); \ + float32x2_t dif = vsub_f32(t.val[0], t.val[1]); \ + float32x2x2_t r = vtrn_f32(sum, dif); \ + (v) = vcombine_f32(r.val[0], r.val[1]); \ + } while (0) + HADAMARD_S1(v0); HADAMARD_S1(v1); HADAMARD_S1(v2); HADAMARD_S1(v3); + HADAMARD_S1(v4); HADAMARD_S1(v5); HADAMARD_S1(v6); HADAMARD_S1(v7); + #undef HADAMARD_S1 + + #define HADAMARD_S2(v) do { \ + float32x2_t lo = vget_low_f32(v); \ + float32x2_t hi = vget_high_f32(v); \ + (v) = vcombine_f32(vadd_f32(lo, hi), vsub_f32(lo, hi)); \ + } while (0) + HADAMARD_S2(v0); HADAMARD_S2(v1); HADAMARD_S2(v2); HADAMARD_S2(v3); + HADAMARD_S2(v4); HADAMARD_S2(v5); HADAMARD_S2(v6); HADAMARD_S2(v7); + #undef HADAMARD_S2 + + #define HADAMARD_S4(a, b) do { \ + float32x4_t s = vaddq_f32(a, b); \ + float32x4_t d = vsubq_f32(a, b); \ + (a) = s; (b) = d; \ + } while (0) + HADAMARD_S4(v0, v1); HADAMARD_S4(v2, v3); + HADAMARD_S4(v4, v5); HADAMARD_S4(v6, v7); + #undef HADAMARD_S4 + + { float32x4_t s, d; + s = vaddq_f32(v0, v2); d = vsubq_f32(v0, v2); v0 = s; v2 = d; + s = vaddq_f32(v1, v3); d = vsubq_f32(v1, v3); v1 = s; v3 = d; + s = vaddq_f32(v4, v6); d = vsubq_f32(v4, v6); v4 = s; v6 = d; + s = vaddq_f32(v5, v7); d = vsubq_f32(v5, v7); v5 = s; v7 = d; + } + + { float32x4_t s, d; + s = vaddq_f32(v0, v4); d = vsubq_f32(v0, v4); v0 = s; v4 = d; + s = vaddq_f32(v1, v5); d = vsubq_f32(v1, v5); v1 = s; v5 = d; + s = vaddq_f32(v2, v6); d = vsubq_f32(v2, v6); v2 = s; v6 = d; + s = vaddq_f32(v3, v7); d = vsubq_f32(v3, v7); v3 = s; v7 = d; + } + + const float32x4_t norm = vdupq_n_f32(MXFP_HADAMARD_32_NORM); + vst1q_f32(vals + 0, vmulq_f32(v0, norm)); + vst1q_f32(vals + 4, vmulq_f32(v1, norm)); + vst1q_f32(vals + 8, vmulq_f32(v2, norm)); + vst1q_f32(vals + 12, vmulq_f32(v3, norm)); + vst1q_f32(vals + 16, vmulq_f32(v4, norm)); + vst1q_f32(vals + 20, vmulq_f32(v5, norm)); + vst1q_f32(vals + 24, vmulq_f32(v6, norm)); + vst1q_f32(vals + 28, vmulq_f32(v7, norm)); +} +#else +// Scalar fallback: delegate to reference implementation in ggml-quants.c +static void hadamard_32_inplace(float vals[32]) { + ggml_hadamard_32_inplace(vals); +} +#endif + +// Apply Hadamard rotation to each 32-element block in a float buffer. +static void ggml_apply_hadamard_blocks(float * data, int64_t n) { + GGML_ASSERT(n % 32 == 0); + for (int64_t i = 0; i < n; i += 32) { + hadamard_32_inplace(data + i); + } +} + +// Prefer SIMD-optimized CPU dequant, fall back to scalar reference. +static inline ggml_to_float_t ggml_get_to_float_fn(ggml_type type) { + ggml_to_float_t fn = ggml_get_type_traits_cpu(type)->to_float; + if (!fn) { fn = ggml_get_type_traits(type)->to_float; } + return fn; +} + template static void ggml_compute_forward_set_rows_f32( const ggml_compute_params * params, @@ -4924,7 +5029,22 @@ static void ggml_compute_forward_set_rows_f32( const int64_t ir0 = dr*ith; const int64_t ir1 = std::min(ir0 + dr, nr); - ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float; + const int32_t apply_hadamard = ((const int32_t *)dst->op_params)[0]; + + // For MXFP types, use SoA quantize (canonical FA layout). + // For non-MXFP types, use the standard AoS from_float. + typedef void (*quantize_soa_fn)(const float *, void *, int64_t); + quantize_soa_fn mxfp_soa_quantize = nullptr; + ggml_from_float_t from_float = nullptr; + + switch (dst->type) { + case GGML_TYPE_MXFP4_E2M1: mxfp_soa_quantize = quantize_row_mxfp4_soa; break; + case GGML_TYPE_MXFP8_E4M3: mxfp_soa_quantize = quantize_row_mxfp8_soa; break; + case GGML_TYPE_MXFP6_E2M3: mxfp_soa_quantize = quantize_row_mxfp6_soa; break; + default: + from_float = ggml_get_type_traits_cpu(dst->type)->from_float; + break; + } for (int64_t i03 = 0; i03 < ne03; ++i03) { for (int64_t i02 = 0; i02 < ne02; ++i02) { @@ -4937,9 +5057,26 @@ static void ggml_compute_forward_set_rows_f32( GGML_ASSERT(i1 >= 0 && i1 < ne1); - from_float( - (const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03), - ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc); + const float * src_row = (const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03); + char * dst_row = ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3); + + if (apply_hadamard) { + GGML_ASSERT(nc <= 1024); + float tmp[1024]; + memcpy(tmp, src_row, nc * sizeof(float)); + ggml_apply_hadamard_blocks(tmp, nc); + if (mxfp_soa_quantize) { + mxfp_soa_quantize(tmp, dst_row, nc); + } else { + from_float(tmp, dst_row, nc); + } + } else { + if (mxfp_soa_quantize) { + mxfp_soa_quantize(src_row, dst_row, nc); + } else { + from_float(src_row, dst_row, nc); + } + } } } } @@ -5560,8 +5697,10 @@ void ggml_compute_forward_clamp( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: - case GGML_TYPE_MXFP4: + case GGML_TYPE_MXFP4_E2M1: case GGML_TYPE_NVFP4: + case GGML_TYPE_MXFP8_E4M3: + case GGML_TYPE_MXFP6_E2M3: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -8118,6 +8257,67 @@ void ggml_compute_forward_top_k( } } +// SoA function pointer types for MXFP flash attention paths. +typedef void (*mxfp_soa_quantize_fn)(const float *, void *, int64_t); +typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t); + +// Shared MXFP dispatch parameters for flash attention. +// Populated once and used by both the one_chunk and tiled paths. +struct mxfp_fa_params { + mxfp_soa_quantize_fn q_quantize; + mxfp_soa_dequantize_fn k_dequantize; + mxfp_soa_dequantize_fn v_dequantize; + bool k_multihead; + bool v_multihead; + int64_t k_soa_elems; + int64_t v_soa_elems; + bool apply_hadamard; +}; + +static mxfp_fa_params mxfp_fa_params_init( + const ggml_tensor * k, const ggml_tensor * v, + int64_t DK, int64_t DV, + size_t nbk2, size_t nbv2, + int64_t nek2, int64_t nev2) { + mxfp_fa_params p = {}; + + const bool is_mxfp_k = ggml_is_type_mxfp(k->type); + const bool is_mxfp_v = ggml_is_type_mxfp(v->type); + + if (is_mxfp_k) { + switch (k->type) { + case GGML_TYPE_MXFP4_E2M1: p.q_quantize = quantize_row_mxfp4_soa; p.k_dequantize = dequantize_row_mxfp4_soa_cpu; break; + case GGML_TYPE_MXFP8_E4M3: p.q_quantize = quantize_row_mxfp8_soa; p.k_dequantize = dequantize_row_mxfp8_soa_cpu; break; + case GGML_TYPE_MXFP6_E2M3: p.q_quantize = quantize_row_mxfp6_soa; p.k_dequantize = dequantize_row_mxfp6_soa_cpu; break; + default: GGML_ABORT("unsupported MXFP K type"); + } + } + + if (is_mxfp_v) { + switch (v->type) { + case GGML_TYPE_MXFP4_E2M1: p.v_dequantize = dequantize_row_mxfp4_soa_cpu; break; + case GGML_TYPE_MXFP8_E4M3: p.v_dequantize = dequantize_row_mxfp8_soa_cpu; break; + case GGML_TYPE_MXFP6_E2M3: p.v_dequantize = dequantize_row_mxfp6_soa_cpu; break; + default: GGML_ABORT("unsupported MXFP V type"); + } + } + + // Hadamard rotation must match K rotation. + // Skipped for: MLA (DK != DV, V is a view of K). + p.apply_hadamard = is_mxfp_k && (DK == DV) && ggml_mxfp_use_hadamard(k->type); + + // SoA layout detection: in the real KV cache, heads are contiguous within + // one KV-position stride (nb[2] == row_size(DK)), so SoA spans all heads. + // In test tensors, heads may be at distant offsets (nb[2] >> row_size(DK)), + // so SoA is per-head. Detect which case and set dequant parameters accordingly. + p.k_multihead = is_mxfp_k && (nbk2 == (size_t)ggml_row_size(k->type, DK)); + p.k_soa_elems = is_mxfp_k ? (p.k_multihead ? nek2 * DK : DK) : 0; + p.v_multihead = is_mxfp_v && (nbv2 == (size_t)ggml_row_size(v->type, DV)); + p.v_soa_elems = is_mxfp_v ? (p.v_multihead ? nev2 * DV : DV) : 0; + + return p; +} + static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( const ggml_compute_params * params, ggml_tensor * dst, @@ -8192,13 +8392,29 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type; - ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float; - ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot; - ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float; + const bool is_mxfp_k = ggml_is_type_mxfp(k->type); + const bool is_mxfp_v = ggml_is_type_mxfp(v->type); - GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type"); - GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type"); + const mxfp_fa_params mxfp = mxfp_fa_params_init(k, v, DK, DV, nbk2, nbv2, nek2, nev2); + + ggml_from_float_t q_to_vec_dot = nullptr; + ggml_vec_dot_t kq_vec_dot = nullptr; + ggml_to_float_t v_to_float = nullptr; + + if (is_mxfp_k) { + kq_vec_dot = nullptr; + } else { + ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type; + q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float; + kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot; + } + + if (!is_mxfp_v) { + v_to_float = ggml_get_to_float_fn(v->type); + } + + GGML_ASSERT((is_mxfp_k || q_to_vec_dot) && "fattn: unsupported K-type"); + GGML_ASSERT((v->type == GGML_TYPE_F32 || is_mxfp_v || v_to_float) && "fattn: unsupported V-type"); int ith = params->ith; @@ -8236,7 +8452,31 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( const int iv2 = iq2 / rv2; const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); - q_to_vec_dot(pq, Q_q, DK); + float Q_f32[1024]; + if (is_mxfp_k) { + // Q preprocessing: Hadamard → SoA quantize → SoA dequant (round-trip). + // Captures the same quantization loss as K, matching GPU MMA semantics. + GGML_ASSERT(DK <= 1024); + if (mxfp.apply_hadamard) { + float q_tmp[1024]; + memcpy(q_tmp, pq, DK * sizeof(float)); + ggml_apply_hadamard_blocks(q_tmp, DK); + mxfp.q_quantize(q_tmp, Q_q, DK); + } else { + mxfp.q_quantize(pq, Q_q, DK); + } + mxfp.k_dequantize(Q_q, Q_f32, DK); + } else { + if (mxfp.apply_hadamard) { + GGML_ASSERT(DK <= 1024); + float q_tmp[1024]; + memcpy(q_tmp, pq, DK * sizeof(float)); + ggml_apply_hadamard_blocks(q_tmp, DK); + q_to_vec_dot(q_tmp, Q_q, DK); + } else { + q_to_vec_dot(pq, Q_q, DK); + } + } // online softmax / attention // loop over n_kv and n_head_kv @@ -8251,7 +8491,20 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( float s; // KQ value const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); - kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1); + if (is_mxfp_k) { + // Dequant SoA data. Multi-head: full row base, extract head portion. + // Per-head: use k_data directly. + const char * k_soa_base = mxfp.k_multihead + ? ((const char *) k->data + ic*nbk1 + ik3*nbk3) + : k_data; + float k_soa_f32[4096]; + GGML_ASSERT(mxfp.k_soa_elems <= 4096); + mxfp.k_dequantize(k_soa_base, k_soa_f32, mxfp.k_soa_elems); + const float * k_head = k_soa_f32 + (mxfp.k_multihead ? ik2 * DK : 0); + ggml_vec_dot_f32(DK, &s, 0, k_head, 0, Q_f32, 0, 1); + } else { + kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1); + } s = s*scale; // scale KQ value @@ -8297,7 +8550,15 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( } // V += v*expf(s - M) - if (v_to_float) { + if (mxfp.v_dequantize) { + const char * v_soa_base = mxfp.v_multihead + ? ((const char *) v->data + ic*nbv1 + iv3*nbv3) + : v_data; + float v_soa_f32[4096]; + GGML_ASSERT(mxfp.v_soa_elems <= 4096); + mxfp.v_dequantize(v_soa_base, v_soa_f32, mxfp.v_soa_elems); + ggml_vec_mad_f32(DV, VKQ32, v_soa_f32 + (mxfp.v_multihead ? iv2 * DV : 0), vs); + } else if (v_to_float) { v_to_float(v_data, V32, DV); ggml_vec_mad_f32(DV, VKQ32, V32, vs); } else { @@ -8399,9 +8660,17 @@ static void ggml_compute_forward_flash_attn_ext_tiled( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - GGML_ASSERT(k->type == v->type); - const ggml_type kv_type = k->type; + const ggml_type k_type = k->type; + const ggml_type v_type = v->type; + const bool is_mxfp_k = ggml_is_type_mxfp(k_type); + const bool is_mxfp_v = ggml_is_type_mxfp(v_type); + + const mxfp_fa_params mxfp = mxfp_fa_params_init(k, v, DK, DV, nbk2, nbv2, nek2, nev2); + + // Non-MXFP dequant functions + ggml_to_float_t k_to_float = is_mxfp_k ? nullptr : ggml_get_to_float_fn(k_type); + ggml_to_float_t v_to_float = is_mxfp_v ? nullptr : ggml_get_to_float_fn(v_type); // broadcast factors const int64_t rk2 = neq2/nek2; @@ -8490,6 +8759,16 @@ static void ggml_compute_forward_flash_attn_ext_tiled( for (int tq = 0; tq < tile_rows; tq++) { const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3)); memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float)); + + if (is_mxfp_k) { + if (mxfp.apply_hadamard) { + ggml_apply_hadamard_blocks(Q_f32 + tq * DK, DK); + } + // SoA round-trip: quantize Q to SoA, then dequant back to float. + uint8_t q_mxfp_buf[1024]; + mxfp.q_quantize(Q_f32 + tq * DK, q_mxfp_buf, DK); + mxfp.k_dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK); + } } for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) { memset(Q_f32 + tq * DK, 0, DK * sizeof(float)); @@ -8528,16 +8807,33 @@ static void ggml_compute_forward_flash_attn_ext_tiled( // Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns for (int tk = 0; tk < kv_tile; tk++) { const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3; - if (kv_type == GGML_TYPE_F16) { + if (k_type == GGML_TYPE_F16) { const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data; for (int64_t dk = 0; dk < DK; dk++) { K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]); } - } else { + } else if (k_type == GGML_TYPE_F32) { const float * k_f32_src = (const float *)k_data; for (int64_t dk = 0; dk < DK; dk++) { K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk]; } + } else if (mxfp.k_dequantize) { + const char * k_soa_base = mxfp.k_multihead + ? ((const char *)k->data + (ic + tk)*nbk1 + ik3*nbk3) + : k_data; + float k_soa[4096]; + GGML_ASSERT(mxfp.k_soa_elems <= 4096); + mxfp.k_dequantize(k_soa_base, k_soa, mxfp.k_soa_elems); + const float * k_head = k_soa + (mxfp.k_multihead ? ik2 * DK : 0); + for (int64_t dk = 0; dk < DK; dk++) { + K_f32[dk * KV_TILE_SZ + tk] = k_head[dk]; + } + } else { + float k_tmp[1024]; + k_to_float(k_data, k_tmp, DK); + for (int64_t dk = 0; dk < DK; dk++) { + K_f32[dk * KV_TILE_SZ + tk] = k_tmp[dk]; + } } } memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float)); @@ -8593,10 +8889,20 @@ static void ggml_compute_forward_flash_attn_ext_tiled( // Pack V tile to contiguous F32, zero-padded for (int tk = 0; tk < kv_tile; tk++) { const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3; - if (kv_type == GGML_TYPE_F16) { + if (v_type == GGML_TYPE_F16) { ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV); - } else { + } else if (v_type == GGML_TYPE_F32) { memcpy(V32 + tk * DV, v_data, DV * sizeof(float)); + } else if (mxfp.v_dequantize) { + const char * v_soa_base = mxfp.v_multihead + ? ((const char *)v->data + (ic + tk)*nbv1 + iv3*nbv3) + : v_data; + float v_soa[4096]; + GGML_ASSERT(mxfp.v_soa_elems <= 4096); + mxfp.v_dequantize(v_soa_base, v_soa, mxfp.v_soa_elems); + memcpy(V32 + tk * DV, v_soa + (mxfp.v_multihead ? iv2 * DV : 0), DV * sizeof(float)); + } else { + v_to_float(v_data, V32 + tk * DV, DV); } } for (int tq = 0; tq < Q_TILE_SZ; tq++) { @@ -8764,8 +9070,10 @@ static void ggml_compute_forward_flash_attn_ext_f16( // When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking) const bool use_ref = params->use_ref; - const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16); - const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512; + // Split-KV: parallelize across KV chunks for single-query decode (token generation). + // Delegates to one_chunk which handles all supported types (F16, Q8_0, Q4_0, MXFP, etc). + const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) + && q->type == GGML_TYPE_F32 && nek1 >= 512; if (use_split_kv_path) { const int64_t chunk_size = (nek1 + nth - 1) / nth; @@ -8824,8 +9132,6 @@ static void ggml_compute_forward_flash_attn_ext_f16( static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q; bool use_tiled = !use_ref && (q->type == GGML_TYPE_F32 && - kv_is_f32_or_f16 && - k->type == v->type && neq1 >= Q_TILE_SZ); #ifdef GGML_SIMD use_tiled &= (DV % GGML_F32_EPR == 0); diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index 7ebbb9c6f1..9152755010 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -54,6 +54,14 @@ void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, i quantize_row_nvfp4_ref(x, y, k); } +void quantize_row_mxfp8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_mxfp8_ref(x, y, k); +} + +void quantize_row_mxfp6_e2m3(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_mxfp6_e2m3_ref(x, y, k); +} + // // 2-6 bit quantization in super-blocks // @@ -256,6 +264,70 @@ void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, *s = sumf; } +// Generic MXFP-to-Q8_0 dot product. Dequants one MX block (32 elements) +// to float via the existing public dequantize_row functions, then dots +// against Q8_0 int8 values. Reference implementation — not SIMD-optimized. +static void ggml_vec_dot_mxfp_q8_0_impl( + int n, float * GGML_RESTRICT s, + const void * GGML_RESTRICT vx, size_t block_size, + const void * GGML_RESTRICT vy, + ggml_to_float_t dequant) { + assert(n % QK8_0 == 0); + const int nb = n / QK8_0; + const block_q8_0 * GGML_RESTRICT y = vy; + float sumf = 0; + + for (int ib = 0; ib < nb; ib++) { + float tmp[QK8_0]; + dequant((const char *)vx + ib * block_size, tmp, QK8_0); + + const float y_d = GGML_CPU_FP16_TO_FP32(y[ib].d); + float block_sum = 0; + for (int j = 0; j < QK8_0; j++) { + block_sum += tmp[j] * (float)y[ib].qs[j]; + } + sumf += block_sum * y_d; + } + *s = sumf; +} + +void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + ggml_vec_dot_mxfp_q8_0_impl(n, s, vx, sizeof(block_mxfp8), vy, + (ggml_to_float_t)dequantize_row_mxfp8); +} + +void ggml_vec_dot_mxfp6_e2m3_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + ggml_vec_dot_mxfp_q8_0_impl(n, s, vx, sizeof(block_mxfp6), vy, + (ggml_to_float_t)dequantize_row_mxfp6_e2m3); +} + +// Generic (scalar) dequant wrappers — delegates to ggml-quants.c reference implementations. +// On x86/ARM, arch-specific SIMD versions override these via the fallback.h mapping. +void dequantize_row_mxfp8_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + dequantize_row_mxfp8(x, y, k); +} +void dequantize_row_mxfp6_e2m3_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + dequantize_row_mxfp6_e2m3(x, y, k); +} +void dequantize_row_mxfp4_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + dequantize_row_mxfp4_soa(x, y, k); +} +void dequantize_row_mxfp8_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + dequantize_row_mxfp8_soa(x, y, k); +} +void dequantize_row_mxfp6_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + dequantize_row_mxfp6_soa(x, y, k); +} void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index 3584aaa43e..7d8c32762a 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -21,6 +21,12 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_mxfp8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_mxfp6_e2m3(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); + +// Dequantization (SIMD-optimized, arch-dispatched) +void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_mxfp6_e2m3_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -44,6 +50,8 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -76,6 +84,20 @@ void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_mxfp6_e2m3_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void dequantize_row_mxfp8_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_mxfp6_e2m3_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + +// SoA dequant (SIMD-optimized for FA) +void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_mxfp6_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + +void dequantize_row_mxfp4_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_mxfp8_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_mxfp6_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 6b76ab3bfb..faba78d29a 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -3767,7 +3767,7 @@ static block_mxfp4x4 make_block_mxfp4x4(block_mxfp4 * in, unsigned int blck_size } static int repack_mxfp4_to_mxfp4_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_MXFP4); + GGML_ASSERT(t->type == GGML_TYPE_MXFP4_E2M1); GGML_ASSERT(interleave_block == 4); const block_mxfp4 * src = (const block_mxfp4 *)data; @@ -3824,7 +3824,7 @@ static block_mxfp4x8 make_block_mxfp4x8(block_mxfp4 * in, unsigned int blck_size } static int repack_mxfp4_to_mxfp4_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_MXFP4); + GGML_ASSERT(t->type == GGML_TYPE_MXFP4_E2M1); GGML_ASSERT(interleave_block == 8); const block_mxfp4 * src = (const block_mxfp4 *)data; @@ -4682,7 +4682,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons } #endif } - } else if (cur->type == GGML_TYPE_MXFP4) { + } else if (cur->type == GGML_TYPE_MXFP4_E2M1) { if (ggml_cpu_has_avx2()) { if (cur->ne[1] % 8 == 0) { return &mxfp4_8x8_q8_0; diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 9256865595..8e5d931df5 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -430,59 +430,28 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x) #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) +// E8M0 shared exponent to float. Canonical source: ggml_mxfp_e8m0_to_fp32() in ggml-common.h. +// Kept here because ggml-impl.h cannot depend on ggml-common.h IMPL section. static inline float ggml_e8m0_to_fp32(uint8_t x) { - uint32_t bits; // Stores the raw bit representation of the float - - // Handle special case for minimum exponent (denormalized float) + uint32_t bits; if (x == 0) { - // Bit pattern for 2^(-127): - // - Sign bit: 0 (positive) - // - Exponent: 0 (denormalized number) - // - Mantissa: 0x400000 (0.5 in fractional form) - // Value = 0.5 * 2^(-126) = 2^(-127) - bits = 0x00400000; - } - // note: disabled as we don't need to handle NaNs - //// Handle special case for NaN (all bits set) - //else if (x == 0xFF) { - // // Standard quiet NaN pattern: - // // - Sign bit: 0 - // // - Exponent: all 1s (0xFF) - // // - Mantissa: 0x400000 (quiet NaN flag) - // bits = 0x7FC00000; - //} - // Normalized values (most common case) - else { - // Construct normalized float by shifting exponent into position: - // - Exponent field: 8 bits (positions 30-23) - // - Mantissa: 0 (implicit leading 1) - // Value = 2^(x - 127) + bits = 0x00400000; // 2^(-127) + } else { bits = (uint32_t) x << 23; } - - float result; // Final float value - // Safely reinterpret bit pattern as float without type-punning issues + float result; memcpy(&result, &bits, sizeof(float)); return result; } -// Equal to ggml_e8m0_to_fp32/2 -// Useful with MXFP4 quantization since the E0M2 values are doubled +// E8M0 to float/2. Canonical source: ggml_mxfp_e8m0_to_fp32_half() in ggml-common.h. static inline float ggml_e8m0_to_fp32_half(uint8_t x) { uint32_t bits; - - // For x < 2: use precomputed denormal patterns if (x < 2) { - // 0x00200000 = 2^(-128), 0x00400000 = 2^(-127) bits = 0x00200000 << x; - } - // For x >= 2: normalized exponent adjustment - else { - // 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1) + } else { bits = (uint32_t)(x - 1) << 23; } - // Note: NaNs are not handled here - float result; memcpy(&result, &bits, sizeof(float)); return result; diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index e1dca6b4b4..f2fcd73dba 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -3760,7 +3760,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te } else if (op->src[0]->type == GGML_TYPE_F32) { return op->src[1]->type == GGML_TYPE_F32; } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || - op->src[0]->type == GGML_TYPE_MXFP4 || + op->src[0]->type == GGML_TYPE_MXFP4_E2M1 || op->src[0]->type == GGML_TYPE_Q4_K || op->src[0]->type == GGML_TYPE_Q6_K) { return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); @@ -3771,7 +3771,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_MUL_MAT_ID: if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0 || - op->src[0]->type == GGML_TYPE_MXFP4) { + op->src[0]->type == GGML_TYPE_MXFP4_E2M1) { if (op->src[1]->type == GGML_TYPE_F32) { return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); } @@ -4559,7 +4559,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, #endif // GGML_OPENCL_USE_ADRENO_KERNELS return; } - if (tensor->type == GGML_TYPE_MXFP4) { + if (tensor->type == GGML_TYPE_MXFP4_E2M1) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); @@ -5136,7 +5136,7 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } - if (tensor->type == GGML_TYPE_MXFP4) { + if (tensor->type == GGML_TYPE_MXFP4_E2M1) { ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra; cl_int err; @@ -5585,7 +5585,7 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL)); CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_d, buf_d, 0, NULL, NULL)); CL_CHECK(clFinish(queue)); - } else if (tensor->type == GGML_TYPE_MXFP4) { + } else if (tensor->type == GGML_TYPE_MXFP4_E2M1) { ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *) tensor->extra; GGML_ASSERT(extra); @@ -10550,7 +10550,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); #endif // GGML_OPENCL_SOA_Q break; - case GGML_TYPE_MXFP4: { + case GGML_TYPE_MXFP4_E2M1: { #ifdef GGML_OPENCL_SOA_Q kernel = backend_ctx->kernel_mul_mv_mxfp4_f32_flat; @@ -10630,7 +10630,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co GGML_ASSERT(false && "not implemented"); } - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4 || + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4_E2M1 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K) { @@ -10864,7 +10864,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, #endif // GGML_OPENCL_SOA_Q break; } - case GGML_TYPE_MXFP4: { + case GGML_TYPE_MXFP4_E2M1: { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (use_adreno_moe_kernels(backend_ctx, src0)) { cl_int status; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 48695a61ea..b2692c45f6 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -257,19 +257,188 @@ void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_REST } } -static inline int best_index_mxfp4(float x, float e) { - int best_index = 0; - float best_err = fabsf(kvalues_mxfp4[0]*e - x); - for (int i = 1; i < 16; i++) { - float err = fabsf(kvalues_mxfp4[i]*e - x); - if (err < best_err) { - best_index = i; - best_err = err; - } - } - return best_index; +// ============================================================================ +// MXFP Element Conversion Functions +// ============================================================================ +// +// Reference implementations for OCP Microscaling (MX) format element types. +// Spec: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf +// +// All converters use IEEE-754 bit manipulation via memcpy (C99 safe, no strict +// aliasing issues). Quantization uses round-to-nearest-even (RNE) per MX spec. +// +// These functions are exposed in ggml-quants.h for use by CPU backends and tests. +// GPU backends (CUDA, Vulkan, Metal) provide their own optimized versions using +// hardware intrinsics (e.g., __nv_cvt_float_to_fp8, SIMD groups, LUT lookups). +// +// Key design decisions validated empirically on CUDA (Qwen3-Coder-30B-A3B): +// +// 1. SATURATION, NOT NaN PROPAGATION: FP8 E4M3 saturates to max (0x7E = 448) +// rather than producing NaN. The single NaN encoding (0x7F) is avoided. +// This matches the MX spec behavior and prevents NaN corruption in KV caches. +// +// 2. MX FP6 HAS NO NaN/Inf: Unlike IEEE-754, the MX spec defines exp=max as a +// valid normal value for FP6 types. Dequantizers must NOT special-case it. +// +// 3. RNE ROUNDING IN SUBNORMALS: Both normal and subnormal paths use proper +// round-to-nearest-even with sticky bit tracking. This was a P0 bug fix — +// truncation caused measurable PPL regression. +// +// 4. E3M2 SUBNORMAL SCALE: mant * 2^(1-bias-m) = mant * 2^(-4) = mant/16. +// NOT mant/4. This was a critical bug — the exponent bias and mantissa width +// both affect the subnormal multiplier. +// + +// FP8 E4M3: 1 sign, 4 exponent (bias 7), 3 mantissa +// Max finite: 448 (exp=15, mant=6), NaN: exp=15, mant=7 +// Thin wrappers around canonical implementations in ggml-common.h. +// Verified bit-for-bit identical by test-mxfp-converters. +float fp8_e4m3_to_float(uint8_t v) { return ggml_mxfp_fp8_e4m3_to_float(v); } +uint8_t float_to_fp8_e4m3_rn(float x) { return ggml_mxfp_float_to_fp8_e4m3(x); } + +// ============================================================================ +// MXFP Quantization Infrastructure +// ============================================================================ +// +// The MX format uses a shared E8M0 exponent per block of 32 elements. Choosing +// the optimal exponent is critical for quantization quality. +// +// The OCP MX v1.0 spec (§5.3) specifies floor(log2(amax)) for the shared exponent. +// We improve on this with an MSE-optimal ±1 search that tests 3 candidate exponents +// {e-1, e, e+1} around round(log2(amax)) and picks whichever minimizes the total +// round-trip quantization error for the block. This consistently improves perplexity +// by 0.05-0.2 across all MX types versus floor-only or round-only approaches. +// +// The round(log2(amax)) base is computed via IEEE-754 integer bit extraction rather +// than log2f(), avoiding GPU Special Function Unit (SFU) bottlenecks. The rounding +// threshold 0x3504F3 is the fractional part of sqrt(2) in IEEE-754 mantissa bits: +// if mantissa >= (sqrt(2)-1)*2^23 ≈ 0x3504F3, then log2(x) >= n+0.5, so round up. +// +// Each MX element type provides an mse_error function that computes the round-trip +// quantization error for a single value at a given scale. The traits structure +// encapsulates this per-type behavior. +// + +// Per-type traits for MSE-optimal E8M0 scale computation. +// emax_offset: type-specific offset from E8M0 bias to type's max representable exponent +// to_elem/to_float: element conversion function pointers (NULL for MXFP4 which uses LUT) +// mse_error: round-trip error function for a single value at a given scale +typedef struct { + int emax_offset; + uint8_t (*to_elem)(float); + float (*to_float)(uint8_t); + float (*mse_error)(float val, float inv_scale, float scale); +} mxfp_elem_traits_t; + +// Forward declaration — defined after kvalues_mxfp4 lookup table section. +static inline int best_index_mxfp4(float x, float e); + +// MXFP4 E2M1 MSE error: decision boundary quantization with HALF scale factor. +// +// This CPU implementation uses the doubled int8 kvalues_mxfp4 LUT {0,1,2,3,4,6,8,12} +// with GGML_E8M0_TO_FP32_HALF(e) = scale/2 for efficient nibble-indexed integer arithmetic. +// The MSE interface passes GGML_E8M0_TO_FP32(e) as scale, so we halve it. +// +// Canonical E2M1 values are {0, 0.5, 1, 1.5, 2, 3, 4, 6} (kvalues_mxfp4_float in ggml-common.h). +// Doubled boundaries {0.5, 1.5, 2.5, 3.5, 5, 7, 10} ÷ 2 = canonical {0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5}. +// Mathematically identical — the doubling is an implementation detail. +// This is the Lloyd-Max quantizer for uniform input density. +static float mse_error_mxfp4(float val, float inv_scale, float scale) { + // Decision boundary quantization with direct reconstruction. + // kvalues_mxfp4 positive sorted: {0, 1, 2, 3, 4, 6, 8, 12} + // Use inv_scale * 2 since MXFP4 scale includes 0.5x factor. + const float d = scale * 0.5f; + const float inv_d = (d > 0.0f) ? 1.0f / d : 0.0f; + const float normalized = fabsf(val) * inv_d; + (void)inv_scale; + float qval; + if (normalized < 0.5f) qval = 0.0f; + else if (normalized < 1.5f) qval = 1.0f; + else if (normalized < 2.5f) qval = 2.0f; + else if (normalized < 3.5f) qval = 3.0f; + else if (normalized < 5.0f) qval = 4.0f; + else if (normalized < 7.0f) qval = 6.0f; + else if (normalized < 10.0f) qval = 8.0f; + else qval = 12.0f; + const float err = fabsf(val) - qval * d; + return err * err; } +static const mxfp_elem_traits_t mxfp4_traits = { MXFP4_E2M1_EMAX_OFFSET, NULL, NULL, mse_error_mxfp4 }; + +// MSE-optimal E8M0 shared exponent computation. +// +// Algorithm: +// 1. Find amax = max(|x[0..qk-1]|) +// 2. Compute e_base = round(log2(amax)) - emax_offset + 127 via integer bit ops +// 3. Test {e_base-R .. e_base+R}, pick the one minimizing total round-trip MSE +// where R = MXFP_E8M0_MSE_RANGE (defined in ggml-common.h) +// +// The ±R search improves on the OCP spec's floor(log2(amax)). Wider search finds +// better scales for blocks with non-uniform value distributions (especially FP4). +// Cost is (2R+1) × qk roundtrip evaluations per block — negligible vs attention compute. +// +// Integer log2 avoids log2f() (SFU-dependent on GPU). The sqrt(2) rounding threshold +// ensures we start from round() not floor(). +// +// Ref: OCP MX v1.0 §5.3; Four Over Six (arXiv:2512.02010) +static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_elem_traits_t * traits) { + float amax = 0.0f; + for (int j = 0; j < qk; j++) { + const float a = fabsf(x[j]); + if (a > amax) amax = a; + } + if (amax == 0.0f) return 0; + + const int e_base = ggml_mxfp_e8m0_base_estimate(amax, traits->emax_offset); + + // ±R MSE search: test 2R+1 candidates around e_base, pick lowest total MSE. + int e_lo = e_base - MXFP_E8M0_MSE_RANGE; + int e_hi = e_base + MXFP_E8M0_MSE_RANGE; + if (e_lo < 1) e_lo = 1; + if (e_hi < 1) e_hi = 1; + if (e_hi > 254) e_hi = 254; + int best_e = e_base < 0 ? 0 : (e_base > 254 ? 254 : e_base); + float best_mse = 1e30f; + + for (int test_e = e_lo; test_e <= e_hi; ++test_e) { + const float test_scale = GGML_E8M0_TO_FP32((uint8_t)test_e); + const float test_inv = 1.0f / test_scale; + float mse = 0.0f; + for (int j = 0; j < qk; ++j) { + mse += traits->mse_error(x[j], test_inv, test_scale); + } + if (mse < best_mse) { + best_mse = mse; + best_e = test_e; + } + } + + return (uint8_t)best_e; +} + +static inline int best_index_mxfp4(float x, float e) { + // Decision boundary quantization: 7 comparisons instead of 16-element scan. + // kvalues_mxfp4 positive sorted: {0, 1, 2, 3, 4, 6, 8, 12} + // Decision boundaries (midpoints): {0.5, 1.5, 2.5, 3.5, 5, 7, 10} + const float inv_e = (e > 0.0f) ? 1.0f / e : 0.0f; + const float normalized = fabsf(x) * inv_e; + int idx; + if (normalized < 0.5f) idx = 0; + else if (normalized < 1.5f) idx = 1; + else if (normalized < 2.5f) idx = 2; + else if (normalized < 3.5f) idx = 3; + else if (normalized < 5.0f) idx = 4; + else if (normalized < 7.0f) idx = 5; + else if (normalized < 10.0f) idx = 6; + else idx = 7; + return (x < 0.0f) ? (idx + 8) : idx; +} + +// FP4 E2M1: search-based quantization using best_index_mxfp4 lookup table. +// Unlike FP6/FP8 which use direct float->element conversion, FP4 finds the +// closest 4-bit value by minimizing reconstruction error against the lookup table. +// Scale uses GGML_E8M0_TO_FP32_HALF (includes 0.5x factor for E2M1 mantissa range). void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) { static const int qk = QK_MXFP4; @@ -278,18 +447,7 @@ void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RE const int nb = k / qk; for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max - - for (int j = 0; j < qk; j++) { - const float v = x[i*qk + j]; - - if (amax < fabsf(v)) { - amax = fabsf(v); - } - } - - const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0; - + const uint8_t e = mxfp_compute_e8m0_mse(&x[i*qk], qk, &mxfp4_traits); const float d = GGML_E8M0_TO_FP32_HALF(e); y[i].e = e; @@ -494,6 +652,305 @@ void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_REST } } +// ============================================================================ +// Hadamard Rotation (reference scalar implementation) +// ============================================================================ +// +// 32-element Walsh-Hadamard transform, applied to MX blocks before quantization +// to spread outlier energy uniformly across the shared-exponent group. +// +// Without rotation, a single outlier in a block of 32 forces the shared E8M0 +// exponent high, wasting precision for all 31 other elements. The Hadamard +// transform is orthogonal (H^T·H = I), so H(K)·H(Q) = K·Q — attention scores +// are preserved exactly when both K and Q undergo the same rotation. +// +// Implementation: 5 butterfly stages (log2(32) = 5) of the fast Walsh-Hadamard +// transform, followed by normalization by 1/sqrt(32). Total: 160 FP add/sub + +// 32 FP mul. This is the standard "in-place" FWHT with O(n·log(n)) operations. +// +// The 1/sqrt(32) normalization factor makes the transform orthonormal: +// H_normalized = H_unnormalized / sqrt(N) +// This ensures the transform preserves vector norms (energy), which is critical +// for maintaining attention score magnitudes after rotation. +// +// Prior art: QuIP# (Tseng et al. 2024), BRQ (Huang et al. 2024) apply Hadamard +// for weight quantization. Our novel contribution: applying it to KV cache +// quantization at the MX block boundary (block-32), where it matches the shared +// exponent group size. Tested alternatives (block-8, block-16, sign flips, +// permutations) all degraded quality — block-32 Hadamard is uniquely optimal +// because it spreads energy across exactly the elements sharing an exponent. +// +// Empirical PPL impact WITHOUT Hadamard rotation (Qwen3-Coder-30B-A3B): +// MXFP8 E4M3: +0.22, MXFP8 E5M2: +1.38, MXFP6 E2M3: +3.34, MXFP6 E3M2: +4.60 +// +void ggml_hadamard_32_inplace(float vals[32]) { + ggml_mxfp_hadamard_32_inplace(vals); +} + +float fp6_e2m3_to_float(uint8_t v) { return ggml_mxfp_fp6_e2m3_to_float(v); } +uint8_t float_to_fp6_e2m3_rn(float x) { return ggml_mxfp_float_to_fp6_e2m3(x); } +float fp6_e3m2_to_float(uint8_t v) { return ggml_mxfp_fp6_e3m2_to_float(v); } +uint8_t float_to_fp6_e3m2_rn(float x) { return ggml_mxfp_float_to_fp6_e3m2(x); } +float fp8_e5m2_to_float(uint8_t v) { return ggml_mxfp_fp8_e5m2_to_float(v); } +uint8_t float_to_fp8_e5m2_rn(float x) { return ggml_mxfp_float_to_fp8_e5m2(x); } + +void pack_fp6x4(const uint8_t v[4], uint8_t out[3]) { ggml_mxfp_pack_fp6x4(v, out); } +void unpack_fp6x4(const uint8_t in[3], uint8_t v[4]) { ggml_mxfp_unpack_fp6x4(in, v); } + +// MSE error functions for FP8/FP6: quantize at given scale → dequantize → squared error. +// Used by mxfp_compute_e8m0_mse() to evaluate candidate E8M0 exponents. +// These call the public API wrappers which delegate to canonical ggml_mxfp_* in ggml-common.h. +static float mse_error_fp8_e4m3(float val, float inv_scale, float scale) { + const float recon = fp8_e4m3_to_float(float_to_fp8_e4m3_rn(val * inv_scale)) * scale; + const float err = val - recon; + return err * err; +} +static float mse_error_fp6_e2m3(float val, float inv_scale, float scale) { + const float recon = fp6_e2m3_to_float(float_to_fp6_e2m3_rn(val * inv_scale)) * scale; + const float err = val - recon; + return err * err; +} +// emax_offset = ceil(log2(max_finite_value)) for each element type. +// This centers the E8M0 exponent search around the optimal scale for the type's range. +// E4M3: max=448, ceil(log2(448)) = 9, but offset=8 matches CUDA (empirically better) +// E2M3: max=7.5, ceil(log2(7.5)) = 3 +static const mxfp_elem_traits_t mxfp8_e4m3_traits = { MXFP8_E4M3_EMAX_OFFSET, float_to_fp8_e4m3_rn, fp8_e4m3_to_float, mse_error_fp8_e4m3 }; +static const mxfp_elem_traits_t mxfp6_e2m3_traits = { MXFP6_E2M3_EMAX_OFFSET, float_to_fp6_e2m3_rn, fp6_e2m3_to_float, mse_error_fp6_e2m3 }; + +// FP8 quantize/dequantize: byte-per-element, shared by E4M3 and E5M2 +static void quantize_row_mxfp8_impl(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, + int64_t k, const mxfp_elem_traits_t * traits) { + assert(k % QK_MXFP8 == 0); + const int nb = k / QK_MXFP8; + + for (int i = 0; i < nb; i++) { + const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP8], QK_MXFP8, traits); + const float d = GGML_E8M0_TO_FP32(e); + const float inv_d = d > 0.0f ? 1.0f / d : 0.0f; + y[i].e = e; + + for (int j = 0; j < QK_MXFP8; ++j) { + y[i].qs[j] = traits->to_elem(x[i*QK_MXFP8 + j] * inv_d); + } + } +} + +static void dequantize_row_mxfp8_impl(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, + int64_t k, const mxfp_elem_traits_t * traits) { + assert(k % QK_MXFP8 == 0); + const int nb = k / QK_MXFP8; + + for (int i = 0; i < nb; i++) { + const float d = GGML_E8M0_TO_FP32(x[i].e); + for (int j = 0; j < QK_MXFP8; ++j) { + y[i*QK_MXFP8 + j] = traits->to_float(x[i].qs[j]) * d; + } + } +} + +// FP6 quantize/dequantize: tight 6-bit packing (4 values per 3 bytes), shared by E2M3 and E3M2 +static void quantize_row_mxfp6_impl(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, + int64_t k, const mxfp_elem_traits_t * traits) { + assert(k % QK_MXFP6 == 0); + const int nb = k / QK_MXFP6; + + for (int i = 0; i < nb; i++) { + const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP6], QK_MXFP6, traits); + const float d = GGML_E8M0_TO_FP32(e); + const float inv_d = d > 0.0f ? 1.0f / d : 0.0f; + y[i].e = e; + + for (int j = 0; j < QK_MXFP6; j += 4) { + uint8_t vals[4]; + for (int jj = 0; jj < 4; jj++) { + vals[jj] = traits->to_elem(x[i*QK_MXFP6 + j + jj] * inv_d); + } + pack_fp6x4(vals, &y[i].qs[j * 3 / 4]); + } + } +} + +static void dequantize_row_mxfp6_impl(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, + int64_t k, const mxfp_elem_traits_t * traits) { + assert(k % QK_MXFP6 == 0); + const int nb = k / QK_MXFP6; + + for (int i = 0; i < nb; i++) { + const float d = GGML_E8M0_TO_FP32(x[i].e); + for (int j = 0; j < QK_MXFP6; j += 4) { + uint8_t vals[4]; + unpack_fp6x4(&x[i].qs[j * 3 / 4], vals); + for (int jj = 0; jj < 4; jj++) { + y[i*QK_MXFP6 + j + jj] = traits->to_float(vals[jj]) * d; + } + } + } +} + +// Public API wrappers — one-line delegates to the traits-parameterized impl + +void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k) { + quantize_row_mxfp8_impl(x, y, k, &mxfp8_e4m3_traits); +} + +void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + dequantize_row_mxfp8_impl(x, y, k, &mxfp8_e4m3_traits); +} + +void quantize_row_mxfp6_e2m3_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k) { + quantize_row_mxfp6_impl(x, y, k, &mxfp6_e2m3_traits); +} + +void dequantize_row_mxfp6_e2m3(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + dequantize_row_mxfp6_impl(x, y, k, &mxfp6_e2m3_traits); +} + +// ============================================================================ +// SoA (Struct-of-Arrays) quantize/dequantize — canonical reference for FA +// ============================================================================ +// +// SoA layout per row: [qs_block0|qs_block1|...|qs_blockN][e8m0_0|e8m0_1|...|e8m0_N] +// Total bytes per row = nblocks * (QS_PER_BLOCK + 1) = identical to AoS. +// This is the ONLY layout used by flash attention across all backends. + +void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) { + assert(k % QK_MXFP4 == 0); + const int nb = k / QK_MXFP4; + char * row = (char *)dst; + char * qs_base = row; + char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK); + + for (int i = 0; i < nb; i++) { + const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP4], QK_MXFP4, &mxfp4_traits); + const float d = GGML_E8M0_TO_FP32_HALF(e); + + e8m0_base[i] = (char)e; + + uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP4_SOA_QS_PER_BLOCK)); + for (int j = 0; j < QK_MXFP4/2; ++j) { + const uint8_t x0 = best_index_mxfp4(x[i*QK_MXFP4 + 0 + j], d); + const uint8_t x1 = best_index_mxfp4(x[i*QK_MXFP4 + QK_MXFP4/2 + j], d); + qs[j] = x0 | (x1 << 4); + } + } +} + +void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_MXFP4 == 0); + const int nb = k / QK_MXFP4; + const char * row = (const char *)src; + const char * qs_base = row; + const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK); + + for (int i = 0; i < nb; i++) { + const float d = GGML_E8M0_TO_FP32_HALF((uint8_t)e8m0_base[i]); + const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP4_SOA_QS_PER_BLOCK)); + + for (int j = 0; j < QK_MXFP4/2; ++j) { + const int8_t x0 = kvalues_mxfp4[qs[j] & 0x0F]; + const int8_t x1 = kvalues_mxfp4[qs[j] >> 4]; + y[i*QK_MXFP4 + j + 0 ] = x0*d; + y[i*QK_MXFP4 + j + QK_MXFP4/2] = x1*d; + } + } +} + +static void quantize_row_mxfp8_soa_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, + int64_t k, const mxfp_elem_traits_t * traits) { + assert(k % QK_MXFP8 == 0); + const int nb = k / QK_MXFP8; + char * row = (char *)dst; + char * qs_base = row; + char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK); + + for (int i = 0; i < nb; i++) { + const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP8], QK_MXFP8, traits); + const float d = GGML_E8M0_TO_FP32(e); + const float inv_d = d > 0.0f ? 1.0f / d : 0.0f; + e8m0_base[i] = (char)e; + + uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP8_SOA_QS_PER_BLOCK)); + for (int j = 0; j < QK_MXFP8; ++j) { + qs[j] = traits->to_elem(x[i*QK_MXFP8 + j] * inv_d); + } + } +} + +static void dequantize_row_mxfp8_soa_impl(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, + int64_t k, const mxfp_elem_traits_t * traits) { + assert(k % QK_MXFP8 == 0); + const int nb = k / QK_MXFP8; + const char * row = (const char *)src; + const char * qs_base = row; + const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK); + + for (int i = 0; i < nb; i++) { + const float d = GGML_E8M0_TO_FP32((uint8_t)e8m0_base[i]); + const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP8_SOA_QS_PER_BLOCK)); + for (int j = 0; j < QK_MXFP8; ++j) { + y[i*QK_MXFP8 + j] = traits->to_float(qs[j]) * d; + } + } +} + +static void quantize_row_mxfp6_soa_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, + int64_t k, const mxfp_elem_traits_t * traits) { + assert(k % QK_MXFP6 == 0); + const int nb = k / QK_MXFP6; + char * row = (char *)dst; + char * qs_base = row; + char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK); + + for (int i = 0; i < nb; i++) { + const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP6], QK_MXFP6, traits); + const float d = GGML_E8M0_TO_FP32(e); + const float inv_d = d > 0.0f ? 1.0f / d : 0.0f; + e8m0_base[i] = (char)e; + + uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP6_SOA_QS_PER_BLOCK)); + for (int j = 0; j < QK_MXFP6; j += 4) { + uint8_t vals[4]; + for (int jj = 0; jj < 4; jj++) { + vals[jj] = traits->to_elem(x[i*QK_MXFP6 + j + jj] * inv_d); + } + pack_fp6x4(vals, &qs[j * 3 / 4]); + } + } +} + +static void dequantize_row_mxfp6_soa_impl(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, + int64_t k, const mxfp_elem_traits_t * traits) { + assert(k % QK_MXFP6 == 0); + const int nb = k / QK_MXFP6; + const char * row = (const char *)src; + const char * qs_base = row; + const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK); + + for (int i = 0; i < nb; i++) { + const float d = GGML_E8M0_TO_FP32((uint8_t)e8m0_base[i]); + const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP6_SOA_QS_PER_BLOCK)); + for (int j = 0; j < QK_MXFP6; j += 4) { + uint8_t vals[4]; + unpack_fp6x4(&qs[j * 3 / 4], vals); + for (int jj = 0; jj < 4; jj++) { + y[i*QK_MXFP6 + j + jj] = traits->to_float(vals[jj]) * d; + } + } + } +} + +void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) { + quantize_row_mxfp8_soa_impl(x, dst, k, &mxfp8_e4m3_traits); +} +void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) { + dequantize_row_mxfp8_soa_impl(src, y, k, &mxfp8_e4m3_traits); +} +void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) { + quantize_row_mxfp6_soa_impl(x, dst, k, &mxfp6_e2m3_traits); +} +void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) { + dequantize_row_mxfp6_soa_impl(src, y, k, &mxfp6_e2m3_traits); +} // // 2-6 bit quantization in super-blocks // @@ -2155,7 +2612,7 @@ size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { GGML_UNUSED(quant_weights); quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row); - return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row); + return nrow * ggml_row_size(GGML_TYPE_MXFP4_E2M1, n_per_row); } size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { @@ -2164,6 +2621,18 @@ size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, return nrow * ggml_row_size(GGML_TYPE_NVFP4, n_per_row); } +size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_UNUSED(quant_weights); + quantize_row_mxfp8_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_MXFP8_E4M3, n_per_row); +} + +size_t quantize_mxfp6_e2m3(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_UNUSED(quant_weights); + quantize_row_mxfp6_e2m3_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_MXFP6_E2M3, n_per_row); +} + // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs) void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) { @@ -5306,7 +5775,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb); } break; - case GGML_TYPE_MXFP4: + case GGML_TYPE_MXFP4_E2M1: { VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb); } break; diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 00604f75c0..33401f2843 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -23,6 +23,8 @@ GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_mxfp6_e2m3_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); @@ -50,6 +52,17 @@ GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GG GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_mxfp6_e2m3(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + +// SoA (Struct-of-Arrays) quantize/dequantize — canonical reference for flash attention. +// Layout: [qs contiguous][e8m0 contiguous] per row. Same total bytes as AoS. +GGML_API void quantize_row_mxfp4_soa (const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); +GGML_API void dequantize_row_mxfp4_soa (const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_mxfp8_soa (const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); +GGML_API void dequantize_row_mxfp8_soa (const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); +GGML_API void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -98,6 +111,87 @@ GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTR GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_mxfp6_e2m3(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + +// +// MXFP element-level conversion functions (reference implementations) +// +// These implement the OCP Microscaling (MX) format element types as defined in: +// OCP Microscaling Formats (MX) Specification v1.0, Sep 2023 +// https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf +// +// Each MX block contains 32 elements sharing a single E8M0 exponent. The element +// types define the per-element mantissa format within each block. +// +// All converters use IEEE-754 bit manipulation for exact results (no floating-point +// rounding in the conversion itself). Quantization functions use round-to-nearest-even +// (RNE) per the MX specification. +// +// GPU backends (CUDA, Metal, Vulkan) provide their own optimized versions — these +// functions serve as the canonical reference and are used by the CPU backend. +// + +// FP8 E4M3: 1 sign, 4 exponent (bias 7), 3 mantissa bits +// Range: ±[2^-9, 448], NaN: exp=15 mant=7 (only NaN encoding) +// Ref: OCP MX v1.0 §4.2 +GGML_API float fp8_e4m3_to_float(uint8_t v); +GGML_API uint8_t float_to_fp8_e4m3_rn(float x); + +// FP8 E5M2: 1 sign, 5 exponent (bias 15), 2 mantissa bits +// Range: ±[2^-16, 57344], NaN/Inf: exp=31 (standard IEEE-like) +// Ref: OCP MX v1.0 §4.2 +GGML_API float fp8_e5m2_to_float(uint8_t v); +GGML_API uint8_t float_to_fp8_e5m2_rn(float x); + +// FP6 E2M3: 1 sign, 2 exponent (bias 1), 3 mantissa bits +// Range: ±[2^-3, 7.5], stored as low 6 bits of a byte (00xxxxxx) +// MX format: NO NaN/Inf — all bit patterns are valid numbers +// Ref: OCP MX v1.0 §4.2 +GGML_API float fp6_e2m3_to_float(uint8_t v); +GGML_API uint8_t float_to_fp6_e2m3_rn(float x); + +// FP6 E3M2: 1 sign, 3 exponent (bias 3), 2 mantissa bits +// Range: ±[2^-4, 28.0], stored as low 6 bits of a byte (00xxxxxx) +// MX format: NO NaN/Inf — exp=7 is a valid normal value (unlike IEEE-754) +// CRITICAL: subnormal scale is 2^(1-bias-m) = 2^(-4) = 1/16, NOT 1/4 +// Ref: OCP MX v1.0 §4.2 +GGML_API float fp6_e3m2_to_float(uint8_t v); +GGML_API uint8_t float_to_fp6_e3m2_rn(float x); + +// FP6 tight packing: pack/unpack 4 six-bit values into/from 3 bytes +// Layout: v[0]=bits[5:0], v[1]=bits[11:6], v[2]=bits[17:12], v[3]=bits[23:18] +// Saves 25% memory vs byte-padded storage (24B vs 32B per MX block) +GGML_API void pack_fp6x4(const uint8_t v[4], uint8_t out[3]); +GGML_API void unpack_fp6x4(const uint8_t in[3], uint8_t v[4]); + +// +// Hadamard rotation (reference scalar implementation) +// +// 32-element Walsh-Hadamard transform applied to MX blocks before quantization. +// Distributes outlier energy uniformly across the block, dramatically improving +// quantization quality for types with shared exponents. +// +// Mathematical property: H^T·H = I (orthogonal), so H(K)·H(Q) = K·Q. +// Flash attention applies matching rotation to Q, preserving attention scores exactly. +// +// Implementation: 5 butterfly stages (log2(32) = 5) + normalization by 1/sqrt(32). +// This is the standard "fast Walsh-Hadamard transform" with O(n log n) operations. +// +// Applied in set_rows (K cache quantization) and flash_attn (Q quantization). +// Skipped for MLA models (DK != DV) where V is a view of K — rotation would corrupt V. +// +// Empirical impact (PPL degradation WITHOUT rotation, Qwen3-Coder-30B): +// MXFP8 E4M3: +0.22, MXFP8 E5M2: +1.38, MXFP6 E2M3: +3.34, MXFP6 E3M2: +4.60 +// +// Prior art: QuIP# (Tseng et al. 2024), BRQ (Huang et al. 2024) use Hadamard for +// weight quantization. Our contribution applies it to KV cache quantization at the +// MX block boundary, where block-32 is optimal because it matches the shared exponent +// group size exactly. +// +// GPU backends provide optimized versions (CUDA warp shuffles, Metal SIMD groups). +// +GGML_API void ggml_hadamard_32_inplace(float vals[32]); GGML_API void iq2xs_init_impl(enum ggml_type type); GGML_API void iq2xs_free_impl(enum ggml_type type); diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index d17aca2cac..09f3a43a90 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -639,7 +639,7 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { return dequantize_row_iq4_xs_sycl; case GGML_TYPE_IQ4_NL: return dequantize_row_iq4_nl_sycl; - case GGML_TYPE_MXFP4: + case GGML_TYPE_MXFP4_E2M1: return dequantize_row_mxfp4_sycl; case GGML_TYPE_F32: return convert_unary_sycl; @@ -706,7 +706,7 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { return dequantize_row_iq4_xs_sycl; case GGML_TYPE_IQ4_NL: return dequantize_row_iq4_nl_sycl; - case GGML_TYPE_MXFP4: + case GGML_TYPE_MXFP4_E2M1: return dequantize_row_mxfp4_sycl; case GGML_TYPE_F16: return convert_unary_sycl; diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 316aa0d0fb..f0d1472f44 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -1142,7 +1142,7 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens case GGML_TYPE_IQ4_XS: mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; - case GGML_TYPE_MXFP4: + case GGML_TYPE_MXFP4_E2M1: mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; default: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e5b83e1447..329f2b93b3 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -710,8 +710,8 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref, }, - [GGML_TYPE_MXFP4] = { - .type_name = "mxfp4", + [GGML_TYPE_MXFP4_E2M1] = { + .type_name = "mxfp4_e2m1", .blck_size = QK_MXFP4, .type_size = sizeof(block_mxfp4), .is_quantized = true, @@ -726,6 +726,22 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .to_float = (ggml_to_float_t) dequantize_row_nvfp4, .from_float_ref = (ggml_from_float_t)quantize_row_nvfp4_ref, }, + [GGML_TYPE_MXFP8_E4M3] = { + .type_name = "mxfp8_e4m3", + .blck_size = QK_MXFP8, + .type_size = sizeof(block_mxfp8), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_mxfp8, + .from_float_ref = (ggml_from_float_t)quantize_row_mxfp8_ref, + }, + [GGML_TYPE_MXFP6_E2M3] = { + .type_name = "mxfp6_e2m3", + .blck_size = QK_MXFP6, + .type_size = sizeof(block_mxfp6), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_mxfp6_e2m3, + .from_float_ref = (ggml_from_float_t)quantize_row_mxfp6_e2m3_ref, + }, [GGML_TYPE_Q2_K] = { .type_name = "q2_K", .blck_size = QK_K, @@ -1306,6 +1322,30 @@ bool ggml_is_quantized(enum ggml_type type) { return type_traits[type].is_quantized; } +bool ggml_is_type_mxfp(enum ggml_type type) { + return type == GGML_TYPE_MXFP4_E2M1 || + type == GGML_TYPE_MXFP8_E4M3 || + type == GGML_TYPE_MXFP6_E2M3; +} + +bool ggml_mxfp_use_hadamard(enum ggml_type type) { + switch (type) { + case GGML_TYPE_MXFP4_E2M1: return MXFP_USE_HADAMARD_E2M1; + case GGML_TYPE_MXFP8_E4M3: return MXFP_USE_HADAMARD_E4M3; + case GGML_TYPE_MXFP6_E2M3: return MXFP_USE_HADAMARD_E2M3; + default: return false; + } +} + +int ggml_mxfp_qs_per_block(enum ggml_type type) { + switch (type) { + case GGML_TYPE_MXFP4_E2M1: return MXFP_QS_PER_BLOCK_E2M1; + case GGML_TYPE_MXFP8_E4M3: return MXFP_QS_PER_BLOCK_E4M3; + case GGML_TYPE_MXFP6_E2M3: return MXFP_QS_PER_BLOCK_E2M3; + default: return 0; + } +} + const char * ggml_op_name(enum ggml_op op) { return GGML_OP_NAME[op]; } @@ -1381,7 +1421,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; - case GGML_FTYPE_MOSTLY_MXFP4: wtype = GGML_TYPE_MXFP4; break; + case GGML_FTYPE_MOSTLY_MXFP4_E2M1: wtype = GGML_TYPE_MXFP4_E2M1; break; case GGML_FTYPE_MOSTLY_NVFP4: wtype = GGML_TYPE_NVFP4; break; case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break; case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break; @@ -7649,8 +7689,10 @@ size_t ggml_quantize_chunk( case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_MXFP4_E2M1: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_MXFP8_E4M3: result = quantize_mxfp8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_MXFP6_E2M3: result = quantize_mxfp6_e2m3(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 01166fac9c..fcd784e79d 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -51,6 +51,7 @@ llama_kv_cache::llama_kv_cache( auto it = ctx_map.find(buft); if (it == ctx_map.end()) { ggml_init_params params = { + // 2 base tensors (K+V) + 2*n_stream view tensors /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, @@ -135,7 +136,17 @@ llama_kv_cache::llama_kv_cache( const bool has_k = true; const bool has_v = !is_mla; - ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream) : nullptr; + // MXFP K cache: align block count to 16 for cp.async. + uint32_t n_embd_k_alloc = n_embd_k_gqa; + const bool is_mxfp_k = ggml_is_type_mxfp(type_k); + if (is_mxfp_k) { + const int qk = (int)ggml_blck_size(type_k); // 32 for all MXFP types + const int blocks = (int)n_embd_k_gqa / qk; + const int blocks_aligned = (blocks + 15) & ~15; // align to 16 + n_embd_k_alloc = (uint32_t)(blocks_aligned * qk); + } + + ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_alloc, kv_size, n_stream) : nullptr; ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream) : nullptr; has_k && ggml_format_name(k, "cache_k_l%d", il); @@ -1025,19 +1036,16 @@ ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_k auto * k = layers[ikv].k; - const uint64_t kv_size = get_size(); - const uint64_t n_embd_k_gqa = k->ne[0]; - - assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il)); - + // For MXFP types: k->ne[0] may include alignment padding (blocks aligned to 16). + // The row stride (k->nb[1]) reflects the padded allocation. const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; return ggml_view_4d(ctx, k, hparams.n_embd_head_k(il), hparams.n_head_kv(il), n_kv, ns, ggml_row_size(k->type, hparams.n_embd_head_k(il)), - ggml_row_size(k->type, n_embd_k_gqa), - ggml_row_size(k->type, n_embd_k_gqa*kv_size), - ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0); + k->nb[1], + k->nb[2], + k->nb[2]*sinfo.s0); } ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const { @@ -1092,19 +1100,38 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_tokens, k_cur->nb[2], 0); const int64_t n_stream = k->ne[2]; + const int64_t kv_size = get_size(); if (n_stream > 1) { - const int64_t kv_size = get_size(); - - assert(n_embd_gqa == k->ne[0]); - assert(kv_size == k->ne[1]); + assert(kv_size == k->ne[1]); // merge the buffer across all streams because the idxs are global - k = ggml_reshape_2d(ctx, k, n_embd_gqa, kv_size*n_stream); + // Use view_2d to preserve nb[1] (which includes alignment padding for MXFP types) + k = ggml_view_2d(ctx, k, k->ne[0], kv_size*n_stream, k->nb[1], 0); + } + + const bool is_mxfp = ggml_is_type_mxfp(k->type); + + // For MXFP: ne[0] may be padded for block alignment, but k_cur has n_embd_gqa. + // Create view with ne[0]=n_embd_gqa, preserving the larger row stride nb[1]. + ggml_tensor * k_dst = k; + if (is_mxfp) { + k_dst = ggml_view_2d(ctx, k, n_embd_gqa, k->ne[1], k->nb[1], 0); } // store the current K values into the cache - return ggml_set_rows(ctx, k, k_cur, k_idxs); + ggml_tensor * result = ggml_set_rows(ctx, k_dst, k_cur, k_idxs); + + // Flag K cache writes for Walsh-Hadamard rotation (QuaRot, arXiv:2404.00456; BRQ, arXiv:2511.04214). + // The flash attention kernel applies matching rotation to Q so H(Q)·H(K)^T = Q·K^T. + // V cache writes are NOT rotated (op_params[0] defaults to 0). + // Skipped for: MLA (V is a view of K — rotation would corrupt V), + // E5M2/E3M2 (2-bit mantissa — Hadamard provides no quality benefit). + if (is_mxfp && !hparams.is_mla() && ggml_mxfp_use_hadamard(k->type)) { + ((int32_t *)result->op_params)[0] = 1; + } + + return result; } ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const { diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 8e8ce23124..279c57e582 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -457,7 +457,7 @@ static ggml_type llama_tensor_get_type_impl(quantize_state_impl & qs, ggml_type // MoE tensors -> MXFP4 // other tensors -> Q8_0 if (tensor->ne[2] > 1) { - new_type = GGML_TYPE_MXFP4; + new_type = GGML_TYPE_MXFP4_E2M1; } else { new_type = GGML_TYPE_Q8_0; } @@ -795,7 +795,7 @@ static ggml_type llama_ftype_get_default_type(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_BF16: return GGML_TYPE_BF16; case LLAMA_FTYPE_ALL_F32: return GGML_TYPE_F32; - case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return GGML_TYPE_MXFP4; + case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return GGML_TYPE_MXFP4_E2M1; // K-quants case LLAMA_FTYPE_MOSTLY_Q2_K_S: diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c9896cc11e..6ddef63336 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -150,6 +150,91 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m } } +// SoA quantize/dequantize functions — declared here because ggml-quants.h is not in the test include path. +typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t); +extern "C" { + void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); + void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); + void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); + void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); + void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); + void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); +} + +// Initialize an MXFP tensor with SoA (Struct-of-Arrays) layout. +// soa_bytes: byte width of one SoA region. Default 0 = ne[0] elements (one ggml row). +// For FA K/V tensors, pass nb[1] so that when heads are physically contiguous +// within one KV-position stride, the SoA region spans all heads (matching FA's read pattern). +static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f, size_t soa_bytes = 0) { + GGML_ASSERT(ggml_is_type_mxfp(tensor->type)); + + typedef void (*soa_quantize_fn)(const float *, void *, int64_t); + soa_quantize_fn quantize_soa = nullptr; + switch (tensor->type) { + case GGML_TYPE_MXFP4_E2M1: quantize_soa = quantize_row_mxfp4_soa; break; + case GGML_TYPE_MXFP8_E4M3: quantize_soa = quantize_row_mxfp8_soa; break; + case GGML_TYPE_MXFP6_E2M3: quantize_soa = quantize_row_mxfp6_soa; break; + default: GGML_ABORT("unsupported MXFP type for SoA init"); + } + + const int qk = (int)ggml_blck_size(tensor->type); + const size_t block_size = ggml_type_size(tensor->type); + const size_t head_row_sz = ggml_row_size(tensor->type, tensor->ne[0]); + if (soa_bytes == 0) { soa_bytes = head_row_sz; } + const int64_t soa_elems = (int64_t)(soa_bytes / block_size) * qk; + + std::default_random_engine gen(42); + std::uniform_real_distribution dist(min, max); + std::vector region_f32(soa_elems); + + // Iterate over logical SoA regions using tensor strides. + // Each SoA region is soa_bytes wide at the innermost stride level. + // Outer dimensions (those with stride > soa_bytes) are iterated explicitly. + const size_t nb1 = tensor->nb[1]; + const size_t nb2 = tensor->nb[2]; + const size_t nb3 = tensor->nb[3]; + const int64_t ne1 = tensor->ne[1]; + const int64_t ne2 = tensor->ne[2]; + const int64_t ne3 = tensor->ne[3]; + + // Determine iteration: if soa_bytes == nb1, iterate over (ne1 * ne2 * ne3) regions. + // If soa_bytes < nb1 (per-head), iterate over (ne1 * ne2 * ne3) regions with stride nb1. + // We use strides to compute offsets, handling views and permutations correctly. + const int64_t heads_per_region = (int64_t)(soa_bytes / head_row_sz); + + // For multi-head regions, we step by nb1 (KV-position stride) between regions. + // For per-head, we step through all dimensions. + std::vector buf(ggml_nbytes(tensor), 0); + + if (heads_per_region > 1) { + // Multi-head SoA: iterate over (kv_positions * batches), each region = nb1 bytes + for (int64_t i3 = 0; i3 < ne3; i3++) { + // ne2/heads_per_region = number of head groups (for GQA broadcast, usually 1) + const int64_t n_groups = ne2 / heads_per_region; + for (int64_t ig = 0; ig < n_groups; ig++) { + for (int64_t i1 = 0; i1 < ne1; i1++) { + size_t offset = i3*nb3 + ig*heads_per_region*nb2 + i1*nb1; + for (int64_t j = 0; j < soa_elems; j++) { region_f32[j] = dist(gen); } + quantize_soa(region_f32.data(), buf.data() + offset, soa_elems); + } + } + } + } else { + // Per-head SoA: one SoA region per ggml row + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = 0; i2 < ne2; i2++) { + for (int64_t i1 = 0; i1 < ne1; i1++) { + size_t offset = i3*nb3 + i2*nb2 + i1*nb1; + for (int64_t j = 0; j < soa_elems; j++) { region_f32[j] = dist(gen); } + quantize_soa(region_f32.data(), buf.data() + offset, soa_elems); + } + } + } + } + + ggml_backend_tensor_set(tensor, buf.data(), 0, buf.size()); +} + // generate an F16 mask where certain blocks are randomly masked with -INF value static void init_tensor_kq_mask(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { GGML_ASSERT(tensor->type == GGML_TYPE_F16); @@ -239,11 +324,30 @@ static std::vector tensor_to_float(const ggml_tensor * t) { size_t bs = ggml_blck_size(t->type); std::vector vq(ggml_blck_size(t->type)); bool quantized = ggml_is_quantized(t->type); + const bool is_mxfp = ggml_is_type_mxfp(t->type); + + // SoA dequant for MXFP readback + mxfp_soa_dequantize_fn mxfp_dequant_soa = nullptr; + if (is_mxfp) { + switch (t->type) { + case GGML_TYPE_MXFP4_E2M1: mxfp_dequant_soa = dequantize_row_mxfp4_soa; break; + case GGML_TYPE_MXFP8_E4M3: mxfp_dequant_soa = dequantize_row_mxfp8_soa; break; + case GGML_TYPE_MXFP6_E2M3: mxfp_dequant_soa = dequantize_row_mxfp6_soa; break; + default: GGML_ABORT("unsupported MXFP type in tensor_to_float"); + } + } // access elements by index to avoid gaps in views for (int64_t i3 = 0; i3 < t->ne[3]; i3++) { for (int64_t i2 = 0; i2 < t->ne[2]; i2++) { for (int64_t i1 = 0; i1 < t->ne[1]; i1++) { + if (is_mxfp) { + size_t row_off = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1]; + std::vector row_f32(t->ne[0]); + mxfp_dequant_soa(&buf[row_off], row_f32.data(), t->ne[0]); + tv.insert(tv.end(), row_f32.begin(), row_f32.end()); + continue; + } for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) { size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0]; if (t->type == GGML_TYPE_F16) { @@ -2309,8 +2413,12 @@ struct test_set_rows : public test_case { const std::array nr23; // broadcast only dims 2 and 3 const int r; // rows to set const bool v; // view (non-contiguous src1) + const bool hadamard; // apply Walsh-Hadamard rotation before quantization std::string vars() override { + if (hadamard) { + return VARS_TO_STR6(type, type_idx, ne, nr23, r, v) + ",hadamard=1"; + } return VARS_TO_STR6(type, type_idx, ne, nr23, r, v); } @@ -2318,8 +2426,8 @@ struct test_set_rows : public test_case { ggml_type type_idx, std::array ne, std::array nr23, - int r, bool v = false) - : type(type), type_idx(type_idx), ne(ne), nr23(nr23), r(r), v(v) {} + int r, bool v = false, bool hadamard = false) + : type(type), type_idx(type_idx), ne(ne), nr23(nr23), r(r), v(v), hadamard(hadamard) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]); @@ -2338,6 +2446,11 @@ struct test_set_rows : public test_case { } ggml_tensor * out = ggml_set_rows(ctx, dst, src, row_idxs); + + if (hadamard) { + ((int32_t *)out->op_params)[0] = 1; + } + ggml_set_name(out, "out"); return out; @@ -2351,6 +2464,10 @@ struct test_set_rows : public test_case { } init_set_rows_row_ids(t, ne[1]); + } else if (ggml_is_type_mxfp(t->type)) { + // MXFP dst tensors must use SoA layout — set_rows writes SoA, + // and tensor_to_float reads back assuming SoA for MXFP types. + init_tensor_mxfp_soa(t); } else { init_tensor_uniform(t); } @@ -3798,7 +3915,7 @@ struct test_mul_mat : public test_case { double max_nmse_err(ggml_backend_t backend) override { // for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance - if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) { + if (type_a == GGML_TYPE_MXFP4_E2M1 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) { return 2e-2; } return max_nmse_err(); @@ -3932,9 +4049,10 @@ struct test_mul_mat_id : public test_case { return 5e-4; } + // Same Blackwell FP4 tolerance as test_mul_mat above. double max_nmse_err(ggml_backend_t backend) override { // for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance - if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) { + if (type_a == GGML_TYPE_MXFP4_E2M1 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) { return 2e-2; } return max_nmse_err(); @@ -6180,9 +6298,14 @@ struct test_flash_attn_ext : public test_case { const ggml_prec prec; const ggml_type type_KV; + const ggml_type type_V; // V type, defaults to type_KV for same-type K/V std::array permute; std::string vars() override { + if (type_V != type_KV) { + return VARS_TO_STR13(hsk, hsv, nh, nr23, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, permute) + + ",type_V=" + ggml_type_name(type_V); + } return VARS_TO_STR13(hsk, hsv, nh, nr23, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, permute); } @@ -6199,12 +6322,14 @@ struct test_flash_attn_ext : public test_case { test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, std::array nr23 = {1, 1}, int64_t kv = 96, int64_t nb = 8, bool mask = true, bool sinks = false, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32, - ggml_type type_KV = GGML_TYPE_F16, std::array permute = {0, 1, 2, 3}) - : hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), sinks(sinks), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {} + ggml_type type_KV = GGML_TYPE_F16, std::array permute = {0, 1, 2, 3}, + ggml_type type_V_override = GGML_TYPE_COUNT) + : hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), sinks(sinks), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), + type_V(type_V_override == GGML_TYPE_COUNT ? type_KV : type_V_override), permute(permute) {} ggml_tensor * build_graph(ggml_context * ctx) override { const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV)); - const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_KV)); + const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_V)); auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, bool is_view) -> ggml_tensor * { int64_t ne[4] = {ne0, ne1, ne2, ne3}; @@ -6242,7 +6367,7 @@ struct test_flash_attn_ext : public test_case { // - https://github.com/ggml-org/llama.cpp/pull/18986 v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], 0); } else { - v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache + v = create_permuted(type_V, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache } ggml_set_name(v, "v"); @@ -6273,6 +6398,11 @@ struct test_flash_attn_ext : public test_case { init_tensor_uniform(t, -10.0f, 10.0f); } else if (strcmp(t->name, "m") == 0) { init_tensor_kq_mask(t); + } else if ((strcmp(t->name, "k") == 0 || strcmp(t->name, "v") == 0) && ggml_is_type_mxfp(t->type)) { + // MXFP K/V tensors use SoA layout. Pass nb[1] (KV-position stride) as the + // SoA region width — when heads are physically contiguous within that stride, + // the FA kernel dequants the full multi-head region as one SoA block. + init_tensor_mxfp_soa(t, -1.0f, 1.0f, t->nb[1]); } else { init_tensor_uniform(t); } @@ -7279,7 +7409,8 @@ static const ggml_type all_types[] = { GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0, - GGML_TYPE_MXFP4, + GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP8_E4M3, + GGML_TYPE_MXFP6_E2M3, GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K, @@ -7295,7 +7426,7 @@ static const ggml_type base_types[] = { GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, // for I8MM tests GGML_TYPE_Q4_K, - GGML_TYPE_MXFP4, // TODO: or "other" + GGML_TYPE_MXFP4_E2M1, GGML_TYPE_IQ2_XXS }; @@ -7413,6 +7544,14 @@ static std::vector> make_test_cases_eval() { } } + // SET_ROWS with Hadamard rotation (exercises the op_params[0] flag used by MXFP KV cache) + for (ggml_type type : {GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP8_E4M3, + GGML_TYPE_MXFP6_E2M3}) { + // ne[0] must be divisible by 32 (Hadamard block size) + test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 128, 5, 1, 1 }, { 1, 1 }, 1, false, true)); + test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 5, 1, 3 }, { 1, 1 }, 1, false, true)); + } + for (int mode : { GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_VISION }) { for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) { for (int ne2 : {1, 8, 512}) { @@ -8143,7 +8282,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat_id_fusion(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3)); // gpt-oss issue with Vulkan mmq_id - test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880)); + test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4_E2M1, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880)); for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { @@ -8603,8 +8742,13 @@ static std::vector> make_test_cases_eval() { for (int nb : { 1, 3, 32, 75, }) { for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) { if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue; - for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { - if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72) continue; + for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, + GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3, + }) { + // Non-F16 types: test at D=64, D=72, and D=128. + if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72 && hsk != 128) continue; + // MXFP types require D % 32 == 0, skip D=72. + if (ggml_is_type_mxfp(type_KV) && hsk == 72) continue; test_cases.emplace_back(new test_flash_attn_ext( hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV)); // run fewer test cases permuted @@ -8626,6 +8770,26 @@ static std::vector> make_test_cases_eval() { } } + // MXFP-specific K/V type combinations (mixed and same-type) + // Mixed: mxfp8 K + mxfp4 V, mxfp6 K + mxfp4 V (our recommended configs) + for (ggml_type type_K : {GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3}) { + for (ggml_type type_V : {GGML_TYPE_MXFP4_E2M1}) { + if (type_K == type_V) continue; + for (int nb : {1, 3, 32}) { + // hsk hsv nh nr23 kv nb mask sinks bias softcap prec type_K permute type_V + test_cases.emplace_back(new test_flash_attn_ext( + 128, 128, 4, {1, 1}, 512, nb, true, false, 0.0f, 0.0f, GGML_PREC_F32, type_K, {0, 1, 2, 3}, type_V)); + } + } + } + // Same-type: mxfp8/mxfp8, mxfp6/mxfp6 + for (ggml_type type_KV : {GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3}) { + for (int nb : {1, 3, 32}) { + test_cases.emplace_back(new test_flash_attn_ext( + 128, 128, 4, {1, 1}, 512, nb, true, false, 0.0f, 0.0f, GGML_PREC_F32, type_KV, {0, 1, 2, 3}, type_KV)); + } + } + test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, { 10, 5, 4, 3})); test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, {30000, 1, 1, 1})); test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, { 10, 5, 4, 3})); @@ -8849,7 +9013,7 @@ static std::vector> make_test_cases_perf() { // gpt-oss-20b for (int bs : {1, 4, 8, 512}) { - for (ggml_type type_a : {GGML_TYPE_MXFP4}) { + for (ggml_type type_a : {GGML_TYPE_MXFP4_E2M1}) { for (ggml_type type_b : {GGML_TYPE_F32}) { test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 2880, bs, 2880)); test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1)); diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index b0f1d6b936..10bf6d8ac1 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -483,7 +483,15 @@ static ggml_type ggml_type_from_name(const std::string & s) { if (s == "iq4_nl") { return GGML_TYPE_IQ4_NL; } - + if (s == "mxfp4" || s == "mxfp4_e2m1") { + return GGML_TYPE_MXFP4_E2M1; + } + if (s == "mxfp8" || s == "mxfp8_e4m3") { + return GGML_TYPE_MXFP8_E4M3; + } + if (s == "mxfp6" || s == "mxfp6_e2m3") { + return GGML_TYPE_MXFP6_E2M3; + } return GGML_TYPE_COUNT; } From a51ff77fae4f698e22d1d8ee9a88e8aa195b4be3 Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Sun, 15 Mar 2026 18:57:12 -0400 Subject: [PATCH 02/13] =?UTF-8?q?ggml:=20address=20PR=20review=20=E2=80=94?= =?UTF-8?q?=20fix=20buffer=20overflows,=20add=20assertions,=20normalize=20?= =?UTF-8?q?MXFP6=20naming?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix potential buffer overflows flagged in PR #20609 review: - set_rows: replace fixed float tmp[1024] with std::vector for large n_embd_k_gqa - tiled FA: size q_mxfp_buf with ggml_row_size guard instead of fixed 1024 - one_chunk FA: pre-allocate k/v dequant buffers from mxfp.{k,v}_soa_elems instead of hard-coded float[4096] stack arrays - kv-cache: assert n_embd_k_gqa % qk == 0 before integer division - test init: assert soa_bytes % block_size == 0 Normalize MXFP6 function naming to match MXFP8 convention (short form without element format suffix): mxfp6_e2m3 → mxfp6 in all function identifiers across 14 files. Format-specific items (type enums, traits, lookup tables, constants) retain their _e2m3 suffix. --- ggml/src/ggml-cpu/arch-fallback.h | 4 +-- ggml/src/ggml-cpu/arch/arm/quants.c | 8 +++--- ggml/src/ggml-cpu/arch/loongarch/quants.c | 4 +-- ggml/src/ggml-cpu/arch/powerpc/quants.c | 4 +-- ggml/src/ggml-cpu/arch/riscv/quants.c | 4 +-- ggml/src/ggml-cpu/arch/s390/quants.c | 4 +-- ggml/src/ggml-cpu/arch/wasm/quants.c | 4 +-- ggml/src/ggml-cpu/arch/x86/quants.c | 8 +++--- ggml/src/ggml-cpu/ggml-cpu.c | 6 ++--- ggml/src/ggml-cpu/ops.cpp | 30 +++++++++++------------ ggml/src/ggml-cpu/quants.c | 12 ++++----- ggml/src/ggml-cpu/quants.h | 10 ++++---- ggml/src/ggml-quants.c | 8 +++--- ggml/src/ggml-quants.h | 6 ++--- ggml/src/ggml.c | 6 ++--- src/llama-kv-cache.cpp | 1 + tests/test-backend-ops.cpp | 1 + 17 files changed, 61 insertions(+), 59 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 42647e14e1..612786e941 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -17,7 +17,7 @@ #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0 -#define ggml_vec_dot_mxfp6_e2m3_q8_0_generic ggml_vec_dot_mxfp6_e2m3_q8_0 +#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -349,7 +349,7 @@ #if !defined(__aarch64__) && !defined(__arm__) && !defined(_M_ARM) && !defined(_M_ARM64) && \ !defined(__x86_64__) && !defined(__i386__) && !defined(_M_IX86) && !defined(_M_X64) #define dequantize_row_mxfp8_cpu_generic dequantize_row_mxfp8_cpu -#define dequantize_row_mxfp6_e2m3_cpu_generic dequantize_row_mxfp6_e2m3_cpu +#define dequantize_row_mxfp6_cpu_generic dequantize_row_mxfp6_cpu #define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu #define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu #define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index 0f0ba86518..b18a276640 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -4333,7 +4333,7 @@ static inline void ggml_vec_dot_mxfp6_q8_0_neon( } #endif -void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); #if defined(__ARM_NEON) @@ -4342,7 +4342,7 @@ void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, con MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); #else - ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); + ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } @@ -4471,13 +4471,13 @@ void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRIC #endif } -void dequantize_row_mxfp6_e2m3_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { #if defined(__ARM_NEON) dequantize_row_mxfp6_neon(x, y, k, sizeof(block_mxfp6), MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); #else - dequantize_row_mxfp6_e2m3_cpu_generic(x, y, k); + dequantize_row_mxfp6_cpu_generic(x, y, k); #endif } diff --git a/ggml/src/ggml-cpu/arch/loongarch/quants.c b/ggml/src/ggml-cpu/arch/loongarch/quants.c index a75dac8b15..fa05e49c5d 100644 --- a/ggml/src/ggml-cpu/arch/loongarch/quants.c +++ b/ggml/src/ggml-cpu/arch/loongarch/quants.c @@ -2165,6 +2165,6 @@ void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); } -void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); } diff --git a/ggml/src/ggml-cpu/arch/powerpc/quants.c b/ggml/src/ggml-cpu/arch/powerpc/quants.c index 82ca1f9df9..efb669da09 100644 --- a/ggml/src/ggml-cpu/arch/powerpc/quants.c +++ b/ggml/src/ggml-cpu/arch/powerpc/quants.c @@ -2307,6 +2307,6 @@ void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); } -void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); } diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index dcb97756c6..beef1885da 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -3612,6 +3612,6 @@ void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); } -void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); } diff --git a/ggml/src/ggml-cpu/arch/s390/quants.c b/ggml/src/ggml-cpu/arch/s390/quants.c index 234488f25c..e696fd4570 100644 --- a/ggml/src/ggml-cpu/arch/s390/quants.c +++ b/ggml/src/ggml-cpu/arch/s390/quants.c @@ -1468,6 +1468,6 @@ void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); } -void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); } diff --git a/ggml/src/ggml-cpu/arch/wasm/quants.c b/ggml/src/ggml-cpu/arch/wasm/quants.c index 88bc6ad778..a3ae8e8885 100644 --- a/ggml/src/ggml-cpu/arch/wasm/quants.c +++ b/ggml/src/ggml-cpu/arch/wasm/quants.c @@ -1227,6 +1227,6 @@ void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); } -void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); } diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index 29d5a28759..0c6f6ed49a 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -3995,7 +3995,7 @@ static inline void ggml_vec_dot_mxfp6_q8_0_avx2( } #endif -void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); #if defined(__AVX2__) @@ -4004,7 +4004,7 @@ void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, con MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); #else - ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); + ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } @@ -4130,13 +4130,13 @@ void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRIC #endif } -void dequantize_row_mxfp6_e2m3_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { #if defined(__AVX2__) dequantize_row_mxfp6_avx2(x, y, k, sizeof(block_mxfp6), MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); #else - dequantize_row_mxfp6_e2m3_cpu_generic(x, y, k); + dequantize_row_mxfp6_cpu_generic(x, y, k); #endif } diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index a87f808c95..7b7fb1e5ce 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -284,9 +284,9 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .nrows = 1, }, [GGML_TYPE_MXFP6_E2M3] = { - .from_float = quantize_row_mxfp6_e2m3, - .to_float = dequantize_row_mxfp6_e2m3_cpu, - .vec_dot = ggml_vec_dot_mxfp6_e2m3_q8_0, + .from_float = quantize_row_mxfp6, + .to_float = dequantize_row_mxfp6_cpu, + .vec_dot = ggml_vec_dot_mxfp6_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, }, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 02cd1abb8d..2267eaa27b 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -5061,14 +5061,13 @@ static void ggml_compute_forward_set_rows_f32( char * dst_row = ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3); if (apply_hadamard) { - GGML_ASSERT(nc <= 1024); - float tmp[1024]; - memcpy(tmp, src_row, nc * sizeof(float)); - ggml_apply_hadamard_blocks(tmp, nc); + std::vector tmp(nc); + memcpy(tmp.data(), src_row, nc * sizeof(float)); + ggml_apply_hadamard_blocks(tmp.data(), nc); if (mxfp_soa_quantize) { - mxfp_soa_quantize(tmp, dst_row, nc); + mxfp_soa_quantize(tmp.data(), dst_row, nc); } else { - from_float(tmp, dst_row, nc); + from_float(tmp.data(), dst_row, nc); } } else { if (mxfp_soa_quantize) { @@ -8418,6 +8417,10 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( int ith = params->ith; + // Pre-allocate dequant buffers for MXFP SoA (avoids per-iteration allocation) + std::vector k_dequant_buf(is_mxfp_k ? mxfp.k_soa_elems : 0); + std::vector v_dequant_buf(is_mxfp_v ? mxfp.v_soa_elems : 0); + for (int ir = ir0; ir < ir1; ++ir) { // q indices const int iq3 = ir/(neq2*neq1); @@ -8497,10 +8500,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( const char * k_soa_base = mxfp.k_multihead ? ((const char *) k->data + ic*nbk1 + ik3*nbk3) : k_data; - float k_soa_f32[4096]; - GGML_ASSERT(mxfp.k_soa_elems <= 4096); - mxfp.k_dequantize(k_soa_base, k_soa_f32, mxfp.k_soa_elems); - const float * k_head = k_soa_f32 + (mxfp.k_multihead ? ik2 * DK : 0); + mxfp.k_dequantize(k_soa_base, k_dequant_buf.data(), mxfp.k_soa_elems); + const float * k_head = k_dequant_buf.data() + (mxfp.k_multihead ? ik2 * DK : 0); ggml_vec_dot_f32(DK, &s, 0, k_head, 0, Q_f32, 0, 1); } else { kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1); @@ -8554,10 +8555,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( const char * v_soa_base = mxfp.v_multihead ? ((const char *) v->data + ic*nbv1 + iv3*nbv3) : v_data; - float v_soa_f32[4096]; - GGML_ASSERT(mxfp.v_soa_elems <= 4096); - mxfp.v_dequantize(v_soa_base, v_soa_f32, mxfp.v_soa_elems); - ggml_vec_mad_f32(DV, VKQ32, v_soa_f32 + (mxfp.v_multihead ? iv2 * DV : 0), vs); + mxfp.v_dequantize(v_soa_base, v_dequant_buf.data(), mxfp.v_soa_elems); + ggml_vec_mad_f32(DV, VKQ32, v_dequant_buf.data() + (mxfp.v_multihead ? iv2 * DV : 0), vs); } else if (v_to_float) { v_to_float(v_data, V32, DV); ggml_vec_mad_f32(DV, VKQ32, V32, vs); @@ -8765,7 +8764,8 @@ static void ggml_compute_forward_flash_attn_ext_tiled( ggml_apply_hadamard_blocks(Q_f32 + tq * DK, DK); } // SoA round-trip: quantize Q to SoA, then dequant back to float. - uint8_t q_mxfp_buf[1024]; + uint8_t q_mxfp_buf[512]; // max: DK=256 * 33/32 = 264 bytes (MXFP8) + GGML_ASSERT(ggml_row_size(k->type, DK) <= sizeof(q_mxfp_buf)); mxfp.q_quantize(Q_f32 + tq * DK, q_mxfp_buf, DK); mxfp.k_dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK); } diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index 9152755010..7303638c81 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -58,8 +58,8 @@ void quantize_row_mxfp8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, i quantize_row_mxfp8_ref(x, y, k); } -void quantize_row_mxfp6_e2m3(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { - quantize_row_mxfp6_e2m3_ref(x, y, k); +void quantize_row_mxfp6(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_mxfp6_ref(x, y, k); } // @@ -301,14 +301,14 @@ void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, (ggml_to_float_t)dequantize_row_mxfp8); } -void ggml_vec_dot_mxfp6_e2m3_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); ggml_vec_dot_mxfp_q8_0_impl(n, s, vx, sizeof(block_mxfp6), vy, - (ggml_to_float_t)dequantize_row_mxfp6_e2m3); + (ggml_to_float_t)dequantize_row_mxfp6); } // Generic (scalar) dequant wrappers — delegates to ggml-quants.c reference implementations. @@ -316,8 +316,8 @@ void ggml_vec_dot_mxfp6_e2m3_q8_0_generic(int n, float * GGML_RESTRICT s, size_t void dequantize_row_mxfp8_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { dequantize_row_mxfp8(x, y, k); } -void dequantize_row_mxfp6_e2m3_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { - dequantize_row_mxfp6_e2m3(x, y, k); +void dequantize_row_mxfp6_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + dequantize_row_mxfp6(x, y, k); } void dequantize_row_mxfp4_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { dequantize_row_mxfp4_soa(x, y, k); diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index 7d8c32762a..0a7ea64135 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -22,11 +22,11 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_mxfp8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_mxfp6_e2m3(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_mxfp6(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); // Dequantization (SIMD-optimized, arch-dispatched) void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_mxfp6_e2m3_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -51,7 +51,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -85,10 +85,10 @@ void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_mxfp6_e2m3_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void dequantize_row_mxfp8_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_mxfp6_e2m3_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_mxfp6_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); // SoA dequant (SIMD-optimized for FA) void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index b2692c45f6..188c7e68b6 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -797,11 +797,11 @@ void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_REST dequantize_row_mxfp8_impl(x, y, k, &mxfp8_e4m3_traits); } -void quantize_row_mxfp6_e2m3_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k) { +void quantize_row_mxfp6_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k) { quantize_row_mxfp6_impl(x, y, k, &mxfp6_e2m3_traits); } -void dequantize_row_mxfp6_e2m3(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { dequantize_row_mxfp6_impl(x, y, k, &mxfp6_e2m3_traits); } @@ -2627,9 +2627,9 @@ size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, return nrow * ggml_row_size(GGML_TYPE_MXFP8_E4M3, n_per_row); } -size_t quantize_mxfp6_e2m3(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { +size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { GGML_UNUSED(quant_weights); - quantize_row_mxfp6_e2m3_ref(src, dst, (int64_t)nrow*n_per_row); + quantize_row_mxfp6_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_MXFP6_E2M3, n_per_row); } diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 33401f2843..a0f6928e10 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -24,7 +24,7 @@ GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k); -GGML_API void quantize_row_mxfp6_e2m3_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_mxfp6_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); @@ -53,7 +53,7 @@ GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GG GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -GGML_API void dequantize_row_mxfp6_e2m3(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); // SoA (Struct-of-Arrays) quantize/dequantize — canonical reference for flash attention. // Layout: [qs contiguous][e8m0 contiguous] per row. Same total bytes as AoS. @@ -112,7 +112,7 @@ GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTR GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -GGML_API size_t quantize_mxfp6_e2m3(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); // // MXFP element-level conversion functions (reference implementations) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 329f2b93b3..37b99e844e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -739,8 +739,8 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .blck_size = QK_MXFP6, .type_size = sizeof(block_mxfp6), .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_mxfp6_e2m3, - .from_float_ref = (ggml_from_float_t)quantize_row_mxfp6_e2m3_ref, + .to_float = (ggml_to_float_t) dequantize_row_mxfp6, + .from_float_ref = (ggml_from_float_t)quantize_row_mxfp6_ref, }, [GGML_TYPE_Q2_K] = { .type_name = "q2_K", @@ -7692,7 +7692,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_MXFP4_E2M1: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_MXFP8_E4M3: result = quantize_mxfp8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_MXFP6_E2M3: result = quantize_mxfp6_e2m3(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_MXFP6_E2M3: result = quantize_mxfp6(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index fcd784e79d..29fdeb3f33 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -141,6 +141,7 @@ llama_kv_cache::llama_kv_cache( const bool is_mxfp_k = ggml_is_type_mxfp(type_k); if (is_mxfp_k) { const int qk = (int)ggml_blck_size(type_k); // 32 for all MXFP types + GGML_ASSERT(n_embd_k_gqa % qk == 0 && "MXFP K cache requires n_embd_k_gqa divisible by block size"); const int blocks = (int)n_embd_k_gqa / qk; const int blocks_aligned = (blocks + 15) & ~15; // align to 16 n_embd_k_alloc = (uint32_t)(blocks_aligned * qk); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6ddef63336..3915ab4db6 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -181,6 +181,7 @@ static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float const size_t block_size = ggml_type_size(tensor->type); const size_t head_row_sz = ggml_row_size(tensor->type, tensor->ne[0]); if (soa_bytes == 0) { soa_bytes = head_row_sz; } + GGML_ASSERT(soa_bytes % block_size == 0 && "soa_bytes must be a multiple of block_size"); const int64_t soa_elems = (int64_t)(soa_bytes / block_size) * qk; std::default_random_engine gen(42); From c2f2ff7814192503e9ddb6d499c624630f5ba681 Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Sun, 15 Mar 2026 19:46:09 -0400 Subject: [PATCH 03/13] ggml: optimize CPU MXFP flash attention hot loop - Per-head dequant: multihead MXFP now extracts only the needed head's SoA blocks (e.g. 20 bytes for mxfp4 DK=128) into a stack buffer and dequants DK elements, instead of dequanting all heads (nek2*DK). For 8 KV heads this is 8x less dequant work per KV position. - Hoist loop invariants: base pointer offsets (k_base, v_base), per-head SoA byte offsets, and multihead row bases are computed once per query row instead of per KV position in the inner loop. - Precompute SoA addressing in mxfp_fa_params_init: qs_per_block, blocks_per_head, head_qs_bytes, and head_e8m0_offset are calculated once at init rather than derived per iteration. - Move thread-local buffer pointers (VKQ32, V32, VKQ16, Q_q) and v_is_f16 check outside the ir loop. --- ggml/src/ggml-cpu/ops.cpp | 157 +++++++++++++++++++++++++++----------- 1 file changed, 111 insertions(+), 46 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 2267eaa27b..2b24e87ae6 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8271,6 +8271,18 @@ struct mxfp_fa_params { int64_t k_soa_elems; int64_t v_soa_elems; bool apply_hadamard; + // Per-head SoA addressing (avoids dequanting all heads in multihead mode). + // qs_per_block: bytes of quantized data per 32-element block. + // head_qs_bytes: total qs bytes for one head (blocks_per_head * qs_per_block). + // head_e8m0_offset: byte offset from row start to e8m0 region. + int k_qs_per_block; + int v_qs_per_block; + int k_head_qs_bytes; + int v_head_qs_bytes; + int64_t k_head_e8m0_offset; + int64_t v_head_e8m0_offset; + int k_blocks_per_head; + int v_blocks_per_head; }; static mxfp_fa_params mxfp_fa_params_init( @@ -8314,6 +8326,34 @@ static mxfp_fa_params mxfp_fa_params_init( p.v_multihead = is_mxfp_v && (nbv2 == (size_t)ggml_row_size(v->type, DV)); p.v_soa_elems = is_mxfp_v ? (p.v_multihead ? nev2 * DV : DV) : 0; + // Per-head SoA addressing for multihead mode. + // Precompute byte offsets so the hot loop can skip per-head pointer math. + auto mxfp_qs_per_block = [](ggml_type type) -> int { + switch (type) { + case GGML_TYPE_MXFP4_E2M1: return MXFP4_SOA_QS_PER_BLOCK; + case GGML_TYPE_MXFP8_E4M3: return MXFP8_SOA_QS_PER_BLOCK; + case GGML_TYPE_MXFP6_E2M3: return MXFP6_SOA_QS_PER_BLOCK; + default: return 0; + } + }; + + if (is_mxfp_k) { + p.k_qs_per_block = mxfp_qs_per_block(k->type); + p.k_blocks_per_head = (int)(DK / 32); + p.k_head_qs_bytes = p.k_blocks_per_head * p.k_qs_per_block; + // e8m0 offset from row start = total_blocks * qs_per_block + const int64_t k_total_blocks = p.k_multihead ? nek2 * p.k_blocks_per_head : p.k_blocks_per_head; + p.k_head_e8m0_offset = k_total_blocks * p.k_qs_per_block; + } + + if (is_mxfp_v) { + p.v_qs_per_block = mxfp_qs_per_block(v->type); + p.v_blocks_per_head = (int)(DV / 32); + p.v_head_qs_bytes = p.v_blocks_per_head * p.v_qs_per_block; + const int64_t v_total_blocks = p.v_multihead ? nev2 * p.v_blocks_per_head : p.v_blocks_per_head; + p.v_head_e8m0_offset = v_total_blocks * p.v_qs_per_block; + } + return p; } @@ -8417,15 +8457,33 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( int ith = params->ith; - // Pre-allocate dequant buffers for MXFP SoA (avoids per-iteration allocation) - std::vector k_dequant_buf(is_mxfp_k ? mxfp.k_soa_elems : 0); - std::vector v_dequant_buf(is_mxfp_v ? mxfp.v_soa_elems : 0); + // Dequant buffers for MXFP SoA — stack-allocated, no heap allocation in the hot path. + // In multihead mode, only dequant one head (DK or DV elements) instead of all heads. + // DK/DV are bounded by 1024 (asserted below for MXFP). + float k_dequant_buf[1024]; + float v_dequant_buf[1024]; + + // Per-head SoA temp buffer: holds [qs | e8m0] for one head in multihead mode. + // Max size: 32 bytes qs (mxfp8, DK=128) + 4 bytes e8m0 = 36 bytes per head. + // For DK up to 1024: 256 + 32 = 288 bytes. Use fixed-size stack buffer. + alignas(16) char k_head_soa[320]; // enough for DK up to 1024 with any MXFP type + alignas(16) char v_head_soa[320]; + + // Thread-local work buffers (constant across ir loop) + float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); + float * V32 = (VKQ32 + 1*DV); + ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); + ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); + + const bool v_is_f16 = (v->type == GGML_TYPE_F16); + const bool use_softcap = (logit_softcap != 0.0f); + const int64_t neq2_x_neq1 = neq2 * neq1; for (int ir = ir0; ir < ir1; ++ir) { // q indices - const int iq3 = ir/(neq2*neq1); - const int iq2 = (ir - iq3*neq2*neq1)/neq1; - const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + const int iq3 = ir / neq2_x_neq1; + const int iq2 = (ir - iq3*neq2_x_neq1) / neq1; + const int iq1 = (ir - iq3*neq2_x_neq1 - iq2*neq1); const uint32_t h = iq2; // head index const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; @@ -8433,12 +8491,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( float S = 0.0f; // sum float M = -INFINITY; // maximum KQ value - float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator - float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer - ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator - ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16 - - if (v->type == GGML_TYPE_F16) { + if (v_is_f16) { memset(VKQ16, 0, DV*sizeof(ggml_fp16_t)); } else { memset(VKQ32, 0, DV*sizeof(float)); @@ -8446,14 +8499,31 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL; - // k indices + // k/v head indices — constant for this query row const int ik3 = iq3 / rk3; const int ik2 = iq2 / rk2; - - // v indices const int iv3 = iq3 / rv3; const int iv2 = iq2 / rv2; + // Precompute loop-invariant base pointer offsets for K and V. + // Only ic varies in the inner loop; head/batch offsets are constant. + const size_t k_base_offset = ik2*nbk2 + ik3*nbk3; + const size_t v_base_offset = iv2*nbv2 + iv3*nbv3; + const char * k_base = (const char *) k->data + k_base_offset; + const char * v_base = (const char *) v->data + v_base_offset; + + // For multihead MXFP: precompute per-head SoA byte offsets (constant per query row). + // head_qs_start: byte offset to this head's qs blocks within the SoA row. + // head_e8m0_start: byte offset to this head's e8m0 scales within the SoA row. + const int k_head_qs_start = mxfp.k_multihead ? ik2 * mxfp.k_head_qs_bytes : 0; + const int k_head_e8m0_start = mxfp.k_multihead ? (int)mxfp.k_head_e8m0_offset + ik2 * mxfp.k_blocks_per_head : 0; + const int v_head_qs_start = mxfp.v_multihead ? iv2 * mxfp.v_head_qs_bytes : 0; + const int v_head_e8m0_start = mxfp.v_multihead ? (int)mxfp.v_head_e8m0_offset + iv2 * mxfp.v_blocks_per_head : 0; + + // Multihead MXFP row base (without head offset) — only ic varies. + const char * k_row_base = mxfp.k_multihead ? ((const char *) k->data + ik3*nbk3) : nullptr; + const char * v_row_base = mxfp.v_multihead ? ((const char *) v->data + iv3*nbv3) : nullptr; + const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); float Q_f32[1024]; if (is_mxfp_k) { @@ -8493,23 +8563,25 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( float s; // KQ value - const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); if (is_mxfp_k) { - // Dequant SoA data. Multi-head: full row base, extract head portion. - // Per-head: use k_data directly. - const char * k_soa_base = mxfp.k_multihead - ? ((const char *) k->data + ic*nbk1 + ik3*nbk3) - : k_data; - mxfp.k_dequantize(k_soa_base, k_dequant_buf.data(), mxfp.k_soa_elems); - const float * k_head = k_dequant_buf.data() + (mxfp.k_multihead ? ik2 * DK : 0); - ggml_vec_dot_f32(DK, &s, 0, k_head, 0, Q_f32, 0, 1); + if (mxfp.k_multihead) { + // Multihead: extract this head's SoA blocks into temp buffer, dequant only DK elements. + // Copy qs blocks then e8m0 scales for this head into contiguous [qs|e8m0] layout. + const char * row = k_row_base + ic*nbk1; + memcpy(k_head_soa, row + k_head_qs_start, mxfp.k_head_qs_bytes); + memcpy(k_head_soa + mxfp.k_head_qs_bytes, row + k_head_e8m0_start, mxfp.k_blocks_per_head); + mxfp.k_dequantize(k_head_soa, k_dequant_buf, DK); + } else { + mxfp.k_dequantize(k_base + ic*nbk1, k_dequant_buf, DK); + } + ggml_vec_dot_f32(DK, &s, 0, k_dequant_buf, 0, Q_f32, 0, 1); } else { - kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1); + kq_vec_dot(DK, &s, 0, k_base + ic*nbk1, 0, Q_q, 0, 1); } s = s*scale; // scale KQ value - if (logit_softcap != 0.0f) { + if (use_softcap) { s = logit_softcap*tanhf(s); } @@ -8520,49 +8592,42 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value float vs = 1.0f; // post-softmax KQ value, expf(s - M) - const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); - - if (v->type == GGML_TYPE_F16) { + if (v_is_f16) { if (s > M) { - // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f M = s; ms = expf(Mold - M); - - // V = V*expf(Mold - M) ggml_vec_scale_f16(DV, VKQ16, ms); } else { - // no new maximum, ms == 1.0f, vs != 1.0f vs = expf(s - M); } // V += v*expf(s - M) - ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs); + ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) (v_base + ic*nbv1), vs); } else { if (s > M) { - // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f M = s; ms = expf(Mold - M); - - // V = V*expf(Mold - M) ggml_vec_scale_f32(DV, VKQ32, ms); } else { - // no new maximum, ms == 1.0f, vs != 1.0f vs = expf(s - M); } // V += v*expf(s - M) if (mxfp.v_dequantize) { - const char * v_soa_base = mxfp.v_multihead - ? ((const char *) v->data + ic*nbv1 + iv3*nbv3) - : v_data; - mxfp.v_dequantize(v_soa_base, v_dequant_buf.data(), mxfp.v_soa_elems); - ggml_vec_mad_f32(DV, VKQ32, v_dequant_buf.data() + (mxfp.v_multihead ? iv2 * DV : 0), vs); + if (mxfp.v_multihead) { + const char * row = v_row_base + ic*nbv1; + memcpy(v_head_soa, row + v_head_qs_start, mxfp.v_head_qs_bytes); + memcpy(v_head_soa + mxfp.v_head_qs_bytes, row + v_head_e8m0_start, mxfp.v_blocks_per_head); + mxfp.v_dequantize(v_head_soa, v_dequant_buf, DV); + } else { + mxfp.v_dequantize(v_base + ic*nbv1, v_dequant_buf, DV); + } + ggml_vec_mad_f32(DV, VKQ32, v_dequant_buf, vs); } else if (v_to_float) { - v_to_float(v_data, V32, DV); + v_to_float(v_base + ic*nbv1, V32, DV); ggml_vec_mad_f32(DV, VKQ32, V32, vs); } else { - // V is F32 - ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs); + ggml_vec_mad_f32(DV, VKQ32, (const float *) (v_base + ic*nbv1), vs); } } From f603c036ec6d483683038bff503de51d4750877d Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Sun, 15 Mar 2026 20:12:38 -0400 Subject: [PATCH 04/13] Comment consistencty pass and cleanup. --- ggml/src/ggml-impl.h | 40 ++++++---- ggml/src/ggml-quants.c | 167 ++++++----------------------------------- src/llama-kv-cache.cpp | 17 ++--- 3 files changed, 54 insertions(+), 170 deletions(-) diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 8e5d931df5..90f020b0e8 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -430,13 +430,16 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x) #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) -// E8M0 shared exponent to float. Canonical source: ggml_mxfp_e8m0_to_fp32() in ggml-common.h. +// E8M0 shared exponent to float: returns 2^(x - 127). +// Canonical source: ggml_mxfp_e8m0_to_fp32() in ggml-common.h. // Kept here because ggml-impl.h cannot depend on ggml-common.h IMPL section. +// NaN (x == 0xFF) is not handled — callers guarantee valid exponents. static inline float ggml_e8m0_to_fp32(uint8_t x) { uint32_t bits; if (x == 0) { - bits = 0x00400000; // 2^(-127) + bits = 0x00400000; // denorm: 0.5 * 2^(-126) = 2^(-127) } else { + // normalized: exponent x placed in bits [30:23], mantissa = 0 → 2^(x-127) bits = (uint32_t) x << 23; } float result; @@ -444,12 +447,16 @@ static inline float ggml_e8m0_to_fp32(uint8_t x) { return result; } -// E8M0 to float/2. Canonical source: ggml_mxfp_e8m0_to_fp32_half() in ggml-common.h. +// E8M0 to float/2: returns 2^(x - 128). Equal to ggml_e8m0_to_fp32(x) / 2. +// Useful with MXFP4 because the E2M1 kvalues table stores 2 * float_value. +// Canonical source: ggml_mxfp_e8m0_to_fp32_half() in ggml-common.h. static inline float ggml_e8m0_to_fp32_half(uint8_t x) { uint32_t bits; if (x < 2) { + // x=0 → 2^(-128), x=1 → 2^(-127): denormal bit patterns bits = 0x00200000 << x; } else { + // 0.5 * 2^(x-127) = 2^(x-128): normalized with exponent (x-1) bits = (uint32_t)(x - 1) << 23; } float result; @@ -460,37 +467,42 @@ static inline float ggml_e8m0_to_fp32_half(uint8_t x) { #define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x) #define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x) -// UE4M3: unsigned, 4 exp bits (bias=7), 3 mantissa bits -// Returns value * 0.5 to match kvalues_mxfp4 convention (kvalues = 2 * E2M1_float) +// UE4M3 (unsigned E4M3): 4 exponent bits (bias 7), 3 mantissa bits, no sign. +// Range: [0, 448], with 0x7F = NaN treated as zero. +// Returns value * 0.5 to match kvalues_mxfp4 convention (kvalues = 2 * E2M1_float). static inline float ggml_ue4m3_to_fp32(uint8_t x) { if (x == 0 || x == 0x7F) { - return 0.0f; + return 0.0f; // zero and NaN → 0 } int exp = (x >> 3) & 0xF; int man = x & 0x7; float raw; if (exp == 0) { + // subnormal: value = man * 2^(1 - bias - mantissa_bits) = man * 2^(-9) raw = ldexpf((float) man, -9); } else { + // normalized: value = (1 + man/8) * 2^(exp - 7) raw = ldexpf(1.0f + (float) man / 8.0f, exp - 7); } return raw * 0.5f; } +// Float32 to UE4M3 with round-to-nearest-even. +// Clamps to [0, 448]. Max representable: 0x7E = (1 + 7/8) * 2^8 = 448. static inline uint8_t ggml_fp32_to_ue4m3(float x) { if (!(x > 0.0f)) { - return 0; + return 0; // negative, zero, NaN → 0 } if (x > 448.0f) { - x = 448.0f; + x = 448.0f; // clamp to max representable } uint32_t bits; memcpy(&bits, &x, 4); int fp32_exp = ((bits >> 23) & 0xFF) - 127; - int fp32_man = (bits >> 20) & 0x7; - int ue4m3_exp = fp32_exp + 7; + int fp32_man = (bits >> 20) & 0x7; // top 3 mantissa bits + int ue4m3_exp = fp32_exp + 7; // rebias: FP32 bias 127 → UE4M3 bias 7 if (ue4m3_exp <= 0) { - // subnormal: value = man * 2^-9, man = round(x * 2^9) + // subnormal: value = man * 2^(-9), so man = round(x * 512) int man = (int) (x * 512.0f + 0.5f); if (man > 7) { man = 7; @@ -501,15 +513,17 @@ static inline uint8_t ggml_fp32_to_ue4m3(float x) { return (uint8_t) man; } if (ue4m3_exp >= 15) { - return 0x7E; + return 0x7E; // max normal } + // round-to-nearest using bit 19 (first bit below UE4M3 mantissa) int round_bit = (bits >> 19) & 1; int ue4m3_man = fp32_man + round_bit; if (ue4m3_man > 7) { + // mantissa overflow → carry into exponent ue4m3_man = 0; ue4m3_exp++; if (ue4m3_exp >= 15) { - return 0x7E; + return 0x7E; // max normal } } return (uint8_t) ((ue4m3_exp << 3) | ue4m3_man); diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 188c7e68b6..e88435061d 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -257,96 +257,31 @@ void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_REST } } -// ============================================================================ -// MXFP Element Conversion Functions -// ============================================================================ -// -// Reference implementations for OCP Microscaling (MX) format element types. -// Spec: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf -// -// All converters use IEEE-754 bit manipulation via memcpy (C99 safe, no strict -// aliasing issues). Quantization uses round-to-nearest-even (RNE) per MX spec. -// -// These functions are exposed in ggml-quants.h for use by CPU backends and tests. -// GPU backends (CUDA, Vulkan, Metal) provide their own optimized versions using -// hardware intrinsics (e.g., __nv_cvt_float_to_fp8, SIMD groups, LUT lookups). -// -// Key design decisions validated empirically on CUDA (Qwen3-Coder-30B-A3B): -// -// 1. SATURATION, NOT NaN PROPAGATION: FP8 E4M3 saturates to max (0x7E = 448) -// rather than producing NaN. The single NaN encoding (0x7F) is avoided. -// This matches the MX spec behavior and prevents NaN corruption in KV caches. -// -// 2. MX FP6 HAS NO NaN/Inf: Unlike IEEE-754, the MX spec defines exp=max as a -// valid normal value for FP6 types. Dequantizers must NOT special-case it. -// -// 3. RNE ROUNDING IN SUBNORMALS: Both normal and subnormal paths use proper -// round-to-nearest-even with sticky bit tracking. This was a P0 bug fix — -// truncation caused measurable PPL regression. -// -// 4. E3M2 SUBNORMAL SCALE: mant * 2^(1-bias-m) = mant * 2^(-4) = mant/16. -// NOT mant/4. This was a critical bug — the exponent bias and mantissa width -// both affect the subnormal multiplier. -// +// ====================== MXFP element conversions (wrappers around ggml-common.h) -// FP8 E4M3: 1 sign, 4 exponent (bias 7), 3 mantissa -// Max finite: 448 (exp=15, mant=6), NaN: exp=15, mant=7 -// Thin wrappers around canonical implementations in ggml-common.h. -// Verified bit-for-bit identical by test-mxfp-converters. +// FP8 E4M3: 1 sign, 4 exp (bias 7), 3 mantissa. Max finite: 448, NaN: 0x7F (saturated to max). float fp8_e4m3_to_float(uint8_t v) { return ggml_mxfp_fp8_e4m3_to_float(v); } uint8_t float_to_fp8_e4m3_rn(float x) { return ggml_mxfp_float_to_fp8_e4m3(x); } -// ============================================================================ -// MXFP Quantization Infrastructure -// ============================================================================ -// -// The MX format uses a shared E8M0 exponent per block of 32 elements. Choosing -// the optimal exponent is critical for quantization quality. -// -// The OCP MX v1.0 spec (§5.3) specifies floor(log2(amax)) for the shared exponent. -// We improve on this with an MSE-optimal ±1 search that tests 3 candidate exponents -// {e-1, e, e+1} around round(log2(amax)) and picks whichever minimizes the total -// round-trip quantization error for the block. This consistently improves perplexity -// by 0.05-0.2 across all MX types versus floor-only or round-only approaches. -// -// The round(log2(amax)) base is computed via IEEE-754 integer bit extraction rather -// than log2f(), avoiding GPU Special Function Unit (SFU) bottlenecks. The rounding -// threshold 0x3504F3 is the fractional part of sqrt(2) in IEEE-754 mantissa bits: -// if mantissa >= (sqrt(2)-1)*2^23 ≈ 0x3504F3, then log2(x) >= n+0.5, so round up. -// -// Each MX element type provides an mse_error function that computes the round-trip -// quantization error for a single value at a given scale. The traits structure -// encapsulates this per-type behavior. +// ====================== MXFP quantization infrastructure // +// MSE-optimal E8M0 shared exponent: tests ±R candidates around round(log2(amax)) +// and picks whichever minimizes total round-trip quantization error per block. +// Improves on OCP MX v1.0 §5.3 floor(log2(amax)) by 0.05-0.2 PPL. +// Ref: OCP MX v1.0 spec, Four Over Six (arXiv:2512.02010) -// Per-type traits for MSE-optimal E8M0 scale computation. -// emax_offset: type-specific offset from E8M0 bias to type's max representable exponent -// to_elem/to_float: element conversion function pointers (NULL for MXFP4 which uses LUT) -// mse_error: round-trip error function for a single value at a given scale typedef struct { - int emax_offset; + int emax_offset; // type-specific offset to max representable exponent uint8_t (*to_elem)(float); float (*to_float)(uint8_t); float (*mse_error)(float val, float inv_scale, float scale); } mxfp_elem_traits_t; -// Forward declaration — defined after kvalues_mxfp4 lookup table section. static inline int best_index_mxfp4(float x, float e); -// MXFP4 E2M1 MSE error: decision boundary quantization with HALF scale factor. -// -// This CPU implementation uses the doubled int8 kvalues_mxfp4 LUT {0,1,2,3,4,6,8,12} -// with GGML_E8M0_TO_FP32_HALF(e) = scale/2 for efficient nibble-indexed integer arithmetic. -// The MSE interface passes GGML_E8M0_TO_FP32(e) as scale, so we halve it. -// -// Canonical E2M1 values are {0, 0.5, 1, 1.5, 2, 3, 4, 6} (kvalues_mxfp4_float in ggml-common.h). -// Doubled boundaries {0.5, 1.5, 2.5, 3.5, 5, 7, 10} ÷ 2 = canonical {0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5}. -// Mathematically identical — the doubling is an implementation detail. -// This is the Lloyd-Max quantizer for uniform input density. +// MXFP4 MSE error using decision boundary quantization with half-scale +// (kvalues_mxfp4 are doubled E2M1 values, so scale is halved to compensate) static float mse_error_mxfp4(float val, float inv_scale, float scale) { - // Decision boundary quantization with direct reconstruction. - // kvalues_mxfp4 positive sorted: {0, 1, 2, 3, 4, 6, 8, 12} - // Use inv_scale * 2 since MXFP4 scale includes 0.5x factor. const float d = scale * 0.5f; const float inv_d = (d > 0.0f) ? 1.0f / d : 0.0f; const float normalized = fabsf(val) * inv_d; @@ -366,22 +301,7 @@ static float mse_error_mxfp4(float val, float inv_scale, float scale) { static const mxfp_elem_traits_t mxfp4_traits = { MXFP4_E2M1_EMAX_OFFSET, NULL, NULL, mse_error_mxfp4 }; -// MSE-optimal E8M0 shared exponent computation. -// -// Algorithm: -// 1. Find amax = max(|x[0..qk-1]|) -// 2. Compute e_base = round(log2(amax)) - emax_offset + 127 via integer bit ops -// 3. Test {e_base-R .. e_base+R}, pick the one minimizing total round-trip MSE -// where R = MXFP_E8M0_MSE_RANGE (defined in ggml-common.h) -// -// The ±R search improves on the OCP spec's floor(log2(amax)). Wider search finds -// better scales for blocks with non-uniform value distributions (especially FP4). -// Cost is (2R+1) × qk roundtrip evaluations per block — negligible vs attention compute. -// -// Integer log2 avoids log2f() (SFU-dependent on GPU). The sqrt(2) rounding threshold -// ensures we start from round() not floor(). -// -// Ref: OCP MX v1.0 §5.3; Four Over Six (arXiv:2512.02010) +// Find MSE-optimal E8M0 exponent by testing ±R candidates around round(log2(amax)) static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_elem_traits_t * traits) { float amax = 0.0f; for (int j = 0; j < qk; j++) { @@ -392,7 +312,6 @@ static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_ const int e_base = ggml_mxfp_e8m0_base_estimate(amax, traits->emax_offset); - // ±R MSE search: test 2R+1 candidates around e_base, pick lowest total MSE. int e_lo = e_base - MXFP_E8M0_MSE_RANGE; int e_hi = e_base + MXFP_E8M0_MSE_RANGE; if (e_lo < 1) e_lo = 1; @@ -417,10 +336,8 @@ static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_ return (uint8_t)best_e; } +// Decision boundary quantization for kvalues_mxfp4 {0,1,2,3,4,6,8,12} static inline int best_index_mxfp4(float x, float e) { - // Decision boundary quantization: 7 comparisons instead of 16-element scan. - // kvalues_mxfp4 positive sorted: {0, 1, 2, 3, 4, 6, 8, 12} - // Decision boundaries (midpoints): {0.5, 1.5, 2.5, 3.5, 5, 7, 10} const float inv_e = (e > 0.0f) ? 1.0f / e : 0.0f; const float normalized = fabsf(x) * inv_e; int idx; @@ -435,10 +352,6 @@ static inline int best_index_mxfp4(float x, float e) { return (x < 0.0f) ? (idx + 8) : idx; } -// FP4 E2M1: search-based quantization using best_index_mxfp4 lookup table. -// Unlike FP6/FP8 which use direct float->element conversion, FP4 finds the -// closest 4-bit value by minimizing reconstruction error against the lookup table. -// Scale uses GGML_E8M0_TO_FP32_HALF (includes 0.5x factor for E2M1 mantissa range). void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) { static const int qk = QK_MXFP4; @@ -652,37 +565,13 @@ void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_REST } } -// ============================================================================ -// Hadamard Rotation (reference scalar implementation) -// ============================================================================ -// -// 32-element Walsh-Hadamard transform, applied to MX blocks before quantization -// to spread outlier energy uniformly across the shared-exponent group. -// -// Without rotation, a single outlier in a block of 32 forces the shared E8M0 -// exponent high, wasting precision for all 31 other elements. The Hadamard -// transform is orthogonal (H^T·H = I), so H(K)·H(Q) = K·Q — attention scores -// are preserved exactly when both K and Q undergo the same rotation. -// -// Implementation: 5 butterfly stages (log2(32) = 5) of the fast Walsh-Hadamard -// transform, followed by normalization by 1/sqrt(32). Total: 160 FP add/sub + -// 32 FP mul. This is the standard "in-place" FWHT with O(n·log(n)) operations. -// -// The 1/sqrt(32) normalization factor makes the transform orthonormal: -// H_normalized = H_unnormalized / sqrt(N) -// This ensures the transform preserves vector norms (energy), which is critical -// for maintaining attention score magnitudes after rotation. -// -// Prior art: QuIP# (Tseng et al. 2024), BRQ (Huang et al. 2024) apply Hadamard -// for weight quantization. Our novel contribution: applying it to KV cache -// quantization at the MX block boundary (block-32), where it matches the shared -// exponent group size. Tested alternatives (block-8, block-16, sign flips, -// permutations) all degraded quality — block-32 Hadamard is uniquely optimal -// because it spreads energy across exactly the elements sharing an exponent. -// -// Empirical PPL impact WITHOUT Hadamard rotation (Qwen3-Coder-30B-A3B): -// MXFP8 E4M3: +0.22, MXFP8 E5M2: +1.38, MXFP6 E2M3: +3.34, MXFP6 E3M2: +4.60 +// ====================== Hadamard rotation // +// 32-element Walsh-Hadamard transform applied before MX quantization to spread +// outlier energy across the shared-exponent group. Orthogonal (H^T·H = I), so +// H(K)·H(Q) = K·Q — attention scores are preserved when both K and Q are rotated. +// Prior art: QuIP# (Tseng et al. 2024), BRQ (Huang et al. 2024) + void ggml_hadamard_32_inplace(float vals[32]) { ggml_mxfp_hadamard_32_inplace(vals); } @@ -697,9 +586,7 @@ uint8_t float_to_fp8_e5m2_rn(float x) { return ggml_mxfp_float_to_fp8_e5m2(x); } void pack_fp6x4(const uint8_t v[4], uint8_t out[3]) { ggml_mxfp_pack_fp6x4(v, out); } void unpack_fp6x4(const uint8_t in[3], uint8_t v[4]) { ggml_mxfp_unpack_fp6x4(in, v); } -// MSE error functions for FP8/FP6: quantize at given scale → dequantize → squared error. -// Used by mxfp_compute_e8m0_mse() to evaluate candidate E8M0 exponents. -// These call the public API wrappers which delegate to canonical ggml_mxfp_* in ggml-common.h. +// round-trip quantization error for MSE-optimal exponent search static float mse_error_fp8_e4m3(float val, float inv_scale, float scale) { const float recon = fp8_e4m3_to_float(float_to_fp8_e4m3_rn(val * inv_scale)) * scale; const float err = val - recon; @@ -710,14 +597,9 @@ static float mse_error_fp6_e2m3(float val, float inv_scale, float scale) { const float err = val - recon; return err * err; } -// emax_offset = ceil(log2(max_finite_value)) for each element type. -// This centers the E8M0 exponent search around the optimal scale for the type's range. -// E4M3: max=448, ceil(log2(448)) = 9, but offset=8 matches CUDA (empirically better) -// E2M3: max=7.5, ceil(log2(7.5)) = 3 static const mxfp_elem_traits_t mxfp8_e4m3_traits = { MXFP8_E4M3_EMAX_OFFSET, float_to_fp8_e4m3_rn, fp8_e4m3_to_float, mse_error_fp8_e4m3 }; static const mxfp_elem_traits_t mxfp6_e2m3_traits = { MXFP6_E2M3_EMAX_OFFSET, float_to_fp6_e2m3_rn, fp6_e2m3_to_float, mse_error_fp6_e2m3 }; -// FP8 quantize/dequantize: byte-per-element, shared by E4M3 and E5M2 static void quantize_row_mxfp8_impl(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k, const mxfp_elem_traits_t * traits) { assert(k % QK_MXFP8 == 0); @@ -748,7 +630,6 @@ static void dequantize_row_mxfp8_impl(const block_mxfp8 * GGML_RESTRICT x, float } } -// FP6 quantize/dequantize: tight 6-bit packing (4 values per 3 bytes), shared by E2M3 and E3M2 static void quantize_row_mxfp6_impl(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k, const mxfp_elem_traits_t * traits) { assert(k % QK_MXFP6 == 0); @@ -787,8 +668,6 @@ static void dequantize_row_mxfp6_impl(const block_mxfp6 * GGML_RESTRICT x, float } } -// Public API wrappers — one-line delegates to the traits-parameterized impl - void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k) { quantize_row_mxfp8_impl(x, y, k, &mxfp8_e4m3_traits); } @@ -805,13 +684,9 @@ void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_REST dequantize_row_mxfp6_impl(x, y, k, &mxfp6_e2m3_traits); } -// ============================================================================ -// SoA (Struct-of-Arrays) quantize/dequantize — canonical reference for FA -// ============================================================================ +// ====================== SoA (Struct-of-Arrays) quantize/dequantize for flash attention // -// SoA layout per row: [qs_block0|qs_block1|...|qs_blockN][e8m0_0|e8m0_1|...|e8m0_N] -// Total bytes per row = nblocks * (QS_PER_BLOCK + 1) = identical to AoS. -// This is the ONLY layout used by flash attention across all backends. +// Layout per row: [qs_block0|qs_block1|...|qs_blockN][e8m0_0|e8m0_1|...|e8m0_N] void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) { assert(k % QK_MXFP4 == 0); diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 29fdeb3f33..0e501697b0 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -136,7 +136,7 @@ llama_kv_cache::llama_kv_cache( const bool has_k = true; const bool has_v = !is_mla; - // MXFP K cache: align block count to 16 for cp.async. + // MXFP: align block count to 16 for cp.async uint32_t n_embd_k_alloc = n_embd_k_gqa; const bool is_mxfp_k = ggml_is_type_mxfp(type_k); if (is_mxfp_k) { @@ -1037,8 +1037,7 @@ ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_k auto * k = layers[ikv].k; - // For MXFP types: k->ne[0] may include alignment padding (blocks aligned to 16). - // The row stride (k->nb[1]) reflects the padded allocation. + // note: for MXFP types, k->ne[0] may be padded for block alignment; use nb[] for strides const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; return ggml_view_4d(ctx, k, @@ -1107,14 +1106,13 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm assert(kv_size == k->ne[1]); // merge the buffer across all streams because the idxs are global - // Use view_2d to preserve nb[1] (which includes alignment padding for MXFP types) + // note: use view_2d to preserve nb[1] (includes MXFP alignment padding) k = ggml_view_2d(ctx, k, k->ne[0], kv_size*n_stream, k->nb[1], 0); } const bool is_mxfp = ggml_is_type_mxfp(k->type); - // For MXFP: ne[0] may be padded for block alignment, but k_cur has n_embd_gqa. - // Create view with ne[0]=n_embd_gqa, preserving the larger row stride nb[1]. + // for MXFP: ne[0] may be padded, narrow view to n_embd_gqa while keeping row stride ggml_tensor * k_dst = k; if (is_mxfp) { k_dst = ggml_view_2d(ctx, k, n_embd_gqa, k->ne[1], k->nb[1], 0); @@ -1123,11 +1121,8 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm // store the current K values into the cache ggml_tensor * result = ggml_set_rows(ctx, k_dst, k_cur, k_idxs); - // Flag K cache writes for Walsh-Hadamard rotation (QuaRot, arXiv:2404.00456; BRQ, arXiv:2511.04214). - // The flash attention kernel applies matching rotation to Q so H(Q)·H(K)^T = Q·K^T. - // V cache writes are NOT rotated (op_params[0] defaults to 0). - // Skipped for: MLA (V is a view of K — rotation would corrupt V), - // E5M2/E3M2 (2-bit mantissa — Hadamard provides no quality benefit). + // enable Hadamard rotation for MXFP K cache (QuaRot arXiv:2404.00456, BRQ arXiv:2511.04214) + // skipped for MLA (V is a view of K) and E5M2/E3M2 (2-bit mantissa, no benefit) if (is_mxfp && !hparams.is_mla() && ggml_mxfp_use_hadamard(k->type)) { ((int32_t *)result->op_params)[0] = 1; } From c913ab36d2e1d2a68125300f6cdae968fc2f2a83 Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Sun, 15 Mar 2026 20:30:01 -0400 Subject: [PATCH 05/13] fix buffer overflows for large DK and multi-head MXFP flash attention - Increase q_mxfp_buf from 512 to 2048 bytes (supports DK up to 1024 with MXFP8) - Replace fixed k_soa[4096]/v_soa[4096] stack arrays with dynamically sized vectors - Replace fixed k_head_soa[320]/v_head_soa[320] with dynamically sized vectors - Add soa_bytes divisibility assertion in test init --- ggml/src/ggml-cpu/ops.cpp | 34 +++++++++++++++++++--------------- tests/test-backend-ops.cpp | 2 +- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 2b24e87ae6..27275ca1e1 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8464,10 +8464,13 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( float v_dequant_buf[1024]; // Per-head SoA temp buffer: holds [qs | e8m0] for one head in multihead mode. - // Max size: 32 bytes qs (mxfp8, DK=128) + 4 bytes e8m0 = 36 bytes per head. - // For DK up to 1024: 256 + 32 = 288 bytes. Use fixed-size stack buffer. - alignas(16) char k_head_soa[320]; // enough for DK up to 1024 with any MXFP type - alignas(16) char v_head_soa[320]; + // For DK=1024 with MXFP8: 32 blocks * 32 qs + 32 e8m0 = 1056 bytes. + const size_t k_head_soa_size = is_mxfp_k ? (size_t)(mxfp.k_head_qs_bytes + mxfp.k_blocks_per_head) : 0; + const size_t v_head_soa_size = is_mxfp_v ? (size_t)(mxfp.v_head_qs_bytes + mxfp.v_blocks_per_head) : 0; + std::vector k_head_soa_vec(k_head_soa_size); + std::vector v_head_soa_vec(v_head_soa_size); + char * k_head_soa = k_head_soa_vec.data(); + char * v_head_soa = v_head_soa_vec.data(); // Thread-local work buffers (constant across ir loop) float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); @@ -8828,9 +8831,10 @@ static void ggml_compute_forward_flash_attn_ext_tiled( if (mxfp.apply_hadamard) { ggml_apply_hadamard_blocks(Q_f32 + tq * DK, DK); } - // SoA round-trip: quantize Q to SoA, then dequant back to float. - uint8_t q_mxfp_buf[512]; // max: DK=256 * 33/32 = 264 bytes (MXFP8) - GGML_ASSERT(ggml_row_size(k->type, DK) <= sizeof(q_mxfp_buf)); + // SoA round-trip: quantize Q to SoA, then dequant back to float + const size_t q_soa_bytes = ggml_row_size(k->type, DK); + GGML_ASSERT(q_soa_bytes <= 2048); + uint8_t q_mxfp_buf[2048]; // max: DK=1024 * 33/32 = 1056 bytes (MXFP8) mxfp.q_quantize(Q_f32 + tq * DK, q_mxfp_buf, DK); mxfp.k_dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK); } @@ -8843,6 +8847,10 @@ static void ggml_compute_forward_flash_attn_ext_tiled( memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float)); memset(V32, 0, KV_TILE_SZ * DV * sizeof(float)); + // dequant scratch buffers for SoA MXFP — allocated once per tile, reused per KV token + std::vector k_soa_buf(mxfp.k_soa_elems); + std::vector v_soa_buf(mxfp.v_soa_elems); + for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) { const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic); @@ -8886,10 +8894,8 @@ static void ggml_compute_forward_flash_attn_ext_tiled( const char * k_soa_base = mxfp.k_multihead ? ((const char *)k->data + (ic + tk)*nbk1 + ik3*nbk3) : k_data; - float k_soa[4096]; - GGML_ASSERT(mxfp.k_soa_elems <= 4096); - mxfp.k_dequantize(k_soa_base, k_soa, mxfp.k_soa_elems); - const float * k_head = k_soa + (mxfp.k_multihead ? ik2 * DK : 0); + mxfp.k_dequantize(k_soa_base, k_soa_buf.data(), mxfp.k_soa_elems); + const float * k_head = k_soa_buf.data() + (mxfp.k_multihead ? ik2 * DK : 0); for (int64_t dk = 0; dk < DK; dk++) { K_f32[dk * KV_TILE_SZ + tk] = k_head[dk]; } @@ -8962,10 +8968,8 @@ static void ggml_compute_forward_flash_attn_ext_tiled( const char * v_soa_base = mxfp.v_multihead ? ((const char *)v->data + (ic + tk)*nbv1 + iv3*nbv3) : v_data; - float v_soa[4096]; - GGML_ASSERT(mxfp.v_soa_elems <= 4096); - mxfp.v_dequantize(v_soa_base, v_soa, mxfp.v_soa_elems); - memcpy(V32 + tk * DV, v_soa + (mxfp.v_multihead ? iv2 * DV : 0), DV * sizeof(float)); + mxfp.v_dequantize(v_soa_base, v_soa_buf.data(), mxfp.v_soa_elems); + memcpy(V32 + tk * DV, v_soa_buf.data() + (mxfp.v_multihead ? iv2 * DV : 0), DV * sizeof(float)); } else { v_to_float(v_data, V32 + tk * DV, DV); } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 3915ab4db6..6dd245aeab 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -202,6 +202,7 @@ static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float // If soa_bytes < nb1 (per-head), iterate over (ne1 * ne2 * ne3) regions with stride nb1. // We use strides to compute offsets, handling views and permutations correctly. const int64_t heads_per_region = (int64_t)(soa_bytes / head_row_sz); + GGML_ASSERT(soa_bytes % head_row_sz == 0 && "soa_bytes must be a multiple of head_row_sz"); // For multi-head regions, we step by nb1 (KV-position stride) between regions. // For per-head, we step through all dimensions. @@ -210,7 +211,6 @@ static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float if (heads_per_region > 1) { // Multi-head SoA: iterate over (kv_positions * batches), each region = nb1 bytes for (int64_t i3 = 0; i3 < ne3; i3++) { - // ne2/heads_per_region = number of head groups (for GQA broadcast, usually 1) const int64_t n_groups = ne2 / heads_per_region; for (int64_t ig = 0; ig < n_groups; ig++) { for (int64_t i1 = 0; i1 < ne1; i1++) { From b8e8d291d119bba4be8baa6423b49e34b7ad2df4 Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Sun, 15 Mar 2026 21:27:42 -0400 Subject: [PATCH 06/13] =?UTF-8?q?ggml:=20refactor=20x86=20AVX2=20and=20ARM?= =?UTF-8?q?=20NEON=20MXFP=20dequant=20=E2=80=94=20shared=20traits=20and=20?= =?UTF-8?q?helpers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add mxfp_dequant_traits_t to ggml-common.h as single source of truth for MXFP IEEE-754 reconstruction parameters. Define static const instances for all 4 formats (E4M3, E5M2, E2M3, E3M2), ready for CUDA/Metal/Vulkan reuse. Extract shared dequant and FP6 unpack helpers on both architectures, replacing duplicated inline code and macros. Net -215 lines. --- ggml/src/ggml-common.h | 41 ++ ggml/src/ggml-cpu/arch/arm/quants.c | 662 ++++++++++++---------------- ggml/src/ggml-cpu/arch/x86/quants.c | 562 +++++++++-------------- ggml/src/ggml-cpu/ops.cpp | 10 +- ggml/src/ggml-quants.h | 2 +- 5 files changed, 531 insertions(+), 746 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 9945fef137..cc9a4a0aca 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -267,6 +267,47 @@ static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4 #define MXFP6_E3M2_MANT_SHIFT 21 // 23-2 #define MXFP6_E3M2_SUB_SCALE (1.0f/16.0f) // 2^(-4) = 2^(1-3-2) +// Unified MXFP dequantization traits for SIMD backends (CPU x86/ARM, CUDA, Metal, Vulkan). +// Contains all parameters needed for IEEE-754 bit reconstruction of FP8/FP6 elements. +// FP4 uses LUT-based dequant and does not need this struct. +typedef struct { + int exp_mask; // (1<> 0) & 0x3F; + u[1] = (pk >> 6) & 0x3F; + u[2] = (pk >> 12) & 0x3F; + u[3] = (pk >> 18) & 0x3F; + const uint8x8_t raw8 = vcreate_u8( + (uint64_t)u[0] | ((uint64_t)u[1] << 8) | + ((uint64_t)u[2] << 16) | ((uint64_t)u[3] << 24)); + return vmovl_u16(vget_low_u16(vmovl_u8(raw8))); +} + +// Widen 8 raw bytes to two uint32x4_t halves. +static inline void widen_u8x8_to_u32x4x2(const uint8_t * src, + uint32x4_t * lo, uint32x4_t * hi) { + const uint8x8_t raw8 = vld1_u8(src); + const uint16x8_t raw16 = vmovl_u8(raw8); + *lo = vmovl_u16(vget_low_u16(raw16)); + *hi = vmovl_u16(vget_high_u16(raw16)); +} + +// Widen 8 Q8_0 int8 values to two float32x4_t halves. +static inline void widen_s8x8_to_f32x4x2(const int8_t * src, + float32x4_t * lo, float32x4_t * hi) { + const int8x8_t q8 = vld1_s8(src); + const int16x8_t q16 = vmovl_s8(q8); + *lo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(q16))); + *hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16))); +} + +// ── MXFP FP8/FP6 vec_dot ────────────────────────────────────────────────── + +static void ggml_vec_dot_mxfp8_q8_0_neon( int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, - // FP8 format parameters: - const uint32_t exp_mask, // 0xF for E4M3, 0x1F for E5M2 - const uint32_t mant_mask, // 0x7 for E4M3, 0x3 for E5M2 - const int exp_shift, // 3 for E4M3, 2 for E5M2 - const uint32_t ieee_exp_off, // 120 for E4M3, 112 for E5M2 - const int mant_shift, // 20 for E4M3, 21 for E5M2 - const float sub_scale) { // 1/512 for E4M3, 1/65536 for E5M2 + const mxfp_neon_traits_t * t) { assert(n % QK_MXFP8 == 0); const int nb = n / QK_MXFP8; const block_mxfp8 * GGML_RESTRICT x = vx; const block_q8_0 * GGML_RESTRICT y = vy; + const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask); + const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask); + const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off); + const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale); + const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift); + const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift); + float32x4_t acc0 = vdupq_n_f32(0.0f); float32x4_t acc1 = vdupq_n_f32(0.0f); - const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask); - const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask); - const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off); - const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale); - // Use variable shifts (vshlq_u32) instead of constant shifts (vshlq_n_u32) - // because exp_shift/mant_shift are function parameters, not compile-time constants. - // Clang requires _n_ intrinsics to have literal constant arguments. - const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift); - const int32x4_t v_mant_shift = vdupq_n_s32(mant_shift); - for (int ib = 0; ib < nb; ++ib) { - const float scale = GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d); - const float32x4_t v_scale = vdupq_n_f32(scale); + const float32x4_t v_scale = vdupq_n_f32( + GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d)); - // Process 32 FP8 elements in 8 groups of 4 for (int j = 0; j < 32; j += 8) { - // Load 8 FP8 bytes, extend to two uint32x4_t - const uint8x8_t raw8 = vld1_u8(x[ib].qs + j); - const uint16x8_t raw16 = vmovl_u8(raw8); - const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16)); - const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16)); + uint32x4_t v_lo, v_hi; + widen_u8x8_to_u32x4x2(x[ib].qs + j, &v_lo, &v_hi); - // Load 8 Q8_0 int8 values, extend to two int32x4_t → float32x4_t - const int8x8_t q8 = vld1_s8(y[ib].qs + j); - const int16x8_t q16 = vmovl_s8(q8); - const float32x4_t qf_lo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(q16))); - const float32x4_t qf_hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16))); + float32x4_t qf_lo, qf_hi; + widen_s8x8_to_f32x4x2(y[ib].qs + j, &qf_lo, &qf_hi); - // Dequant FP8 → float for both groups of 4 - #define DEQUANT_FP8_NEON(v_raw, qf, acc) do { \ - const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x80)); \ - const uint32x4_t exp = vandq_u32( \ - vshlq_u32(v_raw, v_neg_exp_shift), \ - v_exp_mask); \ - const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \ - /* Normal: IEEE bits = (exp + offset) << 23 | mant << mant_shift */ \ - const uint32x4_t ieee = vorrq_u32( \ - vorrq_u32(vshlq_n_u32(sign, 24), \ - vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \ - vshlq_u32(mant, v_mant_shift)); \ - const float32x4_t normal = vreinterpretq_f32_u32(ieee); \ - /* Subnormal: sign * mant * sub_scale */ \ - const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \ - const uint32x4_t sub_bits = vorrq_u32( \ - vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 24)); \ - const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \ - /* Select: subnormal when exp == 0, else normal */ \ - const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \ - const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \ - /* Multiply by scale and Q8 value, accumulate */ \ - (acc) = vfmaq_f32((acc), vmulq_f32(val, v_scale), qf); \ - } while (0) + const float32x4_t val_lo = mxfp8_dequant_neon(v_lo, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); + const float32x4_t val_hi = mxfp8_dequant_neon(v_hi, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); - DEQUANT_FP8_NEON(v_lo, qf_lo, acc0); - DEQUANT_FP8_NEON(v_hi, qf_hi, acc1); - #undef DEQUANT_FP8_NEON + acc0 = vfmaq_f32(acc0, vmulq_f32(val_lo, v_scale), qf_lo); + acc1 = vfmaq_f32(acc1, vmulq_f32(val_hi, v_scale), qf_hi); } } *s = vaddvq_f32(vaddq_f32(acc0, acc1)); } -#endif -void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); -#if defined(__ARM_NEON) - // E4M3: sign(1) exp(4) mant(3), bias=7 - ggml_vec_dot_mxfp8_q8_0_neon(n, s, vx, vy, - MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT, - MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE); -#else - ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif -} - -// NEON-optimized MXFP6 × Q8_0 dot product. -// Unpacks tight 6-bit packing (4 values per 3 bytes), then dequants to float. -#if defined(__ARM_NEON) -static inline void ggml_vec_dot_mxfp6_q8_0_neon( +static void ggml_vec_dot_mxfp6_q8_0_neon( int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, - size_t block_size, - // FP6 format parameters: - const uint32_t exp_mask, // 0x3 for E2M3, 0x7 for E3M2 - const uint32_t mant_mask, // 0x7 for E2M3, 0x3 for E3M2 - const int exp_shift, // 3 for E2M3, 2 for E3M2 - const uint32_t ieee_exp_off, // 126 for E2M3, 124 for E3M2 - const int mant_shift, // 20 for E2M3, 21 for E3M2 - const float sub_scale) { // 1/8 for E2M3, 1/16 for E3M2 + const mxfp_neon_traits_t * t) { assert(n % QK_MXFP6 == 0); const int nb = n / QK_MXFP6; const block_q8_0 * GGML_RESTRICT y = vy; + const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask); + const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask); + const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off); + const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale); + const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift); + const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift); + float32x4_t acc0 = vdupq_n_f32(0.0f); float32x4_t acc1 = vdupq_n_f32(0.0f); - const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask); - const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask); - const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off); - const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale); - const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift); - const int32x4_t v_mant_shift = vdupq_n_s32(mant_shift); - for (int ib = 0; ib < nb; ++ib) { - const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size); - const float scale = GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d); - const float32x4_t v_scale = vdupq_n_f32(scale); + const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib; + const float32x4_t v_scale = vdupq_n_f32( + GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d)); - // Process 32 FP6 elements: 8 groups of 4, each packed in 3 bytes for (int j = 0; j < 32; j += 8) { - // Unpack two groups of 4 FP6 values (6 bytes → 8 values) - uint8_t unpacked[8]; - // Group 1: 3 bytes → 4 values - { - const uint8_t * p = xb->qs + (j * 3 / 4); - const uint32_t packed = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); - unpacked[0] = (packed >> 0) & 0x3F; - unpacked[1] = (packed >> 6) & 0x3F; - unpacked[2] = (packed >> 12) & 0x3F; - unpacked[3] = (packed >> 18) & 0x3F; - } - // Group 2: next 3 bytes → 4 values - { - const uint8_t * p = xb->qs + ((j + 4) * 3 / 4); - const uint32_t packed = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); - unpacked[4] = (packed >> 0) & 0x3F; - unpacked[5] = (packed >> 6) & 0x3F; - unpacked[6] = (packed >> 12) & 0x3F; - unpacked[7] = (packed >> 18) & 0x3F; - } + const uint32x4_t v_lo = unpack_fp6x4_neon(xb->qs + (j * 3 / 4)); + const uint32x4_t v_hi = unpack_fp6x4_neon(xb->qs + ((j + 4) * 3 / 4)); - // Extend to uint32x4_t - const uint8x8_t raw8 = vld1_u8(unpacked); - const uint16x8_t raw16 = vmovl_u8(raw8); - const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16)); - const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16)); + float32x4_t qf_lo, qf_hi; + widen_s8x8_to_f32x4x2(y[ib].qs + j, &qf_lo, &qf_hi); - // Load Q8_0 int8 values - const int8x8_t q8 = vld1_s8(y[ib].qs + j); - const int16x8_t q16 = vmovl_s8(q8); - const float32x4_t qf_lo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(q16))); - const float32x4_t qf_hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16))); + const float32x4_t val_lo = mxfp6_dequant_neon(v_lo, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); + const float32x4_t val_hi = mxfp6_dequant_neon(v_hi, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); - // Dequant FP6 → float (same IEEE construction as FP8, sign bit at position 5) - #define DEQUANT_FP6_NEON(v_raw, qf, acc) do { \ - const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x20)); \ - const uint32x4_t exp = vandq_u32( \ - vshlq_u32(v_raw, v_neg_exp_shift), \ - v_exp_mask); \ - const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \ - const uint32x4_t ieee = vorrq_u32( \ - vorrq_u32(vshlq_n_u32(sign, 26), \ - vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \ - vshlq_u32(mant, v_mant_shift)); \ - const float32x4_t normal = vreinterpretq_f32_u32(ieee); \ - const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \ - const uint32x4_t sub_bits = vorrq_u32( \ - vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 26)); \ - const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \ - const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \ - const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \ - (acc) = vfmaq_f32((acc), vmulq_f32(val, v_scale), qf); \ - } while (0) - - DEQUANT_FP6_NEON(v_lo, qf_lo, acc0); - DEQUANT_FP6_NEON(v_hi, qf_hi, acc1); - #undef DEQUANT_FP6_NEON + acc0 = vfmaq_f32(acc0, vmulq_f32(val_lo, v_scale), qf_lo); + acc1 = vfmaq_f32(acc1, vmulq_f32(val_hi, v_scale), qf_hi); } } *s = vaddvq_f32(vaddq_f32(acc0, acc1)); } -#endif -void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); -#if defined(__ARM_NEON) - // E2M3: sign(1) exp(2) mant(3), bias=1 - ggml_vec_dot_mxfp6_q8_0_neon(n, s, vx, vy, sizeof(block_mxfp6), - MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, - MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); -#else - ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif -} +// ── MXFP FP8/FP6 dequantize_row (AoS) ───────────────────────────────────── -// ---- MXFP dequantize_row (to_float) — NEON-optimized ---- - -#if defined(__ARM_NEON) -static inline void dequantize_row_mxfp8_neon( +static void dequantize_row_mxfp8_neon( const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k, - const uint32_t exp_mask, const uint32_t mant_mask, - const int exp_shift, const uint32_t ieee_exp_off, - const int mant_shift, const float sub_scale) { + const mxfp_neon_traits_t * t) { assert(k % QK_MXFP8 == 0); const int nb = k / QK_MXFP8; const block_mxfp8 * GGML_RESTRICT x = vx; - const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask); - const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask); - const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off); - const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale); - const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift); - const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift); + const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask); + const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask); + const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off); + const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale); + const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift); + const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift); for (int ib = 0; ib < nb; ++ib) { const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32(x[ib].e)); for (int j = 0; j < 32; j += 8) { - const uint8x8_t raw8 = vld1_u8(x[ib].qs + j); - const uint16x8_t raw16 = vmovl_u8(raw8); - const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16)); - const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16)); + uint32x4_t v_lo, v_hi; + widen_u8x8_to_u32x4x2(x[ib].qs + j, &v_lo, &v_hi); - #define DEQUANT_FP8_STORE(v_raw, dst) do { \ - const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x80)); \ - const uint32x4_t exp = vandq_u32( \ - vshlq_u32(v_raw, v_neg_exp_shift), \ - v_exp_mask); \ - const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \ - const uint32x4_t ieee = vorrq_u32( \ - vorrq_u32(vshlq_n_u32(sign, 24), \ - vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \ - vshlq_u32(mant, v_mant_shift_v)); \ - const float32x4_t normal = vreinterpretq_f32_u32(ieee); \ - const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \ - const uint32x4_t sub_bits = vorrq_u32( \ - vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 24)); \ - const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \ - const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \ - const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \ - vst1q_f32(dst, vmulq_f32(val, v_scale)); \ - } while (0) + const float32x4_t val_lo = mxfp8_dequant_neon(v_lo, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); + const float32x4_t val_hi = mxfp8_dequant_neon(v_hi, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); - DEQUANT_FP8_STORE(v_lo, y + ib * QK_MXFP8 + j); - DEQUANT_FP8_STORE(v_hi, y + ib * QK_MXFP8 + j + 4); - #undef DEQUANT_FP8_STORE + vst1q_f32(y + ib * QK_MXFP8 + j, vmulq_f32(val_lo, v_scale)); + vst1q_f32(y + ib * QK_MXFP8 + j + 4, vmulq_f32(val_hi, v_scale)); } } } -static inline void dequantize_row_mxfp6_neon( +static void dequantize_row_mxfp6_neon( const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k, - size_t block_size, - const uint32_t exp_mask, const uint32_t mant_mask, - const int exp_shift, const uint32_t ieee_exp_off, - const int mant_shift, const float sub_scale) { + const mxfp_neon_traits_t * t) { assert(k % QK_MXFP6 == 0); const int nb = k / QK_MXFP6; - const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask); - const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask); - const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off); - const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale); - const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift); - const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift); + const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask); + const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask); + const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off); + const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale); + const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift); + const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift); for (int ib = 0; ib < nb; ++ib) { - const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size); + const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib; const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32(xb->e)); for (int j = 0; j < 32; j += 4) { - const uint8_t * p = xb->qs + (j * 3 / 4); - const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); - uint8_t unpacked[4]; - unpacked[0] = (pk >> 0) & 0x3F; - unpacked[1] = (pk >> 6) & 0x3F; - unpacked[2] = (pk >> 12) & 0x3F; - unpacked[3] = (pk >> 18) & 0x3F; + const uint32x4_t v_raw = unpack_fp6x4_neon(xb->qs + (j * 3 / 4)); - const uint8x8_t raw8 = vcreate_u8( - (uint64_t)unpacked[0] | ((uint64_t)unpacked[1] << 8) | - ((uint64_t)unpacked[2] << 16) | ((uint64_t)unpacked[3] << 24)); - const uint32x4_t v_raw = vmovl_u16(vget_low_u16(vmovl_u8(raw8))); - - const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x20)); - const uint32x4_t exp = vandq_u32( - vshlq_u32(v_raw, v_neg_exp_shift), - v_exp_mask); - const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); - - const uint32x4_t ieee = vorrq_u32( - vorrq_u32(vshlq_n_u32(sign, 26), - vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), - vshlq_u32(mant, v_mant_shift_v)); - const float32x4_t normal = vreinterpretq_f32_u32(ieee); - - const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); - const uint32x4_t sub_bits = vorrq_u32( - vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 26)); - const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); - - const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); - const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); + const float32x4_t val = mxfp6_dequant_neon(v_raw, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); vst1q_f32(y + ib * QK_MXFP6 + j, vmulq_f32(val, v_scale)); } } } -#endif -void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { -#if defined(__ARM_NEON) - dequantize_row_mxfp8_neon(x, y, k, - MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT, - MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE); -#else - dequantize_row_mxfp8_cpu_generic(x, y, k); -#endif +// ── MXFP SoA dequant (flash attention) ───────────────────────────────────── + +static void dequantize_row_mxfp8_soa_neon( + const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, + const mxfp_neon_traits_t * t) { + assert(k % QK_MXFP8 == 0); + const int nb = k / QK_MXFP8; + const char * qs_base = (const char *)src; + const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK); + + const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask); + const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask); + const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off); + const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale); + const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift); + const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift); + + for (int ib = 0; ib < nb; ++ib) { + const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib])); + const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP8_SOA_QS_PER_BLOCK)); + + for (int j = 0; j < 32; j += 8) { + uint32x4_t v_lo, v_hi; + widen_u8x8_to_u32x4x2(qs + j, &v_lo, &v_hi); + + const float32x4_t val_lo = mxfp8_dequant_neon(v_lo, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); + const float32x4_t val_hi = mxfp8_dequant_neon(v_hi, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); + + vst1q_f32(y + ib * QK_MXFP8 + j, vmulq_f32(val_lo, v_scale)); + vst1q_f32(y + ib * QK_MXFP8 + j + 4, vmulq_f32(val_hi, v_scale)); + } + } } -void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { -#if defined(__ARM_NEON) - dequantize_row_mxfp6_neon(x, y, k, sizeof(block_mxfp6), - MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, - MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); -#else - dequantize_row_mxfp6_cpu_generic(x, y, k); -#endif +static void dequantize_row_mxfp6_soa_neon( + const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, + const mxfp_neon_traits_t * t) { + assert(k % QK_MXFP6 == 0); + const int nb = k / QK_MXFP6; + const char * qs_base = (const char *)src; + const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK); + + const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask); + const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask); + const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off); + const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale); + const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift); + const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift); + + for (int ib = 0; ib < nb; ++ib) { + const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib])); + const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP6_SOA_QS_PER_BLOCK)); + + for (int j = 0; j < 32; j += 4) { + const uint32x4_t v_raw = unpack_fp6x4_neon(qs + (j * 3 / 4)); + + const float32x4_t val = mxfp6_dequant_neon(v_raw, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); + + vst1q_f32(y + ib * QK_MXFP6 + j, vmulq_f32(val, v_scale)); + } + } } -// ---- MXFP SoA dequantize_row (to_float) — NEON-optimized ---- - -#if defined(__ARM_NEON) -static inline void dequantize_row_mxfp4_soa_neon( +// MXFP4 SoA dequant — LUT-based, no IEEE reconstruction needed. +static void dequantize_row_mxfp4_soa_neon( const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) { assert(k % QK_MXFP4 == 0); const int nb = k / QK_MXFP4; - const char * row = (const char *)src; - const char * qs_base = row; - const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK); + const char * qs_base = (const char *)src; + const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK); const int8x16_t values = vld1q_s8(kvalues_mxfp4); const uint8x16_t m4b = vdupq_n_u8(0x0f); @@ -4527,122 +4490,45 @@ static inline void dequantize_row_mxfp4_soa_neon( } } -static inline void dequantize_row_mxfp8_soa_neon( - const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, - const uint32_t exp_mask, const uint32_t mant_mask, - const int exp_shift, const uint32_t ieee_exp_off, - const int mant_shift, const float sub_scale) { - assert(k % QK_MXFP8 == 0); - const int nb = k / QK_MXFP8; - const char * row = (const char *)src; - const char * qs_base = row; - const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK); +#endif // __ARM_NEON - const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask); - const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask); - const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off); - const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale); - const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift); - const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift); +// ── Public dispatch functions ────────────────────────────────────────────── - for (int ib = 0; ib < nb; ++ib) { - const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib])); - const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP8_SOA_QS_PER_BLOCK)); - - for (int j = 0; j < 32; j += 8) { - const uint8x8_t raw8 = vld1_u8(qs + j); - const uint16x8_t raw16 = vmovl_u8(raw8); - const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16)); - const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16)); - - #define DEQUANT_FP8_STORE_SOA(v_raw, dst) do { \ - const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x80)); \ - const uint32x4_t exp = vandq_u32( \ - vshlq_u32(v_raw, v_neg_exp_shift), \ - v_exp_mask); \ - const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \ - const uint32x4_t ieee = vorrq_u32( \ - vorrq_u32(vshlq_n_u32(sign, 24), \ - vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \ - vshlq_u32(mant, v_mant_shift_v)); \ - const float32x4_t normal = vreinterpretq_f32_u32(ieee); \ - const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \ - const uint32x4_t sub_bits = vorrq_u32( \ - vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 24)); \ - const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \ - const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \ - const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \ - vst1q_f32(dst, vmulq_f32(val, v_scale)); \ - } while (0) - - DEQUANT_FP8_STORE_SOA(v_lo, y + ib * QK_MXFP8 + j); - DEQUANT_FP8_STORE_SOA(v_hi, y + ib * QK_MXFP8 + j + 4); - #undef DEQUANT_FP8_STORE_SOA - } - } -} - -static inline void dequantize_row_mxfp6_soa_neon( - const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, - const uint32_t exp_mask, const uint32_t mant_mask, - const int exp_shift, const uint32_t ieee_exp_off, - const int mant_shift, const float sub_scale) { - assert(k % QK_MXFP6 == 0); - const int nb = k / QK_MXFP6; - const char * row = (const char *)src; - const char * qs_base = row; - const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK); - - const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask); - const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask); - const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off); - const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale); - const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift); - const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift); - - for (int ib = 0; ib < nb; ++ib) { - const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib])); - const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP6_SOA_QS_PER_BLOCK)); - - for (int j = 0; j < 32; j += 4) { - const uint8_t * p = qs + (j * 3 / 4); - const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); - uint8_t unpacked[4]; - unpacked[0] = (pk >> 0) & 0x3F; - unpacked[1] = (pk >> 6) & 0x3F; - unpacked[2] = (pk >> 12) & 0x3F; - unpacked[3] = (pk >> 18) & 0x3F; - - const uint8x8_t raw8 = vcreate_u8( - (uint64_t)unpacked[0] | ((uint64_t)unpacked[1] << 8) | - ((uint64_t)unpacked[2] << 16) | ((uint64_t)unpacked[3] << 24)); - const uint32x4_t v_raw = vmovl_u16(vget_low_u16(vmovl_u8(raw8))); - - const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x20)); - const uint32x4_t exp = vandq_u32( - vshlq_u32(v_raw, v_neg_exp_shift), - v_exp_mask); - const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); - - const uint32x4_t ieee = vorrq_u32( - vorrq_u32(vshlq_n_u32(sign, 26), - vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), - vshlq_u32(mant, v_mant_shift_v)); - const float32x4_t normal = vreinterpretq_f32_u32(ieee); - - const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); - const uint32x4_t sub_bits = vorrq_u32( - vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 26)); - const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); - - const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); - const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); - - vst1q_f32(y + ib * QK_MXFP6 + j, vmulq_f32(val, v_scale)); - } - } -} +void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); +#if defined(__ARM_NEON) + ggml_vec_dot_mxfp8_q8_0_neon(n, s, vx, vy, &MXFP_TRAITS_E4M3); +#else + ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); #endif +} + +void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); +#if defined(__ARM_NEON) + ggml_vec_dot_mxfp6_q8_0_neon(n, s, vx, vy, &MXFP_TRAITS_E2M3); +#else + ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +#if defined(__ARM_NEON) + dequantize_row_mxfp8_neon(x, y, k, &MXFP_TRAITS_E4M3); +#else + dequantize_row_mxfp8_cpu_generic(x, y, k); +#endif +} + +void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +#if defined(__ARM_NEON) + dequantize_row_mxfp6_neon(x, y, k, &MXFP_TRAITS_E2M3); +#else + dequantize_row_mxfp6_cpu_generic(x, y, k); +#endif +} void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { #if defined(__ARM_NEON) @@ -4654,9 +4540,7 @@ void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RES void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { #if defined(__ARM_NEON) - dequantize_row_mxfp8_soa_neon(x, y, k, - MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT, - MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE); + dequantize_row_mxfp8_soa_neon(x, y, k, &MXFP_TRAITS_E4M3); #else dequantize_row_mxfp8_soa_cpu_generic(x, y, k); #endif @@ -4664,9 +4548,7 @@ void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RES void dequantize_row_mxfp6_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { #if defined(__ARM_NEON) - dequantize_row_mxfp6_soa_neon(x, y, k, - MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, - MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); + dequantize_row_mxfp6_soa_neon(x, y, k, &MXFP_TRAITS_E2M3); #else dequantize_row_mxfp6_soa_cpu_generic(x, y, k); #endif diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index 0c6f6ed49a..b00b1467d3 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -3819,30 +3819,77 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -// AVX2-optimized MXFP8 × Q8_0 dot product. -// Dequants FP8 elements to float via IEEE 754 bit construction, then dots against Q8_0. -// Parameters encode the FP8 format: exp_mask, mant_mask, exp_shift, ieee_exp_offset, mant_shift, sub_scale. +// ── MXFP FP8/FP6 AVX2 helpers ────────────────────────────────────────────── +// Shared IEEE-754 bit reconstruction and FP6 unpacking used by vec_dot, +// dequantize_row, and SoA dequant functions. + #if defined(__AVX2__) -static inline void ggml_vec_dot_mxfp8_q8_0_avx2( + +// Use shared mxfp_dequant_traits_t from ggml-common.h. +// Aliases for readability within this file. +#define mxfp_avx2_traits_t mxfp_dequant_traits_t + +// Dequantize 8 raw MXFP values (widened to int32) → 8 IEEE-754 floats. +// Handles both normal and subnormal paths. Works for any FP6/FP8 format. +static inline __m256 mxfp_dequant_avx2( + const __m256i v_raw, + const __m256i v_exp_mask, const __m256i v_mant_mask, + const __m256i v_ieee_off, const __m256 v_sub_sc, + const __m256i v_sign_mask, const __m256i v_zero, + int exp_shift, int sign_shift, int mant_shift) { + const __m256i sign = _mm256_and_si256(v_raw, v_sign_mask); + const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask); + const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask); + + const __m256i ieee = _mm256_or_si256( + _mm256_or_si256(_mm256_slli_epi32(sign, sign_shift), + _mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)), + _mm256_slli_epi32(mant, mant_shift)); + const __m256 normal = _mm256_castsi256_ps(ieee); + + const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc); + const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256( + _mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, sign_shift))); + + const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero)); + return _mm256_blendv_ps(normal, sub_val, is_sub); +} + +// Unpack 4 tightly-packed 6-bit values from 3 bytes into separate bytes. +static inline void unpack_fp6x4_avx2(const uint8_t * p, uint8_t out[4]) { + const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); + out[0] = (pk >> 0) & 0x3F; + out[1] = (pk >> 6) & 0x3F; + out[2] = (pk >> 12) & 0x3F; + out[3] = (pk >> 18) & 0x3F; +} + +// Unpack 8 FP6 values (two groups of 4) from packed qs data at offset j. +static inline __m256i unpack_fp6x8_avx2(const uint8_t * qs, int j) { + uint8_t unpacked[8]; + unpack_fp6x4_avx2(qs + (j * 3 / 4), unpacked); + unpack_fp6x4_avx2(qs + ((j + 4) * 3 / 4), unpacked + 4); + return _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)unpacked)); +} + +// ── MXFP FP8/FP6 vec_dot ────────────────────────────────────────────────── + +// Unified FP8 × Q8_0 dot product (works for E4M3 and E5M2). +static void ggml_vec_dot_mxfp8_q8_0_avx2( int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, - // FP8 format parameters: - const int exp_mask, // 0xF for E4M3, 0x1F for E5M2 - const int mant_mask, // 0x7 for E4M3, 0x3 for E5M2 - const int exp_shift, // 3 for E4M3, 2 for E5M2 - const int ieee_exp_off, // 120 for E4M3, 112 for E5M2 - const int mant_shift, // 20 for E4M3, 21 for E5M2 - const float sub_scale) { // 1/512 for E4M3, 1/65536 for E5M2 + const mxfp_avx2_traits_t * t) { assert(n % QK_MXFP8 == 0); const int nb = n / QK_MXFP8; const block_mxfp8 * GGML_RESTRICT x = vx; const block_q8_0 * GGML_RESTRICT y = vy; - const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask); - const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask); - const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off); - const __m256 v_sub_sc = _mm256_set1_ps(sub_scale); + const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask); + const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask); + const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off); + const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale); + const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask); const __m256i v_zero = _mm256_setzero_si256(); __m256 acc = _mm256_setzero_ps(); @@ -3851,141 +3898,55 @@ static inline void ggml_vec_dot_mxfp8_q8_0_avx2( const __m256 v_scale = _mm256_set1_ps( GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d)); - // Process 32 FP8 elements in 4 groups of 8 - // AVX2 _mm256_cvtepu8_epi32 widens 8 bytes → 8 int32s directly for (int j = 0; j < 32; j += 8) { - // Load 8 FP8 bytes → 8 int32s - const __m128i raw8 = _mm_loadl_epi64((const __m128i *)(x[ib].qs + j)); - const __m256i v_raw = _mm256_cvtepu8_epi32(raw8); + const __m256i v_raw = _mm256_cvtepu8_epi32( + _mm_loadl_epi64((const __m128i *)(x[ib].qs + j))); + const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32( + _mm_loadl_epi64((const __m128i *)(y[ib].qs + j)))); - // Load 8 Q8_0 int8 values → float - const __m128i q8 = _mm_loadl_epi64((const __m128i *)(y[ib].qs + j)); - const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(q8)); + const __m256 val = mxfp_dequant_avx2(v_raw, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, + v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift); - // Extract sign (bit 7), exponent, mantissa - const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x80)); - const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask); - const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask); - - // Normal path: IEEE bits = (sign << 24) | ((exp + offset) << 23) | (mant << mant_shift) - const __m256i ieee = _mm256_or_si256( - _mm256_or_si256(_mm256_slli_epi32(sign, 24), - _mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)), - _mm256_slli_epi32(mant, mant_shift)); - const __m256 normal = _mm256_castsi256_ps(ieee); - - // Subnormal path: |val| = mant * sub_scale, then apply sign - const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc); - const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256( - _mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 24))); - - // Select: subnormal when exp == 0, else normal - const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero)); - const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub); - - // Accumulate: val * scale * q8_float acc = _mm256_fmadd_ps(_mm256_mul_ps(val, v_scale), qf, acc); } } *s = hsum_float_8(acc); } -#endif -void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); -#if defined(__AVX2__) - // E4M3: sign(1) exp(4) mant(3), bias=7 - ggml_vec_dot_mxfp8_q8_0_avx2(n, s, vx, vy, - MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT, - MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE); -#else - ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif -} - -// AVX2-optimized MXFP6 × Q8_0 dot product. -// Unpacks tight 6-bit packing (4 values per 3 bytes), then dequants to float. -#if defined(__AVX2__) -static inline void ggml_vec_dot_mxfp6_q8_0_avx2( +// Unified FP6 × Q8_0 dot product (works for E2M3 and E3M2). +static void ggml_vec_dot_mxfp6_q8_0_avx2( int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, - size_t block_size, - // FP6 format parameters: - const int exp_mask, // 0x3 for E2M3, 0x7 for E3M2 - const int mant_mask, // 0x7 for E2M3, 0x3 for E3M2 - const int exp_shift, // 3 for E2M3, 2 for E3M2 - const int ieee_exp_off, // 126 for E2M3, 124 for E3M2 - const int mant_shift, // 20 for E2M3, 21 for E3M2 - const float sub_scale) { // 1/8 for E2M3, 1/16 for E3M2 + const mxfp_avx2_traits_t * t) { assert(n % QK_MXFP6 == 0); const int nb = n / QK_MXFP6; const block_q8_0 * GGML_RESTRICT y = vy; - const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask); - const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask); - const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off); - const __m256 v_sub_sc = _mm256_set1_ps(sub_scale); + const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask); + const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask); + const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off); + const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale); + const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask); const __m256i v_zero = _mm256_setzero_si256(); __m256 acc = _mm256_setzero_ps(); for (int ib = 0; ib < nb; ++ib) { - const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size); + const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib; const __m256 v_scale = _mm256_set1_ps( GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d)); - // Process 32 FP6 elements in 4 groups of 8 (each group = 2 × 3-byte packs) for (int j = 0; j < 32; j += 8) { - // Unpack 8 FP6 values from 6 bytes (two groups of 3 bytes → 4 values each) - uint8_t unpacked[8]; - { - const uint8_t * p = xb->qs + (j * 3 / 4); - const uint32_t pk0 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); - unpacked[0] = (pk0 >> 0) & 0x3F; - unpacked[1] = (pk0 >> 6) & 0x3F; - unpacked[2] = (pk0 >> 12) & 0x3F; - unpacked[3] = (pk0 >> 18) & 0x3F; - } - { - const uint8_t * p = xb->qs + ((j + 4) * 3 / 4); - const uint32_t pk1 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); - unpacked[4] = (pk1 >> 0) & 0x3F; - unpacked[5] = (pk1 >> 6) & 0x3F; - unpacked[6] = (pk1 >> 12) & 0x3F; - unpacked[7] = (pk1 >> 18) & 0x3F; - } + const __m256i v_raw = unpack_fp6x8_avx2(xb->qs, j); + const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32( + _mm_loadl_epi64((const __m128i *)(y[ib].qs + j)))); - // Widen 8 bytes → 8 int32s - const __m128i raw8 = _mm_loadl_epi64((const __m128i *)unpacked); - const __m256i v_raw = _mm256_cvtepu8_epi32(raw8); - - // Load 8 Q8_0 int8 values → float - const __m128i q8 = _mm_loadl_epi64((const __m128i *)(y[ib].qs + j)); - const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(q8)); - - // Extract sign (bit 5 for FP6), exponent, mantissa - const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x20)); - const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask); - const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask); - - // Normal: IEEE bits = (sign << 26) | ((exp + offset) << 23) | (mant << mant_shift) - const __m256i ieee = _mm256_or_si256( - _mm256_or_si256(_mm256_slli_epi32(sign, 26), - _mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)), - _mm256_slli_epi32(mant, mant_shift)); - const __m256 normal = _mm256_castsi256_ps(ieee); - - // Subnormal: |val| = mant * sub_scale, apply sign - const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc); - const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256( - _mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 26))); - - // Select: subnormal when exp == 0 - const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero)); - const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub); + const __m256 val = mxfp_dequant_avx2(v_raw, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, + v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift); acc = _mm256_fmadd_ps(_mm256_mul_ps(val, v_scale), qf, acc); } @@ -3993,162 +3954,140 @@ static inline void ggml_vec_dot_mxfp6_q8_0_avx2( *s = hsum_float_8(acc); } -#endif -void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); -#if defined(__AVX2__) - // E2M3: sign(1) exp(2) mant(3), bias=1 - ggml_vec_dot_mxfp6_q8_0_avx2(n, s, vx, vy, sizeof(block_mxfp6), - MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, - MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); -#else - ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif -} +// ── MXFP FP8/FP6 dequantize_row (AoS) ───────────────────────────────────── -// ---- MXFP dequantize_row (to_float) — AVX2-optimized ---- -// Extracts the SIMD dequant logic from vec_dot above, writing floats to output buffer -// instead of accumulating a dot product. - -#if defined(__AVX2__) -static inline void dequantize_row_mxfp8_avx2( +static void dequantize_row_mxfp8_avx2( const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k, - const int exp_mask, const int mant_mask, const int exp_shift, - const int ieee_exp_off, const int mant_shift, const float sub_scale) { + const mxfp_avx2_traits_t * t) { assert(k % QK_MXFP8 == 0); const int nb = k / QK_MXFP8; const block_mxfp8 * GGML_RESTRICT x = vx; - const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask); - const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask); - const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off); - const __m256 v_sub_sc = _mm256_set1_ps(sub_scale); + const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask); + const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask); + const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off); + const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale); + const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask); const __m256i v_zero = _mm256_setzero_si256(); for (int ib = 0; ib < nb; ++ib) { const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32(x[ib].e)); for (int j = 0; j < 32; j += 8) { - const __m128i raw8 = _mm_loadl_epi64((const __m128i *)(x[ib].qs + j)); - const __m256i v_raw = _mm256_cvtepu8_epi32(raw8); + const __m256i v_raw = _mm256_cvtepu8_epi32( + _mm_loadl_epi64((const __m128i *)(x[ib].qs + j))); - const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x80)); - const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask); - const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask); - - const __m256i ieee = _mm256_or_si256( - _mm256_or_si256(_mm256_slli_epi32(sign, 24), - _mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)), - _mm256_slli_epi32(mant, mant_shift)); - const __m256 normal = _mm256_castsi256_ps(ieee); - - const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc); - const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256( - _mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 24))); - - const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero)); - const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub); + const __m256 val = mxfp_dequant_avx2(v_raw, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, + v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift); _mm256_storeu_ps(y + ib * QK_MXFP8 + j, _mm256_mul_ps(val, v_scale)); } } } -static inline void dequantize_row_mxfp6_avx2( +static void dequantize_row_mxfp6_avx2( const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k, - size_t block_size, - const int exp_mask, const int mant_mask, const int exp_shift, - const int ieee_exp_off, const int mant_shift, const float sub_scale) { + const mxfp_avx2_traits_t * t) { assert(k % QK_MXFP6 == 0); const int nb = k / QK_MXFP6; - const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask); - const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask); - const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off); - const __m256 v_sub_sc = _mm256_set1_ps(sub_scale); + const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask); + const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask); + const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off); + const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale); + const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask); const __m256i v_zero = _mm256_setzero_si256(); for (int ib = 0; ib < nb; ++ib) { - const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size); + const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib; const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32(xb->e)); for (int j = 0; j < 32; j += 8) { - // Unpack 8 FP6 values from 6 bytes (two groups of 3 bytes → 4 values each) - uint8_t unpacked[8]; - { - const uint8_t * p = xb->qs + (j * 3 / 4); - const uint32_t pk0 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); - unpacked[0] = (pk0 >> 0) & 0x3F; - unpacked[1] = (pk0 >> 6) & 0x3F; - unpacked[2] = (pk0 >> 12) & 0x3F; - unpacked[3] = (pk0 >> 18) & 0x3F; - } - { - const uint8_t * p = xb->qs + ((j + 4) * 3 / 4); - const uint32_t pk1 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); - unpacked[4] = (pk1 >> 0) & 0x3F; - unpacked[5] = (pk1 >> 6) & 0x3F; - unpacked[6] = (pk1 >> 12) & 0x3F; - unpacked[7] = (pk1 >> 18) & 0x3F; - } + const __m256i v_raw = unpack_fp6x8_avx2(xb->qs, j); - const __m128i raw8 = _mm_loadl_epi64((const __m128i *)unpacked); - const __m256i v_raw = _mm256_cvtepu8_epi32(raw8); - - const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x20)); - const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask); - const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask); - - const __m256i ieee = _mm256_or_si256( - _mm256_or_si256(_mm256_slli_epi32(sign, 26), - _mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)), - _mm256_slli_epi32(mant, mant_shift)); - const __m256 normal = _mm256_castsi256_ps(ieee); - - const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc); - const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256( - _mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 26))); - - const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero)); - const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub); + const __m256 val = mxfp_dequant_avx2(v_raw, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, + v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift); _mm256_storeu_ps(y + ib * QK_MXFP6 + j, _mm256_mul_ps(val, v_scale)); } } } -#endif -void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { -#if defined(__AVX2__) - dequantize_row_mxfp8_avx2(x, y, k, - MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT, - MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE); -#else - dequantize_row_mxfp8_cpu_generic(x, y, k); -#endif +// ── MXFP SoA dequant (flash attention) ───────────────────────────────────── + +static void dequantize_row_mxfp8_soa_avx2( + const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, + const mxfp_avx2_traits_t * t) { + assert(k % QK_MXFP8 == 0); + const int nb = k / QK_MXFP8; + const char * qs_base = (const char *)src; + const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK); + + const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask); + const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask); + const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off); + const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale); + const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask); + const __m256i v_zero = _mm256_setzero_si256(); + + for (int ib = 0; ib < nb; ++ib) { + const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib])); + const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP8_SOA_QS_PER_BLOCK)); + + for (int j = 0; j < 32; j += 8) { + const __m256i v_raw = _mm256_cvtepu8_epi32( + _mm_loadl_epi64((const __m128i *)(qs + j))); + + const __m256 val = mxfp_dequant_avx2(v_raw, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, + v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift); + + _mm256_storeu_ps(y + ib * QK_MXFP8 + j, _mm256_mul_ps(val, v_scale)); + } + } } -void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { -#if defined(__AVX2__) - dequantize_row_mxfp6_avx2(x, y, k, sizeof(block_mxfp6), - MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, - MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); -#else - dequantize_row_mxfp6_cpu_generic(x, y, k); -#endif +static void dequantize_row_mxfp6_soa_avx2( + const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, + const mxfp_avx2_traits_t * t) { + assert(k % QK_MXFP6 == 0); + const int nb = k / QK_MXFP6; + const char * qs_base = (const char *)src; + const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK); + + const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask); + const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask); + const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off); + const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale); + const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask); + const __m256i v_zero = _mm256_setzero_si256(); + + for (int ib = 0; ib < nb; ++ib) { + const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib])); + const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP6_SOA_QS_PER_BLOCK)); + + for (int j = 0; j < 32; j += 8) { + const __m256i v_raw = unpack_fp6x8_avx2(qs, j); + + const __m256 val = mxfp_dequant_avx2(v_raw, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, + v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift); + + _mm256_storeu_ps(y + ib * QK_MXFP6 + j, _mm256_mul_ps(val, v_scale)); + } + } } -// SoA dequant for flash attention — contiguous qs region + separate e8m0 region -#if defined(__AVX2__) -static inline void dequantize_row_mxfp4_soa_avx2( +// MXFP4 SoA dequant — LUT-based, no IEEE reconstruction needed. +static void dequantize_row_mxfp4_soa_avx2( const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) { assert(k % QK_MXFP4 == 0); const int nb = k / QK_MXFP4; - const char * row = (const char *)src; - const char * qs_base = row; - const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK); + const char * qs_base = (const char *)src; + const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK); const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4); const __m128i m4b = _mm_set1_epi8(0x0f); @@ -4163,13 +4102,11 @@ static inline void dequantize_row_mxfp4_soa_avx2( const __m128i lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits, m4b)); const __m128i hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4b)); - // lo nibbles → first 16 floats const __m256i lo32_0 = _mm256_cvtepi8_epi32(lo); const __m256i lo32_1 = _mm256_cvtepi8_epi32(_mm_srli_si128(lo, 8)); _mm256_storeu_ps(y + i * QK_MXFP4 + 0, _mm256_mul_ps(_mm256_cvtepi32_ps(lo32_0), v_scale)); _mm256_storeu_ps(y + i * QK_MXFP4 + 8, _mm256_mul_ps(_mm256_cvtepi32_ps(lo32_1), v_scale)); - // hi nibbles → second 16 floats const __m256i hi32_0 = _mm256_cvtepi8_epi32(hi); const __m256i hi32_1 = _mm256_cvtepi8_epi32(_mm_srli_si128(hi, 8)); _mm256_storeu_ps(y + i * QK_MXFP4 + 16, _mm256_mul_ps(_mm256_cvtepi32_ps(hi32_0), v_scale)); @@ -4177,116 +4114,45 @@ static inline void dequantize_row_mxfp4_soa_avx2( } } -static inline void dequantize_row_mxfp8_soa_avx2( - const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, - const int exp_mask, const int mant_mask, const int exp_shift, - const int ieee_exp_off, const int mant_shift, const float sub_scale) { - assert(k % QK_MXFP8 == 0); - const int nb = k / QK_MXFP8; - const char * row = (const char *)src; - const char * qs_base = row; - const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK); +#endif // __AVX2__ - const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask); - const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask); - const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off); - const __m256 v_sub_sc = _mm256_set1_ps(sub_scale); - const __m256i v_zero = _mm256_setzero_si256(); +// ── Public dispatch functions ────────────────────────────────────────────── - for (int ib = 0; ib < nb; ++ib) { - const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib])); - const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP8_SOA_QS_PER_BLOCK)); - - for (int j = 0; j < 32; j += 8) { - const __m128i raw8 = _mm_loadl_epi64((const __m128i *)(qs + j)); - const __m256i v_raw = _mm256_cvtepu8_epi32(raw8); - - const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x80)); - const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask); - const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask); - - const __m256i ieee = _mm256_or_si256( - _mm256_or_si256(_mm256_slli_epi32(sign, 24), - _mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)), - _mm256_slli_epi32(mant, mant_shift)); - const __m256 normal = _mm256_castsi256_ps(ieee); - - const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc); - const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256( - _mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 24))); - - const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero)); - const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub); - - _mm256_storeu_ps(y + ib * QK_MXFP8 + j, _mm256_mul_ps(val, v_scale)); - } - } -} - -static inline void dequantize_row_mxfp6_soa_avx2( - const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, - const int exp_mask, const int mant_mask, const int exp_shift, - const int ieee_exp_off, const int mant_shift, const float sub_scale) { - assert(k % QK_MXFP6 == 0); - const int nb = k / QK_MXFP6; - const char * row = (const char *)src; - const char * qs_base = row; - const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK); - - const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask); - const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask); - const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off); - const __m256 v_sub_sc = _mm256_set1_ps(sub_scale); - const __m256i v_zero = _mm256_setzero_si256(); - - for (int ib = 0; ib < nb; ++ib) { - const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib])); - const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP6_SOA_QS_PER_BLOCK)); - - for (int j = 0; j < 32; j += 8) { - uint8_t unpacked[8]; - { - const uint8_t * p = qs + (j * 3 / 4); - const uint32_t pk0 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); - unpacked[0] = (pk0 >> 0) & 0x3F; - unpacked[1] = (pk0 >> 6) & 0x3F; - unpacked[2] = (pk0 >> 12) & 0x3F; - unpacked[3] = (pk0 >> 18) & 0x3F; - } - { - const uint8_t * p = qs + ((j + 4) * 3 / 4); - const uint32_t pk1 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); - unpacked[4] = (pk1 >> 0) & 0x3F; - unpacked[5] = (pk1 >> 6) & 0x3F; - unpacked[6] = (pk1 >> 12) & 0x3F; - unpacked[7] = (pk1 >> 18) & 0x3F; - } - - const __m128i raw8 = _mm_loadl_epi64((const __m128i *)unpacked); - const __m256i v_raw = _mm256_cvtepu8_epi32(raw8); - - const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x20)); - const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask); - const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask); - - const __m256i ieee = _mm256_or_si256( - _mm256_or_si256(_mm256_slli_epi32(sign, 26), - _mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)), - _mm256_slli_epi32(mant, mant_shift)); - const __m256 normal = _mm256_castsi256_ps(ieee); - - const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc); - const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256( - _mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 26))); - - const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero)); - const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub); - - _mm256_storeu_ps(y + ib * QK_MXFP6 + j, _mm256_mul_ps(val, v_scale)); - } - } -} +void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); +#if defined(__AVX2__) + ggml_vec_dot_mxfp8_q8_0_avx2(n, s, vx, vy, &MXFP_TRAITS_E4M3); +#else + ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); #endif +} + +void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); +#if defined(__AVX2__) + ggml_vec_dot_mxfp6_q8_0_avx2(n, s, vx, vy, &MXFP_TRAITS_E2M3); +#else + ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +#if defined(__AVX2__) + dequantize_row_mxfp8_avx2(x, y, k, &MXFP_TRAITS_E4M3); +#else + dequantize_row_mxfp8_cpu_generic(x, y, k); +#endif +} + +void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +#if defined(__AVX2__) + dequantize_row_mxfp6_avx2(x, y, k, &MXFP_TRAITS_E2M3); +#else + dequantize_row_mxfp6_cpu_generic(x, y, k); +#endif +} void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { #if defined(__AVX2__) @@ -4298,9 +4164,7 @@ void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RES void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { #if defined(__AVX2__) - dequantize_row_mxfp8_soa_avx2(x, y, k, - MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT, - MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE); + dequantize_row_mxfp8_soa_avx2(x, y, k, &MXFP_TRAITS_E4M3); #else dequantize_row_mxfp8_soa_cpu_generic(x, y, k); #endif @@ -4308,9 +4172,7 @@ void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RES void dequantize_row_mxfp6_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { #if defined(__AVX2__) - dequantize_row_mxfp6_soa_avx2(x, y, k, - MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, - MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); + dequantize_row_mxfp6_soa_avx2(x, y, k, &MXFP_TRAITS_E2M3); #else dequantize_row_mxfp6_soa_cpu_generic(x, y, k); #endif diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 27275ca1e1..12f04905af 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8297,8 +8297,8 @@ static mxfp_fa_params mxfp_fa_params_init( if (is_mxfp_k) { switch (k->type) { - case GGML_TYPE_MXFP4_E2M1: p.q_quantize = quantize_row_mxfp4_soa; p.k_dequantize = dequantize_row_mxfp4_soa_cpu; break; - case GGML_TYPE_MXFP8_E4M3: p.q_quantize = quantize_row_mxfp8_soa; p.k_dequantize = dequantize_row_mxfp8_soa_cpu; break; + case GGML_TYPE_MXFP4_E2M1: p.q_quantize = quantize_row_mxfp4_soa; p.k_dequantize = dequantize_row_mxfp4_soa_cpu; break; + case GGML_TYPE_MXFP8_E4M3: p.q_quantize = quantize_row_mxfp8_soa; p.k_dequantize = dequantize_row_mxfp8_soa_cpu; break; case GGML_TYPE_MXFP6_E2M3: p.q_quantize = quantize_row_mxfp6_soa; p.k_dequantize = dequantize_row_mxfp6_soa_cpu; break; default: GGML_ABORT("unsupported MXFP K type"); } @@ -8306,8 +8306,8 @@ static mxfp_fa_params mxfp_fa_params_init( if (is_mxfp_v) { switch (v->type) { - case GGML_TYPE_MXFP4_E2M1: p.v_dequantize = dequantize_row_mxfp4_soa_cpu; break; - case GGML_TYPE_MXFP8_E4M3: p.v_dequantize = dequantize_row_mxfp8_soa_cpu; break; + case GGML_TYPE_MXFP4_E2M1: p.v_dequantize = dequantize_row_mxfp4_soa_cpu; break; + case GGML_TYPE_MXFP8_E4M3: p.v_dequantize = dequantize_row_mxfp8_soa_cpu; break; case GGML_TYPE_MXFP6_E2M3: p.v_dequantize = dequantize_row_mxfp6_soa_cpu; break; default: GGML_ABORT("unsupported MXFP V type"); } @@ -8328,6 +8328,7 @@ static mxfp_fa_params mxfp_fa_params_init( // Per-head SoA addressing for multihead mode. // Precompute byte offsets so the hot loop can skip per-head pointer math. + // qs_per_block values from centralized MXFP_QS_PER_BLOCK_* defines in ggml-common.h. auto mxfp_qs_per_block = [](ggml_type type) -> int { switch (type) { case GGML_TYPE_MXFP4_E2M1: return MXFP4_SOA_QS_PER_BLOCK; @@ -8341,7 +8342,6 @@ static mxfp_fa_params mxfp_fa_params_init( p.k_qs_per_block = mxfp_qs_per_block(k->type); p.k_blocks_per_head = (int)(DK / 32); p.k_head_qs_bytes = p.k_blocks_per_head * p.k_qs_per_block; - // e8m0 offset from row start = total_blocks * qs_per_block const int64_t k_total_blocks = p.k_multihead ? nek2 * p.k_blocks_per_head : p.k_blocks_per_head; p.k_head_e8m0_offset = k_total_blocks * p.k_qs_per_block; } diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index a0f6928e10..b386446035 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -61,7 +61,7 @@ GGML_API void quantize_row_mxfp4_soa (const float * GGML_RESTRICT x, void * G GGML_API void dequantize_row_mxfp4_soa (const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_mxfp8_soa (const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); GGML_API void dequantize_row_mxfp8_soa (const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); -GGML_API void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); +GGML_API void quantize_row_mxfp6_soa (const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); GGML_API void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); From 8036edc99aa8b7d6c7d6bbe2b89ca480c3ad9006 Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Sun, 15 Mar 2026 22:55:27 -0400 Subject: [PATCH 07/13] ggml: eliminate hot-path heap allocations and fix tiled MXFP multihead dequant Replace per-row/per-tile std::vector heap allocations with stack buffers in set_rows, one_chunk, and tiled flash attention paths. Fix tiled path to use per-head SoA extraction (matching one_chunk) instead of dequanting the full multihead region per token. --- ggml/src/ggml-cpu/arch-fallback.h | 12 ++--- ggml/src/ggml-cpu/ops.cpp | 74 ++++++++++++++++++++----------- ggml/src/ggml-quants.h | 10 ++--- tests/test-backend-ops.cpp | 3 +- 4 files changed, 61 insertions(+), 38 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 612786e941..3f01c0b1c7 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -16,8 +16,8 @@ #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 -#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0 -#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0 +#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0 +#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -348,9 +348,9 @@ // All other targets use the scalar generic as the public cpu function. #if !defined(__aarch64__) && !defined(__arm__) && !defined(_M_ARM) && !defined(_M_ARM64) && \ !defined(__x86_64__) && !defined(__i386__) && !defined(_M_IX86) && !defined(_M_X64) -#define dequantize_row_mxfp8_cpu_generic dequantize_row_mxfp8_cpu -#define dequantize_row_mxfp6_cpu_generic dequantize_row_mxfp6_cpu -#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu -#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu +#define dequantize_row_mxfp8_cpu_generic dequantize_row_mxfp8_cpu +#define dequantize_row_mxfp6_cpu_generic dequantize_row_mxfp6_cpu +#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu +#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu #define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu #endif diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 12f04905af..81cbbdb4bf 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -5046,6 +5046,13 @@ static void ggml_compute_forward_set_rows_f32( break; } + // Pre-allocate Hadamard temp buffer once outside the hot loop (nc is constant). + // nc == n_embd_k_gqa which is bounded by model architecture (typically <= 8192). + std::vector had_tmp; + if (apply_hadamard) { + had_tmp.resize(nc); + } + for (int64_t i03 = 0; i03 < ne03; ++i03) { for (int64_t i02 = 0; i02 < ne02; ++i02) { for (int64_t i = ir0; i < ir1; ++i) { @@ -5061,13 +5068,12 @@ static void ggml_compute_forward_set_rows_f32( char * dst_row = ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3); if (apply_hadamard) { - std::vector tmp(nc); - memcpy(tmp.data(), src_row, nc * sizeof(float)); - ggml_apply_hadamard_blocks(tmp.data(), nc); + memcpy(had_tmp.data(), src_row, nc * sizeof(float)); + ggml_apply_hadamard_blocks(had_tmp.data(), nc); if (mxfp_soa_quantize) { - mxfp_soa_quantize(tmp.data(), dst_row, nc); + mxfp_soa_quantize(had_tmp.data(), dst_row, nc); } else { - from_float(tmp.data(), dst_row, nc); + from_float(had_tmp.data(), dst_row, nc); } } else { if (mxfp_soa_quantize) { @@ -8465,12 +8471,10 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( // Per-head SoA temp buffer: holds [qs | e8m0] for one head in multihead mode. // For DK=1024 with MXFP8: 32 blocks * 32 qs + 32 e8m0 = 1056 bytes. - const size_t k_head_soa_size = is_mxfp_k ? (size_t)(mxfp.k_head_qs_bytes + mxfp.k_blocks_per_head) : 0; - const size_t v_head_soa_size = is_mxfp_v ? (size_t)(mxfp.v_head_qs_bytes + mxfp.v_blocks_per_head) : 0; - std::vector k_head_soa_vec(k_head_soa_size); - std::vector v_head_soa_vec(v_head_soa_size); - char * k_head_soa = k_head_soa_vec.data(); - char * v_head_soa = v_head_soa_vec.data(); + // Stack-allocated since sizes are bounded by DK/DV <= 1024. + // Max: 1024/32 * 32(qs) + 1024/32 = 1056 bytes (MXFP8). + char k_head_soa[1088]; // 1056 rounded up for alignment + char v_head_soa[1088]; // Thread-local work buffers (constant across ir loop) float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); @@ -8769,6 +8773,15 @@ static void ggml_compute_forward_flash_attn_ext_tiled( static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q; static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV; + // MXFP dequant scratch buffers — allocated once per thread, reused across all tiles. + // DK/DV bounded by 1024, so per-head dequant fits in stack buffers. + float k_dequant_buf[1024]; + float v_dequant_buf[1024]; + + // Per-head SoA temp buffers for multihead extraction (same as one_chunk path). + char k_head_soa[1088]; + char v_head_soa[1088]; + int ir = ir0; while (ir < ir1) { // q indices for the start of this tile @@ -8847,10 +8860,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled( memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float)); memset(V32, 0, KV_TILE_SZ * DV * sizeof(float)); - // dequant scratch buffers for SoA MXFP — allocated once per tile, reused per KV token - std::vector k_soa_buf(mxfp.k_soa_elems); - std::vector v_soa_buf(mxfp.v_soa_elems); - for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) { const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic); @@ -8891,13 +8900,19 @@ static void ggml_compute_forward_flash_attn_ext_tiled( K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk]; } } else if (mxfp.k_dequantize) { - const char * k_soa_base = mxfp.k_multihead - ? ((const char *)k->data + (ic + tk)*nbk1 + ik3*nbk3) - : k_data; - mxfp.k_dequantize(k_soa_base, k_soa_buf.data(), mxfp.k_soa_elems); - const float * k_head = k_soa_buf.data() + (mxfp.k_multihead ? ik2 * DK : 0); + if (mxfp.k_multihead) { + // Per-head extraction: copy only this head's SoA blocks + e8m0, dequant DK elements. + const char * row = (const char *)k->data + (ic + tk)*nbk1 + ik3*nbk3; + const int kqs = ik2 * mxfp.k_head_qs_bytes; + const int ke8 = (int)mxfp.k_head_e8m0_offset + ik2 * mxfp.k_blocks_per_head; + memcpy(k_head_soa, row + kqs, mxfp.k_head_qs_bytes); + memcpy(k_head_soa + mxfp.k_head_qs_bytes, row + ke8, mxfp.k_blocks_per_head); + mxfp.k_dequantize(k_head_soa, k_dequant_buf, DK); + } else { + mxfp.k_dequantize(k_data, k_dequant_buf, DK); + } for (int64_t dk = 0; dk < DK; dk++) { - K_f32[dk * KV_TILE_SZ + tk] = k_head[dk]; + K_f32[dk * KV_TILE_SZ + tk] = k_dequant_buf[dk]; } } else { float k_tmp[1024]; @@ -8965,11 +8980,18 @@ static void ggml_compute_forward_flash_attn_ext_tiled( } else if (v_type == GGML_TYPE_F32) { memcpy(V32 + tk * DV, v_data, DV * sizeof(float)); } else if (mxfp.v_dequantize) { - const char * v_soa_base = mxfp.v_multihead - ? ((const char *)v->data + (ic + tk)*nbv1 + iv3*nbv3) - : v_data; - mxfp.v_dequantize(v_soa_base, v_soa_buf.data(), mxfp.v_soa_elems); - memcpy(V32 + tk * DV, v_soa_buf.data() + (mxfp.v_multihead ? iv2 * DV : 0), DV * sizeof(float)); + if (mxfp.v_multihead) { + // Per-head extraction: copy only this head's SoA blocks + e8m0, dequant DV elements. + const char * row = (const char *)v->data + (ic + tk)*nbv1 + iv3*nbv3; + const int vqs = iv2 * mxfp.v_head_qs_bytes; + const int ve8 = (int)mxfp.v_head_e8m0_offset + iv2 * mxfp.v_blocks_per_head; + memcpy(v_head_soa, row + vqs, mxfp.v_head_qs_bytes); + memcpy(v_head_soa + mxfp.v_head_qs_bytes, row + ve8, mxfp.v_blocks_per_head); + mxfp.v_dequantize(v_head_soa, v_dequant_buf, DV); + } else { + mxfp.v_dequantize(v_data, v_dequant_buf, DV); + } + memcpy(V32 + tk * DV, v_dequant_buf, DV * sizeof(float)); } else { v_to_float(v_data, V32 + tk * DV, DV); } diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index b386446035..c4d2ae86d3 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -57,11 +57,11 @@ GGML_API void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * // SoA (Struct-of-Arrays) quantize/dequantize — canonical reference for flash attention. // Layout: [qs contiguous][e8m0 contiguous] per row. Same total bytes as AoS. -GGML_API void quantize_row_mxfp4_soa (const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); -GGML_API void dequantize_row_mxfp4_soa (const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); -GGML_API void quantize_row_mxfp8_soa (const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); -GGML_API void dequantize_row_mxfp8_soa (const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); -GGML_API void quantize_row_mxfp6_soa (const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); +GGML_API void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); +GGML_API void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); +GGML_API void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); GGML_API void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6dd245aeab..d123211505 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -150,7 +150,8 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m } } -// SoA quantize/dequantize functions — declared here because ggml-quants.h is not in the test include path. +// MXFP SoA quantize/dequantize (from ggml-quants.h, which is internal to ggml +// and not in the test include path). Signatures must match ggml-quants.h exactly. typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t); extern "C" { void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); From 23e88631c41b30665811909148e2bf877514203b Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Fri, 20 Mar 2026 23:38:43 -0400 Subject: [PATCH 08/13] fix: gate tiled GEMM and split-KV paths to preserve q8_0/q4_0 vec_dot semantics --- ggml/src/ggml-cpu/ops.cpp | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 6a81e2b0db..8ac3cd0912 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9171,8 +9171,12 @@ static void ggml_compute_forward_flash_attn_ext_f16( const bool use_ref = params->use_ref; // Split-KV: parallelize across KV chunks for single-query decode (token generation). - // Delegates to one_chunk which handles all supported types (F16, Q8_0, Q4_0, MXFP, etc). + // Only for types whose tiled/one_chunk paths produce identical results (f32, f16, MXFP). + // Standard quant types (q8_0, q4_0) must use the scalar path to preserve vec_dot semantics. + const bool kv_is_f32_f16_or_mxfp = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16 + || ggml_is_type_mxfp(k->type)); const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) + && kv_is_f32_f16_or_mxfp && q->type == GGML_TYPE_F32 && nek1 >= 512; if (use_split_kv_path) { @@ -9230,8 +9234,14 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t dr = (nr + nchunk - 1) / nchunk; static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q; + // Tiled GEMM path: dequant K/V to float, then simd_gemm. + // Only for types that natively dequant to float (f32, f16, MXFP). + // Standard quant types (q8_0, q4_0) must use the scalar one_chunk path + // to preserve vec_dot semantics and produce identical results to master. bool use_tiled = !use_ref && (q->type == GGML_TYPE_F32 && + kv_is_f32_f16_or_mxfp && + (k->type == v->type || ggml_is_type_mxfp(k->type)) && neq1 >= Q_TILE_SZ); #ifdef GGML_SIMD use_tiled &= (DV % GGML_F32_EPR == 0); From 5bb05ed21c63cee908f982299a9727dc28ab1b55 Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Sat, 21 Mar 2026 13:37:09 -0400 Subject: [PATCH 09/13] Comment consistency pass and cleanup. --- common/arg.cpp | 2 - ggml/src/ggml-common.h | 135 +++++++++++----------------- ggml/src/ggml-cpu/arch-fallback.h | 1 - ggml/src/ggml-cpu/arch/arm/quants.c | 24 ++--- ggml/src/ggml-cpu/arch/x86/quants.c | 21 ++--- ggml/src/ggml-cpu/ops.cpp | 68 +++++--------- ggml/src/ggml-cpu/quants.c | 7 +- ggml/src/ggml-impl.h | 28 ++---- ggml/src/ggml-quants.c | 19 +--- ggml/src/ggml-quants.h | 70 ++------------- src/llama-kv-cache.cpp | 5 +- tests/test-backend-ops.cpp | 27 ++---- 12 files changed, 113 insertions(+), 294 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 6f61ad07ff..5f221b7263 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -404,8 +404,6 @@ const std::vector kv_cache_types = { }; static ggml_type kv_cache_type_from_str(const std::string & s) { - // Short aliases: "mxfp4" → E2M1, "mxfp6" → E2M3, "mxfp8" → E4M3. - // Full names (mxfp4_e2m1, mxfp8_e4m3, mxfp6_e2m3, etc.) match via ggml_type_name() below. if (s == "mxfp4") { return GGML_TYPE_MXFP4_E2M1; } diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index cc9a4a0aca..7308c3749b 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -71,7 +71,6 @@ typedef sycl::half2 ggml_half2; #define GGML_COMMON_DECL #endif -// Pure numeric constants needed by both DECL and IMPL sections. #define MXFP_HADAMARD_32_NORM 0.17677669529663689f // 1/sqrt(32) #if defined(GGML_COMMON_DECL) @@ -199,11 +198,8 @@ typedef struct { } block_q4_1; static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding"); -// MXFP E8M0 shared exponent constants (OCP MX v1.0 §5.3). -// EMAX_OFFSET: ceil(log2(max_finite)) for each element type — used to center the E8M0 scale. -// MSE_RANGE: search radius around round(log2(amax)). Tests 2*range+1 candidate exponents, -// picking the one that minimizes total round-trip quantization error per block. -// Inspired by "Four Over Six" (arXiv:2512.02010); generalized to all MX types. +// E8M0 shared exponent constants (OCP MX v1.0 SS5.3). +// EMAX_OFFSET = ceil(log2(max_finite)), MSE_RANGE = search radius for optimal scale. #define MXFP_E8M0_MSE_RANGE 2 #define MXFP4_E2M1_EMAX_OFFSET 2 // ceil(log2(6.0)) #define MXFP6_E2M3_EMAX_OFFSET 3 // ceil(log2(7.5)) @@ -211,11 +207,7 @@ static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4 #define MXFP8_E4M3_EMAX_OFFSET 8 // ceil(log2(448)) #define MXFP8_E5M2_EMAX_OFFSET 16 // ceil(log2(57344)) -// MXFP type properties — single source of truth for all backends. -// Bits per element, quantized bytes per block, and Hadamard rotation flag. -// USE_HADAMARD: 1 for types with >= 3-bit mantissa (E2M1, E4M3, E2M3). -// 0 for 2-bit mantissa types (E5M2, E3M2) where Hadamard provides -// no quality benefit and hurts models with D_head ≤ 64. +// MXFP type properties -- shared across all backends. #define MXFP_BITS_PER_ELEM_E2M1 4 #define MXFP_BITS_PER_ELEM_E4M3 8 #define MXFP_BITS_PER_ELEM_E5M2 8 @@ -239,52 +231,48 @@ static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4 // EXP_MASK = (1< float. GGML_TABLE_BEGIN(float, kvalues_mxfp6_e2m3, 64) 0.0f, 0.125f, 0.25f, 0.375f, 0.5f, 0.625f, 0.75f, 0.875f, 1.0f, 1.125f, 1.25f, 1.375f, 1.5f, 1.625f, 1.75f, 1.875f, @@ -1316,8 +1296,7 @@ GGML_TABLE_BEGIN(float, kvalues_mxfp6_e2m3, 64) -4.0f, -4.5f, -5.0f, -5.5f, -6.0f, -6.5f, -7.0f, -7.5f, GGML_TABLE_END() -// FP6 E3M2 dequantization LUT: 6-bit value → float (64 entries). -// Generated from ggml_mxfp_fp6_e3m2_to_float(). No NaN/Inf — all bit patterns are valid. +// FP6 E3M2 dequantization LUT: 6-bit value -> float. No NaN/Inf. GGML_TABLE_BEGIN(float, kvalues_mxfp6_e3m2, 64) 0.0f, 0.0625f, 0.125f, 0.1875f, 0.25f, 0.3125f, 0.375f, 0.4375f, 0.5f, 0.625f, 0.75f, 0.875f, 1.0f, 1.25f, 1.5f, 1.75f, @@ -1329,8 +1308,7 @@ GGML_TABLE_BEGIN(float, kvalues_mxfp6_e3m2, 64) -8.0f, -10.0f, -12.0f, -14.0f, -16.0f, -20.0f, -24.0f, -28.0f, GGML_TABLE_END() -// FP8 E4M3 dequantization LUT: byte → float (256 entries). -// Generated from ggml_mxfp_fp8_e4m3_to_float(). Entry 127 = 448 (max finite), 255 = NaN. +// FP8 E4M3 dequantization LUT: byte -> float. Entry 127 = 448 (max finite), 255 = NaN. GGML_TABLE_BEGIN(float, kvalues_mxfp8_e4m3, 256) 0.0f, 0.001953125f, 0.00390625f, 0.005859375f, 0.0078125f, 0.009765625f, 0.01171875f, 0.013671875f, 0.015625f, 0.017578125f, 0.01953125f, 0.021484375f, 0.0234375f, 0.025390625f, 0.02734375f, 0.029296875f, @@ -1366,8 +1344,7 @@ GGML_TABLE_BEGIN(float, kvalues_mxfp8_e4m3, 256) -256.0f, -288.0f, -320.0f, -352.0f, -384.0f, -416.0f, -448.0f, NAN, GGML_TABLE_END() -// FP8 E5M2 dequantization LUT: byte → float (256 entries). -// Generated from ggml_mxfp_fp8_e5m2_to_float(). Entries 124-127 = {Inf, NaN, NaN, NaN}. +// FP8 E5M2 dequantization LUT: byte -> float. Entries 124-127 = {Inf, NaN, NaN, NaN}. GGML_TABLE_BEGIN(float, kvalues_mxfp8_e5m2, 256) 0.0f, 1.525879e-05f, 3.051758e-05f, 4.577637e-05f, 6.103516e-05f, 7.629395e-05f, 9.155273e-05f, 1.068115e-04f, 1.220703e-04f, 1.525879e-04f, 1.831055e-04f, 2.136230e-04f, 2.441406e-04f, 3.051758e-04f, 3.662109e-04f, 4.272461e-04f, @@ -1403,16 +1380,10 @@ GGML_TABLE_BEGIN(float, kvalues_mxfp8_e5m2, 256) -32768.0f, -40960.0f, -49152.0f, -57344.0f, -INFINITY, NAN, NAN, NAN, GGML_TABLE_END() -// ------------------------------------------------------------------------------------------------------------------ -// Canonical MXFP element converters — portable IEEE-754 bit manipulation. -// Single source of truth for CPU, CUDA, HIP, MUSA, SYCL. Metal/Vulkan keep MSL/GLSL copies. -// ------------------------------------------------------------------------------------------------------------------ +// MXFP element converters -- portable IEEE-754 bit manipulation. #if defined(GGML_MXFP_FUNC) -// --- FP4 E2M1: [S(1) | E(2) | M(1)] — max normal = 6.0 --- -// Canonical converters using true E2M1 values {0, 0.5, 1, 1.5, 2, 3, 4, 6}. -// The int8 kvalues_mxfp4 LUT stores doubled values {0,1,2,3,4,6,8,12} for -// CPU/CUDA nibble-indexed integer arithmetic — that doubling is an implementation detail. +// FP4 E2M1: [S(1) | E(2) | M(1)], max normal = 6.0 GGML_MXFP_FUNC float ggml_mxfp_fp4_e2m1_to_float(uint8_t v) { const float sign = (v & 0x8) ? -1.0f : 1.0f; @@ -1427,8 +1398,6 @@ GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp4_e2m1(float x) { if (x < 0) { sign = 0x8; x = -x; } if (x == 0) return sign; if (x >= 6.0f) return sign | 0x7; // max finite - // Decision boundaries (midpoints of adjacent canonical values): - // {0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0} if (x < 0.25f) return sign | 0x0; // 0 else if (x < 0.75f) return sign | 0x1; // 0.5 else if (x < 1.25f) return sign | 0x2; // 1.0 @@ -1439,7 +1408,7 @@ GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp4_e2m1(float x) { else return sign | 0x7; // 6.0 } -// --- FP6 E2M3: [S(1) | E(2) | M(3)] — max normal = 7.5 --- +// FP6 E2M3: [S(1) | E(2) | M(3)], max normal = 7.5 GGML_MXFP_FUNC float ggml_mxfp_fp6_e2m3_to_float(uint8_t v) { const float sign = (v & 0x20) ? -1.0f : 1.0f; @@ -1474,7 +1443,7 @@ GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp6_e2m3(float x) { return sign | (uint8_t)(((f32_exp + 1) << 3) | mant); } -// --- FP6 E3M2: [S(1) | E(3) | M(2)] — max normal = 28.0, no NaN/Inf --- +// FP6 E3M2: [S(1) | E(3) | M(2)], max normal = 28.0, no NaN/Inf GGML_MXFP_FUNC float ggml_mxfp_fp6_e3m2_to_float(uint8_t v) { const float sign = (v & 0x20) ? -1.0f : 1.0f; @@ -1512,7 +1481,7 @@ GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp6_e3m2(float x) { return sign | (uint8_t)((biased_exp << 2) | mant); } -// --- FP8 E4M3: [S(1) | E(4) | M(3)] — bias=7, max finite=448 --- +// FP8 E4M3: [S(1) | E(4) | M(3)], bias=7, max finite=448 GGML_MXFP_FUNC float ggml_mxfp_fp8_e4m3_to_float(uint8_t v) { uint32_t sign = ((uint32_t)(v & 0x80)) << 24; @@ -1571,7 +1540,7 @@ GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp8_e4m3(float x) { return sign | (uint8_t)((e4m3_exp << 3) | mant3); } -// --- FP8 E5M2: [S(1) | E(5) | M(2)] — bias=15, max finite=57344 --- +// FP8 E5M2: [S(1) | E(5) | M(2)], bias=15, max finite=57344 GGML_MXFP_FUNC float ggml_mxfp_fp8_e5m2_to_float(uint8_t v) { uint32_t sign = ((uint32_t)(v & 0x80)) << 24; @@ -1629,7 +1598,7 @@ GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp8_e5m2(float x) { return sign | (uint8_t)((e5m2_exp << 2) | mant2); } -// --- FP6 packing/unpacking --- +// FP6 packing/unpacking // Pack 4 six-bit values into 3 bytes GGML_MXFP_FUNC void ggml_mxfp_pack_fp6x4(const uint8_t v[4], uint8_t out[3]) { @@ -1674,8 +1643,6 @@ GGML_MXFP_FUNC int ggml_mxfp_e8m0_base_estimate(float amax, int emax_offset) { } // Block-32 Walsh-Hadamard Transform, normalized by 1/sqrt(32). -// Spreads outlier energy across all elements sharing an E8M0 exponent, -// improving quantization quality (see QuaRot arXiv:2404.00456). GGML_MXFP_FUNC void ggml_mxfp_hadamard_32_inplace(GGML_MXFP_THREAD float * vals) { GGML_MXFP_UNROLL for (int stride = 1; stride < 32; stride *= 2) { diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 3f01c0b1c7..f622658918 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -2,7 +2,6 @@ #pragma once // Rename `_generic` functions if no native implementation is available. -// This effectively selects the generic implementation. #if defined(GGML_CPU_GENERIC) // quants.c diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index 6564131f2b..9ad9f29ae5 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -4134,21 +4134,14 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -// ── MXFP FP8/FP6 NEON helpers ────────────────────────────────────────────── -// Shared IEEE-754 bit reconstruction and FP6 unpacking used by vec_dot, -// dequantize_row, and SoA dequant functions. -// -// NEON requires vshlq_n_u32 to have a compile-time literal constant, so we use -// two separate helpers for FP8 (sign at bit 7, shift 24) and FP6 (sign at bit 5, -// shift 26) rather than a single parameterized function. +// MXFP FP8/FP6 NEON helpers +// Separate FP8/FP6 functions because NEON vshlq_n_u32 requires compile-time constants. #if defined(__ARM_NEON) -// Use shared mxfp_dequant_traits_t from ggml-common.h. #define mxfp_neon_traits_t mxfp_dequant_traits_t -// Dequantize 4 raw FP8 values (uint32x4_t) → 4 IEEE-754 floats. -// Sign bit at position 7, sign shift = 24. +// Dequantize 4 FP8 values to floats. static inline float32x4_t mxfp8_dequant_neon( const uint32x4_t v_raw, const uint32x4_t v_exp_mask, const uint32x4_t v_mant_mask, @@ -4172,8 +4165,7 @@ static inline float32x4_t mxfp8_dequant_neon( return vbslq_f32(is_sub, sub_val, normal); } -// Dequantize 4 raw FP6 values (uint32x4_t) → 4 IEEE-754 floats. -// Sign bit at position 5, sign shift = 26. +// Dequantize 4 FP6 values to floats. static inline float32x4_t mxfp6_dequant_neon( const uint32x4_t v_raw, const uint32x4_t v_exp_mask, const uint32x4_t v_mant_mask, @@ -4229,7 +4221,7 @@ static inline void widen_s8x8_to_f32x4x2(const int8_t * src, *hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16))); } -// ── MXFP FP8/FP6 vec_dot ────────────────────────────────────────────────── +// MXFP FP8/FP6 vec_dot static void ggml_vec_dot_mxfp8_q8_0_neon( int n, float * GGML_RESTRICT s, @@ -4319,7 +4311,7 @@ static void ggml_vec_dot_mxfp6_q8_0_neon( *s = vaddvq_f32(vaddq_f32(acc0, acc1)); } -// ── MXFP FP8/FP6 dequantize_row (AoS) ───────────────────────────────────── +// MXFP FP8/FP6 dequantize_row (AoS) static void dequantize_row_mxfp8_neon( const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k, @@ -4381,7 +4373,7 @@ static void dequantize_row_mxfp6_neon( } } -// ── MXFP SoA dequant (flash attention) ───────────────────────────────────── +// MXFP SoA dequant (flash attention) static void dequantize_row_mxfp8_soa_neon( const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, @@ -4492,7 +4484,7 @@ static void dequantize_row_mxfp4_soa_neon( #endif // __ARM_NEON -// ── Public dispatch functions ────────────────────────────────────────────── +// Public dispatch functions void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index b00b1467d3..21b3fb4605 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -3819,18 +3819,13 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -// ── MXFP FP8/FP6 AVX2 helpers ────────────────────────────────────────────── -// Shared IEEE-754 bit reconstruction and FP6 unpacking used by vec_dot, -// dequantize_row, and SoA dequant functions. +// MXFP FP8/FP6 AVX2 helpers #if defined(__AVX2__) -// Use shared mxfp_dequant_traits_t from ggml-common.h. -// Aliases for readability within this file. #define mxfp_avx2_traits_t mxfp_dequant_traits_t -// Dequantize 8 raw MXFP values (widened to int32) → 8 IEEE-754 floats. -// Handles both normal and subnormal paths. Works for any FP6/FP8 format. +// Dequantize 8 FP8/FP6 values to floats. static inline __m256 mxfp_dequant_avx2( const __m256i v_raw, const __m256i v_exp_mask, const __m256i v_mant_mask, @@ -3872,9 +3867,9 @@ static inline __m256i unpack_fp6x8_avx2(const uint8_t * qs, int j) { return _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)unpacked)); } -// ── MXFP FP8/FP6 vec_dot ────────────────────────────────────────────────── +// MXFP FP8/FP6 vec_dot -// Unified FP8 × Q8_0 dot product (works for E4M3 and E5M2). +// FP8 x Q8_0 dot product (E4M3/E5M2). static void ggml_vec_dot_mxfp8_q8_0_avx2( int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, @@ -3915,7 +3910,7 @@ static void ggml_vec_dot_mxfp8_q8_0_avx2( *s = hsum_float_8(acc); } -// Unified FP6 × Q8_0 dot product (works for E2M3 and E3M2). +// FP6 x Q8_0 dot product (E2M3/E3M2). static void ggml_vec_dot_mxfp6_q8_0_avx2( int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, @@ -3955,7 +3950,7 @@ static void ggml_vec_dot_mxfp6_q8_0_avx2( *s = hsum_float_8(acc); } -// ── MXFP FP8/FP6 dequantize_row (AoS) ───────────────────────────────────── +// MXFP FP8/FP6 dequantize_row (AoS) static void dequantize_row_mxfp8_avx2( const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k, @@ -4016,7 +4011,7 @@ static void dequantize_row_mxfp6_avx2( } } -// ── MXFP SoA dequant (flash attention) ───────────────────────────────────── +// MXFP SoA dequant (flash attention) static void dequantize_row_mxfp8_soa_avx2( const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, @@ -4116,7 +4111,7 @@ static void dequantize_row_mxfp4_soa_avx2( #endif // __AVX2__ -// ── Public dispatch functions ────────────────────────────────────────────── +// Public dispatch functions void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 8ac3cd0912..a8f55efbed 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4909,8 +4909,7 @@ void ggml_compute_forward_get_rows( //} } -// NEON-optimized Hadamard for ARM platforms; scalar fallback uses ggml_hadamard_32_inplace -// from ggml-quants.c (the reference implementation). +// NEON-optimized Hadamard; scalar fallback below #if defined(__ARM_NEON) static void hadamard_32_inplace(float vals[32]) { float32x4_t v0 = vld1q_f32(vals + 0); @@ -4978,13 +4977,11 @@ static void hadamard_32_inplace(float vals[32]) { vst1q_f32(vals + 28, vmulq_f32(v7, norm)); } #else -// Scalar fallback: delegate to reference implementation in ggml-quants.c static void hadamard_32_inplace(float vals[32]) { ggml_hadamard_32_inplace(vals); } #endif -// Apply Hadamard rotation to each 32-element block in a float buffer. static void ggml_apply_hadamard_blocks(float * data, int64_t n) { GGML_ASSERT(n % 32 == 0); for (int64_t i = 0; i < n; i += 32) { @@ -5031,8 +5028,6 @@ static void ggml_compute_forward_set_rows_f32( const int32_t apply_hadamard = ((const int32_t *)dst->op_params)[0]; - // For MXFP types, use SoA quantize (canonical FA layout). - // For non-MXFP types, use the standard AoS from_float. typedef void (*quantize_soa_fn)(const float *, void *, int64_t); quantize_soa_fn mxfp_soa_quantize = nullptr; ggml_from_float_t from_float = nullptr; @@ -5046,8 +5041,6 @@ static void ggml_compute_forward_set_rows_f32( break; } - // Pre-allocate Hadamard temp buffer once outside the hot loop (nc is constant). - // nc == n_embd_k_gqa which is bounded by model architecture (typically <= 8192). std::vector had_tmp; if (apply_hadamard) { had_tmp.resize(nc); @@ -8275,8 +8268,7 @@ void ggml_compute_forward_top_k( typedef void (*mxfp_soa_quantize_fn)(const float *, void *, int64_t); typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t); -// Shared MXFP dispatch parameters for flash attention. -// Populated once and used by both the one_chunk and tiled paths. +// MXFP dispatch parameters for flash attention. struct mxfp_fa_params { mxfp_soa_quantize_fn q_quantize; mxfp_soa_dequantize_fn k_dequantize; @@ -8287,9 +8279,6 @@ struct mxfp_fa_params { int64_t v_soa_elems; bool apply_hadamard; // Per-head SoA addressing (avoids dequanting all heads in multihead mode). - // qs_per_block: bytes of quantized data per 32-element block. - // head_qs_bytes: total qs bytes for one head (blocks_per_head * qs_per_block). - // head_e8m0_offset: byte offset from row start to e8m0 region. int k_qs_per_block; int v_qs_per_block; int k_head_qs_bytes; @@ -8455,9 +8444,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( ggml_vec_dot_t kq_vec_dot = nullptr; ggml_to_float_t v_to_float = nullptr; - if (is_mxfp_k) { - kq_vec_dot = nullptr; - } else { + if (!is_mxfp_k) { ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type; q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float; kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot; @@ -8472,20 +8459,15 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( int ith = params->ith; - // Dequant buffers for MXFP SoA — stack-allocated, no heap allocation in the hot path. - // In multihead mode, only dequant one head (DK or DV elements) instead of all heads. - // DK/DV are bounded by 1024 (asserted below for MXFP). + if (is_mxfp_k) { GGML_ASSERT(DK <= 1024); } + if (is_mxfp_v) { GGML_ASSERT(DV <= 1024); } + float k_dequant_buf[1024]; float v_dequant_buf[1024]; - // Per-head SoA temp buffer: holds [qs | e8m0] for one head in multihead mode. - // For DK=1024 with MXFP8: 32 blocks * 32 qs + 32 e8m0 = 1056 bytes. - // Stack-allocated since sizes are bounded by DK/DV <= 1024. - // Max: 1024/32 * 32(qs) + 1024/32 = 1056 bytes (MXFP8). - char k_head_soa[1088]; // 1056 rounded up for alignment + char k_head_soa[1088]; // max: DK=1024 MXFP8 -> 1056 bytes, rounded up char v_head_soa[1088]; - // Thread-local work buffers (constant across ir loop) float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); float * V32 = (VKQ32 + 1*DV); ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); @@ -8521,31 +8503,24 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( const int iv3 = iq3 / rv3; const int iv2 = iq2 / rv2; - // Precompute loop-invariant base pointer offsets for K and V. - // Only ic varies in the inner loop; head/batch offsets are constant. const size_t k_base_offset = ik2*nbk2 + ik3*nbk3; const size_t v_base_offset = iv2*nbv2 + iv3*nbv3; const char * k_base = (const char *) k->data + k_base_offset; const char * v_base = (const char *) v->data + v_base_offset; - // For multihead MXFP: precompute per-head SoA byte offsets (constant per query row). - // head_qs_start: byte offset to this head's qs blocks within the SoA row. - // head_e8m0_start: byte offset to this head's e8m0 scales within the SoA row. + // Per-head SoA byte offsets const int k_head_qs_start = mxfp.k_multihead ? ik2 * mxfp.k_head_qs_bytes : 0; const int k_head_e8m0_start = mxfp.k_multihead ? (int)mxfp.k_head_e8m0_offset + ik2 * mxfp.k_blocks_per_head : 0; const int v_head_qs_start = mxfp.v_multihead ? iv2 * mxfp.v_head_qs_bytes : 0; const int v_head_e8m0_start = mxfp.v_multihead ? (int)mxfp.v_head_e8m0_offset + iv2 * mxfp.v_blocks_per_head : 0; - // Multihead MXFP row base (without head offset) — only ic varies. const char * k_row_base = mxfp.k_multihead ? ((const char *) k->data + ik3*nbk3) : nullptr; const char * v_row_base = mxfp.v_multihead ? ((const char *) v->data + iv3*nbv3) : nullptr; const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); float Q_f32[1024]; if (is_mxfp_k) { - // Q preprocessing: Hadamard → SoA quantize → SoA dequant (round-trip). - // Captures the same quantization loss as K, matching GPU MMA semantics. - GGML_ASSERT(DK <= 1024); + // Q preprocessing: Hadamard + SoA round-trip captures same quantization loss as K. if (mxfp.apply_hadamard) { float q_tmp[1024]; memcpy(q_tmp, pq, DK * sizeof(float)); @@ -8557,7 +8532,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( mxfp.k_dequantize(Q_q, Q_f32, DK); } else { if (mxfp.apply_hadamard) { - GGML_ASSERT(DK <= 1024); float q_tmp[1024]; memcpy(q_tmp, pq, DK * sizeof(float)); ggml_apply_hadamard_blocks(q_tmp, DK); @@ -8581,8 +8555,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( if (is_mxfp_k) { if (mxfp.k_multihead) { - // Multihead: extract this head's SoA blocks into temp buffer, dequant only DK elements. - // Copy qs blocks then e8m0 scales for this head into contiguous [qs|e8m0] layout. + // Extract this head's SoA blocks const char * row = k_row_base + ic*nbk1; memcpy(k_head_soa, row + k_head_qs_start, mxfp.k_head_qs_bytes); memcpy(k_head_soa + mxfp.k_head_qs_bytes, row + k_head_e8m0_start, mxfp.k_blocks_per_head); @@ -8610,10 +8583,12 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( if (v_is_f16) { if (s > M) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f M = s; ms = expf(Mold - M); ggml_vec_scale_f16(DV, VKQ16, ms); } else { + // no new maximum, ms == 1.0f, vs != 1.0f vs = expf(s - M); } @@ -8621,10 +8596,12 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) (v_base + ic*nbv1), vs); } else { if (s > M) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f M = s; ms = expf(Mold - M); ggml_vec_scale_f32(DV, VKQ32, ms); } else { + // no new maximum, ms == 1.0f, vs != 1.0f vs = expf(s - M); } @@ -8643,6 +8620,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( v_to_float(v_base + ic*nbv1, V32, DV); ggml_vec_mad_f32(DV, VKQ32, V32, vs); } else { + // V is F32 ggml_vec_mad_f32(DV, VKQ32, (const float *) (v_base + ic*nbv1), vs); } } @@ -8782,12 +8760,12 @@ static void ggml_compute_forward_flash_attn_ext_tiled( static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q; static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV; - // MXFP dequant scratch buffers — allocated once per thread, reused across all tiles. - // DK/DV bounded by 1024, so per-head dequant fits in stack buffers. + if (is_mxfp_k) { GGML_ASSERT(DK <= 1024); } + if (is_mxfp_v) { GGML_ASSERT(DV <= 1024); } + float k_dequant_buf[1024]; float v_dequant_buf[1024]; - // Per-head SoA temp buffers for multihead extraction (same as one_chunk path). char k_head_soa[1088]; char v_head_soa[1088]; @@ -8853,10 +8831,7 @@ static void ggml_compute_forward_flash_attn_ext_tiled( if (mxfp.apply_hadamard) { ggml_apply_hadamard_blocks(Q_f32 + tq * DK, DK); } - // SoA round-trip: quantize Q to SoA, then dequant back to float - const size_t q_soa_bytes = ggml_row_size(k->type, DK); - GGML_ASSERT(q_soa_bytes <= 2048); - uint8_t q_mxfp_buf[2048]; // max: DK=1024 * 33/32 = 1056 bytes (MXFP8) + uint8_t q_mxfp_buf[1088]; // max: DK=1024 MXFP8 -> 1056 bytes mxfp.q_quantize(Q_f32 + tq * DK, q_mxfp_buf, DK); mxfp.k_dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK); } @@ -9234,10 +9209,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t dr = (nr + nchunk - 1) / nchunk; static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q; - // Tiled GEMM path: dequant K/V to float, then simd_gemm. - // Only for types that natively dequant to float (f32, f16, MXFP). - // Standard quant types (q8_0, q4_0) must use the scalar one_chunk path - // to preserve vec_dot semantics and produce identical results to master. + // Tiled path: f32, f16, and MXFP only (quant types use one_chunk) bool use_tiled = !use_ref && (q->type == GGML_TYPE_F32 && kv_is_f32_f16_or_mxfp && diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index 7303638c81..477bd07304 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -264,9 +264,7 @@ void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, *s = sumf; } -// Generic MXFP-to-Q8_0 dot product. Dequants one MX block (32 elements) -// to float via the existing public dequantize_row functions, then dots -// against Q8_0 int8 values. Reference implementation — not SIMD-optimized. +// Generic MXFP x Q8_0 dot product (scalar, not SIMD-optimized) static void ggml_vec_dot_mxfp_q8_0_impl( int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, size_t block_size, @@ -311,8 +309,7 @@ void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, (ggml_to_float_t)dequantize_row_mxfp6); } -// Generic (scalar) dequant wrappers — delegates to ggml-quants.c reference implementations. -// On x86/ARM, arch-specific SIMD versions override these via the fallback.h mapping. +// Generic dequant wrappers — arch-specific SIMD versions override via fallback.h. void dequantize_row_mxfp8_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { dequantize_row_mxfp8(x, y, k); } diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 90f020b0e8..f9358b0432 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -431,15 +431,11 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) // E8M0 shared exponent to float: returns 2^(x - 127). -// Canonical source: ggml_mxfp_e8m0_to_fp32() in ggml-common.h. -// Kept here because ggml-impl.h cannot depend on ggml-common.h IMPL section. -// NaN (x == 0xFF) is not handled — callers guarantee valid exponents. static inline float ggml_e8m0_to_fp32(uint8_t x) { uint32_t bits; if (x == 0) { bits = 0x00400000; // denorm: 0.5 * 2^(-126) = 2^(-127) } else { - // normalized: exponent x placed in bits [30:23], mantissa = 0 → 2^(x-127) bits = (uint32_t) x << 23; } float result; @@ -447,9 +443,7 @@ static inline float ggml_e8m0_to_fp32(uint8_t x) { return result; } -// E8M0 to float/2: returns 2^(x - 128). Equal to ggml_e8m0_to_fp32(x) / 2. -// Useful with MXFP4 because the E2M1 kvalues table stores 2 * float_value. -// Canonical source: ggml_mxfp_e8m0_to_fp32_half() in ggml-common.h. +// E8M0 to float/2: returns 2^(x - 128). static inline float ggml_e8m0_to_fp32_half(uint8_t x) { uint32_t bits; if (x < 2) { @@ -467,8 +461,7 @@ static inline float ggml_e8m0_to_fp32_half(uint8_t x) { #define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x) #define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x) -// UE4M3 (unsigned E4M3): 4 exponent bits (bias 7), 3 mantissa bits, no sign. -// Range: [0, 448], with 0x7F = NaN treated as zero. +// UE4M3 (unsigned E4M3): 4 exponent bits (bias 7), 3 mantissa bits. // Returns value * 0.5 to match kvalues_mxfp4 convention (kvalues = 2 * E2M1_float). static inline float ggml_ue4m3_to_fp32(uint8_t x) { if (x == 0 || x == 0x7F) { @@ -487,20 +480,19 @@ static inline float ggml_ue4m3_to_fp32(uint8_t x) { return raw * 0.5f; } -// Float32 to UE4M3 with round-to-nearest-even. -// Clamps to [0, 448]. Max representable: 0x7E = (1 + 7/8) * 2^8 = 448. +// Float32 to UE4M3 with round-to-nearest. static inline uint8_t ggml_fp32_to_ue4m3(float x) { if (!(x > 0.0f)) { - return 0; // negative, zero, NaN → 0 + return 0; } if (x > 448.0f) { - x = 448.0f; // clamp to max representable + x = 448.0f; } uint32_t bits; memcpy(&bits, &x, 4); int fp32_exp = ((bits >> 23) & 0xFF) - 127; - int fp32_man = (bits >> 20) & 0x7; // top 3 mantissa bits - int ue4m3_exp = fp32_exp + 7; // rebias: FP32 bias 127 → UE4M3 bias 7 + int fp32_man = (bits >> 20) & 0x7; + int ue4m3_exp = fp32_exp + 7; if (ue4m3_exp <= 0) { // subnormal: value = man * 2^(-9), so man = round(x * 512) int man = (int) (x * 512.0f + 0.5f); @@ -513,17 +505,15 @@ static inline uint8_t ggml_fp32_to_ue4m3(float x) { return (uint8_t) man; } if (ue4m3_exp >= 15) { - return 0x7E; // max normal + return 0x7E; } - // round-to-nearest using bit 19 (first bit below UE4M3 mantissa) int round_bit = (bits >> 19) & 1; int ue4m3_man = fp32_man + round_bit; if (ue4m3_man > 7) { - // mantissa overflow → carry into exponent ue4m3_man = 0; ue4m3_exp++; if (ue4m3_exp >= 15) { - return 0x7E; // max normal + return 0x7E; } } return (uint8_t) ((ue4m3_exp << 3) | ue4m3_man); diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index e88435061d..1a99711401 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -259,16 +259,12 @@ void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_REST // ====================== MXFP element conversions (wrappers around ggml-common.h) -// FP8 E4M3: 1 sign, 4 exp (bias 7), 3 mantissa. Max finite: 448, NaN: 0x7F (saturated to max). float fp8_e4m3_to_float(uint8_t v) { return ggml_mxfp_fp8_e4m3_to_float(v); } uint8_t float_to_fp8_e4m3_rn(float x) { return ggml_mxfp_float_to_fp8_e4m3(x); } // ====================== MXFP quantization infrastructure // -// MSE-optimal E8M0 shared exponent: tests ±R candidates around round(log2(amax)) -// and picks whichever minimizes total round-trip quantization error per block. -// Improves on OCP MX v1.0 §5.3 floor(log2(amax)) by 0.05-0.2 PPL. -// Ref: OCP MX v1.0 spec, Four Over Six (arXiv:2512.02010) +// MSE-optimal E8M0: tests candidates around round(log2(amax)), picks lowest quantization error. typedef struct { int emax_offset; // type-specific offset to max representable exponent @@ -279,8 +275,7 @@ typedef struct { static inline int best_index_mxfp4(float x, float e); -// MXFP4 MSE error using decision boundary quantization with half-scale -// (kvalues_mxfp4 are doubled E2M1 values, so scale is halved to compensate) +// MSE error for MXFP4 (kvalues are doubled, so scale is halved) static float mse_error_mxfp4(float val, float inv_scale, float scale) { const float d = scale * 0.5f; const float inv_d = (d > 0.0f) ? 1.0f / d : 0.0f; @@ -301,7 +296,6 @@ static float mse_error_mxfp4(float val, float inv_scale, float scale) { static const mxfp_elem_traits_t mxfp4_traits = { MXFP4_E2M1_EMAX_OFFSET, NULL, NULL, mse_error_mxfp4 }; -// Find MSE-optimal E8M0 exponent by testing ±R candidates around round(log2(amax)) static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_elem_traits_t * traits) { float amax = 0.0f; for (int j = 0; j < qk; j++) { @@ -336,7 +330,6 @@ static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_ return (uint8_t)best_e; } -// Decision boundary quantization for kvalues_mxfp4 {0,1,2,3,4,6,8,12} static inline int best_index_mxfp4(float x, float e) { const float inv_e = (e > 0.0f) ? 1.0f / e : 0.0f; const float normalized = fabsf(x) * inv_e; @@ -566,11 +559,6 @@ void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_REST } // ====================== Hadamard rotation -// -// 32-element Walsh-Hadamard transform applied before MX quantization to spread -// outlier energy across the shared-exponent group. Orthogonal (H^T·H = I), so -// H(K)·H(Q) = K·Q — attention scores are preserved when both K and Q are rotated. -// Prior art: QuIP# (Tseng et al. 2024), BRQ (Huang et al. 2024) void ggml_hadamard_32_inplace(float vals[32]) { ggml_mxfp_hadamard_32_inplace(vals); @@ -586,7 +574,6 @@ uint8_t float_to_fp8_e5m2_rn(float x) { return ggml_mxfp_float_to_fp8_e5m2(x); } void pack_fp6x4(const uint8_t v[4], uint8_t out[3]) { ggml_mxfp_pack_fp6x4(v, out); } void unpack_fp6x4(const uint8_t in[3], uint8_t v[4]) { ggml_mxfp_unpack_fp6x4(in, v); } -// round-trip quantization error for MSE-optimal exponent search static float mse_error_fp8_e4m3(float val, float inv_scale, float scale) { const float recon = fp8_e4m3_to_float(float_to_fp8_e4m3_rn(val * inv_scale)) * scale; const float err = val - recon; @@ -685,8 +672,6 @@ void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_REST } // ====================== SoA (Struct-of-Arrays) quantize/dequantize for flash attention -// -// Layout per row: [qs_block0|qs_block1|...|qs_blockN][e8m0_0|e8m0_1|...|e8m0_N] void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) { assert(k % QK_MXFP4 == 0); diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index c4d2ae86d3..7d848edffe 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -55,8 +55,7 @@ GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_API void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -// SoA (Struct-of-Arrays) quantize/dequantize — canonical reference for flash attention. -// Layout: [qs contiguous][e8m0 contiguous] per row. Same total bytes as AoS. +// SoA quantize/dequantize for flash attention GGML_API void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); GGML_API void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); @@ -114,83 +113,26 @@ GGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_REST GGML_API size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -// -// MXFP element-level conversion functions (reference implementations) -// -// These implement the OCP Microscaling (MX) format element types as defined in: -// OCP Microscaling Formats (MX) Specification v1.0, Sep 2023 -// https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf -// -// Each MX block contains 32 elements sharing a single E8M0 exponent. The element -// types define the per-element mantissa format within each block. -// -// All converters use IEEE-754 bit manipulation for exact results (no floating-point -// rounding in the conversion itself). Quantization functions use round-to-nearest-even -// (RNE) per the MX specification. -// -// GPU backends (CUDA, Metal, Vulkan) provide their own optimized versions — these -// functions serve as the canonical reference and are used by the CPU backend. -// - -// FP8 E4M3: 1 sign, 4 exponent (bias 7), 3 mantissa bits -// Range: ±[2^-9, 448], NaN: exp=15 mant=7 (only NaN encoding) -// Ref: OCP MX v1.0 §4.2 +// MXFP element converters GGML_API float fp8_e4m3_to_float(uint8_t v); GGML_API uint8_t float_to_fp8_e4m3_rn(float x); -// FP8 E5M2: 1 sign, 5 exponent (bias 15), 2 mantissa bits -// Range: ±[2^-16, 57344], NaN/Inf: exp=31 (standard IEEE-like) -// Ref: OCP MX v1.0 §4.2 GGML_API float fp8_e5m2_to_float(uint8_t v); GGML_API uint8_t float_to_fp8_e5m2_rn(float x); -// FP6 E2M3: 1 sign, 2 exponent (bias 1), 3 mantissa bits -// Range: ±[2^-3, 7.5], stored as low 6 bits of a byte (00xxxxxx) -// MX format: NO NaN/Inf — all bit patterns are valid numbers -// Ref: OCP MX v1.0 §4.2 +// no NaN/Inf in FP6 — all bit patterns are valid numbers GGML_API float fp6_e2m3_to_float(uint8_t v); GGML_API uint8_t float_to_fp6_e2m3_rn(float x); -// FP6 E3M2: 1 sign, 3 exponent (bias 3), 2 mantissa bits -// Range: ±[2^-4, 28.0], stored as low 6 bits of a byte (00xxxxxx) -// MX format: NO NaN/Inf — exp=7 is a valid normal value (unlike IEEE-754) -// CRITICAL: subnormal scale is 2^(1-bias-m) = 2^(-4) = 1/16, NOT 1/4 -// Ref: OCP MX v1.0 §4.2 +// no NaN/Inf — exp=7 is a valid normal value (unlike IEEE-754) GGML_API float fp6_e3m2_to_float(uint8_t v); GGML_API uint8_t float_to_fp6_e3m2_rn(float x); -// FP6 tight packing: pack/unpack 4 six-bit values into/from 3 bytes -// Layout: v[0]=bits[5:0], v[1]=bits[11:6], v[2]=bits[17:12], v[3]=bits[23:18] -// Saves 25% memory vs byte-padded storage (24B vs 32B per MX block) +// Pack/unpack 4 six-bit values into 3 bytes GGML_API void pack_fp6x4(const uint8_t v[4], uint8_t out[3]); GGML_API void unpack_fp6x4(const uint8_t in[3], uint8_t v[4]); -// -// Hadamard rotation (reference scalar implementation) -// -// 32-element Walsh-Hadamard transform applied to MX blocks before quantization. -// Distributes outlier energy uniformly across the block, dramatically improving -// quantization quality for types with shared exponents. -// -// Mathematical property: H^T·H = I (orthogonal), so H(K)·H(Q) = K·Q. -// Flash attention applies matching rotation to Q, preserving attention scores exactly. -// -// Implementation: 5 butterfly stages (log2(32) = 5) + normalization by 1/sqrt(32). -// This is the standard "fast Walsh-Hadamard transform" with O(n log n) operations. -// -// Applied in set_rows (K cache quantization) and flash_attn (Q quantization). -// Skipped for MLA models (DK != DV) where V is a view of K — rotation would corrupt V. -// -// Empirical impact (PPL degradation WITHOUT rotation, Qwen3-Coder-30B): -// MXFP8 E4M3: +0.22, MXFP8 E5M2: +1.38, MXFP6 E2M3: +3.34, MXFP6 E3M2: +4.60 -// -// Prior art: QuIP# (Tseng et al. 2024), BRQ (Huang et al. 2024) use Hadamard for -// weight quantization. Our contribution applies it to KV cache quantization at the -// MX block boundary, where block-32 is optimal because it matches the shared exponent -// group size exactly. -// -// GPU backends provide optimized versions (CUDA warp shuffles, Metal SIMD groups). -// +// Block-32 Walsh-Hadamard transform, normalized by 1/sqrt(32) GGML_API void ggml_hadamard_32_inplace(float vals[32]); GGML_API void iq2xs_init_impl(enum ggml_type type); diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 0e501697b0..a890510edf 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -51,7 +51,6 @@ llama_kv_cache::llama_kv_cache( auto it = ctx_map.find(buft); if (it == ctx_map.end()) { ggml_init_params params = { - // 2 base tensors (K+V) + 2*n_stream view tensors /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, @@ -140,10 +139,10 @@ llama_kv_cache::llama_kv_cache( uint32_t n_embd_k_alloc = n_embd_k_gqa; const bool is_mxfp_k = ggml_is_type_mxfp(type_k); if (is_mxfp_k) { - const int qk = (int)ggml_blck_size(type_k); // 32 for all MXFP types + const int qk = (int)ggml_blck_size(type_k); GGML_ASSERT(n_embd_k_gqa % qk == 0 && "MXFP K cache requires n_embd_k_gqa divisible by block size"); const int blocks = (int)n_embd_k_gqa / qk; - const int blocks_aligned = (blocks + 15) & ~15; // align to 16 + const int blocks_aligned = (blocks + 15) & ~15; n_embd_k_alloc = (uint32_t)(blocks_aligned * qk); } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d123211505..0e48e9e354 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -150,8 +150,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m } } -// MXFP SoA quantize/dequantize (from ggml-quants.h, which is internal to ggml -// and not in the test include path). Signatures must match ggml-quants.h exactly. +// MXFP SoA functions (internal to ggml, not in test include path) typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t); extern "C" { void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); @@ -162,10 +161,7 @@ extern "C" { void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); } -// Initialize an MXFP tensor with SoA (Struct-of-Arrays) layout. -// soa_bytes: byte width of one SoA region. Default 0 = ne[0] elements (one ggml row). -// For FA K/V tensors, pass nb[1] so that when heads are physically contiguous -// within one KV-position stride, the SoA region spans all heads (matching FA's read pattern). +// Initialize an MXFP tensor with SoA layout (soa_bytes = region width, 0 = one row). static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f, size_t soa_bytes = 0) { GGML_ASSERT(ggml_is_type_mxfp(tensor->type)); @@ -189,9 +185,6 @@ static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float std::uniform_real_distribution dist(min, max); std::vector region_f32(soa_elems); - // Iterate over logical SoA regions using tensor strides. - // Each SoA region is soa_bytes wide at the innermost stride level. - // Outer dimensions (those with stride > soa_bytes) are iterated explicitly. const size_t nb1 = tensor->nb[1]; const size_t nb2 = tensor->nb[2]; const size_t nb3 = tensor->nb[3]; @@ -199,18 +192,13 @@ static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float const int64_t ne2 = tensor->ne[2]; const int64_t ne3 = tensor->ne[3]; - // Determine iteration: if soa_bytes == nb1, iterate over (ne1 * ne2 * ne3) regions. - // If soa_bytes < nb1 (per-head), iterate over (ne1 * ne2 * ne3) regions with stride nb1. - // We use strides to compute offsets, handling views and permutations correctly. const int64_t heads_per_region = (int64_t)(soa_bytes / head_row_sz); GGML_ASSERT(soa_bytes % head_row_sz == 0 && "soa_bytes must be a multiple of head_row_sz"); - // For multi-head regions, we step by nb1 (KV-position stride) between regions. - // For per-head, we step through all dimensions. std::vector buf(ggml_nbytes(tensor), 0); if (heads_per_region > 1) { - // Multi-head SoA: iterate over (kv_positions * batches), each region = nb1 bytes + // Multi-head SoA: for (int64_t i3 = 0; i3 < ne3; i3++) { const int64_t n_groups = ne2 / heads_per_region; for (int64_t ig = 0; ig < n_groups; ig++) { @@ -222,7 +210,7 @@ static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float } } } else { - // Per-head SoA: one SoA region per ggml row + // Per-head SoA: for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i2 = 0; i2 < ne2; i2++) { for (int64_t i1 = 0; i1 < ne1; i1++) { @@ -328,7 +316,6 @@ static std::vector tensor_to_float(const ggml_tensor * t) { bool quantized = ggml_is_quantized(t->type); const bool is_mxfp = ggml_is_type_mxfp(t->type); - // SoA dequant for MXFP readback mxfp_soa_dequantize_fn mxfp_dequant_soa = nullptr; if (is_mxfp) { switch (t->type) { @@ -4051,7 +4038,6 @@ struct test_mul_mat_id : public test_case { return 5e-4; } - // Same Blackwell FP4 tolerance as test_mul_mat above. double max_nmse_err(ggml_backend_t backend) override { // for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance if (type_a == GGML_TYPE_MXFP4_E2M1 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) { @@ -6401,9 +6387,7 @@ struct test_flash_attn_ext : public test_case { } else if (strcmp(t->name, "m") == 0) { init_tensor_kq_mask(t); } else if ((strcmp(t->name, "k") == 0 || strcmp(t->name, "v") == 0) && ggml_is_type_mxfp(t->type)) { - // MXFP K/V tensors use SoA layout. Pass nb[1] (KV-position stride) as the - // SoA region width — when heads are physically contiguous within that stride, - // the FA kernel dequants the full multi-head region as one SoA block. + // MXFP K/V use SoA layout; nb[1] spans all heads in one KV-position stride init_tensor_mxfp_soa(t, -1.0f, 1.0f, t->nb[1]); } else { init_tensor_uniform(t); @@ -8778,7 +8762,6 @@ static std::vector> make_test_cases_eval() { for (ggml_type type_V : {GGML_TYPE_MXFP4_E2M1}) { if (type_K == type_V) continue; for (int nb : {1, 3, 32}) { - // hsk hsv nh nr23 kv nb mask sinks bias softcap prec type_K permute type_V test_cases.emplace_back(new test_flash_attn_ext( 128, 128, 4, {1, 1}, 512, nb, true, false, 0.0f, 0.0f, GGML_PREC_F32, type_K, {0, 1, 2, 3}, type_V)); } From dd263ff567b8fc0df77d93341d17be13be7f87aa Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Sat, 21 Mar 2026 15:09:49 -0400 Subject: [PATCH 10/13] mxfp traits : ensure mxfp soa quant and dequant functions are tested --- ggml/include/ggml-cpu.h | 6 ++++-- ggml/include/ggml.h | 1 + ggml/src/ggml-cpu/ggml-cpu.c | 6 ++++++ ggml/src/ggml-cpu/ops.cpp | 31 +++++++--------------------- ggml/src/ggml-cpu/quants.h | 5 ++++- ggml/src/ggml-quants.c | 8 ++++++++ tests/test-backend-ops.cpp | 29 ++++++--------------------- tests/test-quantize-fns.cpp | 39 +++++++++++++++++++++++++++++++++++- 8 files changed, 74 insertions(+), 51 deletions(-) diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h index 2e13dd58ba..19c06a033d 100644 --- a/ggml/include/ggml-cpu.h +++ b/ggml/include/ggml-cpu.h @@ -115,10 +115,12 @@ extern "C" { struct ggml_type_traits_cpu { ggml_from_float_t from_float; - ggml_to_float_t to_float; // SIMD-optimized dequant (NULL = use global to_float) + ggml_to_float_t to_float; + ggml_from_float_t from_float_soa; // SoA quantize (MXFP flash attention layout) + ggml_to_float_t to_float_soa; // SoA dequant (MXFP flash attention layout) ggml_vec_dot_t vec_dot; enum ggml_type vec_dot_type; - int64_t nrows; // number of rows to process simultaneously + int64_t nrows; }; GGML_BACKEND_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 81b552ec78..6edf9909cf 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -427,6 +427,7 @@ extern "C" { // GGML_TYPE_IQ4_NL_4_8 = 37, // GGML_TYPE_IQ4_NL_8_8 = 38, GGML_TYPE_MXFP4_E2M1 = 39, // MX FP4 E2M1 + GGML_TYPE_MXFP4 = GGML_TYPE_MXFP4_E2M1, // compat alias GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale) GGML_TYPE_MXFP8_E4M3 = 41, // MX FP8 E4M3 GGML_TYPE_MXFP6_E2M3 = 42, // MX FP6 E2M3 diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 7b7fb1e5ce..9b8618423c 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -266,6 +266,8 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { }, [GGML_TYPE_MXFP4_E2M1] = { .from_float = quantize_row_mxfp4, + .from_float_soa = quantize_row_mxfp4_soa, + .to_float_soa = dequantize_row_mxfp4_soa_cpu, .vec_dot = ggml_vec_dot_mxfp4_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, @@ -279,6 +281,8 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { [GGML_TYPE_MXFP8_E4M3] = { .from_float = quantize_row_mxfp8, .to_float = dequantize_row_mxfp8_cpu, + .from_float_soa = quantize_row_mxfp8_soa, + .to_float_soa = dequantize_row_mxfp8_soa_cpu, .vec_dot = ggml_vec_dot_mxfp8_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, @@ -286,6 +290,8 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { [GGML_TYPE_MXFP6_E2M3] = { .from_float = quantize_row_mxfp6, .to_float = dequantize_row_mxfp6_cpu, + .from_float_soa = quantize_row_mxfp6_soa, + .to_float_soa = dequantize_row_mxfp6_soa_cpu, .vec_dot = ggml_vec_dot_mxfp6_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index a8f55efbed..cb1f881391 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -5028,18 +5028,9 @@ static void ggml_compute_forward_set_rows_f32( const int32_t apply_hadamard = ((const int32_t *)dst->op_params)[0]; - typedef void (*quantize_soa_fn)(const float *, void *, int64_t); - quantize_soa_fn mxfp_soa_quantize = nullptr; - ggml_from_float_t from_float = nullptr; - - switch (dst->type) { - case GGML_TYPE_MXFP4_E2M1: mxfp_soa_quantize = quantize_row_mxfp4_soa; break; - case GGML_TYPE_MXFP8_E4M3: mxfp_soa_quantize = quantize_row_mxfp8_soa; break; - case GGML_TYPE_MXFP6_E2M3: mxfp_soa_quantize = quantize_row_mxfp6_soa; break; - default: - from_float = ggml_get_type_traits_cpu(dst->type)->from_float; - break; - } + const struct ggml_type_traits_cpu * dst_traits = ggml_get_type_traits_cpu(dst->type); + ggml_from_float_t mxfp_soa_quantize = dst_traits->from_float_soa; + ggml_from_float_t from_float = mxfp_soa_quantize ? nullptr : dst_traits->from_float; std::vector had_tmp; if (apply_hadamard) { @@ -8300,21 +8291,13 @@ static mxfp_fa_params mxfp_fa_params_init( const bool is_mxfp_v = ggml_is_type_mxfp(v->type); if (is_mxfp_k) { - switch (k->type) { - case GGML_TYPE_MXFP4_E2M1: p.q_quantize = quantize_row_mxfp4_soa; p.k_dequantize = dequantize_row_mxfp4_soa_cpu; break; - case GGML_TYPE_MXFP8_E4M3: p.q_quantize = quantize_row_mxfp8_soa; p.k_dequantize = dequantize_row_mxfp8_soa_cpu; break; - case GGML_TYPE_MXFP6_E2M3: p.q_quantize = quantize_row_mxfp6_soa; p.k_dequantize = dequantize_row_mxfp6_soa_cpu; break; - default: GGML_ABORT("unsupported MXFP K type"); - } + const struct ggml_type_traits_cpu * k_traits = ggml_get_type_traits_cpu(k->type); + p.q_quantize = k_traits->from_float_soa; + p.k_dequantize = k_traits->to_float_soa; } if (is_mxfp_v) { - switch (v->type) { - case GGML_TYPE_MXFP4_E2M1: p.v_dequantize = dequantize_row_mxfp4_soa_cpu; break; - case GGML_TYPE_MXFP8_E4M3: p.v_dequantize = dequantize_row_mxfp8_soa_cpu; break; - case GGML_TYPE_MXFP6_E2M3: p.v_dequantize = dequantize_row_mxfp6_soa_cpu; break; - default: GGML_ABORT("unsupported MXFP V type"); - } + p.v_dequantize = ggml_get_type_traits_cpu(v->type)->to_float_soa; } // Hadamard rotation must match K rotation. diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index 0a7ea64135..c16e87a2e9 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -90,7 +90,10 @@ void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, void dequantize_row_mxfp8_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_mxfp6_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -// SoA dequant (SIMD-optimized for FA) +// SoA quantize/dequant for MXFP flash attention +void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); +void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); +void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_mxfp6_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 1a99711401..cca3d99c82 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -5639,6 +5639,14 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb); } break; + case GGML_TYPE_MXFP8_E4M3: + { + VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp8, data, nb); + } break; + case GGML_TYPE_MXFP6_E2M3: + { + VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp6, data, nb); + } break; case GGML_TYPE_NVFP4: { // UE4M3 scales are uint8_t — all byte values are valid diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0e48e9e354..7f47079835 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -150,29 +151,15 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m } } -// MXFP SoA functions (internal to ggml, not in test include path) typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t); -extern "C" { - void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); - void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); - void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); - void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); - void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); - void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); -} // Initialize an MXFP tensor with SoA layout (soa_bytes = region width, 0 = one row). static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f, size_t soa_bytes = 0) { GGML_ASSERT(ggml_is_type_mxfp(tensor->type)); - typedef void (*soa_quantize_fn)(const float *, void *, int64_t); - soa_quantize_fn quantize_soa = nullptr; - switch (tensor->type) { - case GGML_TYPE_MXFP4_E2M1: quantize_soa = quantize_row_mxfp4_soa; break; - case GGML_TYPE_MXFP8_E4M3: quantize_soa = quantize_row_mxfp8_soa; break; - case GGML_TYPE_MXFP6_E2M3: quantize_soa = quantize_row_mxfp6_soa; break; - default: GGML_ABORT("unsupported MXFP type for SoA init"); - } + const auto * traits = ggml_get_type_traits_cpu(tensor->type); + GGML_ASSERT(traits->from_float_soa && "MXFP type missing SoA quantize in traits"); + auto quantize_soa = traits->from_float_soa; const int qk = (int)ggml_blck_size(tensor->type); const size_t block_size = ggml_type_size(tensor->type); @@ -318,12 +305,8 @@ static std::vector tensor_to_float(const ggml_tensor * t) { mxfp_soa_dequantize_fn mxfp_dequant_soa = nullptr; if (is_mxfp) { - switch (t->type) { - case GGML_TYPE_MXFP4_E2M1: mxfp_dequant_soa = dequantize_row_mxfp4_soa; break; - case GGML_TYPE_MXFP8_E4M3: mxfp_dequant_soa = dequantize_row_mxfp8_soa; break; - case GGML_TYPE_MXFP6_E2M3: mxfp_dequant_soa = dequantize_row_mxfp6_soa; break; - default: GGML_ABORT("unsupported MXFP type in tensor_to_float"); - } + mxfp_dequant_soa = (mxfp_soa_dequantize_fn) ggml_get_type_traits_cpu(t->type)->to_float_soa; + GGML_ASSERT(mxfp_dequant_soa && "MXFP type missing SoA dequant in traits"); } // access elements by index to avoid gaps in views diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index a8fb192623..ca2f4a2994 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -21,9 +21,13 @@ constexpr float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS = 0.0050f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR_FP4 = 0.0030f; +constexpr float MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 = 0.0070f; +constexpr float MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 = 0.0040f; +constexpr float MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 = 0.0020f; constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f; constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f; constexpr float MAX_DOT_PRODUCT_ERROR_FP4 = 0.03f; +constexpr float MAX_DOT_PRODUCT_ERROR_MXFP = 0.04f; constexpr float MAX_DOT_PRODUCT_ERROR_TERNARY = 0.15f; static const char* RESULT_STR[] = {"ok", "FAILED"}; @@ -152,7 +156,10 @@ int main(int argc, char * argv[]) { type == GGML_TYPE_Q3_K ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS : type == GGML_TYPE_IQ3_S ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS : type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS : - type == GGML_TYPE_NVFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_FP4 : MAX_QUANTIZATION_TOTAL_ERROR; + type == GGML_TYPE_NVFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_FP4 : + type == GGML_TYPE_MXFP4_E2M1 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 : + type == GGML_TYPE_MXFP6_E2M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 : + type == GGML_TYPE_MXFP8_E4M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR; failed = !(total_error < max_quantization_error); num_failed += failed; if (failed || verbose) { @@ -174,6 +181,8 @@ int main(int argc, char * argv[]) { ? MAX_DOT_PRODUCT_ERROR_TERNARY : type == GGML_TYPE_NVFP4 ? MAX_DOT_PRODUCT_ERROR_FP4 + : type == GGML_TYPE_MXFP4_E2M1 || type == GGML_TYPE_MXFP6_E2M3 || type == GGML_TYPE_MXFP8_E4M3 + ? MAX_DOT_PRODUCT_ERROR_MXFP : MAX_DOT_PRODUCT_ERROR; failed = !(vec_dot_error < max_allowed_error); num_failed += failed; @@ -183,6 +192,34 @@ int main(int argc, char * argv[]) { } } + // MXFP SoA roundtrip: test from_float_soa → to_float_soa through the traits system + for (int i = 0; i < GGML_TYPE_COUNT; i++) { + ggml_type type = (ggml_type) i; + const auto * qfns_cpu = ggml_get_type_traits_cpu(type); + + if (!qfns_cpu->from_float_soa || !qfns_cpu->to_float_soa) { + continue; + } + + const size_t buf_size = ggml_row_size(type, test_size); + std::vector tmp_q(buf_size); + std::vector tmp_out(test_size); + + qfns_cpu->from_float_soa(test_data.data(), tmp_q.data(), test_size); + qfns_cpu->to_float_soa(tmp_q.data(), tmp_out.data(), test_size); + + const float soa_error = array_rmse(test_data.data(), tmp_out.data(), test_size); + const float max_soa_error = + type == GGML_TYPE_MXFP4_E2M1 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 : + type == GGML_TYPE_MXFP6_E2M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 : + type == GGML_TYPE_MXFP8_E4M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR; + failed = !(soa_error < max_soa_error); + num_failed += failed; + if (failed || verbose) { + printf("%5s SoA quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], soa_error); + } + } + if (num_failed || verbose) { printf("%d tests failed\n", num_failed); } From ad2fa9035a89add2954dc22fe390eb56b2e68b22 Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Sun, 22 Mar 2026 01:07:55 -0400 Subject: [PATCH 11/13] test : add testing and fixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * cleanup : hoist mxfp soa functions * fix: CI failures — CUDA __device__ init, Metal MXFP supports_op, SoA test assert Three fixes for CI failures: 1. Remove from CUDA/HIP/MUSA section of ggml-common.h — the include causes NAN/INFINITY to become non-constexpr, breaking __device__ static table initialization for the MXFP LUTs. 2. Add MXFP type guards to Metal's supports_op: MXFP8/MXFP6 have no Metal shaders yet (reject all ops), MXFP4 has AoS shaders (MUL_MAT, GET_ROWS) but no SoA/flash attention support yet (reject FLASH_ATTN_EXT, SET_ROWS). 3. Replace strict assert in test-backend-ops init_tensor_mxfp_soa with a conditional fallback — when ne2 is not divisible by heads_per_region, fall back to per-head SoA init instead of crashing. * fix : correct guard for mxfp cpu dequant functions * fix: CUDA MXFP LUT init and MXFP flash attention SoA test layout - Add per-platform GGML_TABLE_NAN/GGML_TABLE_INFINITY macros for MXFP LUTs — uses __uint_as_float on CUDA to avoid MSVC non-constexpr INFINITY - Fix init_tensor_mxfp_soa to detect multihead SoA from tensor strides, matching the KV cache layout for permuted flash attention tests * fix: CUDA MXFP LUT init — use __builtin_nanf/__builtin_inff for constexpr device tables CUDA/HIP/MUSA __device__ static tables require constexpr initializers. Standard NAN/INFINITY macros may expand to non-constexpr expressions (e.g. MSVC: (float)(1e+300), nvcc: __uint_as_float is not constexpr for static init). Previous fix attempted __uint_as_float for nvcc and __builtin_bit_cast for clang — neither worked universally. Use __builtin_nanf("") and __builtin_inff() which are constexpr on all target compilers (nvcc, clang for HIP/MUSA, GCC, MSVC). Define once before the platform #if chain instead of per-platform copies. * fix: correct E5M2 LUT precision and add converter-vs-LUT validation tests The kvalues_mxfp8_e5m2 LUT had 50 values with insufficient decimal precision, causing bitwise mismatches against the IEEE-754 element converter. Regenerated from ggml_mxfp_fp8_e5m2_to_float() with %.9e precision for exact float round-trip on all 256 entries. Also consolidates GGML_TABLE_NAN/GGML_TABLE_INFINITY into a single definition using __builtin_nanf/__builtin_inff (constexpr on all target compilers), and adds LUT validation tests to test-quantize-fns that verify all 5 MXFP element converters match their canonical LUT values (FP4 E2M1: 16, FP6 E2M3: 64, FP6 E3M2: 64, FP8 E4M3: 256, FP8 E5M2: 256 — 656 total values verified). * fix: MSVC compat for GGML_TABLE_NAN/INFINITY — use builtins only on GCC/Clang/nvcc MSVC does not support __builtin_nanf/__builtin_inff. Use standard NAN/INFINITY macros on MSVC (which work for regular static tables), and compiler builtins only on GCC/Clang/nvcc (needed for CUDA __device__ table constexpr initialization). * fix: handle nvcc+MSVC host — check __CUDACC__ before _MSC_VER for NAN/INF macros When nvcc uses MSVC as the host compiler, both _MSC_VER and __CUDACC__ are defined. The previous fix checked _MSC_VER first, giving nvcc the MSVC NAN/INFINITY macros which are not constexpr for __device__ tables. Add __CUDACC__ exclusion so nvcc gets __builtin_nanf/__builtin_inff. * cleanup: remove AoS MXFP6/MXFP8 dequant code — these types are KV-cache-only (SoA) MXFP6 (E2M3) and MXFP8 (E4M3) exist only for KV cache flash attention, which uses SoA (Struct-of-Arrays) layout. The AoS dequant functions (NEON, AVX2, CPU dispatch, generic wrappers) were incorrectly added and are dead code — no model stores weights in these formats. Removed: - AoS NEON dequant: dequantize_row_mxfp{6,8}_neon, _cpu dispatch - AoS AVX2 dequant: dequantize_row_mxfp{6,8}_avx2, _cpu dispatch - AoS generic wrappers: dequantize_row_mxfp{6,8}_cpu_generic - AoS fallback defines in arch-fallback.h - CPU traits .to_float entries for MXFP6/MXFP8 - MXFP6/MXFP8 from all_types[] in test-backend-ops (no AoS tests) Kept (correct SoA code): - All *_soa_* functions (NEON, AVX2, generic, dispatch) - CPU traits .from_float_soa / .to_float_soa - Flash attention and SET_ROWS Hadamard test cases - Scalar reference dequant in ggml-quants.c (test-quantize-fns roundtrip) - MXFP4 AoS code (upstream model weight support, untouched) Fixes ARM64 CI failure: GET_ROWS(mxfp6_e2m3) was testing dead AoS code that had a NEON bug. The test no longer runs because the type is correctly excluded from AoS test paths. * test: guard all MXFP types must have SoA traits for flash attention All MXFP flash attention uses SoA layout exclusively. Test validates: - ALL MXFP types (MXFP4, MXFP6, MXFP8) have from_float_soa and to_float_soa - MXFP6/MXFP8 (KV-cache-only) do NOT have AoS CPU to_float Prevents regression: if someone adds AoS dequant back for MXFP6/MXFP8, or removes SoA traits from any MXFP type, CI will catch it. * test: add Hadamard, SoA cross-check, E8M0, and layout offset tests * test: add MXFP converter edge cases, FP6 packing, E8M0 known-answer tests Add comprehensive tests to catch the bugs backend implementers hit most: - Element converter edge cases: subnormals, max finite, saturation, NaN, sign - FP6 pack/unpack exhaustive round-trip with known-answer byte verification - E8M0 known-answer decode + HALF vs FULL scale distinction - E8M0 rounding boundary at sqrt(2) threshold (catches floor-only bugs) - Converter exhaustive round-trip: quantize(dequantize(i))==i for all formats - Consolidate duplicate SoA switches into single table in test-backend-ops * test: add AoS/SoA cross-check, Hadamard pipeline, format spec, and mxfp_rmse - MXFP4 AoS vs SoA cross-check: two independent code paths, bitwise match - Full Hadamard pipeline roundtrip: H→quantize→dequant→H for all 3 types - mxfp_rmse helper: computes sqrt(sum/n), with named pipeline constants - Block size consistency: verify QK_MXFP{4,8,6} == 32 - EMAX_OFFSET vs format max: validate constants produce valid E8M0 - Edge case LUT validation: expected_bits verified against canonical LUTs - FP4 E2M1 exhaustive converter round-trip (16/16) * cleanup: tighten MXFP test comments to match repo conventions * fix: platform-specific NaN/Infinity for GPU device table initializers FP8 E4M3/E5M2 LUTs contain NaN/Inf which cannot be constexpr-initialized in __device__ tables on any CUDA/HIP/MUSA version. No GPU backend uses these LUTs (they use converter functions instead), so guard them out of GPU builds entirely. Simplify GGML_TABLE_NAN/INFINITY to CPU-only macros. --- ggml/include/ggml.h | 1 + ggml/src/ggml-common.h | 87 ++- ggml/src/ggml-cpu/arch-fallback.h | 8 +- ggml/src/ggml-cpu/arch/arm/quants.c | 78 --- ggml/src/ggml-cpu/arch/x86/quants.c | 77 -- ggml/src/ggml-cpu/ggml-cpu.c | 3 +- ggml/src/ggml-cpu/ops.cpp | 16 +- ggml/src/ggml-cpu/quants.c | 8 +- ggml/src/ggml-cpu/quants.h | 12 +- ggml/src/ggml-metal/ggml-metal-device.m | 13 + ggml/src/ggml.c | 2 +- tests/CMakeLists.txt | 1 + tests/test-backend-ops.cpp | 117 ++-- tests/test-quantize-fns.cpp | 891 +++++++++++++++++++++++- 14 files changed, 1032 insertions(+), 282 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 6edf9909cf..1f66550459 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -467,6 +467,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors GGML_FTYPE_MOSTLY_MXFP4_E2M1 = 25, // except 1d tensors + GGML_FTYPE_MOSTLY_MXFP4 = GGML_FTYPE_MOSTLY_MXFP4_E2M1, // compat alias GGML_FTYPE_MOSTLY_NVFP4 = 26, // except 1d tensors }; diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 7308c3749b..271de1943c 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -574,11 +574,20 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_ #ifndef GGML_COMMON_IMPL +// NaN/Infinity for FP8 LUT initializers (CPU-only, guarded out of GPU builds). +#if defined(_MSC_VER) && !defined(__clang__) +#include +#define GGML_TABLE_NAN NAN +#define GGML_TABLE_INFINITY INFINITY +#else +#define GGML_TABLE_NAN __builtin_nanf("") +#define GGML_TABLE_INFINITY __builtin_inff() +#endif + #if defined(GGML_COMMON_IMPL_C) #include #include #include - #define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = { #define GGML_TABLE_END() }; #define GGML_MXFP_FUNC static inline @@ -636,7 +645,6 @@ static inline float ggml_mxfp_u32_as_f32_(uint32_t u) { float f; memcpy(&f, & #define GGML_COMMON_IMPL #elif defined(GGML_COMMON_IMPL_SYCL) - #include #include #include @@ -1308,6 +1316,10 @@ GGML_TABLE_BEGIN(float, kvalues_mxfp6_e3m2, 64) -8.0f, -10.0f, -12.0f, -14.0f, -16.0f, -20.0f, -24.0f, -28.0f, GGML_TABLE_END() +// FP8 E4M3/E5M2 LUTs contain NaN/Inf which cannot be constexpr-initialized in +// __device__ tables. GPU backends use the converter functions instead. +#if !defined(GGML_COMMON_DECL_CUDA) && !defined(GGML_COMMON_DECL_HIP) && !defined(GGML_COMMON_DECL_MUSA) + // FP8 E4M3 dequantization LUT: byte -> float. Entry 127 = 448 (max finite), 255 = NaN. GGML_TABLE_BEGIN(float, kvalues_mxfp8_e4m3, 256) 0.0f, 0.001953125f, 0.00390625f, 0.005859375f, 0.0078125f, 0.009765625f, 0.01171875f, 0.013671875f, @@ -1325,7 +1337,7 @@ GGML_TABLE_BEGIN(float, kvalues_mxfp8_e4m3, 256) 32.0f, 36.0f, 40.0f, 44.0f, 48.0f, 52.0f, 56.0f, 60.0f, 64.0f, 72.0f, 80.0f, 88.0f, 96.0f, 104.0f, 112.0f, 120.0f, 128.0f, 144.0f, 160.0f, 176.0f, 192.0f, 208.0f, 224.0f, 240.0f, - 256.0f, 288.0f, 320.0f, 352.0f, 384.0f, 416.0f, 448.0f, NAN, + 256.0f, 288.0f, 320.0f, 352.0f, 384.0f, 416.0f, 448.0f, GGML_TABLE_NAN, -0.0f,-0.001953125f, -0.00390625f,-0.005859375f, -0.0078125f,-0.009765625f, -0.01171875f,-0.013671875f, -0.015625f,-0.017578125f, -0.01953125f,-0.021484375f, -0.0234375f,-0.025390625f, -0.02734375f,-0.029296875f, -0.03125f, -0.03515625f, -0.0390625f, -0.04296875f, -0.046875f, -0.05078125f, -0.0546875f, -0.05859375f, @@ -1341,45 +1353,48 @@ GGML_TABLE_BEGIN(float, kvalues_mxfp8_e4m3, 256) -32.0f, -36.0f, -40.0f, -44.0f, -48.0f, -52.0f, -56.0f, -60.0f, -64.0f, -72.0f, -80.0f, -88.0f, -96.0f, -104.0f, -112.0f, -120.0f, -128.0f, -144.0f, -160.0f, -176.0f, -192.0f, -208.0f, -224.0f, -240.0f, - -256.0f, -288.0f, -320.0f, -352.0f, -384.0f, -416.0f, -448.0f, NAN, + -256.0f, -288.0f, -320.0f, -352.0f, -384.0f, -416.0f, -448.0f, GGML_TABLE_NAN, GGML_TABLE_END() // FP8 E5M2 dequantization LUT: byte -> float. Entries 124-127 = {Inf, NaN, NaN, NaN}. +// Generated from ggml_mxfp_fp8_e5m2_to_float() with %.9e precision for exact float round-trip. GGML_TABLE_BEGIN(float, kvalues_mxfp8_e5m2, 256) - 0.0f, 1.525879e-05f, 3.051758e-05f, 4.577637e-05f, 6.103516e-05f, 7.629395e-05f, 9.155273e-05f, 1.068115e-04f, - 1.220703e-04f, 1.525879e-04f, 1.831055e-04f, 2.136230e-04f, 2.441406e-04f, 3.051758e-04f, 3.662109e-04f, 4.272461e-04f, - 4.882812e-04f, 6.103516e-04f, 7.324219e-04f, 8.544922e-04f, 9.765625e-04f, 1.220703e-03f, 1.464844e-03f, 1.708984e-03f, - 1.953125e-03f, 2.441406e-03f, 2.929688e-03f, 3.417969e-03f, 3.906250e-03f, 4.882812e-03f, 5.859375e-03f, 6.835938e-03f, - 7.812500e-03f, 9.765625e-03f, 1.171875e-02f, 1.367188e-02f, 1.562500e-02f, 1.953125e-02f, 2.343750e-02f, 2.734375e-02f, - 3.125000e-02f, 3.906250e-02f, 4.687500e-02f, 5.468750e-02f, 6.250000e-02f, 7.812500e-02f, 9.375000e-02f, 1.093750e-01f, - 0.125f, 0.15625f, 0.1875f, 0.21875f, 0.25f, 0.3125f, 0.375f, 0.4375f, - 0.5f, 0.625f, 0.75f, 0.875f, 1.0f, 1.25f, 1.5f, 1.75f, - 2.0f, 2.5f, 3.0f, 3.5f, 4.0f, 5.0f, 6.0f, 7.0f, - 8.0f, 10.0f, 12.0f, 14.0f, 16.0f, 20.0f, 24.0f, 28.0f, - 32.0f, 40.0f, 48.0f, 56.0f, 64.0f, 80.0f, 96.0f, 112.0f, - 128.0f, 160.0f, 192.0f, 224.0f, 256.0f, 320.0f, 384.0f, 448.0f, - 512.0f, 640.0f, 768.0f, 896.0f, 1024.0f, 1280.0f, 1536.0f, 1792.0f, - 2048.0f, 2560.0f, 3072.0f, 3584.0f, 4096.0f, 5120.0f, 6144.0f, 7168.0f, - 8192.0f, 10240.0f, 12288.0f, 14336.0f, 16384.0f, 20480.0f, 24576.0f, 28672.0f, - 32768.0f, 40960.0f, 49152.0f, 57344.0f, INFINITY, NAN, NAN, NAN, - -0.0f,-1.525879e-05f,-3.051758e-05f,-4.577637e-05f,-6.103516e-05f,-7.629395e-05f,-9.155273e-05f,-1.068115e-04f, - -1.220703e-04f,-1.525879e-04f,-1.831055e-04f,-2.136230e-04f,-2.441406e-04f,-3.051758e-04f,-3.662109e-04f,-4.272461e-04f, - -4.882812e-04f,-6.103516e-04f,-7.324219e-04f,-8.544922e-04f,-9.765625e-04f,-1.220703e-03f,-1.464844e-03f,-1.708984e-03f, - -1.953125e-03f,-2.441406e-03f,-2.929688e-03f,-3.417969e-03f,-3.906250e-03f,-4.882812e-03f,-5.859375e-03f,-6.835938e-03f, - -7.812500e-03f,-9.765625e-03f,-1.171875e-02f,-1.367188e-02f,-1.562500e-02f,-1.953125e-02f,-2.343750e-02f,-2.734375e-02f, - -3.125000e-02f,-3.906250e-02f,-4.687500e-02f,-5.468750e-02f,-6.250000e-02f,-7.812500e-02f,-9.375000e-02f,-1.093750e-01f, - -0.125f, -0.15625f, -0.1875f, -0.21875f, -0.25f, -0.3125f, -0.375f, -0.4375f, - -0.5f, -0.625f, -0.75f, -0.875f, -1.0f, -1.25f, -1.5f, -1.75f, - -2.0f, -2.5f, -3.0f, -3.5f, -4.0f, -5.0f, -6.0f, -7.0f, - -8.0f, -10.0f, -12.0f, -14.0f, -16.0f, -20.0f, -24.0f, -28.0f, - -32.0f, -40.0f, -48.0f, -56.0f, -64.0f, -80.0f, -96.0f, -112.0f, - -128.0f, -160.0f, -192.0f, -224.0f, -256.0f, -320.0f, -384.0f, -448.0f, - -512.0f, -640.0f, -768.0f, -896.0f, -1024.0f, -1280.0f, -1536.0f, -1792.0f, - -2048.0f, -2560.0f, -3072.0f, -3584.0f, -4096.0f, -5120.0f, -6144.0f, -7168.0f, - -8192.0f, -10240.0f, -12288.0f, -14336.0f, -16384.0f, -20480.0f, -24576.0f, -28672.0f, - -32768.0f, -40960.0f, -49152.0f, -57344.0f, -INFINITY, NAN, NAN, NAN, + 0.000000000e+00f, 1.525878906e-05f, 3.051757812e-05f, 4.577636719e-05f, 6.103515625e-05f, 7.629394531e-05f, 9.155273438e-05f, 1.068115234e-04f, + 1.220703125e-04f, 1.525878906e-04f, 1.831054688e-04f, 2.136230469e-04f, 2.441406250e-04f, 3.051757812e-04f, 3.662109375e-04f, 4.272460938e-04f, + 4.882812500e-04f, 6.103515625e-04f, 7.324218750e-04f, 8.544921875e-04f, 9.765625000e-04f, 1.220703125e-03f, 1.464843750e-03f, 1.708984375e-03f, + 1.953125000e-03f, 2.441406250e-03f, 2.929687500e-03f, 3.417968750e-03f, 3.906250000e-03f, 4.882812500e-03f, 5.859375000e-03f, 6.835937500e-03f, + 7.812500000e-03f, 9.765625000e-03f, 1.171875000e-02f, 1.367187500e-02f, 1.562500000e-02f, 1.953125000e-02f, 2.343750000e-02f, 2.734375000e-02f, + 3.125000000e-02f, 3.906250000e-02f, 4.687500000e-02f, 5.468750000e-02f, 6.250000000e-02f, 7.812500000e-02f, 9.375000000e-02f, 1.093750000e-01f, + 1.250000000e-01f, 1.562500000e-01f, 1.875000000e-01f, 2.187500000e-01f, 2.500000000e-01f, 3.125000000e-01f, 3.750000000e-01f, 4.375000000e-01f, + 5.000000000e-01f, 6.250000000e-01f, 7.500000000e-01f, 8.750000000e-01f, 1.000000000e+00f, 1.250000000e+00f, 1.500000000e+00f, 1.750000000e+00f, + 2.000000000e+00f, 2.500000000e+00f, 3.000000000e+00f, 3.500000000e+00f, 4.000000000e+00f, 5.000000000e+00f, 6.000000000e+00f, 7.000000000e+00f, + 8.000000000e+00f, 1.000000000e+01f, 1.200000000e+01f, 1.400000000e+01f, 1.600000000e+01f, 2.000000000e+01f, 2.400000000e+01f, 2.800000000e+01f, + 3.200000000e+01f, 4.000000000e+01f, 4.800000000e+01f, 5.600000000e+01f, 6.400000000e+01f, 8.000000000e+01f, 9.600000000e+01f, 1.120000000e+02f, + 1.280000000e+02f, 1.600000000e+02f, 1.920000000e+02f, 2.240000000e+02f, 2.560000000e+02f, 3.200000000e+02f, 3.840000000e+02f, 4.480000000e+02f, + 5.120000000e+02f, 6.400000000e+02f, 7.680000000e+02f, 8.960000000e+02f, 1.024000000e+03f, 1.280000000e+03f, 1.536000000e+03f, 1.792000000e+03f, + 2.048000000e+03f, 2.560000000e+03f, 3.072000000e+03f, 3.584000000e+03f, 4.096000000e+03f, 5.120000000e+03f, 6.144000000e+03f, 7.168000000e+03f, + 8.192000000e+03f, 1.024000000e+04f, 1.228800000e+04f, 1.433600000e+04f, 1.638400000e+04f, 2.048000000e+04f, 2.457600000e+04f, 2.867200000e+04f, + 3.276800000e+04f, 4.096000000e+04f, 4.915200000e+04f, 5.734400000e+04f, GGML_TABLE_INFINITY, GGML_TABLE_NAN, GGML_TABLE_NAN, GGML_TABLE_NAN, + -0.000000000e+00f,-1.525878906e-05f,-3.051757812e-05f,-4.577636719e-05f,-6.103515625e-05f,-7.629394531e-05f,-9.155273438e-05f,-1.068115234e-04f, + -1.220703125e-04f,-1.525878906e-04f,-1.831054688e-04f,-2.136230469e-04f,-2.441406250e-04f,-3.051757812e-04f,-3.662109375e-04f,-4.272460938e-04f, + -4.882812500e-04f,-6.103515625e-04f,-7.324218750e-04f,-8.544921875e-04f,-9.765625000e-04f,-1.220703125e-03f,-1.464843750e-03f,-1.708984375e-03f, + -1.953125000e-03f,-2.441406250e-03f,-2.929687500e-03f,-3.417968750e-03f,-3.906250000e-03f,-4.882812500e-03f,-5.859375000e-03f,-6.835937500e-03f, + -7.812500000e-03f,-9.765625000e-03f,-1.171875000e-02f,-1.367187500e-02f,-1.562500000e-02f,-1.953125000e-02f,-2.343750000e-02f,-2.734375000e-02f, + -3.125000000e-02f,-3.906250000e-02f,-4.687500000e-02f,-5.468750000e-02f,-6.250000000e-02f,-7.812500000e-02f,-9.375000000e-02f,-1.093750000e-01f, + -1.250000000e-01f,-1.562500000e-01f,-1.875000000e-01f,-2.187500000e-01f,-2.500000000e-01f,-3.125000000e-01f,-3.750000000e-01f,-4.375000000e-01f, + -5.000000000e-01f,-6.250000000e-01f,-7.500000000e-01f,-8.750000000e-01f,-1.000000000e+00f,-1.250000000e+00f,-1.500000000e+00f,-1.750000000e+00f, + -2.000000000e+00f,-2.500000000e+00f,-3.000000000e+00f,-3.500000000e+00f,-4.000000000e+00f,-5.000000000e+00f,-6.000000000e+00f,-7.000000000e+00f, + -8.000000000e+00f,-1.000000000e+01f,-1.200000000e+01f,-1.400000000e+01f,-1.600000000e+01f,-2.000000000e+01f,-2.400000000e+01f,-2.800000000e+01f, + -3.200000000e+01f,-4.000000000e+01f,-4.800000000e+01f,-5.600000000e+01f,-6.400000000e+01f,-8.000000000e+01f,-9.600000000e+01f,-1.120000000e+02f, + -1.280000000e+02f,-1.600000000e+02f,-1.920000000e+02f,-2.240000000e+02f,-2.560000000e+02f,-3.200000000e+02f,-3.840000000e+02f,-4.480000000e+02f, + -5.120000000e+02f,-6.400000000e+02f,-7.680000000e+02f,-8.960000000e+02f,-1.024000000e+03f,-1.280000000e+03f,-1.536000000e+03f,-1.792000000e+03f, + -2.048000000e+03f,-2.560000000e+03f,-3.072000000e+03f,-3.584000000e+03f,-4.096000000e+03f,-5.120000000e+03f,-6.144000000e+03f,-7.168000000e+03f, + -8.192000000e+03f,-1.024000000e+04f,-1.228800000e+04f,-1.433600000e+04f,-1.638400000e+04f,-2.048000000e+04f,-2.457600000e+04f,-2.867200000e+04f, + -3.276800000e+04f,-4.096000000e+04f,-4.915200000e+04f,-5.734400000e+04f, -GGML_TABLE_INFINITY, GGML_TABLE_NAN, GGML_TABLE_NAN, GGML_TABLE_NAN, GGML_TABLE_END() +#endif // !CUDA && !HIP && !MUSA + // MXFP element converters -- portable IEEE-754 bit manipulation. #if defined(GGML_MXFP_FUNC) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index f622658918..ddeee1fa7e 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -343,12 +343,8 @@ #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #endif -// MXFP dequantize has no arch-specific (SIMD) implementations except on arm and x86. -// All other targets use the scalar generic as the public cpu function. -#if !defined(__aarch64__) && !defined(__arm__) && !defined(_M_ARM) && !defined(_M_ARM64) && \ - !defined(__x86_64__) && !defined(__i386__) && !defined(_M_IX86) && !defined(_M_X64) -#define dequantize_row_mxfp8_cpu_generic dequantize_row_mxfp8_cpu -#define dequantize_row_mxfp6_cpu_generic dequantize_row_mxfp6_cpu +// MXFP dequantize fallbacks (same GGML_CPU_GENERIC guard as above) +#if defined(GGML_CPU_GENERIC) #define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu #define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu #define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index 9ad9f29ae5..53507b97ce 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -4311,68 +4311,6 @@ static void ggml_vec_dot_mxfp6_q8_0_neon( *s = vaddvq_f32(vaddq_f32(acc0, acc1)); } -// MXFP FP8/FP6 dequantize_row (AoS) - -static void dequantize_row_mxfp8_neon( - const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k, - const mxfp_neon_traits_t * t) { - assert(k % QK_MXFP8 == 0); - const int nb = k / QK_MXFP8; - const block_mxfp8 * GGML_RESTRICT x = vx; - - const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask); - const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask); - const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off); - const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale); - const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift); - const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift); - - for (int ib = 0; ib < nb; ++ib) { - const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32(x[ib].e)); - - for (int j = 0; j < 32; j += 8) { - uint32x4_t v_lo, v_hi; - widen_u8x8_to_u32x4x2(x[ib].qs + j, &v_lo, &v_hi); - - const float32x4_t val_lo = mxfp8_dequant_neon(v_lo, - v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); - const float32x4_t val_hi = mxfp8_dequant_neon(v_hi, - v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); - - vst1q_f32(y + ib * QK_MXFP8 + j, vmulq_f32(val_lo, v_scale)); - vst1q_f32(y + ib * QK_MXFP8 + j + 4, vmulq_f32(val_hi, v_scale)); - } - } -} - -static void dequantize_row_mxfp6_neon( - const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k, - const mxfp_neon_traits_t * t) { - assert(k % QK_MXFP6 == 0); - const int nb = k / QK_MXFP6; - - const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask); - const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask); - const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off); - const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale); - const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift); - const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift); - - for (int ib = 0; ib < nb; ++ib) { - const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib; - const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32(xb->e)); - - for (int j = 0; j < 32; j += 4) { - const uint32x4_t v_raw = unpack_fp6x4_neon(xb->qs + (j * 3 / 4)); - - const float32x4_t val = mxfp6_dequant_neon(v_raw, - v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); - - vst1q_f32(y + ib * QK_MXFP6 + j, vmulq_f32(val, v_scale)); - } - } -} - // MXFP SoA dequant (flash attention) static void dequantize_row_mxfp8_soa_neon( @@ -4506,22 +4444,6 @@ void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { -#if defined(__ARM_NEON) - dequantize_row_mxfp8_neon(x, y, k, &MXFP_TRAITS_E4M3); -#else - dequantize_row_mxfp8_cpu_generic(x, y, k); -#endif -} - -void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { -#if defined(__ARM_NEON) - dequantize_row_mxfp6_neon(x, y, k, &MXFP_TRAITS_E2M3); -#else - dequantize_row_mxfp6_cpu_generic(x, y, k); -#endif -} - void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { #if defined(__ARM_NEON) dequantize_row_mxfp4_soa_neon(x, y, k); diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index 21b3fb4605..4b8f3386fa 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -3950,67 +3950,6 @@ static void ggml_vec_dot_mxfp6_q8_0_avx2( *s = hsum_float_8(acc); } -// MXFP FP8/FP6 dequantize_row (AoS) - -static void dequantize_row_mxfp8_avx2( - const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k, - const mxfp_avx2_traits_t * t) { - assert(k % QK_MXFP8 == 0); - const int nb = k / QK_MXFP8; - const block_mxfp8 * GGML_RESTRICT x = vx; - - const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask); - const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask); - const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off); - const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale); - const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask); - const __m256i v_zero = _mm256_setzero_si256(); - - for (int ib = 0; ib < nb; ++ib) { - const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32(x[ib].e)); - - for (int j = 0; j < 32; j += 8) { - const __m256i v_raw = _mm256_cvtepu8_epi32( - _mm_loadl_epi64((const __m128i *)(x[ib].qs + j))); - - const __m256 val = mxfp_dequant_avx2(v_raw, - v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, - v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift); - - _mm256_storeu_ps(y + ib * QK_MXFP8 + j, _mm256_mul_ps(val, v_scale)); - } - } -} - -static void dequantize_row_mxfp6_avx2( - const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k, - const mxfp_avx2_traits_t * t) { - assert(k % QK_MXFP6 == 0); - const int nb = k / QK_MXFP6; - - const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask); - const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask); - const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off); - const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale); - const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask); - const __m256i v_zero = _mm256_setzero_si256(); - - for (int ib = 0; ib < nb; ++ib) { - const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib; - const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32(xb->e)); - - for (int j = 0; j < 32; j += 8) { - const __m256i v_raw = unpack_fp6x8_avx2(xb->qs, j); - - const __m256 val = mxfp_dequant_avx2(v_raw, - v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, - v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift); - - _mm256_storeu_ps(y + ib * QK_MXFP6 + j, _mm256_mul_ps(val, v_scale)); - } - } -} - // MXFP SoA dequant (flash attention) static void dequantize_row_mxfp8_soa_avx2( @@ -4133,22 +4072,6 @@ void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { -#if defined(__AVX2__) - dequantize_row_mxfp8_avx2(x, y, k, &MXFP_TRAITS_E4M3); -#else - dequantize_row_mxfp8_cpu_generic(x, y, k); -#endif -} - -void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { -#if defined(__AVX2__) - dequantize_row_mxfp6_avx2(x, y, k, &MXFP_TRAITS_E2M3); -#else - dequantize_row_mxfp6_cpu_generic(x, y, k); -#endif -} - void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { #if defined(__AVX2__) dequantize_row_mxfp4_soa_avx2(x, y, k); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 9b8618423c..b84b6e0031 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -7,6 +7,7 @@ #include "ggml-cpu-impl.h" #include "ggml-impl.h" #include "quants.h" +#include "ggml-quants.h" #include "ggml-threading.h" #include "unary-ops.h" #include "binary-ops.h" @@ -280,7 +281,6 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { }, [GGML_TYPE_MXFP8_E4M3] = { .from_float = quantize_row_mxfp8, - .to_float = dequantize_row_mxfp8_cpu, .from_float_soa = quantize_row_mxfp8_soa, .to_float_soa = dequantize_row_mxfp8_soa_cpu, .vec_dot = ggml_vec_dot_mxfp8_q8_0, @@ -289,7 +289,6 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { }, [GGML_TYPE_MXFP6_E2M3] = { .from_float = quantize_row_mxfp6, - .to_float = dequantize_row_mxfp6_cpu, .from_float_soa = quantize_row_mxfp6_soa, .to_float_soa = dequantize_row_mxfp6_soa_cpu, .vec_dot = ggml_vec_dot_mxfp6_q8_0, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index cb1f881391..9291af62dc 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8313,20 +8313,8 @@ static mxfp_fa_params mxfp_fa_params_init( p.v_multihead = is_mxfp_v && (nbv2 == (size_t)ggml_row_size(v->type, DV)); p.v_soa_elems = is_mxfp_v ? (p.v_multihead ? nev2 * DV : DV) : 0; - // Per-head SoA addressing for multihead mode. - // Precompute byte offsets so the hot loop can skip per-head pointer math. - // qs_per_block values from centralized MXFP_QS_PER_BLOCK_* defines in ggml-common.h. - auto mxfp_qs_per_block = [](ggml_type type) -> int { - switch (type) { - case GGML_TYPE_MXFP4_E2M1: return MXFP4_SOA_QS_PER_BLOCK; - case GGML_TYPE_MXFP8_E4M3: return MXFP8_SOA_QS_PER_BLOCK; - case GGML_TYPE_MXFP6_E2M3: return MXFP6_SOA_QS_PER_BLOCK; - default: return 0; - } - }; - if (is_mxfp_k) { - p.k_qs_per_block = mxfp_qs_per_block(k->type); + p.k_qs_per_block = ggml_mxfp_qs_per_block(k->type); p.k_blocks_per_head = (int)(DK / 32); p.k_head_qs_bytes = p.k_blocks_per_head * p.k_qs_per_block; const int64_t k_total_blocks = p.k_multihead ? nek2 * p.k_blocks_per_head : p.k_blocks_per_head; @@ -8334,7 +8322,7 @@ static mxfp_fa_params mxfp_fa_params_init( } if (is_mxfp_v) { - p.v_qs_per_block = mxfp_qs_per_block(v->type); + p.v_qs_per_block = ggml_mxfp_qs_per_block(v->type); p.v_blocks_per_head = (int)(DV / 32); p.v_head_qs_bytes = p.v_blocks_per_head * p.v_qs_per_block; const int64_t v_total_blocks = p.v_multihead ? nev2 * p.v_blocks_per_head : p.v_blocks_per_head; diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index 477bd07304..eed3be90fc 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -309,13 +309,7 @@ void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, (ggml_to_float_t)dequantize_row_mxfp6); } -// Generic dequant wrappers — arch-specific SIMD versions override via fallback.h. -void dequantize_row_mxfp8_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { - dequantize_row_mxfp8(x, y, k); -} -void dequantize_row_mxfp6_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { - dequantize_row_mxfp6(x, y, k); -} +// Generic SoA dequant wrappers — arch-specific SIMD versions override via fallback.h. void dequantize_row_mxfp4_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { dequantize_row_mxfp4_soa(x, y, k); } diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index c16e87a2e9..78c9984bdc 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -24,10 +24,6 @@ void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, i void quantize_row_mxfp8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_mxfp6(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -// Dequantization (SIMD-optimized, arch-dispatched) -void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); - void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -87,13 +83,7 @@ void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void dequantize_row_mxfp8_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_mxfp6_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); - -// SoA quantize/dequant for MXFP flash attention -void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); -void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); -void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); +// SoA dequant (SIMD-dispatched, CPU backend) void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_mxfp6_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 82101f4714..a8996a2ab5 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1010,6 +1010,19 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te } } + // MXFP8/MXFP6: no Metal shaders yet — reject for all ops. + // MXFP4: has AoS shaders (MUL_MAT, GET_ROWS) but no SoA/flash attention support yet. + for (size_t i = 0, n = 3; i < n; ++i) { + if (op->src[i] != NULL && ggml_is_type_mxfp(op->src[i]->type)) { + if (op->src[i]->type != GGML_TYPE_MXFP4_E2M1) { + return false; + } + if (op->op == GGML_OP_FLASH_ATTN_EXT || op->op == GGML_OP_SET_ROWS) { + return false; + } + } + } + switch (op->op) { case GGML_OP_SCALE: case GGML_OP_FILL: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index ad48236169..40a0aab62b 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -711,7 +711,7 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref, }, [GGML_TYPE_MXFP4_E2M1] = { - .type_name = "mxfp4_e2m1", + .type_name = "mxfp4", .blck_size = QK_MXFP4, .type_size = sizeof(block_mxfp4), .is_quantized = true, diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9582164b58..575928e636 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -252,6 +252,7 @@ if (NOT GGML_BACKEND_DL) # these tests use the backends directly and cannot be built with dynamic loading llama_build_and_test(test-barrier.cpp) llama_build_and_test(test-quantize-fns.cpp) + target_include_directories(test-quantize-fns PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src) llama_build_and_test(test-quantize-perf.cpp) llama_build_and_test(test-rope.cpp) endif() diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 7f47079835..8e57cb1d1d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -18,7 +18,6 @@ #include #include #include -#include #include #include @@ -151,59 +150,79 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m } } +// MXFP SoA quantization functions +extern "C" { + void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); + void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); + void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); + void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); + void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); + void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); +} + +typedef void (*mxfp_soa_quantize_fn)(const float *, void *, int64_t); typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t); -// Initialize an MXFP tensor with SoA layout (soa_bytes = region width, 0 = one row). -static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f, size_t soa_bytes = 0) { +struct mxfp_soa_fns { + ggml_type type; + mxfp_soa_quantize_fn quantize; + mxfp_soa_dequantize_fn dequantize; +}; + +static const mxfp_soa_fns mxfp_soa_table[] = { + { GGML_TYPE_MXFP4_E2M1, quantize_row_mxfp4_soa, dequantize_row_mxfp4_soa }, + { GGML_TYPE_MXFP8_E4M3, quantize_row_mxfp8_soa, dequantize_row_mxfp8_soa }, + { GGML_TYPE_MXFP6_E2M3, quantize_row_mxfp6_soa, dequantize_row_mxfp6_soa }, +}; + +static const mxfp_soa_fns * get_mxfp_soa(ggml_type type) { + for (const auto & e : mxfp_soa_table) { + if (e.type == type) return &e; + } + return nullptr; +} + +// init MXFP tensor with SoA layout +static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { GGML_ASSERT(ggml_is_type_mxfp(tensor->type)); - const auto * traits = ggml_get_type_traits_cpu(tensor->type); - GGML_ASSERT(traits->from_float_soa && "MXFP type missing SoA quantize in traits"); - auto quantize_soa = traits->from_float_soa; + const auto * soa = get_mxfp_soa(tensor->type); + GGML_ASSERT(soa && "unsupported MXFP type for SoA init"); - const int qk = (int)ggml_blck_size(tensor->type); - const size_t block_size = ggml_type_size(tensor->type); - const size_t head_row_sz = ggml_row_size(tensor->type, tensor->ne[0]); - if (soa_bytes == 0) { soa_bytes = head_row_sz; } - GGML_ASSERT(soa_bytes % block_size == 0 && "soa_bytes must be a multiple of block_size"); - const int64_t soa_elems = (int64_t)(soa_bytes / block_size) * qk; + const int64_t DK = tensor->ne[0]; + const size_t row_sz = ggml_row_size(tensor->type, DK); + + // multihead: heads packed contiguously + const bool multihead = (tensor->nb[2] == row_sz) && (tensor->ne[2] > 1); std::default_random_engine gen(42); std::uniform_real_distribution dist(min, max); - std::vector region_f32(soa_elems); - - const size_t nb1 = tensor->nb[1]; - const size_t nb2 = tensor->nb[2]; - const size_t nb3 = tensor->nb[3]; - const int64_t ne1 = tensor->ne[1]; - const int64_t ne2 = tensor->ne[2]; - const int64_t ne3 = tensor->ne[3]; - - const int64_t heads_per_region = (int64_t)(soa_bytes / head_row_sz); - GGML_ASSERT(soa_bytes % head_row_sz == 0 && "soa_bytes must be a multiple of head_row_sz"); std::vector buf(ggml_nbytes(tensor), 0); - if (heads_per_region > 1) { - // Multi-head SoA: - for (int64_t i3 = 0; i3 < ne3; i3++) { - const int64_t n_groups = ne2 / heads_per_region; - for (int64_t ig = 0; ig < n_groups; ig++) { - for (int64_t i1 = 0; i1 < ne1; i1++) { - size_t offset = i3*nb3 + ig*heads_per_region*nb2 + i1*nb1; - for (int64_t j = 0; j < soa_elems; j++) { region_f32[j] = dist(gen); } - quantize_soa(region_f32.data(), buf.data() + offset, soa_elems); - } + if (multihead) { + // all heads at one position share one SoA region + const int64_t n_heads = tensor->ne[2]; + const int64_t soa_elems = n_heads * DK; + std::vector region(soa_elems); + + for (int64_t i3 = 0; i3 < tensor->ne[3]; i3++) { + for (int64_t i1 = 0; i1 < tensor->ne[1]; i1++) { + size_t offset = i3*tensor->nb[3] + i1*tensor->nb[1]; + for (int64_t j = 0; j < soa_elems; j++) { region[j] = dist(gen); } + soa->quantize(region.data(), buf.data() + offset, soa_elems); } } } else { - // Per-head SoA: - for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = 0; i2 < ne2; i2++) { - for (int64_t i1 = 0; i1 < ne1; i1++) { - size_t offset = i3*nb3 + i2*nb2 + i1*nb1; - for (int64_t j = 0; j < soa_elems; j++) { region_f32[j] = dist(gen); } - quantize_soa(region_f32.data(), buf.data() + offset, soa_elems); + // per-head SoA: each head independently packed + std::vector region(DK); + + for (int64_t i3 = 0; i3 < tensor->ne[3]; i3++) { + for (int64_t i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int64_t i1 = 0; i1 < tensor->ne[1]; i1++) { + size_t offset = i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1]; + for (int64_t j = 0; j < DK; j++) { region[j] = dist(gen); } + soa->quantize(region.data(), buf.data() + offset, DK); } } } @@ -304,9 +323,12 @@ static std::vector tensor_to_float(const ggml_tensor * t) { const bool is_mxfp = ggml_is_type_mxfp(t->type); mxfp_soa_dequantize_fn mxfp_dequant_soa = nullptr; + std::vector mxfp_row_f32; if (is_mxfp) { - mxfp_dequant_soa = (mxfp_soa_dequantize_fn) ggml_get_type_traits_cpu(t->type)->to_float_soa; - GGML_ASSERT(mxfp_dequant_soa && "MXFP type missing SoA dequant in traits"); + const auto * soa_fns = get_mxfp_soa(t->type); + GGML_ASSERT(soa_fns && "unsupported MXFP type in tensor_to_float"); + mxfp_dequant_soa = soa_fns->dequantize; + mxfp_row_f32.resize(t->ne[0]); } // access elements by index to avoid gaps in views @@ -315,9 +337,8 @@ static std::vector tensor_to_float(const ggml_tensor * t) { for (int64_t i1 = 0; i1 < t->ne[1]; i1++) { if (is_mxfp) { size_t row_off = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1]; - std::vector row_f32(t->ne[0]); - mxfp_dequant_soa(&buf[row_off], row_f32.data(), t->ne[0]); - tv.insert(tv.end(), row_f32.begin(), row_f32.end()); + mxfp_dequant_soa(&buf[row_off], mxfp_row_f32.data(), t->ne[0]); + tv.insert(tv.end(), mxfp_row_f32.begin(), mxfp_row_f32.end()); continue; } for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) { @@ -6370,8 +6391,7 @@ struct test_flash_attn_ext : public test_case { } else if (strcmp(t->name, "m") == 0) { init_tensor_kq_mask(t); } else if ((strcmp(t->name, "k") == 0 || strcmp(t->name, "v") == 0) && ggml_is_type_mxfp(t->type)) { - // MXFP K/V use SoA layout; nb[1] spans all heads in one KV-position stride - init_tensor_mxfp_soa(t, -1.0f, 1.0f, t->nb[1]); + init_tensor_mxfp_soa(t); } else { init_tensor_uniform(t); } @@ -7378,8 +7398,7 @@ static const ggml_type all_types[] = { GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0, - GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP8_E4M3, - GGML_TYPE_MXFP6_E2M3, + GGML_TYPE_MXFP4_E2M1, GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K, diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index ca2f4a2994..98e3d489dd 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -2,6 +2,11 @@ #include "ggml.h" #include "ggml-cpu.h" +#include "ggml-quants.h" + +#define GGML_COMMON_DECL_CPP +#define GGML_COMMON_IMPL_CPP +#include "ggml-common.h" #undef NDEBUG #include @@ -24,6 +29,12 @@ constexpr float MAX_QUANTIZATION_TOTAL_ERROR_FP4 = 0.0030f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 = 0.0070f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 = 0.0040f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 = 0.0020f; +// MXFP Hadamard pipeline thresholds (mxfp_rmse, which computes sqrt(sum/n)). +// These represent actual RMSE through the full KV cache write/read path. +constexpr float MAX_MXFP_PIPELINE_ERROR_MXFP4 = 0.40f; +constexpr float MAX_MXFP_PIPELINE_ERROR_MXFP8 = 0.08f; +constexpr float MAX_MXFP_PIPELINE_ERROR_MXFP6 = 0.10f; + constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f; constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f; constexpr float MAX_DOT_PRODUCT_ERROR_FP4 = 0.03f; @@ -50,6 +61,16 @@ static float array_rmse(const float * a1, const float * a2, size_t n) { return sqrtf(sum) / n; } +// MXFP RMSE: sqrt(sum/n), used with MAX_MXFP_PIPELINE_ERROR_* thresholds +static float mxfp_rmse(const float * a1, const float * a2, size_t n) { + double sum = 0; + for (size_t i = 0; i < n; i++) { + double diff = a1[i] - a2[i]; + sum += diff * diff; + } + return sqrtf((float)(sum / n)); +} + // Total quantization error on test data static float total_quantization_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data) { std::vector tmp_q(2*test_size); @@ -192,7 +213,7 @@ int main(int argc, char * argv[]) { } } - // MXFP SoA roundtrip: test from_float_soa → to_float_soa through the traits system + // MXFP SoA roundtrip via traits for (int i = 0; i < GGML_TYPE_COUNT; i++) { ggml_type type = (ggml_type) i; const auto * qfns_cpu = ggml_get_type_traits_cpu(type); @@ -220,6 +241,874 @@ int main(int argc, char * argv[]) { } } + // MXFP traits: SoA required, MXFP6/MXFP8 are KV-cache-only (no AoS dequant) + { + const ggml_type all_mxfp_types[] = { GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3 }; + for (ggml_type type : all_mxfp_types) { + const auto * cpu = ggml_get_type_traits_cpu(type); + + failed = !(cpu->from_float_soa && cpu->to_float_soa); + num_failed += failed; + if (failed || verbose) { + printf("%5s SoA traits present: %s\n", ggml_type_name(type), RESULT_STR[failed]); + } + } + + // KV-cache-only types: no AoS dequant + const ggml_type kv_only_types[] = { GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3 }; + for (ggml_type type : kv_only_types) { + const auto * cpu = ggml_get_type_traits_cpu(type); + failed = (cpu->to_float != nullptr); + num_failed += failed; + if (failed || verbose) { + printf("%5s AoS CPU to_float absent: %s\n", ggml_type_name(type), RESULT_STR[failed]); + } + } + } + + // Hadamard self-inverse: H(H(x)) == x + { + float original[32], transformed[32]; + for (int i = 0; i < 32; i++) { + original[i] = 0.1f + 2.0f * cosf(i + 0.5f); + transformed[i] = original[i]; + } + ggml_hadamard_32_inplace(transformed); + ggml_hadamard_32_inplace(transformed); // apply twice = identity + + float max_err = 0.0f; + for (int i = 0; i < 32; i++) { + float err = fabsf(transformed[i] - original[i]); + if (err > max_err) max_err = err; + } + // floating-point rounding tolerance + failed = !(max_err < 1e-5f); + num_failed += failed; + if (failed || verbose) { + printf("hadamard H(H(x))==x roundtrip: %s (max_err=%.2e)\n", RESULT_STR[failed], max_err); + } + } + + // SoA SIMD vs scalar dequant + { + struct soa_cross_check { + ggml_type type; + void (*ref_dequant)(const void *, float *, int64_t); + }; + + const soa_cross_check checks[] = { + { GGML_TYPE_MXFP4_E2M1, dequantize_row_mxfp4_soa }, + { GGML_TYPE_MXFP8_E4M3, dequantize_row_mxfp8_soa }, + { GGML_TYPE_MXFP6_E2M3, dequantize_row_mxfp6_soa }, + }; + + for (const auto & c : checks) { + const auto * cpu = ggml_get_type_traits_cpu(c.type); + if (!cpu->from_float_soa || !cpu->to_float_soa) continue; + + const size_t buf_size = ggml_row_size(c.type, test_size); + std::vector tmp_q(buf_size); + std::vector out_ref(test_size); + std::vector out_simd(test_size); + + // Quantize with SoA + cpu->from_float_soa(test_data.data(), tmp_q.data(), test_size); + + // Dequant with scalar reference + c.ref_dequant(tmp_q.data(), out_ref.data(), test_size); + + // Dequant with CPU/SIMD path + cpu->to_float_soa(tmp_q.data(), out_simd.data(), test_size); + + // Compare bitwise + int mismatches = 0; + for (size_t j = 0; j < test_size; j++) { + uint32_t a, b; + memcpy(&a, &out_ref[j], 4); + memcpy(&b, &out_simd[j], 4); + if (a != b) mismatches++; + } + failed = (mismatches > 0); + num_failed += failed; + if (failed || verbose) { + printf("%5s SoA SIMD vs scalar ref: %s (%zu/%zu match)\n", + ggml_type_name(c.type), RESULT_STR[failed], + test_size - mismatches, test_size); + } + } + } + + // element converters vs canonical LUT values + { + struct lut_test { + const char * name; + const float * lut; + int count; + float (*converter)(uint8_t); + }; + + const lut_test lut_tests[] = { + { "fp8_e4m3", kvalues_mxfp8_e4m3, 256, fp8_e4m3_to_float }, + { "fp8_e5m2", kvalues_mxfp8_e5m2, 256, fp8_e5m2_to_float }, + { "fp6_e2m3", kvalues_mxfp6_e2m3, 64, fp6_e2m3_to_float }, + { "fp6_e3m2", kvalues_mxfp6_e3m2, 64, fp6_e3m2_to_float }, + }; + + for (const auto & t : lut_tests) { + int mismatches = 0; + for (int i = 0; i < t.count; i++) { + const float converter_val = t.converter((uint8_t)i); + const float lut_val = t.lut[i]; + + // both NaN = match + if (isnan(converter_val) && isnan(lut_val)) continue; + if (converter_val != lut_val) { + if (mismatches == 0 || verbose) { + printf(" %s LUT mismatch at [%d]: converter=%.8g, lut=%.8g\n", + t.name, i, converter_val, lut_val); + } + mismatches++; + } + } + failed = (mismatches > 0); + num_failed += failed; + if (failed || verbose) { + printf("%5s converter vs LUT: %s (%d/%d values match)\n", + t.name, RESULT_STR[failed], t.count - mismatches, t.count); + } + } + + // FP4 E2M1 + { + int mismatches = 0; + for (int i = 0; i < 16; i++) { + const float converter_val = ggml_mxfp_fp4_e2m1_to_float((uint8_t)i); + const float lut_val = kvalues_mxfp4_float[i]; + if (converter_val != lut_val) { + if (mismatches == 0 || verbose) { + printf(" fp4_e2m1 LUT mismatch at [%d]: converter=%.8g, lut=%.8g\n", + i, converter_val, lut_val); + } + mismatches++; + } + } + failed = (mismatches > 0); + num_failed += failed; + if (failed || verbose) { + printf("fp4_e2m1 converter vs LUT: %s (%d/16 values match)\n", + RESULT_STR[failed], 16 - mismatches); + } + } + } + + // element converter edge cases (expected values validated against LUTs) + { + struct conv_check { + const char * name; + float input; + uint8_t expected_bits; + bool is_saturation; // true = input overflows, expected_bits is max finite + const float * lut; // canonical LUT to validate expected_bits against (NULL for FP4) + float (*to_float)(uint8_t); + uint8_t (*to_quant)(float); + }; + + const conv_check checks[] = { + // FP4 E2M1 -[S(1)|E(2)|M(1)], bias=0 + { "fp4 zero", 0.0f, 0x00, false, nullptr, nullptr, nullptr }, + { "fp4 sub 0.5", 0.5f, 0x01, false, nullptr, nullptr, nullptr }, + { "fp4 norm 1.0", 1.0f, 0x02, false, nullptr, nullptr, nullptr }, + { "fp4 max 6.0", 6.0f, 0x07, false, nullptr, nullptr, nullptr }, + { "fp4 neg -3.0", -3.0f, 0x0D, false, nullptr, nullptr, nullptr }, + { "fp4 sat 100", 100.0f, 0x07, true, nullptr, nullptr, nullptr }, + + // FP8 E4M3 -[S(1)|E(4)|M(3)], bias=7 + { "e4m3 zero", 0.0f, 0x00, false, kvalues_mxfp8_e4m3, fp8_e4m3_to_float, float_to_fp8_e4m3_rn }, + { "e4m3 sub", 1.f/512, 0x01, false, kvalues_mxfp8_e4m3, fp8_e4m3_to_float, float_to_fp8_e4m3_rn }, + { "e4m3 max 448", 448.0f, 0x7E, false, kvalues_mxfp8_e4m3, fp8_e4m3_to_float, float_to_fp8_e4m3_rn }, + { "e4m3 sat 500", 500.0f, 0x7E, true, kvalues_mxfp8_e4m3, fp8_e4m3_to_float, float_to_fp8_e4m3_rn }, + { "e4m3 neg -1", -1.0f, 0xB8, false, kvalues_mxfp8_e4m3, fp8_e4m3_to_float, float_to_fp8_e4m3_rn }, + + // FP6 E2M3 -[S(1)|E(2)|M(3)], no NaN/Inf + { "e2m3 zero", 0.0f, 0x00, false, kvalues_mxfp6_e2m3, fp6_e2m3_to_float, float_to_fp6_e2m3_rn }, + { "e2m3 sub", 0.125f, 0x01, false, kvalues_mxfp6_e2m3, fp6_e2m3_to_float, float_to_fp6_e2m3_rn }, + { "e2m3 max 7.5", 7.5f, 0x1F, false, kvalues_mxfp6_e2m3, fp6_e2m3_to_float, float_to_fp6_e2m3_rn }, + { "e2m3 sat 100", 100.0f, 0x1F, true, kvalues_mxfp6_e2m3, fp6_e2m3_to_float, float_to_fp6_e2m3_rn }, + + // FP6 E3M2 -[S(1)|E(3)|M(2)], no NaN/Inf, exp=7 is NORMAL + { "e3m2 zero", 0.0f, 0x00, false, kvalues_mxfp6_e3m2, fp6_e3m2_to_float, float_to_fp6_e3m2_rn }, + { "e3m2 sub", 0.0625f, 0x01, false, kvalues_mxfp6_e3m2, fp6_e3m2_to_float, float_to_fp6_e3m2_rn }, + { "e3m2 max 28.0", 28.0f, 0x1F, false, kvalues_mxfp6_e3m2, fp6_e3m2_to_float, float_to_fp6_e3m2_rn }, + { "e3m2 exp7 16", 16.0f, 0x1C, false, kvalues_mxfp6_e3m2, fp6_e3m2_to_float, float_to_fp6_e3m2_rn }, + + // FP8 E5M2 -[S(1)|E(5)|M(2)], bias=15 + { "e5m2 zero", 0.0f, 0x00, false, kvalues_mxfp8_e5m2, fp8_e5m2_to_float, float_to_fp8_e5m2_rn }, + { "e5m2 max", 57344.f, 0x7B, false, kvalues_mxfp8_e5m2, fp8_e5m2_to_float, float_to_fp8_e5m2_rn }, + }; + + int conv_bad = 0; + + // validate expected_bits against LUTs + for (const auto & c : checks) { + if (c.lut && !c.is_saturation) { + float lut_val = c.lut[c.expected_bits]; + if (c.input != lut_val && !(c.input == 0.0f && lut_val == 0.0f)) { + printf(" TEST BUG %s: expected_bits=0x%02X → LUT=%.8g, but input=%.8g\n", + c.name, c.expected_bits, lut_val, c.input); + conv_bad++; + } + } else if (!c.lut && !c.is_saturation) { + float lut_val = kvalues_mxfp4_float[c.expected_bits]; + if (c.input != lut_val && !(c.input == 0.0f && lut_val == 0.0f)) { + printf(" TEST BUG %s: expected_bits=0x%02X → LUT=%.8g, but input=%.8g\n", + c.name, c.expected_bits, lut_val, c.input); + conv_bad++; + } + } + } + + // Now test the quantize direction + for (const auto & c : checks) { + uint8_t got; + if (c.to_quant) { + got = c.to_quant(c.input); + } else { + got = ggml_mxfp_float_to_fp4_e2m1(c.input); + } + if (got != c.expected_bits) { + if (conv_bad == 0 || verbose) { + printf(" %s: quantize(%.6g) = 0x%02X, expected 0x%02X\n", + c.name, c.input, got, c.expected_bits); + } + conv_bad++; + } + } + + // FP8 E4M3: 0x7F must dequantize to NaN + { + float nan_val = fp8_e4m3_to_float(0x7F); + if (!isnan(nan_val)) { + if (conv_bad == 0 || verbose) { + printf(" e4m3 0x7F dequant: expected NaN, got %.6g\n", nan_val); + } + conv_bad++; + } + } + + // FP6 E3M2: exp=7 must dequant to valid float (NOT Inf/NaN) + { + float exp7_val = fp6_e3m2_to_float(0x1F); // max: exp=7, mant=3 → 28.0 + if (isnan(exp7_val) || exp7_val != 28.0f) { + if (conv_bad == 0 || verbose) { + printf(" e3m2 0x1F dequant: expected 28.0, got %.6g\n", exp7_val); + } + conv_bad++; + } + } + + failed = (conv_bad > 0); + num_failed += failed; + if (failed || verbose) { + printf(" element converter edge cases: %s (%d/%d passed)\n", + RESULT_STR[failed], + (int)(sizeof(checks)/sizeof(checks[0])) + 2 - conv_bad, + (int)(sizeof(checks)/sizeof(checks[0])) + 2); + } + } + + // FP6 pack/unpack round-trip + { + int pack_bad = 0; + + // Test all 64 possible 6-bit values in each of the 4 positions + for (int pos = 0; pos < 4; pos++) { + for (int val = 0; val < 64; val++) { + uint8_t in[4] = {0, 0, 0, 0}; + in[pos] = (uint8_t)val; + + uint8_t packed[3], out[4]; + pack_fp6x4(in, packed); + unpack_fp6x4(packed, out); + + if (out[pos] != (uint8_t)val) { + if (pack_bad == 0 || verbose) { + printf(" fp6 pack roundtrip: pos=%d val=0x%02X → got 0x%02X\n", + pos, val, out[pos]); + } + pack_bad++; + } + // no crosstalk + for (int k = 0; k < 4; k++) { + if (k != pos && out[k] != 0) { + if (pack_bad == 0 || verbose) { + printf(" fp6 pack crosstalk: pos=%d val=0x%02X leaked to pos=%d (0x%02X)\n", + pos, val, k, out[k]); + } + pack_bad++; + } + } + } + } + + // known-answer: [0x3F, 0x00, 0x3F, 0x00] -> {0x3F, 0xF0, 0x03} + { + uint8_t in[4] = {0x3F, 0x00, 0x3F, 0x00}; + uint8_t packed[3]; + pack_fp6x4(in, packed); + uint8_t expected[3] = {0x3F, 0xF0, 0x03}; + if (packed[0] != expected[0] || packed[1] != expected[1] || packed[2] != expected[2]) { + if (pack_bad == 0 || verbose) { + printf(" fp6 known-answer: packed [%02X,%02X,%02X] expected [%02X,%02X,%02X]\n", + packed[0], packed[1], packed[2], expected[0], expected[1], expected[2]); + } + pack_bad++; + } + } + + failed = (pack_bad > 0); + num_failed += failed; + if (failed || verbose) { + printf(" fp6 pack/unpack round-trip: %s\n", RESULT_STR[failed]); + } + } + + // E8M0 known-answer decode + HALF vs FULL (MXFP4 uses HALF, MXFP6/8 use FULL) + { + int e8m0_bad = 0; + + // Known-answer E8M0 decodes + struct { uint8_t e; float expected; } e8m0_known[] = { + { 127, 1.0f }, // 2^(127-127) = 2^0 = 1.0 + { 128, 2.0f }, // 2^(128-127) = 2^1 = 2.0 + { 126, 0.5f }, // 2^(126-127) = 2^(-1) = 0.5 + { 254, 1.70141183e+38f }, // 2^127 (max representable) + { 1, 1.17549435e-38f }, // 2^(-126) (min normal) + }; + for (const auto & t : e8m0_known) { + float got = ggml_mxfp_e8m0_to_fp32(t.e); + if (got != t.expected) { + if (e8m0_bad == 0 || verbose) { + printf(" E8M0 decode e=%d: got %.8g, expected %.8g\n", t.e, got, t.expected); + } + e8m0_bad++; + } + } + + // HALF must be exactly half of FULL for all valid exponents + for (int e = 2; e < 255; e++) { + float full = ggml_mxfp_e8m0_to_fp32((uint8_t)e); + float half = ggml_mxfp_e8m0_to_fp32_half((uint8_t)e); + if (half != full * 0.5f) { + if (e8m0_bad == 0 || verbose) { + printf(" E8M0 HALF!=FULL/2 at e=%d: half=%.8g, full/2=%.8g\n", e, half, full * 0.5f); + } + e8m0_bad++; + break; // one failure is enough to flag the pattern + } + } + + failed = (e8m0_bad > 0); + num_failed += failed; + if (failed || verbose) { + printf(" E8M0 known-answer + HALF/FULL: %s\n", RESULT_STR[failed]); + } + } + + // E8M0 rounding at sqrt(2) threshold + { + int round_bad = 0; + + // amax=1.0: floor_log2=0, mantissa=0 → no round → e_base = 0 - 0 + 127 = 127 + { + int e = ggml_mxfp_e8m0_base_estimate(1.0f, 0); + if (e != 127) { + printf(" E8M0 round: amax=1.0 → e=%d, expected 127\n", e); + round_bad++; + } + } + // amax=2.0: floor_log2=1, mantissa=0 → no round → e_base = 1 + 127 = 128 + { + int e = ggml_mxfp_e8m0_base_estimate(2.0f, 0); + if (e != 128) { + printf(" E8M0 round: amax=2.0 → e=%d, expected 128\n", e); + round_bad++; + } + } + // amax just below sqrt(2): mantissa < 0x3504F3 → floor only → e=127 + { + // 1.41421 has IEEE mantissa just below 0x3504F3 + float below = 1.4142f; + int e = ggml_mxfp_e8m0_base_estimate(below, 0); + if (e != 127) { + printf(" E8M0 round: amax=%.6f → e=%d, expected 127 (no round)\n", below, e); + round_bad++; + } + } + // amax at sqrt(2): mantissa >= 0x3504F3 → rounds up → e=128 + { + float at_sqrt2 = 1.41422f; + int e = ggml_mxfp_e8m0_base_estimate(at_sqrt2, 0); + if (e != 128) { + printf(" E8M0 round: amax=%.6f → e=%d, expected 128 (rounds up)\n", at_sqrt2, e); + round_bad++; + } + } + // Verify emax_offset shifts the result + { + int e_no_off = ggml_mxfp_e8m0_base_estimate(448.0f, 0); + int e_e4m3 = ggml_mxfp_e8m0_base_estimate(448.0f, MXFP8_E4M3_EMAX_OFFSET); + if (e_no_off - e_e4m3 != MXFP8_E4M3_EMAX_OFFSET) { + printf(" E8M0 emax_offset: diff=%d, expected %d\n", + e_no_off - e_e4m3, MXFP8_E4M3_EMAX_OFFSET); + round_bad++; + } + } + + failed = (round_bad > 0); + num_failed += failed; + if (failed || verbose) { + printf(" E8M0 rounding boundary: %s\n", RESULT_STR[failed]); + } + } + + // Element converter exhaustive round-trip: quantize(dequantize(i)) == i for all valid bit patterns. + // Catches asymmetries between the to_float and to_quant paths. + { + struct rt_test { + const char * name; + int count; + float (*to_float)(uint8_t); + uint8_t (*to_quant)(float); + uint8_t nan_bits; // bit pattern for NaN (0 = no NaN in format) + }; + + const rt_test rt_tests[] = { + { "fp8_e4m3", 256, fp8_e4m3_to_float, float_to_fp8_e4m3_rn, 0x7F }, + { "fp8_e5m2", 256, fp8_e5m2_to_float, float_to_fp8_e5m2_rn, 0 }, + { "fp6_e2m3", 64, fp6_e2m3_to_float, float_to_fp6_e2m3_rn, 0 }, + { "fp6_e3m2", 64, fp6_e3m2_to_float, float_to_fp6_e3m2_rn, 0 }, + }; + + for (const auto & t : rt_tests) { + int rt_bad = 0; + for (int i = 0; i < t.count; i++) { + if ((uint8_t)i == t.nan_bits) continue; // skip NaN -quantize(NaN) is implementation-defined + + float f = t.to_float((uint8_t)i); + if (isnan(f) || isinf(f)) continue; // E5M2 Inf/NaN + + uint8_t back = t.to_quant(f); + // Negative zero may round-trip to positive zero -both are valid + if (back != (uint8_t)i && !(f == 0.0f && t.to_float(back) == 0.0f)) { + if (rt_bad == 0 || verbose) { + printf(" %s roundtrip: 0x%02X → %.6g → 0x%02X\n", + t.name, i, f, back); + } + rt_bad++; + } + } + failed = (rt_bad > 0); + num_failed += failed; + if (failed || verbose) { + printf("%5s converter round-trip: %s (%d/%d survived)\n", + t.name, RESULT_STR[failed], t.count - rt_bad, t.count); + } + } + + // FP4 E2M1: uses static inline converters (not GGML_API wrappers), only 16 values + { + int rt_bad = 0; + for (int i = 0; i < 16; i++) { + float f = ggml_mxfp_fp4_e2m1_to_float((uint8_t)i); + uint8_t back = ggml_mxfp_float_to_fp4_e2m1(f); + if (back != (uint8_t)i && !(f == 0.0f && ggml_mxfp_fp4_e2m1_to_float(back) == 0.0f)) { + if (rt_bad == 0 || verbose) { + printf(" fp4_e2m1 roundtrip: 0x%02X → %.6g → 0x%02X\n", i, f, back); + } + rt_bad++; + } + } + failed = (rt_bad > 0); + num_failed += failed; + if (failed || verbose) { + printf("fp4_e2m1 converter round-trip: %s (%d/16 survived)\n", + RESULT_STR[failed], 16 - rt_bad); + } + } + } + + // E8M0 scale computation: verify base exponent is reasonable for various amax values + { + const float test_amax[] = { 0.001f, 0.1f, 1.0f, 6.0f, 100.0f, 448.0f, 10000.0f }; + int bad = 0; + for (float amax : test_amax) { + // ggml_mxfp_e8m0_base_estimate returns unclamped e_base + int e_base = ggml_mxfp_e8m0_base_estimate(amax, 0); + if (e_base < 1 || e_base > 254) { + if (bad == 0 || verbose) { + printf(" E8M0 bad e_base=%d for amax=%.4f\n", e_base, amax); + } + bad++; + continue; + } + float scale = ggml_mxfp_e8m0_to_fp32((uint8_t)e_base); + // Scale should be within 2x of amax (rough sanity check) + float ratio = amax / scale; + if (ratio < 0.25f || ratio > 4.0f) { + if (bad == 0 || verbose) { + printf(" E8M0 scale=%.6g for amax=%.4f, ratio=%.4f (expected ~1)\n", + scale, amax, ratio); + } + bad++; + } + } + failed = (bad > 0); + num_failed += failed; + if (failed || verbose) { + printf(" E8M0 scale sanity check: %s (%d/%d passed)\n", + RESULT_STR[failed], (int)(sizeof(test_amax)/sizeof(test_amax[0])) - bad, + (int)(sizeof(test_amax)/sizeof(test_amax[0]))); + } + } + + // SoA layout: verify offset macros produce correct byte positions + { + const struct { ggml_type type; int qs_per_block; } soa_types[] = { + { GGML_TYPE_MXFP4_E2M1, MXFP4_SOA_QS_PER_BLOCK }, + { GGML_TYPE_MXFP8_E4M3, MXFP8_SOA_QS_PER_BLOCK }, + { GGML_TYPE_MXFP6_E2M3, MXFP6_SOA_QS_PER_BLOCK }, + }; + + for (const auto & st : soa_types) { + for (int nblocks : { 1, 4, 8, 32 }) { + size_t expected_e8m0_off = (size_t)nblocks * st.qs_per_block; + size_t actual_e8m0_off = MXFP_SOA_E8M0_OFFSET(nblocks, st.qs_per_block); + size_t total = actual_e8m0_off + nblocks; // e8m0 region = 1 byte per block + size_t row_size = ggml_row_size(st.type, nblocks * 32); + + bool offset_ok = (actual_e8m0_off == expected_e8m0_off); + bool size_ok = (total == row_size); + + if (!offset_ok || !size_ok) { + failed = true; + num_failed++; + if (verbose) { + printf(" %s SoA layout nblocks=%d: e8m0_off=%zu (expected %zu), total=%zu (row_size=%zu)\n", + ggml_type_name(st.type), nblocks, actual_e8m0_off, expected_e8m0_off, total, row_size); + } + } + } + } + if (verbose) { + printf(" SoA layout offset check: %s\n", RESULT_STR[0]); // only prints failures above + } + } + + // block size consistency + { + failed = !(QK_MXFP4 == 32 && QK_MXFP8 == 32 && QK_MXFP6 == 32); + num_failed += failed; + if (failed || verbose) { + printf(" MXFP block size == 32: %s (QK4=%d, QK8=%d, QK6=%d)\n", + RESULT_STR[failed], QK_MXFP4, QK_MXFP8, QK_MXFP6); + } + } + + // EMAX_OFFSET produces valid E8M0 for each format's max finite value + { + struct emax_check { + const char * name; + int emax_offset; + float max_finite; // from LUT / converter + }; + + const emax_check emax_checks[] = { + { "fp4_e2m1", MXFP4_E2M1_EMAX_OFFSET, 6.0f }, + { "fp6_e2m3", MXFP6_E2M3_EMAX_OFFSET, 7.5f }, + { "fp6_e3m2", MXFP6_E3M2_EMAX_OFFSET, 28.0f }, + { "fp8_e4m3", MXFP8_E4M3_EMAX_OFFSET, 448.0f }, + { "fp8_e5m2", MXFP8_E5M2_EMAX_OFFSET, 57344.0f }, + }; + + int emax_bad = 0; + for (const auto & e : emax_checks) { + // When amax == max_finite, the base estimate must produce a valid E8M0 (1..254) + int e_base = ggml_mxfp_e8m0_base_estimate(e.max_finite, e.emax_offset); + if (e_base < 1 || e_base > 254) { + if (emax_bad == 0 || verbose) { + printf(" %s emax_offset=%d: max_finite=%.1f gives e_base=%d (out of range)\n", + e.name, e.emax_offset, e.max_finite, e_base); + } + emax_bad++; + } + } + failed = (emax_bad > 0); + num_failed += failed; + if (failed || verbose) { + printf(" EMAX_OFFSET vs format max: %s\n", RESULT_STR[failed]); + } + } + + // MXFP4 AoS vs SoA: two independent code paths, same result + { + const int nelems = 64; // 2 blocks + float input[64]; + for (int i = 0; i < 64; i++) { + input[i] = 0.5f + 2.0f * sinf(i * 0.7f + 0.3f); + } + + // Quantize and dequant via AoS (block_mxfp4 structs) + std::vector aos_q(nelems / QK_MXFP4); + std::vector aos_out(nelems); + quantize_row_mxfp4_ref(input, aos_q.data(), nelems); + dequantize_row_mxfp4(aos_q.data(), aos_out.data(), nelems); + + // Quantize and dequant via SoA + const size_t soa_buf_size = ggml_row_size(GGML_TYPE_MXFP4_E2M1, nelems); + std::vector soa_q(soa_buf_size); + std::vector soa_out(nelems); + quantize_row_mxfp4_soa(input, soa_q.data(), nelems); + dequantize_row_mxfp4_soa(soa_q.data(), soa_out.data(), nelems); + + // Compare: both paths should produce identical results + int mismatches = 0; + for (int i = 0; i < nelems; i++) { + uint32_t a, b; + memcpy(&a, &aos_out[i], 4); + memcpy(&b, &soa_out[i], 4); + if (a != b) { + if (mismatches == 0 || verbose) { + printf(" mxfp4 AoS/SoA mismatch at [%d]: AoS=%.8g, SoA=%.8g\n", + i, aos_out[i], soa_out[i]); + } + mismatches++; + } + } + failed = (mismatches > 0); + num_failed += failed; + if (failed || verbose) { + printf("mxfp4 AoS vs SoA cross-check: %s (%d/%d match)\n", + RESULT_STR[failed], nelems - mismatches, nelems); + } + } + + // Hadamard + quantize + dequant + Hadamard roundtrip (KV cache write/read path) + { + struct hadamard_pipeline_check { + const char * name; + ggml_type type; + float max_err; + }; + + const hadamard_pipeline_check pipeline_checks[] = { + { "mxfp4", GGML_TYPE_MXFP4_E2M1, MAX_MXFP_PIPELINE_ERROR_MXFP4 }, + { "mxfp8", GGML_TYPE_MXFP8_E4M3, MAX_MXFP_PIPELINE_ERROR_MXFP8 }, + { "mxfp6", GGML_TYPE_MXFP6_E2M3, MAX_MXFP_PIPELINE_ERROR_MXFP6 }, + }; + + for (const auto & p : pipeline_checks) { + const auto * cpu = ggml_get_type_traits_cpu(p.type); + + std::vector original(test_size); + std::vector rotated(test_size); + std::vector recovered(test_size); + generate_data(2.0, test_size, original.data()); + + // Write path: Hadamard each block, then quantize + memcpy(rotated.data(), original.data(), test_size * sizeof(float)); + for (size_t b = 0; b < test_size / 32; b++) { + ggml_hadamard_32_inplace(&rotated[b * 32]); + } + + const size_t buf_size = ggml_row_size(p.type, test_size); + std::vector qbuf(buf_size); + cpu->from_float_soa(rotated.data(), qbuf.data(), test_size); + + // Read path: dequant, then Hadamard each block (self-inverse) + cpu->to_float_soa(qbuf.data(), recovered.data(), test_size); + for (size_t b = 0; b < test_size / 32; b++) { + ggml_hadamard_32_inplace(&recovered[b * 32]); + } + + float err = mxfp_rmse(original.data(), recovered.data(), test_size); + failed = !(err < p.max_err); + num_failed += failed; + if (failed || verbose) { + printf("%5s Hadamard pipeline roundtrip: %s (err=%.6f, max=%.6f)\n", + p.name, RESULT_STR[failed], err, p.max_err); + } + } + } + + // Hadamard known output: H([1,0,...,0]) = [1/sqrt(32), ...] + { + float unit[32] = {}; + unit[0] = 1.0f; + ggml_hadamard_32_inplace(unit); + + const float expected = MXFP_HADAMARD_32_NORM; // 1/sqrt(32) + float max_err = 0.0f; + for (int i = 0; i < 32; i++) { + float err = fabsf(unit[i] - expected); + if (err > max_err) max_err = err; + } + failed = !(max_err < 1e-7f); + num_failed += failed; + if (failed || verbose) { + printf("hadamard unit vector: %s (max_err=%.2e, expected %.8f)\n", + RESULT_STR[failed], max_err, expected); + } + } + + // zero block produces E8M0=0 + { + float zeros[32] = {}; + const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP8_E4M3, 32); + std::vector buf(buf_size, 0xFF); // fill with 0xFF to detect non-writes + + quantize_row_mxfp8_soa(zeros, buf.data(), 32); + + // E8M0 scale is at offset MXFP8_SOA_QS_PER_BLOCK (32) for 1 block + uint8_t e8m0 = buf[MXFP8_SOA_QS_PER_BLOCK]; + failed = (e8m0 != 0); + num_failed += failed; + if (failed || verbose) { + printf(" zero block E8M0: %s (e8m0=%d, expected 0)\n", + RESULT_STR[failed], e8m0); + } + } + + // SoA format spec: quantize, manually walk raw bytes, compare against reference dequant + { + // 2 blocks, asymmetric data + const int nblocks = 2; + const int nelems = nblocks * 32; + float input[64]; + for (int i = 0; i < 64; i++) { + // Block 0: small values, Block 1: large values -different E8M0 scales + input[i] = (i < 32) ? 0.1f * sinf(i + 0.5f) : 3.0f * cosf(i + 0.5f); + } + + // MXFP4 + { + const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP4_E2M1, nelems); + std::vector buf(buf_size); + std::vector ref_out(nelems); + std::vector manual_out(nelems); + + quantize_row_mxfp4_soa(input, buf.data(), nelems); + dequantize_row_mxfp4_soa(buf.data(), ref_out.data(), nelems); + + // manual dequant from raw bytes + const uint8_t * qs = buf.data(); + const uint8_t * e8m0 = buf.data() + MXFP_SOA_E8M0_OFFSET(nblocks, MXFP4_SOA_QS_PER_BLOCK); + + for (int b = 0; b < nblocks; b++) { + const float d = ggml_mxfp_e8m0_to_fp32_half(e8m0[b]); + const uint8_t * block_qs = qs + MXFP_SOA_QS_OFFSET(b, MXFP4_SOA_QS_PER_BLOCK); + for (int j = 0; j < 16; j++) { + // low nibble = first half, high nibble = second half + int8_t v_lo = kvalues_mxfp4[block_qs[j] & 0x0F]; + int8_t v_hi = kvalues_mxfp4[block_qs[j] >> 4]; + manual_out[b*32 + j] = v_lo * d; + manual_out[b*32 + j + 16] = v_hi * d; + } + } + + int mismatches = 0; + for (int i = 0; i < nelems; i++) { + uint32_t a, b; + memcpy(&a, &ref_out[i], 4); + memcpy(&b, &manual_out[i], 4); + if (a != b) mismatches++; + } + failed = (mismatches > 0); + num_failed += failed; + if (failed || verbose) { + printf("mxfp4 SoA format spec: %s (%d/%d match)\n", + RESULT_STR[failed], nelems - mismatches, nelems); + } + } + + // MXFP8 + { + const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP8_E4M3, nelems); + std::vector buf(buf_size); + std::vector ref_out(nelems); + std::vector manual_out(nelems); + + quantize_row_mxfp8_soa(input, buf.data(), nelems); + dequantize_row_mxfp8_soa(buf.data(), ref_out.data(), nelems); + + const uint8_t * qs = buf.data(); + const uint8_t * e8m0 = buf.data() + MXFP_SOA_E8M0_OFFSET(nblocks, MXFP8_SOA_QS_PER_BLOCK); + + for (int b = 0; b < nblocks; b++) { + const float d = ggml_mxfp_e8m0_to_fp32(e8m0[b]); + const uint8_t * block_qs = qs + MXFP_SOA_QS_OFFSET(b, MXFP8_SOA_QS_PER_BLOCK); + for (int j = 0; j < 32; j++) { + // one byte per element + manual_out[b*32 + j] = fp8_e4m3_to_float(block_qs[j]) * d; + } + } + + int mismatches = 0; + for (int i = 0; i < nelems; i++) { + uint32_t a, b; + memcpy(&a, &ref_out[i], 4); + memcpy(&b, &manual_out[i], 4); + if (a != b) mismatches++; + } + failed = (mismatches > 0); + num_failed += failed; + if (failed || verbose) { + printf("mxfp8 SoA format spec: %s (%d/%d match)\n", + RESULT_STR[failed], nelems - mismatches, nelems); + } + } + + // MXFP6 + { + const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP6_E2M3, nelems); + std::vector buf(buf_size); + std::vector ref_out(nelems); + std::vector manual_out(nelems); + + quantize_row_mxfp6_soa(input, buf.data(), nelems); + dequantize_row_mxfp6_soa(buf.data(), ref_out.data(), nelems); + + const uint8_t * qs = buf.data(); + const uint8_t * e8m0 = buf.data() + MXFP_SOA_E8M0_OFFSET(nblocks, MXFP6_SOA_QS_PER_BLOCK); + + for (int b = 0; b < nblocks; b++) { + const float d = ggml_mxfp_e8m0_to_fp32(e8m0[b]); + const uint8_t * block_qs = qs + MXFP_SOA_QS_OFFSET(b, MXFP6_SOA_QS_PER_BLOCK); + for (int j = 0; j < 32; j += 4) { + // 4 elements packed into 3 bytes + uint8_t vals[4]; + unpack_fp6x4(&block_qs[j * 3 / 4], vals); + for (int k = 0; k < 4; k++) { + manual_out[b*32 + j + k] = fp6_e2m3_to_float(vals[k]) * d; + } + } + } + + int mismatches = 0; + for (int i = 0; i < nelems; i++) { + uint32_t a, b; + memcpy(&a, &ref_out[i], 4); + memcpy(&b, &manual_out[i], 4); + if (a != b) mismatches++; + } + failed = (mismatches > 0); + num_failed += failed; + if (failed || verbose) { + printf("mxfp6 SoA format spec: %s (%d/%d match)\n", + RESULT_STR[failed], nelems - mismatches, nelems); + } + } + } + if (num_failed || verbose) { printf("%d tests failed\n", num_failed); } From c919bc471bbddf0d5df1cec52a50339ce4f279c3 Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Sun, 22 Mar 2026 02:44:56 -0400 Subject: [PATCH 12/13] cleanup : remove unused untested code and improve consistency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * cleanup: consolidate MXFP type aliases, fix SoA linker bug on 5 platforms - Add GGML_TYPE_MXFP8 and GGML_TYPE_MXFP6 short aliases (matching existing GGML_TYPE_MXFP4 pattern) and use short names consistently throughout the codebase instead of mixing long/short forms. - Fix missing SoA dequant symbols (dequantize_row_mxfp{4,8,6}_soa_cpu) on loongarch, powerpc, riscv, s390, and wasm by adding proper aliases to each arch section in arch-fallback.h. Previously these were only defined under GGML_CPU_GENERIC, causing linker failures on those platforms when using MXFP flash attention. - Remove 10 files from the PR diff: - 5 arch stub files replaced by arch-fallback.h aliases - 5 rename-only files (sycl, opencl, repack, llama-quant) reverted since the GGML_TYPE_MXFP4 compat alias handles them * cleanup: DRY FP6 unpack, extract mxfp_kv_params + mxfp_dequant_head helper - FP6 unpack: x86 and ARM SIMD versions now call ggml_mxfp_unpack_fp6x4() from ggml-common.h instead of duplicating the scalar bit manipulation. - Extract mxfp_kv_params sub-struct from mxfp_fa_params: the 7 symmetric K/V fields (dequantize, multihead, soa_elems, qs_per_block, head_qs_bytes, head_e8m0_offset, blocks_per_head) are now in a reusable struct accessed as mxfp.k and mxfp.v. - Add mxfp_dequant_head() helper: replaces 4 instances of the multihead SoA extraction pattern (2x memcpy + dequant, with multihead/single-head branching) with a single function call. Future backends get the pattern for free. * cleanup: extract mxfp_kv_params_init to DRY the K/V init blocks The K and V initialization in mxfp_fa_params_init were structurally identical 10-line blocks differing only by tensor/dimension. Extract into mxfp_kv_params_init(type, D, nb2, ne2) so future MXFP formats get the multihead SoA addressing logic automatically. * cleanup: generic MSE round-trip, replace magic buffer sizes with constants - Remove mse_error_fp8_e4m3 and mse_error_fp6_e2m3: these were identical round-trip functions differing only by converter. mxfp_compute_e8m0_mse now uses to_elem/to_float directly when mse_error is NULL (FP8/FP6). MXFP4 keeps its custom decision-tree MSE. New formats get MSE for free by just setting to_elem/to_float in their traits. - Replace magic 1024/1088 buffer sizes in flash attention with named constants MXFP_FA_MAX_D and MXFP_FA_SOA_BUF. One place to change if max head dimension grows. * cleanup: remove dead AoS vec_dot for MXFP8/MXFP6, unify SoA impls MXFP8 and MXFP6 are KV-cache-only types that use SoA layout for flash attention. The AoS vec_dot functions (scalar generic, AVX2, NEON) were dead code — no matmul path uses them. Removed: - ggml_vec_dot_mxfp{8,6}_q8_0 from scalar, x86, ARM, quants.h - ggml_vec_dot_mxfp_q8_0_impl shared helper - arch-fallback.h aliases for vec_dot mxfp8/mxfp6 (12 lines) - vec_dot/vec_dot_type registration in ggml-cpu.c Also unified SoA quantize/dequant: the separate mxfp8_soa_impl and mxfp6_soa_impl functions (4 functions, ~80 lines) are replaced by two generic functions (quantize_row_mxfp_soa_impl, dequantize_row_mxfp_soa_impl) that use traits->bits_per_elem and traits->qs_per_block to handle both byte-aligned (FP8) and 6-bit packed (FP6) formats. New MXFP formats get SoA for free by setting these trait fields. * cleanup: remove all AoS MXFP8/MXFP6 quantize/dequant — SoA only MXFP8 and MXFP6 are KV-cache-only types. All quantization and dequantization goes through the SoA (Struct-of-Arrays) path for flash attention. The AoS (block_mxfp8/block_mxfp6 struct) implementations were dead code that should never have been added. Removed: - quantize_row_mxfp{8,6}_impl, dequantize_row_mxfp{8,6}_impl - quantize_row_mxfp{8,6}_ref, dequantize_row_mxfp{8,6} - quantize_mxfp{8,6} (ggml_quantize_chunk wrappers) - All declarations from ggml-quants.h and quants.h - to_float/from_float_ref registrations from ggml.c type traits - from_float registration from ggml-cpu.c CPU traits Block struct definitions (block_mxfp8, block_mxfp6) are retained for sizeof() in type traits and validate_row_data. * cleanup: fail fast in ggml_quantize_chunk for KV-cache-only types Add explicit GGML_ABORT for MXFP8/MXFP6 in ggml_quantize_chunk — these are KV-cache-only types that use SoA layout via from_float_soa. Attempting AoS quantization through this entry point is a bug. --- common/arg.cpp | 12 +- ggml/include/ggml.h | 2 + ggml/src/ggml-cpu/arch-fallback.h | 27 ++- ggml/src/ggml-cpu/arch/arm/quants.c | 116 +--------- ggml/src/ggml-cpu/arch/loongarch/quants.c | 11 - ggml/src/ggml-cpu/arch/powerpc/quants.c | 7 - ggml/src/ggml-cpu/arch/riscv/quants.c | 8 - ggml/src/ggml-cpu/arch/s390/quants.c | 7 - ggml/src/ggml-cpu/arch/wasm/quants.c | 11 - ggml/src/ggml-cpu/arch/x86/quants.c | 116 +--------- ggml/src/ggml-cpu/ggml-cpu.c | 12 +- ggml/src/ggml-cpu/ops.cpp | 244 +++++++++------------ ggml/src/ggml-cpu/quants.c | 53 ----- ggml/src/ggml-cpu/quants.h | 9 - ggml/src/ggml-cpu/repack.cpp | 6 +- ggml/src/ggml-metal/ggml-metal-device.m | 2 +- ggml/src/ggml-opencl/ggml-opencl.cpp | 16 +- ggml/src/ggml-quants.c | 250 ++++++---------------- ggml/src/ggml-quants.h | 9 - ggml/src/ggml-sycl/convert.cpp | 4 +- ggml/src/ggml-sycl/mmvq.cpp | 2 +- ggml/src/ggml.c | 36 ++-- src/llama-quant.cpp | 4 +- tests/test-backend-ops.cpp | 30 +-- tests/test-quantize-fns.cpp | 46 ++-- tools/llama-bench/llama-bench.cpp | 6 +- 26 files changed, 277 insertions(+), 769 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index ff12646a70..26c1904a2a 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -398,20 +398,20 @@ const std::vector kv_cache_types = { GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, - GGML_TYPE_MXFP4_E2M1, - GGML_TYPE_MXFP8_E4M3, - GGML_TYPE_MXFP6_E2M3, + GGML_TYPE_MXFP4, + GGML_TYPE_MXFP8, + GGML_TYPE_MXFP6, }; static ggml_type kv_cache_type_from_str(const std::string & s) { if (s == "mxfp4") { - return GGML_TYPE_MXFP4_E2M1; + return GGML_TYPE_MXFP4; } if (s == "mxfp6") { - return GGML_TYPE_MXFP6_E2M3; + return GGML_TYPE_MXFP6; } if (s == "mxfp8") { - return GGML_TYPE_MXFP8_E4M3; + return GGML_TYPE_MXFP8; } for (const auto & type : kv_cache_types) { if (ggml_type_name(type) == s) { diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 1f66550459..c068113932 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -430,7 +430,9 @@ extern "C" { GGML_TYPE_MXFP4 = GGML_TYPE_MXFP4_E2M1, // compat alias GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale) GGML_TYPE_MXFP8_E4M3 = 41, // MX FP8 E4M3 + GGML_TYPE_MXFP8 = GGML_TYPE_MXFP8_E4M3, // compat alias GGML_TYPE_MXFP6_E2M3 = 42, // MX FP6 E2M3 + GGML_TYPE_MXFP6 = GGML_TYPE_MXFP6_E2M3, // compat alias GGML_TYPE_COUNT = 43, }; diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index ddeee1fa7e..eac031e68e 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -15,8 +15,9 @@ #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 -#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0 -#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0 +#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu +#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu +#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -113,6 +114,9 @@ // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu +#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu +#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K @@ -161,6 +165,9 @@ #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu +#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu +#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 @@ -201,6 +208,9 @@ #elif defined(__riscv) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu +#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu +#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu // repack.cpp #define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 @@ -241,6 +251,9 @@ // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu +#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu +#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -291,6 +304,9 @@ #elif defined(__wasm__) // quants.c #define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1 +#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu +#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu +#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K @@ -342,10 +358,3 @@ #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #endif - -// MXFP dequantize fallbacks (same GGML_CPU_GENERIC guard as above) -#if defined(GGML_CPU_GENERIC) -#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu -#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu -#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu -#endif diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index 53507b97ce..9a4ac95aab 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -4191,12 +4191,8 @@ static inline float32x4_t mxfp6_dequant_neon( // Unpack 4 tightly-packed 6-bit values from 3 bytes, widen to uint32x4_t. static inline uint32x4_t unpack_fp6x4_neon(const uint8_t * p) { - const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); uint8_t u[4]; - u[0] = (pk >> 0) & 0x3F; - u[1] = (pk >> 6) & 0x3F; - u[2] = (pk >> 12) & 0x3F; - u[3] = (pk >> 18) & 0x3F; + ggml_mxfp_unpack_fp6x4(p, u); const uint8x8_t raw8 = vcreate_u8( (uint64_t)u[0] | ((uint64_t)u[1] << 8) | ((uint64_t)u[2] << 16) | ((uint64_t)u[3] << 24)); @@ -4221,96 +4217,6 @@ static inline void widen_s8x8_to_f32x4x2(const int8_t * src, *hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16))); } -// MXFP FP8/FP6 vec_dot - -static void ggml_vec_dot_mxfp8_q8_0_neon( - int n, float * GGML_RESTRICT s, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - const mxfp_neon_traits_t * t) { - assert(n % QK_MXFP8 == 0); - const int nb = n / QK_MXFP8; - const block_mxfp8 * GGML_RESTRICT x = vx; - const block_q8_0 * GGML_RESTRICT y = vy; - - const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask); - const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask); - const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off); - const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale); - const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift); - const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift); - - float32x4_t acc0 = vdupq_n_f32(0.0f); - float32x4_t acc1 = vdupq_n_f32(0.0f); - - for (int ib = 0; ib < nb; ++ib) { - const float32x4_t v_scale = vdupq_n_f32( - GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d)); - - for (int j = 0; j < 32; j += 8) { - uint32x4_t v_lo, v_hi; - widen_u8x8_to_u32x4x2(x[ib].qs + j, &v_lo, &v_hi); - - float32x4_t qf_lo, qf_hi; - widen_s8x8_to_f32x4x2(y[ib].qs + j, &qf_lo, &qf_hi); - - const float32x4_t val_lo = mxfp8_dequant_neon(v_lo, - v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); - const float32x4_t val_hi = mxfp8_dequant_neon(v_hi, - v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); - - acc0 = vfmaq_f32(acc0, vmulq_f32(val_lo, v_scale), qf_lo); - acc1 = vfmaq_f32(acc1, vmulq_f32(val_hi, v_scale), qf_hi); - } - } - - *s = vaddvq_f32(vaddq_f32(acc0, acc1)); -} - -static void ggml_vec_dot_mxfp6_q8_0_neon( - int n, float * GGML_RESTRICT s, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - const mxfp_neon_traits_t * t) { - assert(n % QK_MXFP6 == 0); - const int nb = n / QK_MXFP6; - const block_q8_0 * GGML_RESTRICT y = vy; - - const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask); - const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask); - const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off); - const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale); - const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift); - const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift); - - float32x4_t acc0 = vdupq_n_f32(0.0f); - float32x4_t acc1 = vdupq_n_f32(0.0f); - - for (int ib = 0; ib < nb; ++ib) { - const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib; - const float32x4_t v_scale = vdupq_n_f32( - GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d)); - - for (int j = 0; j < 32; j += 8) { - const uint32x4_t v_lo = unpack_fp6x4_neon(xb->qs + (j * 3 / 4)); - const uint32x4_t v_hi = unpack_fp6x4_neon(xb->qs + ((j + 4) * 3 / 4)); - - float32x4_t qf_lo, qf_hi; - widen_s8x8_to_f32x4x2(y[ib].qs + j, &qf_lo, &qf_hi); - - const float32x4_t val_lo = mxfp6_dequant_neon(v_lo, - v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); - const float32x4_t val_hi = mxfp6_dequant_neon(v_hi, - v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); - - acc0 = vfmaq_f32(acc0, vmulq_f32(val_lo, v_scale), qf_lo); - acc1 = vfmaq_f32(acc1, vmulq_f32(val_hi, v_scale), qf_hi); - } - } - - *s = vaddvq_f32(vaddq_f32(acc0, acc1)); -} - // MXFP SoA dequant (flash attention) static void dequantize_row_mxfp8_soa_neon( @@ -4424,26 +4330,6 @@ static void dequantize_row_mxfp4_soa_neon( // Public dispatch functions -void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); -#if defined(__ARM_NEON) - ggml_vec_dot_mxfp8_q8_0_neon(n, s, vx, vy, &MXFP_TRAITS_E4M3); -#else - ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif -} - -void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); -#if defined(__ARM_NEON) - ggml_vec_dot_mxfp6_q8_0_neon(n, s, vx, vy, &MXFP_TRAITS_E2M3); -#else - ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif -} - void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { #if defined(__ARM_NEON) dequantize_row_mxfp4_soa_neon(x, y, k); diff --git a/ggml/src/ggml-cpu/arch/loongarch/quants.c b/ggml/src/ggml-cpu/arch/loongarch/quants.c index fa05e49c5d..f531e916b9 100644 --- a/ggml/src/ggml-cpu/arch/loongarch/quants.c +++ b/ggml/src/ggml-cpu/arch/loongarch/quants.c @@ -2157,14 +2157,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} - -void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} - -void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} diff --git a/ggml/src/ggml-cpu/arch/powerpc/quants.c b/ggml/src/ggml-cpu/arch/powerpc/quants.c index efb669da09..d3dfd049ea 100644 --- a/ggml/src/ggml-cpu/arch/powerpc/quants.c +++ b/ggml/src/ggml-cpu/arch/powerpc/quants.c @@ -2303,10 +2303,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} - -void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index 10e0ff04d8..d7e9ba4634 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -3621,11 +3621,3 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - -void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} - -void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} diff --git a/ggml/src/ggml-cpu/arch/s390/quants.c b/ggml/src/ggml-cpu/arch/s390/quants.c index e696fd4570..34184ed851 100644 --- a/ggml/src/ggml-cpu/arch/s390/quants.c +++ b/ggml/src/ggml-cpu/arch/s390/quants.c @@ -1464,10 +1464,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} - -void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} diff --git a/ggml/src/ggml-cpu/arch/wasm/quants.c b/ggml/src/ggml-cpu/arch/wasm/quants.c index a3ae8e8885..74a359e6d1 100644 --- a/ggml/src/ggml-cpu/arch/wasm/quants.c +++ b/ggml/src/ggml-cpu/arch/wasm/quants.c @@ -1219,14 +1219,3 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } -void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} - -void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} - -void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index 4b8f3386fa..775a7c742f 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -3850,106 +3850,14 @@ static inline __m256 mxfp_dequant_avx2( return _mm256_blendv_ps(normal, sub_val, is_sub); } -// Unpack 4 tightly-packed 6-bit values from 3 bytes into separate bytes. -static inline void unpack_fp6x4_avx2(const uint8_t * p, uint8_t out[4]) { - const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); - out[0] = (pk >> 0) & 0x3F; - out[1] = (pk >> 6) & 0x3F; - out[2] = (pk >> 12) & 0x3F; - out[3] = (pk >> 18) & 0x3F; -} - // Unpack 8 FP6 values (two groups of 4) from packed qs data at offset j. static inline __m256i unpack_fp6x8_avx2(const uint8_t * qs, int j) { uint8_t unpacked[8]; - unpack_fp6x4_avx2(qs + (j * 3 / 4), unpacked); - unpack_fp6x4_avx2(qs + ((j + 4) * 3 / 4), unpacked + 4); + ggml_mxfp_unpack_fp6x4(qs + (j * 3 / 4), unpacked); + ggml_mxfp_unpack_fp6x4(qs + ((j + 4) * 3 / 4), unpacked + 4); return _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)unpacked)); } -// MXFP FP8/FP6 vec_dot - -// FP8 x Q8_0 dot product (E4M3/E5M2). -static void ggml_vec_dot_mxfp8_q8_0_avx2( - int n, float * GGML_RESTRICT s, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - const mxfp_avx2_traits_t * t) { - assert(n % QK_MXFP8 == 0); - const int nb = n / QK_MXFP8; - const block_mxfp8 * GGML_RESTRICT x = vx; - const block_q8_0 * GGML_RESTRICT y = vy; - - const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask); - const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask); - const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off); - const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale); - const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask); - const __m256i v_zero = _mm256_setzero_si256(); - - __m256 acc = _mm256_setzero_ps(); - - for (int ib = 0; ib < nb; ++ib) { - const __m256 v_scale = _mm256_set1_ps( - GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d)); - - for (int j = 0; j < 32; j += 8) { - const __m256i v_raw = _mm256_cvtepu8_epi32( - _mm_loadl_epi64((const __m128i *)(x[ib].qs + j))); - const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32( - _mm_loadl_epi64((const __m128i *)(y[ib].qs + j)))); - - const __m256 val = mxfp_dequant_avx2(v_raw, - v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, - v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift); - - acc = _mm256_fmadd_ps(_mm256_mul_ps(val, v_scale), qf, acc); - } - } - - *s = hsum_float_8(acc); -} - -// FP6 x Q8_0 dot product (E2M3/E3M2). -static void ggml_vec_dot_mxfp6_q8_0_avx2( - int n, float * GGML_RESTRICT s, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - const mxfp_avx2_traits_t * t) { - assert(n % QK_MXFP6 == 0); - const int nb = n / QK_MXFP6; - const block_q8_0 * GGML_RESTRICT y = vy; - - const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask); - const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask); - const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off); - const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale); - const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask); - const __m256i v_zero = _mm256_setzero_si256(); - - __m256 acc = _mm256_setzero_ps(); - - for (int ib = 0; ib < nb; ++ib) { - const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib; - const __m256 v_scale = _mm256_set1_ps( - GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d)); - - for (int j = 0; j < 32; j += 8) { - const __m256i v_raw = unpack_fp6x8_avx2(xb->qs, j); - const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32( - _mm_loadl_epi64((const __m128i *)(y[ib].qs + j)))); - - const __m256 val = mxfp_dequant_avx2(v_raw, - v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, - v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift); - - acc = _mm256_fmadd_ps(_mm256_mul_ps(val, v_scale), qf, acc); - } - } - - *s = hsum_float_8(acc); -} - // MXFP SoA dequant (flash attention) static void dequantize_row_mxfp8_soa_avx2( @@ -4052,26 +3960,6 @@ static void dequantize_row_mxfp4_soa_avx2( // Public dispatch functions -void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); -#if defined(__AVX2__) - ggml_vec_dot_mxfp8_q8_0_avx2(n, s, vx, vy, &MXFP_TRAITS_E4M3); -#else - ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif -} - -void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); -#if defined(__AVX2__) - ggml_vec_dot_mxfp6_q8_0_avx2(n, s, vx, vy, &MXFP_TRAITS_E2M3); -#else - ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif -} - void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { #if defined(__AVX2__) dequantize_row_mxfp4_soa_avx2(x, y, k); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index b84b6e0031..782a54392f 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -265,7 +265,7 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_1, .nrows = 1, }, - [GGML_TYPE_MXFP4_E2M1] = { + [GGML_TYPE_MXFP4] = { .from_float = quantize_row_mxfp4, .from_float_soa = quantize_row_mxfp4_soa, .to_float_soa = dequantize_row_mxfp4_soa_cpu, @@ -279,20 +279,14 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, }, - [GGML_TYPE_MXFP8_E4M3] = { - .from_float = quantize_row_mxfp8, + [GGML_TYPE_MXFP8] = { .from_float_soa = quantize_row_mxfp8_soa, .to_float_soa = dequantize_row_mxfp8_soa_cpu, - .vec_dot = ggml_vec_dot_mxfp8_q8_0, - .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, }, - [GGML_TYPE_MXFP6_E2M3] = { - .from_float = quantize_row_mxfp6, + [GGML_TYPE_MXFP6] = { .from_float_soa = quantize_row_mxfp6_soa, .to_float_soa = dequantize_row_mxfp6_soa_cpu, - .vec_dot = ggml_vec_dot_mxfp6_q8_0, - .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, }, [GGML_TYPE_Q2_K] = { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 9291af62dc..15424d40c4 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -672,10 +672,10 @@ void ggml_compute_forward_add( case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: - case GGML_TYPE_MXFP4_E2M1: + case GGML_TYPE_MXFP4: case GGML_TYPE_NVFP4: - case GGML_TYPE_MXFP8_E4M3: - case GGML_TYPE_MXFP6_E2M3: + case GGML_TYPE_MXFP8: + case GGML_TYPE_MXFP6: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -1124,10 +1124,10 @@ void ggml_compute_forward_add1( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: - case GGML_TYPE_MXFP4_E2M1: + case GGML_TYPE_MXFP4: case GGML_TYPE_NVFP4: - case GGML_TYPE_MXFP8_E4M3: - case GGML_TYPE_MXFP6_E2M3: + case GGML_TYPE_MXFP8: + case GGML_TYPE_MXFP6: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -1255,10 +1255,10 @@ void ggml_compute_forward_acc( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: - case GGML_TYPE_MXFP4_E2M1: + case GGML_TYPE_MXFP4: case GGML_TYPE_NVFP4: - case GGML_TYPE_MXFP8_E4M3: - case GGML_TYPE_MXFP6_E2M3: + case GGML_TYPE_MXFP8: + case GGML_TYPE_MXFP6: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4345,10 +4345,10 @@ void ggml_compute_forward_out_prod( case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: - case GGML_TYPE_MXFP4_E2M1: + case GGML_TYPE_MXFP4: case GGML_TYPE_NVFP4: - case GGML_TYPE_MXFP8_E4M3: - case GGML_TYPE_MXFP6_E2M3: + case GGML_TYPE_MXFP8: + case GGML_TYPE_MXFP6: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4623,10 +4623,10 @@ void ggml_compute_forward_set( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: - case GGML_TYPE_MXFP4_E2M1: + case GGML_TYPE_MXFP4: case GGML_TYPE_NVFP4: - case GGML_TYPE_MXFP8_E4M3: - case GGML_TYPE_MXFP6_E2M3: + case GGML_TYPE_MXFP8: + case GGML_TYPE_MXFP6: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4848,10 +4848,10 @@ void ggml_compute_forward_get_rows( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: - case GGML_TYPE_MXFP4_E2M1: + case GGML_TYPE_MXFP4: case GGML_TYPE_NVFP4: - case GGML_TYPE_MXFP8_E4M3: - case GGML_TYPE_MXFP6_E2M3: + case GGML_TYPE_MXFP8: + case GGML_TYPE_MXFP6: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -5686,10 +5686,10 @@ void ggml_compute_forward_clamp( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: - case GGML_TYPE_MXFP4_E2M1: + case GGML_TYPE_MXFP4: case GGML_TYPE_NVFP4: - case GGML_TYPE_MXFP8_E4M3: - case GGML_TYPE_MXFP6_E2M3: + case GGML_TYPE_MXFP8: + case GGML_TYPE_MXFP6: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -8255,31 +8255,65 @@ void ggml_compute_forward_top_k( } } +// Max head dimension for stack-allocated MXFP buffers. +static constexpr int64_t MXFP_FA_MAX_D = 1024; +// SoA buffer size for MXFP_FA_MAX_D with MXFP8 (worst case: 1024 + 32 e8m0 = 1056, rounded up). +static constexpr int MXFP_FA_SOA_BUF = 1088; + // SoA function pointer types for MXFP flash attention paths. typedef void (*mxfp_soa_quantize_fn)(const float *, void *, int64_t); typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t); +// Per-KV-type MXFP parameters (shared between K and V). +struct mxfp_kv_params { + mxfp_soa_dequantize_fn dequantize; + bool multihead; + int64_t soa_elems; + int qs_per_block; + int head_qs_bytes; + int64_t head_e8m0_offset; + int blocks_per_head; +}; + // MXFP dispatch parameters for flash attention. struct mxfp_fa_params { - mxfp_soa_quantize_fn q_quantize; - mxfp_soa_dequantize_fn k_dequantize; - mxfp_soa_dequantize_fn v_dequantize; - bool k_multihead; - bool v_multihead; - int64_t k_soa_elems; - int64_t v_soa_elems; - bool apply_hadamard; - // Per-head SoA addressing (avoids dequanting all heads in multihead mode). - int k_qs_per_block; - int v_qs_per_block; - int k_head_qs_bytes; - int v_head_qs_bytes; - int64_t k_head_e8m0_offset; - int64_t v_head_e8m0_offset; - int k_blocks_per_head; - int v_blocks_per_head; + mxfp_soa_quantize_fn q_quantize; + mxfp_kv_params k; + mxfp_kv_params v; + bool apply_hadamard; }; +// Extract one head's SoA data from a multihead row and dequantize. +static inline void mxfp_dequant_head( + const mxfp_kv_params & kv, const char * row, int head_idx, + char * soa_buf, float * out, int64_t D) { + if (kv.multihead) { + const int qs_off = head_idx * kv.head_qs_bytes; + const int e8m0_off = (int)kv.head_e8m0_offset + head_idx * kv.blocks_per_head; + memcpy(soa_buf, row + qs_off, kv.head_qs_bytes); + memcpy(soa_buf + kv.head_qs_bytes, row + e8m0_off, kv.blocks_per_head); + kv.dequantize(soa_buf, out, D); + } else { + kv.dequantize(row, out, D); + } +} + +// Initialize per-KV-type params from tensor metadata. +// Multihead detection: nb2 == row_size(D) means heads are contiguous within +// one KV-position stride, so SoA spans all heads. Otherwise SoA is per-head. +static mxfp_kv_params mxfp_kv_params_init(ggml_type type, int64_t D, size_t nb2, int64_t ne2) { + mxfp_kv_params kv = {}; + kv.dequantize = ggml_get_type_traits_cpu(type)->to_float_soa; + kv.multihead = (nb2 == (size_t)ggml_row_size(type, D)); + kv.soa_elems = kv.multihead ? ne2 * D : D; + kv.qs_per_block = ggml_mxfp_qs_per_block(type); + kv.blocks_per_head = (int)(D / 32); + kv.head_qs_bytes = kv.blocks_per_head * kv.qs_per_block; + const int64_t total_blocks = kv.multihead ? ne2 * kv.blocks_per_head : kv.blocks_per_head; + kv.head_e8m0_offset = total_blocks * kv.qs_per_block; + return kv; +} + static mxfp_fa_params mxfp_fa_params_init( const ggml_tensor * k, const ggml_tensor * v, int64_t DK, int64_t DV, @@ -8291,44 +8325,17 @@ static mxfp_fa_params mxfp_fa_params_init( const bool is_mxfp_v = ggml_is_type_mxfp(v->type); if (is_mxfp_k) { - const struct ggml_type_traits_cpu * k_traits = ggml_get_type_traits_cpu(k->type); - p.q_quantize = k_traits->from_float_soa; - p.k_dequantize = k_traits->to_float_soa; + p.q_quantize = ggml_get_type_traits_cpu(k->type)->from_float_soa; + p.k = mxfp_kv_params_init(k->type, DK, nbk2, nek2); } - if (is_mxfp_v) { - p.v_dequantize = ggml_get_type_traits_cpu(v->type)->to_float_soa; + p.v = mxfp_kv_params_init(v->type, DV, nbv2, nev2); } // Hadamard rotation must match K rotation. - // Skipped for: MLA (DK != DV, V is a view of K). + // Skipped for MLA (DK != DV, V is a view of K). p.apply_hadamard = is_mxfp_k && (DK == DV) && ggml_mxfp_use_hadamard(k->type); - // SoA layout detection: in the real KV cache, heads are contiguous within - // one KV-position stride (nb[2] == row_size(DK)), so SoA spans all heads. - // In test tensors, heads may be at distant offsets (nb[2] >> row_size(DK)), - // so SoA is per-head. Detect which case and set dequant parameters accordingly. - p.k_multihead = is_mxfp_k && (nbk2 == (size_t)ggml_row_size(k->type, DK)); - p.k_soa_elems = is_mxfp_k ? (p.k_multihead ? nek2 * DK : DK) : 0; - p.v_multihead = is_mxfp_v && (nbv2 == (size_t)ggml_row_size(v->type, DV)); - p.v_soa_elems = is_mxfp_v ? (p.v_multihead ? nev2 * DV : DV) : 0; - - if (is_mxfp_k) { - p.k_qs_per_block = ggml_mxfp_qs_per_block(k->type); - p.k_blocks_per_head = (int)(DK / 32); - p.k_head_qs_bytes = p.k_blocks_per_head * p.k_qs_per_block; - const int64_t k_total_blocks = p.k_multihead ? nek2 * p.k_blocks_per_head : p.k_blocks_per_head; - p.k_head_e8m0_offset = k_total_blocks * p.k_qs_per_block; - } - - if (is_mxfp_v) { - p.v_qs_per_block = ggml_mxfp_qs_per_block(v->type); - p.v_blocks_per_head = (int)(DV / 32); - p.v_head_qs_bytes = p.v_blocks_per_head * p.v_qs_per_block; - const int64_t v_total_blocks = p.v_multihead ? nev2 * p.v_blocks_per_head : p.v_blocks_per_head; - p.v_head_e8m0_offset = v_total_blocks * p.v_qs_per_block; - } - return p; } @@ -8430,14 +8437,14 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( int ith = params->ith; - if (is_mxfp_k) { GGML_ASSERT(DK <= 1024); } - if (is_mxfp_v) { GGML_ASSERT(DV <= 1024); } + if (is_mxfp_k) { GGML_ASSERT(DK <= MXFP_FA_MAX_D); } + if (is_mxfp_v) { GGML_ASSERT(DV <= MXFP_FA_MAX_D); } - float k_dequant_buf[1024]; - float v_dequant_buf[1024]; + float k_dequant_buf[MXFP_FA_MAX_D]; + float v_dequant_buf[MXFP_FA_MAX_D]; - char k_head_soa[1088]; // max: DK=1024 MXFP8 -> 1056 bytes, rounded up - char v_head_soa[1088]; + char k_head_soa[MXFP_FA_SOA_BUF]; // max: DK=1024 MXFP8 -> 1056 bytes, rounded up + char v_head_soa[MXFP_FA_SOA_BUF]; float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); float * V32 = (VKQ32 + 1*DV); @@ -8479,31 +8486,25 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( const char * k_base = (const char *) k->data + k_base_offset; const char * v_base = (const char *) v->data + v_base_offset; - // Per-head SoA byte offsets - const int k_head_qs_start = mxfp.k_multihead ? ik2 * mxfp.k_head_qs_bytes : 0; - const int k_head_e8m0_start = mxfp.k_multihead ? (int)mxfp.k_head_e8m0_offset + ik2 * mxfp.k_blocks_per_head : 0; - const int v_head_qs_start = mxfp.v_multihead ? iv2 * mxfp.v_head_qs_bytes : 0; - const int v_head_e8m0_start = mxfp.v_multihead ? (int)mxfp.v_head_e8m0_offset + iv2 * mxfp.v_blocks_per_head : 0; - - const char * k_row_base = mxfp.k_multihead ? ((const char *) k->data + ik3*nbk3) : nullptr; - const char * v_row_base = mxfp.v_multihead ? ((const char *) v->data + iv3*nbv3) : nullptr; + const char * k_row_base = mxfp.k.multihead ? ((const char *) k->data + ik3*nbk3) : nullptr; + const char * v_row_base = mxfp.v.multihead ? ((const char *) v->data + iv3*nbv3) : nullptr; const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); - float Q_f32[1024]; + float Q_f32[MXFP_FA_MAX_D]; if (is_mxfp_k) { // Q preprocessing: Hadamard + SoA round-trip captures same quantization loss as K. if (mxfp.apply_hadamard) { - float q_tmp[1024]; + float q_tmp[MXFP_FA_MAX_D]; memcpy(q_tmp, pq, DK * sizeof(float)); ggml_apply_hadamard_blocks(q_tmp, DK); mxfp.q_quantize(q_tmp, Q_q, DK); } else { mxfp.q_quantize(pq, Q_q, DK); } - mxfp.k_dequantize(Q_q, Q_f32, DK); + mxfp.k.dequantize(Q_q, Q_f32, DK); } else { if (mxfp.apply_hadamard) { - float q_tmp[1024]; + float q_tmp[MXFP_FA_MAX_D]; memcpy(q_tmp, pq, DK * sizeof(float)); ggml_apply_hadamard_blocks(q_tmp, DK); q_to_vec_dot(q_tmp, Q_q, DK); @@ -8525,15 +8526,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( float s; // KQ value if (is_mxfp_k) { - if (mxfp.k_multihead) { - // Extract this head's SoA blocks - const char * row = k_row_base + ic*nbk1; - memcpy(k_head_soa, row + k_head_qs_start, mxfp.k_head_qs_bytes); - memcpy(k_head_soa + mxfp.k_head_qs_bytes, row + k_head_e8m0_start, mxfp.k_blocks_per_head); - mxfp.k_dequantize(k_head_soa, k_dequant_buf, DK); - } else { - mxfp.k_dequantize(k_base + ic*nbk1, k_dequant_buf, DK); - } + const char * k_row = mxfp.k.multihead ? k_row_base + ic*nbk1 : k_base + ic*nbk1; + mxfp_dequant_head(mxfp.k, k_row, ik2, k_head_soa, k_dequant_buf, DK); ggml_vec_dot_f32(DK, &s, 0, k_dequant_buf, 0, Q_f32, 0, 1); } else { kq_vec_dot(DK, &s, 0, k_base + ic*nbk1, 0, Q_q, 0, 1); @@ -8577,15 +8571,9 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( } // V += v*expf(s - M) - if (mxfp.v_dequantize) { - if (mxfp.v_multihead) { - const char * row = v_row_base + ic*nbv1; - memcpy(v_head_soa, row + v_head_qs_start, mxfp.v_head_qs_bytes); - memcpy(v_head_soa + mxfp.v_head_qs_bytes, row + v_head_e8m0_start, mxfp.v_blocks_per_head); - mxfp.v_dequantize(v_head_soa, v_dequant_buf, DV); - } else { - mxfp.v_dequantize(v_base + ic*nbv1, v_dequant_buf, DV); - } + if (mxfp.v.dequantize) { + const char * v_row = mxfp.v.multihead ? v_row_base + ic*nbv1 : v_base + ic*nbv1; + mxfp_dequant_head(mxfp.v, v_row, iv2, v_head_soa, v_dequant_buf, DV); ggml_vec_mad_f32(DV, VKQ32, v_dequant_buf, vs); } else if (v_to_float) { v_to_float(v_base + ic*nbv1, V32, DV); @@ -8731,14 +8719,14 @@ static void ggml_compute_forward_flash_attn_ext_tiled( static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q; static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV; - if (is_mxfp_k) { GGML_ASSERT(DK <= 1024); } - if (is_mxfp_v) { GGML_ASSERT(DV <= 1024); } + if (is_mxfp_k) { GGML_ASSERT(DK <= MXFP_FA_MAX_D); } + if (is_mxfp_v) { GGML_ASSERT(DV <= MXFP_FA_MAX_D); } - float k_dequant_buf[1024]; - float v_dequant_buf[1024]; + float k_dequant_buf[MXFP_FA_MAX_D]; + float v_dequant_buf[MXFP_FA_MAX_D]; - char k_head_soa[1088]; - char v_head_soa[1088]; + char k_head_soa[MXFP_FA_SOA_BUF]; + char v_head_soa[MXFP_FA_SOA_BUF]; int ir = ir0; while (ir < ir1) { @@ -8802,9 +8790,9 @@ static void ggml_compute_forward_flash_attn_ext_tiled( if (mxfp.apply_hadamard) { ggml_apply_hadamard_blocks(Q_f32 + tq * DK, DK); } - uint8_t q_mxfp_buf[1088]; // max: DK=1024 MXFP8 -> 1056 bytes + uint8_t q_mxfp_buf[MXFP_FA_SOA_BUF]; mxfp.q_quantize(Q_f32 + tq * DK, q_mxfp_buf, DK); - mxfp.k_dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK); + mxfp.k.dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK); } } for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) { @@ -8854,23 +8842,13 @@ static void ggml_compute_forward_flash_attn_ext_tiled( for (int64_t dk = 0; dk < DK; dk++) { K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk]; } - } else if (mxfp.k_dequantize) { - if (mxfp.k_multihead) { - // Per-head extraction: copy only this head's SoA blocks + e8m0, dequant DK elements. - const char * row = (const char *)k->data + (ic + tk)*nbk1 + ik3*nbk3; - const int kqs = ik2 * mxfp.k_head_qs_bytes; - const int ke8 = (int)mxfp.k_head_e8m0_offset + ik2 * mxfp.k_blocks_per_head; - memcpy(k_head_soa, row + kqs, mxfp.k_head_qs_bytes); - memcpy(k_head_soa + mxfp.k_head_qs_bytes, row + ke8, mxfp.k_blocks_per_head); - mxfp.k_dequantize(k_head_soa, k_dequant_buf, DK); - } else { - mxfp.k_dequantize(k_data, k_dequant_buf, DK); - } + } else if (mxfp.k.dequantize) { + mxfp_dequant_head(mxfp.k, k_data, ik2, k_head_soa, k_dequant_buf, DK); for (int64_t dk = 0; dk < DK; dk++) { K_f32[dk * KV_TILE_SZ + tk] = k_dequant_buf[dk]; } } else { - float k_tmp[1024]; + float k_tmp[MXFP_FA_MAX_D]; k_to_float(k_data, k_tmp, DK); for (int64_t dk = 0; dk < DK; dk++) { K_f32[dk * KV_TILE_SZ + tk] = k_tmp[dk]; @@ -8934,18 +8912,8 @@ static void ggml_compute_forward_flash_attn_ext_tiled( ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV); } else if (v_type == GGML_TYPE_F32) { memcpy(V32 + tk * DV, v_data, DV * sizeof(float)); - } else if (mxfp.v_dequantize) { - if (mxfp.v_multihead) { - // Per-head extraction: copy only this head's SoA blocks + e8m0, dequant DV elements. - const char * row = (const char *)v->data + (ic + tk)*nbv1 + iv3*nbv3; - const int vqs = iv2 * mxfp.v_head_qs_bytes; - const int ve8 = (int)mxfp.v_head_e8m0_offset + iv2 * mxfp.v_blocks_per_head; - memcpy(v_head_soa, row + vqs, mxfp.v_head_qs_bytes); - memcpy(v_head_soa + mxfp.v_head_qs_bytes, row + ve8, mxfp.v_blocks_per_head); - mxfp.v_dequantize(v_head_soa, v_dequant_buf, DV); - } else { - mxfp.v_dequantize(v_data, v_dequant_buf, DV); - } + } else if (mxfp.v.dequantize) { + mxfp_dequant_head(mxfp.v, v_data, iv2, v_head_soa, v_dequant_buf, DV); memcpy(V32 + tk * DV, v_dequant_buf, DV * sizeof(float)); } else { v_to_float(v_data, V32 + tk * DV, DV); diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index eed3be90fc..5cbd177234 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -54,14 +54,6 @@ void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, i quantize_row_nvfp4_ref(x, y, k); } -void quantize_row_mxfp8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { - quantize_row_mxfp8_ref(x, y, k); -} - -void quantize_row_mxfp6(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { - quantize_row_mxfp6_ref(x, y, k); -} - // // 2-6 bit quantization in super-blocks // @@ -264,51 +256,6 @@ void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, *s = sumf; } -// Generic MXFP x Q8_0 dot product (scalar, not SIMD-optimized) -static void ggml_vec_dot_mxfp_q8_0_impl( - int n, float * GGML_RESTRICT s, - const void * GGML_RESTRICT vx, size_t block_size, - const void * GGML_RESTRICT vy, - ggml_to_float_t dequant) { - assert(n % QK8_0 == 0); - const int nb = n / QK8_0; - const block_q8_0 * GGML_RESTRICT y = vy; - float sumf = 0; - - for (int ib = 0; ib < nb; ib++) { - float tmp[QK8_0]; - dequant((const char *)vx + ib * block_size, tmp, QK8_0); - - const float y_d = GGML_CPU_FP16_TO_FP32(y[ib].d); - float block_sum = 0; - for (int j = 0; j < QK8_0; j++) { - block_sum += tmp[j] * (float)y[ib].qs[j]; - } - sumf += block_sum * y_d; - } - *s = sumf; -} - -void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - ggml_vec_dot_mxfp_q8_0_impl(n, s, vx, sizeof(block_mxfp8), vy, - (ggml_to_float_t)dequantize_row_mxfp8); -} - -void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - ggml_vec_dot_mxfp_q8_0_impl(n, s, vx, sizeof(block_mxfp6), vy, - (ggml_to_float_t)dequantize_row_mxfp6); -} - // Generic SoA dequant wrappers — arch-specific SIMD versions override via fallback.h. void dequantize_row_mxfp4_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { dequantize_row_mxfp4_soa(x, y, k); diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index 78c9984bdc..4a4dd264fe 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -21,9 +21,6 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_mxfp8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_mxfp6(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); - void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -46,9 +43,6 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); - void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -80,9 +74,6 @@ void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); - // SoA dequant (SIMD-dispatched, CPU backend) void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index e9a101ff66..f18758f16b 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -3770,7 +3770,7 @@ static block_mxfp4x4 make_block_mxfp4x4(block_mxfp4 * in, unsigned int blck_size } static int repack_mxfp4_to_mxfp4_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_MXFP4_E2M1); + GGML_ASSERT(t->type == GGML_TYPE_MXFP4); GGML_ASSERT(interleave_block == 4); const block_mxfp4 * src = (const block_mxfp4 *)data; @@ -3827,7 +3827,7 @@ static block_mxfp4x8 make_block_mxfp4x8(block_mxfp4 * in, unsigned int blck_size } static int repack_mxfp4_to_mxfp4_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_MXFP4_E2M1); + GGML_ASSERT(t->type == GGML_TYPE_MXFP4); GGML_ASSERT(interleave_block == 8); const block_mxfp4 * src = (const block_mxfp4 *)data; @@ -4685,7 +4685,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons } #endif } - } else if (cur->type == GGML_TYPE_MXFP4_E2M1) { + } else if (cur->type == GGML_TYPE_MXFP4) { if (ggml_cpu_has_avx2()) { if (cur->ne[1] % 8 == 0) { return &mxfp4_8x8_q8_0; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index a8996a2ab5..9388b1be46 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1014,7 +1014,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te // MXFP4: has AoS shaders (MUL_MAT, GET_ROWS) but no SoA/flash attention support yet. for (size_t i = 0, n = 3; i < n; ++i) { if (op->src[i] != NULL && ggml_is_type_mxfp(op->src[i]->type)) { - if (op->src[i]->type != GGML_TYPE_MXFP4_E2M1) { + if (op->src[i]->type != GGML_TYPE_MXFP4) { return false; } if (op->op == GGML_OP_FLASH_ATTN_EXT || op->op == GGML_OP_SET_ROWS) { diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index f2fcd73dba..e1dca6b4b4 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -3760,7 +3760,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te } else if (op->src[0]->type == GGML_TYPE_F32) { return op->src[1]->type == GGML_TYPE_F32; } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || - op->src[0]->type == GGML_TYPE_MXFP4_E2M1 || + op->src[0]->type == GGML_TYPE_MXFP4 || op->src[0]->type == GGML_TYPE_Q4_K || op->src[0]->type == GGML_TYPE_Q6_K) { return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); @@ -3771,7 +3771,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_MUL_MAT_ID: if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0 || - op->src[0]->type == GGML_TYPE_MXFP4_E2M1) { + op->src[0]->type == GGML_TYPE_MXFP4) { if (op->src[1]->type == GGML_TYPE_F32) { return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); } @@ -4559,7 +4559,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, #endif // GGML_OPENCL_USE_ADRENO_KERNELS return; } - if (tensor->type == GGML_TYPE_MXFP4_E2M1) { + if (tensor->type == GGML_TYPE_MXFP4) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); @@ -5136,7 +5136,7 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } - if (tensor->type == GGML_TYPE_MXFP4_E2M1) { + if (tensor->type == GGML_TYPE_MXFP4) { ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra; cl_int err; @@ -5585,7 +5585,7 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL)); CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_d, buf_d, 0, NULL, NULL)); CL_CHECK(clFinish(queue)); - } else if (tensor->type == GGML_TYPE_MXFP4_E2M1) { + } else if (tensor->type == GGML_TYPE_MXFP4) { ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *) tensor->extra; GGML_ASSERT(extra); @@ -10550,7 +10550,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); #endif // GGML_OPENCL_SOA_Q break; - case GGML_TYPE_MXFP4_E2M1: { + case GGML_TYPE_MXFP4: { #ifdef GGML_OPENCL_SOA_Q kernel = backend_ctx->kernel_mul_mv_mxfp4_f32_flat; @@ -10630,7 +10630,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co GGML_ASSERT(false && "not implemented"); } - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4_E2M1 || + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K) { @@ -10864,7 +10864,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, #endif // GGML_OPENCL_SOA_Q break; } - case GGML_TYPE_MXFP4_E2M1: { + case GGML_TYPE_MXFP4: { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (use_adreno_moe_kernels(backend_ctx, src0)) { cl_int status; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index cca3d99c82..5c8eb97806 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -267,10 +267,12 @@ uint8_t float_to_fp8_e4m3_rn(float x) { return ggml_mxfp_float_to_fp8_e4m3(x); } // MSE-optimal E8M0: tests candidates around round(log2(amax)), picks lowest quantization error. typedef struct { - int emax_offset; // type-specific offset to max representable exponent + int emax_offset; // type-specific offset to max representable exponent + int qs_per_block; // quantized scalar bytes per 32-element block + int bits_per_elem; // 8 = byte-aligned, 6 = packed via fp6x4 uint8_t (*to_elem)(float); float (*to_float)(uint8_t); - float (*mse_error)(float val, float inv_scale, float scale); + float (*mse_error)(float val, float inv_scale, float scale); // NULL = use generic round-trip via to_elem/to_float } mxfp_elem_traits_t; static inline int best_index_mxfp4(float x, float e); @@ -294,7 +296,7 @@ static float mse_error_mxfp4(float val, float inv_scale, float scale) { return err * err; } -static const mxfp_elem_traits_t mxfp4_traits = { MXFP4_E2M1_EMAX_OFFSET, NULL, NULL, mse_error_mxfp4 }; +static const mxfp_elem_traits_t mxfp4_traits = { MXFP4_E2M1_EMAX_OFFSET, MXFP4_SOA_QS_PER_BLOCK, 4, NULL, NULL, mse_error_mxfp4 }; static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_elem_traits_t * traits) { float amax = 0.0f; @@ -319,7 +321,13 @@ static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_ const float test_inv = 1.0f / test_scale; float mse = 0.0f; for (int j = 0; j < qk; ++j) { - mse += traits->mse_error(x[j], test_inv, test_scale); + if (traits->mse_error) { + mse += traits->mse_error(x[j], test_inv, test_scale); + } else { + const float recon = traits->to_float(traits->to_elem(x[j] * test_inv)) * test_scale; + const float err = x[j] - recon; + mse += err * err; + } } if (mse < best_mse) { best_mse = mse; @@ -574,102 +582,8 @@ uint8_t float_to_fp8_e5m2_rn(float x) { return ggml_mxfp_float_to_fp8_e5m2(x); } void pack_fp6x4(const uint8_t v[4], uint8_t out[3]) { ggml_mxfp_pack_fp6x4(v, out); } void unpack_fp6x4(const uint8_t in[3], uint8_t v[4]) { ggml_mxfp_unpack_fp6x4(in, v); } -static float mse_error_fp8_e4m3(float val, float inv_scale, float scale) { - const float recon = fp8_e4m3_to_float(float_to_fp8_e4m3_rn(val * inv_scale)) * scale; - const float err = val - recon; - return err * err; -} -static float mse_error_fp6_e2m3(float val, float inv_scale, float scale) { - const float recon = fp6_e2m3_to_float(float_to_fp6_e2m3_rn(val * inv_scale)) * scale; - const float err = val - recon; - return err * err; -} -static const mxfp_elem_traits_t mxfp8_e4m3_traits = { MXFP8_E4M3_EMAX_OFFSET, float_to_fp8_e4m3_rn, fp8_e4m3_to_float, mse_error_fp8_e4m3 }; -static const mxfp_elem_traits_t mxfp6_e2m3_traits = { MXFP6_E2M3_EMAX_OFFSET, float_to_fp6_e2m3_rn, fp6_e2m3_to_float, mse_error_fp6_e2m3 }; - -static void quantize_row_mxfp8_impl(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, - int64_t k, const mxfp_elem_traits_t * traits) { - assert(k % QK_MXFP8 == 0); - const int nb = k / QK_MXFP8; - - for (int i = 0; i < nb; i++) { - const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP8], QK_MXFP8, traits); - const float d = GGML_E8M0_TO_FP32(e); - const float inv_d = d > 0.0f ? 1.0f / d : 0.0f; - y[i].e = e; - - for (int j = 0; j < QK_MXFP8; ++j) { - y[i].qs[j] = traits->to_elem(x[i*QK_MXFP8 + j] * inv_d); - } - } -} - -static void dequantize_row_mxfp8_impl(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, - int64_t k, const mxfp_elem_traits_t * traits) { - assert(k % QK_MXFP8 == 0); - const int nb = k / QK_MXFP8; - - for (int i = 0; i < nb; i++) { - const float d = GGML_E8M0_TO_FP32(x[i].e); - for (int j = 0; j < QK_MXFP8; ++j) { - y[i*QK_MXFP8 + j] = traits->to_float(x[i].qs[j]) * d; - } - } -} - -static void quantize_row_mxfp6_impl(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, - int64_t k, const mxfp_elem_traits_t * traits) { - assert(k % QK_MXFP6 == 0); - const int nb = k / QK_MXFP6; - - for (int i = 0; i < nb; i++) { - const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP6], QK_MXFP6, traits); - const float d = GGML_E8M0_TO_FP32(e); - const float inv_d = d > 0.0f ? 1.0f / d : 0.0f; - y[i].e = e; - - for (int j = 0; j < QK_MXFP6; j += 4) { - uint8_t vals[4]; - for (int jj = 0; jj < 4; jj++) { - vals[jj] = traits->to_elem(x[i*QK_MXFP6 + j + jj] * inv_d); - } - pack_fp6x4(vals, &y[i].qs[j * 3 / 4]); - } - } -} - -static void dequantize_row_mxfp6_impl(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, - int64_t k, const mxfp_elem_traits_t * traits) { - assert(k % QK_MXFP6 == 0); - const int nb = k / QK_MXFP6; - - for (int i = 0; i < nb; i++) { - const float d = GGML_E8M0_TO_FP32(x[i].e); - for (int j = 0; j < QK_MXFP6; j += 4) { - uint8_t vals[4]; - unpack_fp6x4(&x[i].qs[j * 3 / 4], vals); - for (int jj = 0; jj < 4; jj++) { - y[i*QK_MXFP6 + j + jj] = traits->to_float(vals[jj]) * d; - } - } - } -} - -void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k) { - quantize_row_mxfp8_impl(x, y, k, &mxfp8_e4m3_traits); -} - -void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { - dequantize_row_mxfp8_impl(x, y, k, &mxfp8_e4m3_traits); -} - -void quantize_row_mxfp6_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k) { - quantize_row_mxfp6_impl(x, y, k, &mxfp6_e2m3_traits); -} - -void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { - dequantize_row_mxfp6_impl(x, y, k, &mxfp6_e2m3_traits); -} +static const mxfp_elem_traits_t mxfp8_e4m3_traits = { MXFP8_E4M3_EMAX_OFFSET, MXFP8_SOA_QS_PER_BLOCK, 8, float_to_fp8_e4m3_rn, fp8_e4m3_to_float, NULL }; +static const mxfp_elem_traits_t mxfp6_e2m3_traits = { MXFP6_E2M3_EMAX_OFFSET, MXFP6_SOA_QS_PER_BLOCK, 6, float_to_fp6_e2m3_rn, fp6_e2m3_to_float, NULL }; // ====================== SoA (Struct-of-Arrays) quantize/dequantize for flash attention @@ -715,101 +629,79 @@ void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTR } } -static void quantize_row_mxfp8_soa_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, - int64_t k, const mxfp_elem_traits_t * traits) { - assert(k % QK_MXFP8 == 0); - const int nb = k / QK_MXFP8; - char * row = (char *)dst; - char * qs_base = row; - char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK); +// Unified SoA quantize for byte-aligned (FP8) and 6-bit packed (FP6) MXFP formats. +static void quantize_row_mxfp_soa_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, + int64_t k, const mxfp_elem_traits_t * traits) { + const int qk = 32; + assert(k % qk == 0); + const int nb = k / qk; + const int qpb = traits->qs_per_block; + char * qs_base = (char *)dst; + char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, qpb); for (int i = 0; i < nb; i++) { - const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP8], QK_MXFP8, traits); + const uint8_t e = mxfp_compute_e8m0_mse(&x[i*qk], qk, traits); const float d = GGML_E8M0_TO_FP32(e); const float inv_d = d > 0.0f ? 1.0f / d : 0.0f; e8m0_base[i] = (char)e; - uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP8_SOA_QS_PER_BLOCK)); - for (int j = 0; j < QK_MXFP8; ++j) { - qs[j] = traits->to_elem(x[i*QK_MXFP8 + j] * inv_d); - } - } -} - -static void dequantize_row_mxfp8_soa_impl(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, - int64_t k, const mxfp_elem_traits_t * traits) { - assert(k % QK_MXFP8 == 0); - const int nb = k / QK_MXFP8; - const char * row = (const char *)src; - const char * qs_base = row; - const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK); - - for (int i = 0; i < nb; i++) { - const float d = GGML_E8M0_TO_FP32((uint8_t)e8m0_base[i]); - const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP8_SOA_QS_PER_BLOCK)); - for (int j = 0; j < QK_MXFP8; ++j) { - y[i*QK_MXFP8 + j] = traits->to_float(qs[j]) * d; - } - } -} - -static void quantize_row_mxfp6_soa_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, - int64_t k, const mxfp_elem_traits_t * traits) { - assert(k % QK_MXFP6 == 0); - const int nb = k / QK_MXFP6; - char * row = (char *)dst; - char * qs_base = row; - char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK); - - for (int i = 0; i < nb; i++) { - const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP6], QK_MXFP6, traits); - const float d = GGML_E8M0_TO_FP32(e); - const float inv_d = d > 0.0f ? 1.0f / d : 0.0f; - e8m0_base[i] = (char)e; - - uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP6_SOA_QS_PER_BLOCK)); - for (int j = 0; j < QK_MXFP6; j += 4) { - uint8_t vals[4]; - for (int jj = 0; jj < 4; jj++) { - vals[jj] = traits->to_elem(x[i*QK_MXFP6 + j + jj] * inv_d); + uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, qpb)); + if (traits->bits_per_elem == 8) { + for (int j = 0; j < qk; ++j) { + qs[j] = traits->to_elem(x[i*qk + j] * inv_d); + } + } else { + for (int j = 0; j < qk; j += 4) { + uint8_t vals[4]; + for (int jj = 0; jj < 4; jj++) { + vals[jj] = traits->to_elem(x[i*qk + j + jj] * inv_d); + } + pack_fp6x4(vals, &qs[j * 3 / 4]); } - pack_fp6x4(vals, &qs[j * 3 / 4]); } } } -static void dequantize_row_mxfp6_soa_impl(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, - int64_t k, const mxfp_elem_traits_t * traits) { - assert(k % QK_MXFP6 == 0); - const int nb = k / QK_MXFP6; - const char * row = (const char *)src; - const char * qs_base = row; - const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK); +// Unified SoA dequantize for byte-aligned (FP8) and 6-bit packed (FP6) MXFP formats. +static void dequantize_row_mxfp_soa_impl(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, + int64_t k, const mxfp_elem_traits_t * traits) { + const int qk = 32; + assert(k % qk == 0); + const int nb = k / qk; + const int qpb = traits->qs_per_block; + const char * qs_base = (const char *)src; + const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, qpb); for (int i = 0; i < nb; i++) { const float d = GGML_E8M0_TO_FP32((uint8_t)e8m0_base[i]); - const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP6_SOA_QS_PER_BLOCK)); - for (int j = 0; j < QK_MXFP6; j += 4) { - uint8_t vals[4]; - unpack_fp6x4(&qs[j * 3 / 4], vals); - for (int jj = 0; jj < 4; jj++) { - y[i*QK_MXFP6 + j + jj] = traits->to_float(vals[jj]) * d; + const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, qpb)); + if (traits->bits_per_elem == 8) { + for (int j = 0; j < qk; ++j) { + y[i*qk + j] = traits->to_float(qs[j]) * d; + } + } else { + for (int j = 0; j < qk; j += 4) { + uint8_t vals[4]; + unpack_fp6x4(&qs[j * 3 / 4], vals); + for (int jj = 0; jj < 4; jj++) { + y[i*qk + j + jj] = traits->to_float(vals[jj]) * d; + } } } } } void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) { - quantize_row_mxfp8_soa_impl(x, dst, k, &mxfp8_e4m3_traits); + quantize_row_mxfp_soa_impl(x, dst, k, &mxfp8_e4m3_traits); } void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) { - dequantize_row_mxfp8_soa_impl(src, y, k, &mxfp8_e4m3_traits); + dequantize_row_mxfp_soa_impl(src, y, k, &mxfp8_e4m3_traits); } void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) { - quantize_row_mxfp6_soa_impl(x, dst, k, &mxfp6_e2m3_traits); + quantize_row_mxfp_soa_impl(x, dst, k, &mxfp6_e2m3_traits); } void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) { - dequantize_row_mxfp6_soa_impl(src, y, k, &mxfp6_e2m3_traits); + dequantize_row_mxfp_soa_impl(src, y, k, &mxfp6_e2m3_traits); } // // 2-6 bit quantization in super-blocks @@ -2472,7 +2364,7 @@ size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { GGML_UNUSED(quant_weights); quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row); - return nrow * ggml_row_size(GGML_TYPE_MXFP4_E2M1, n_per_row); + return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row); } size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { @@ -2481,18 +2373,6 @@ size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, return nrow * ggml_row_size(GGML_TYPE_NVFP4, n_per_row); } -size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - GGML_UNUSED(quant_weights); - quantize_row_mxfp8_ref(src, dst, (int64_t)nrow*n_per_row); - return nrow * ggml_row_size(GGML_TYPE_MXFP8_E4M3, n_per_row); -} - -size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - GGML_UNUSED(quant_weights); - quantize_row_mxfp6_ref(src, dst, (int64_t)nrow*n_per_row); - return nrow * ggml_row_size(GGML_TYPE_MXFP6_E2M3, n_per_row); -} - // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs) void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) { @@ -5635,15 +5515,15 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb); } break; - case GGML_TYPE_MXFP4_E2M1: + case GGML_TYPE_MXFP4: { VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb); } break; - case GGML_TYPE_MXFP8_E4M3: + case GGML_TYPE_MXFP8: { VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp8, data, nb); } break; - case GGML_TYPE_MXFP6_E2M3: + case GGML_TYPE_MXFP6: { VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp6, data, nb); } break; diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 7d848edffe..4dec9ad351 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -23,9 +23,6 @@ GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k); -GGML_API void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k); -GGML_API void quantize_row_mxfp6_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k); - GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k); @@ -52,9 +49,6 @@ GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GG GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -GGML_API void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -GGML_API void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); - // SoA quantize/dequantize for flash attention GGML_API void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); GGML_API void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); @@ -110,9 +104,6 @@ GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTR GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -GGML_API size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -GGML_API size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); - // MXFP element converters GGML_API float fp8_e4m3_to_float(uint8_t v); GGML_API uint8_t float_to_fp8_e4m3_rn(float x); diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index 09f3a43a90..d17aca2cac 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -639,7 +639,7 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { return dequantize_row_iq4_xs_sycl; case GGML_TYPE_IQ4_NL: return dequantize_row_iq4_nl_sycl; - case GGML_TYPE_MXFP4_E2M1: + case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_sycl; case GGML_TYPE_F32: return convert_unary_sycl; @@ -706,7 +706,7 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { return dequantize_row_iq4_xs_sycl; case GGML_TYPE_IQ4_NL: return dequantize_row_iq4_nl_sycl; - case GGML_TYPE_MXFP4_E2M1: + case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_sycl; case GGML_TYPE_F16: return convert_unary_sycl; diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index f0d1472f44..316aa0d0fb 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -1142,7 +1142,7 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens case GGML_TYPE_IQ4_XS: mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; - case GGML_TYPE_MXFP4_E2M1: + case GGML_TYPE_MXFP4: mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; default: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 40a0aab62b..21b9a81eae 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -710,7 +710,7 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref, }, - [GGML_TYPE_MXFP4_E2M1] = { + [GGML_TYPE_MXFP4] = { .type_name = "mxfp4", .blck_size = QK_MXFP4, .type_size = sizeof(block_mxfp4), @@ -726,21 +726,17 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .to_float = (ggml_to_float_t) dequantize_row_nvfp4, .from_float_ref = (ggml_from_float_t)quantize_row_nvfp4_ref, }, - [GGML_TYPE_MXFP8_E4M3] = { + [GGML_TYPE_MXFP8] = { .type_name = "mxfp8_e4m3", .blck_size = QK_MXFP8, .type_size = sizeof(block_mxfp8), .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_mxfp8, - .from_float_ref = (ggml_from_float_t)quantize_row_mxfp8_ref, }, - [GGML_TYPE_MXFP6_E2M3] = { + [GGML_TYPE_MXFP6] = { .type_name = "mxfp6_e2m3", .blck_size = QK_MXFP6, .type_size = sizeof(block_mxfp6), .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_mxfp6, - .from_float_ref = (ggml_from_float_t)quantize_row_mxfp6_ref, }, [GGML_TYPE_Q2_K] = { .type_name = "q2_K", @@ -1329,25 +1325,25 @@ bool ggml_is_quantized(enum ggml_type type) { } bool ggml_is_type_mxfp(enum ggml_type type) { - return type == GGML_TYPE_MXFP4_E2M1 || - type == GGML_TYPE_MXFP8_E4M3 || - type == GGML_TYPE_MXFP6_E2M3; + return type == GGML_TYPE_MXFP4 || + type == GGML_TYPE_MXFP8 || + type == GGML_TYPE_MXFP6; } bool ggml_mxfp_use_hadamard(enum ggml_type type) { switch (type) { - case GGML_TYPE_MXFP4_E2M1: return MXFP_USE_HADAMARD_E2M1; - case GGML_TYPE_MXFP8_E4M3: return MXFP_USE_HADAMARD_E4M3; - case GGML_TYPE_MXFP6_E2M3: return MXFP_USE_HADAMARD_E2M3; + case GGML_TYPE_MXFP4: return MXFP_USE_HADAMARD_E2M1; + case GGML_TYPE_MXFP8: return MXFP_USE_HADAMARD_E4M3; + case GGML_TYPE_MXFP6: return MXFP_USE_HADAMARD_E2M3; default: return false; } } int ggml_mxfp_qs_per_block(enum ggml_type type) { switch (type) { - case GGML_TYPE_MXFP4_E2M1: return MXFP_QS_PER_BLOCK_E2M1; - case GGML_TYPE_MXFP8_E4M3: return MXFP_QS_PER_BLOCK_E4M3; - case GGML_TYPE_MXFP6_E2M3: return MXFP_QS_PER_BLOCK_E2M3; + case GGML_TYPE_MXFP4: return MXFP_QS_PER_BLOCK_E2M1; + case GGML_TYPE_MXFP8: return MXFP_QS_PER_BLOCK_E4M3; + case GGML_TYPE_MXFP6: return MXFP_QS_PER_BLOCK_E2M3; default: return 0; } } @@ -7695,10 +7691,10 @@ size_t ggml_quantize_chunk( case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_MXFP4_E2M1: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_MXFP8_E4M3: result = quantize_mxfp8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_MXFP6_E2M3: result = quantize_mxfp6(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_MXFP8: GGML_ABORT("MXFP8 is KV-cache-only (SoA layout) — use from_float_soa"); + case GGML_TYPE_MXFP6: GGML_ABORT("MXFP6 is KV-cache-only (SoA layout) — use from_float_soa"); case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 279c57e582..8e8ce23124 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -457,7 +457,7 @@ static ggml_type llama_tensor_get_type_impl(quantize_state_impl & qs, ggml_type // MoE tensors -> MXFP4 // other tensors -> Q8_0 if (tensor->ne[2] > 1) { - new_type = GGML_TYPE_MXFP4_E2M1; + new_type = GGML_TYPE_MXFP4; } else { new_type = GGML_TYPE_Q8_0; } @@ -795,7 +795,7 @@ static ggml_type llama_ftype_get_default_type(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_BF16: return GGML_TYPE_BF16; case LLAMA_FTYPE_ALL_F32: return GGML_TYPE_F32; - case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return GGML_TYPE_MXFP4_E2M1; + case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return GGML_TYPE_MXFP4; // K-quants case LLAMA_FTYPE_MOSTLY_Q2_K_S: diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 8e57cb1d1d..d102c5676c 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -170,9 +170,9 @@ struct mxfp_soa_fns { }; static const mxfp_soa_fns mxfp_soa_table[] = { - { GGML_TYPE_MXFP4_E2M1, quantize_row_mxfp4_soa, dequantize_row_mxfp4_soa }, - { GGML_TYPE_MXFP8_E4M3, quantize_row_mxfp8_soa, dequantize_row_mxfp8_soa }, - { GGML_TYPE_MXFP6_E2M3, quantize_row_mxfp6_soa, dequantize_row_mxfp6_soa }, + { GGML_TYPE_MXFP4, quantize_row_mxfp4_soa, dequantize_row_mxfp4_soa }, + { GGML_TYPE_MXFP8, quantize_row_mxfp8_soa, dequantize_row_mxfp8_soa }, + { GGML_TYPE_MXFP6, quantize_row_mxfp6_soa, dequantize_row_mxfp6_soa }, }; static const mxfp_soa_fns * get_mxfp_soa(ggml_type type) { @@ -3908,7 +3908,7 @@ struct test_mul_mat : public test_case { double max_nmse_err(ggml_backend_t backend) override { // for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance - if (type_a == GGML_TYPE_MXFP4_E2M1 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) { + if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) { return 2e-2; } return max_nmse_err(); @@ -4044,7 +4044,7 @@ struct test_mul_mat_id : public test_case { double max_nmse_err(ggml_backend_t backend) override { // for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance - if (type_a == GGML_TYPE_MXFP4_E2M1 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) { + if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) { return 2e-2; } return max_nmse_err(); @@ -7398,7 +7398,7 @@ static const ggml_type all_types[] = { GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0, - GGML_TYPE_MXFP4_E2M1, + GGML_TYPE_MXFP4, GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K, @@ -7414,7 +7414,7 @@ static const ggml_type base_types[] = { GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, // for I8MM tests GGML_TYPE_Q4_K, - GGML_TYPE_MXFP4_E2M1, + GGML_TYPE_MXFP4, GGML_TYPE_IQ2_XXS }; @@ -7533,8 +7533,8 @@ static std::vector> make_test_cases_eval() { } // SET_ROWS with Hadamard rotation (exercises the op_params[0] flag used by MXFP KV cache) - for (ggml_type type : {GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP8_E4M3, - GGML_TYPE_MXFP6_E2M3}) { + for (ggml_type type : {GGML_TYPE_MXFP4, GGML_TYPE_MXFP8, + GGML_TYPE_MXFP6}) { // ne[0] must be divisible by 32 (Hadamard block size) test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 128, 5, 1, 1 }, { 1, 1 }, 1, false, true)); test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 5, 1, 3 }, { 1, 1 }, 1, false, true)); @@ -8270,7 +8270,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat_id_fusion(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3)); // gpt-oss issue with Vulkan mmq_id - test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4_E2M1, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880)); + test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880)); for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { @@ -8731,7 +8731,7 @@ static std::vector> make_test_cases_eval() { for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) { if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue; for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, - GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3, + GGML_TYPE_MXFP4, GGML_TYPE_MXFP8, GGML_TYPE_MXFP6, }) { // Non-F16 types: test at D=64, D=72, and D=128. if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72 && hsk != 128) continue; @@ -8760,8 +8760,8 @@ static std::vector> make_test_cases_eval() { // MXFP-specific K/V type combinations (mixed and same-type) // Mixed: mxfp8 K + mxfp4 V, mxfp6 K + mxfp4 V (our recommended configs) - for (ggml_type type_K : {GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3}) { - for (ggml_type type_V : {GGML_TYPE_MXFP4_E2M1}) { + for (ggml_type type_K : {GGML_TYPE_MXFP8, GGML_TYPE_MXFP6}) { + for (ggml_type type_V : {GGML_TYPE_MXFP4}) { if (type_K == type_V) continue; for (int nb : {1, 3, 32}) { test_cases.emplace_back(new test_flash_attn_ext( @@ -8770,7 +8770,7 @@ static std::vector> make_test_cases_eval() { } } // Same-type: mxfp8/mxfp8, mxfp6/mxfp6 - for (ggml_type type_KV : {GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3}) { + for (ggml_type type_KV : {GGML_TYPE_MXFP8, GGML_TYPE_MXFP6}) { for (int nb : {1, 3, 32}) { test_cases.emplace_back(new test_flash_attn_ext( 128, 128, 4, {1, 1}, 512, nb, true, false, 0.0f, 0.0f, GGML_PREC_F32, type_KV, {0, 1, 2, 3}, type_KV)); @@ -9000,7 +9000,7 @@ static std::vector> make_test_cases_perf() { // gpt-oss-20b for (int bs : {1, 4, 8, 512}) { - for (ggml_type type_a : {GGML_TYPE_MXFP4_E2M1}) { + for (ggml_type type_a : {GGML_TYPE_MXFP4}) { for (ggml_type type_b : {GGML_TYPE_F32}) { test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 2880, bs, 2880)); test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1)); diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index 98e3d489dd..8f1dcf10f0 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -178,9 +178,9 @@ int main(int argc, char * argv[]) { type == GGML_TYPE_IQ3_S ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS : type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS : type == GGML_TYPE_NVFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_FP4 : - type == GGML_TYPE_MXFP4_E2M1 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 : - type == GGML_TYPE_MXFP6_E2M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 : - type == GGML_TYPE_MXFP8_E4M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR; + type == GGML_TYPE_MXFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 : + type == GGML_TYPE_MXFP6 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 : + type == GGML_TYPE_MXFP8 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR; failed = !(total_error < max_quantization_error); num_failed += failed; if (failed || verbose) { @@ -202,7 +202,7 @@ int main(int argc, char * argv[]) { ? MAX_DOT_PRODUCT_ERROR_TERNARY : type == GGML_TYPE_NVFP4 ? MAX_DOT_PRODUCT_ERROR_FP4 - : type == GGML_TYPE_MXFP4_E2M1 || type == GGML_TYPE_MXFP6_E2M3 || type == GGML_TYPE_MXFP8_E4M3 + : type == GGML_TYPE_MXFP4 || type == GGML_TYPE_MXFP6 || type == GGML_TYPE_MXFP8 ? MAX_DOT_PRODUCT_ERROR_MXFP : MAX_DOT_PRODUCT_ERROR; failed = !(vec_dot_error < max_allowed_error); @@ -231,9 +231,9 @@ int main(int argc, char * argv[]) { const float soa_error = array_rmse(test_data.data(), tmp_out.data(), test_size); const float max_soa_error = - type == GGML_TYPE_MXFP4_E2M1 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 : - type == GGML_TYPE_MXFP6_E2M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 : - type == GGML_TYPE_MXFP8_E4M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR; + type == GGML_TYPE_MXFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 : + type == GGML_TYPE_MXFP6 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 : + type == GGML_TYPE_MXFP8 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR; failed = !(soa_error < max_soa_error); num_failed += failed; if (failed || verbose) { @@ -243,7 +243,7 @@ int main(int argc, char * argv[]) { // MXFP traits: SoA required, MXFP6/MXFP8 are KV-cache-only (no AoS dequant) { - const ggml_type all_mxfp_types[] = { GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3 }; + const ggml_type all_mxfp_types[] = { GGML_TYPE_MXFP4, GGML_TYPE_MXFP8, GGML_TYPE_MXFP6 }; for (ggml_type type : all_mxfp_types) { const auto * cpu = ggml_get_type_traits_cpu(type); @@ -255,7 +255,7 @@ int main(int argc, char * argv[]) { } // KV-cache-only types: no AoS dequant - const ggml_type kv_only_types[] = { GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3 }; + const ggml_type kv_only_types[] = { GGML_TYPE_MXFP8, GGML_TYPE_MXFP6 }; for (ggml_type type : kv_only_types) { const auto * cpu = ggml_get_type_traits_cpu(type); failed = (cpu->to_float != nullptr); @@ -297,9 +297,9 @@ int main(int argc, char * argv[]) { }; const soa_cross_check checks[] = { - { GGML_TYPE_MXFP4_E2M1, dequantize_row_mxfp4_soa }, - { GGML_TYPE_MXFP8_E4M3, dequantize_row_mxfp8_soa }, - { GGML_TYPE_MXFP6_E2M3, dequantize_row_mxfp6_soa }, + { GGML_TYPE_MXFP4, dequantize_row_mxfp4_soa }, + { GGML_TYPE_MXFP8, dequantize_row_mxfp8_soa }, + { GGML_TYPE_MXFP6, dequantize_row_mxfp6_soa }, }; for (const auto & c : checks) { @@ -774,9 +774,9 @@ int main(int argc, char * argv[]) { // SoA layout: verify offset macros produce correct byte positions { const struct { ggml_type type; int qs_per_block; } soa_types[] = { - { GGML_TYPE_MXFP4_E2M1, MXFP4_SOA_QS_PER_BLOCK }, - { GGML_TYPE_MXFP8_E4M3, MXFP8_SOA_QS_PER_BLOCK }, - { GGML_TYPE_MXFP6_E2M3, MXFP6_SOA_QS_PER_BLOCK }, + { GGML_TYPE_MXFP4, MXFP4_SOA_QS_PER_BLOCK }, + { GGML_TYPE_MXFP8, MXFP8_SOA_QS_PER_BLOCK }, + { GGML_TYPE_MXFP6, MXFP6_SOA_QS_PER_BLOCK }, }; for (const auto & st : soa_types) { @@ -864,7 +864,7 @@ int main(int argc, char * argv[]) { dequantize_row_mxfp4(aos_q.data(), aos_out.data(), nelems); // Quantize and dequant via SoA - const size_t soa_buf_size = ggml_row_size(GGML_TYPE_MXFP4_E2M1, nelems); + const size_t soa_buf_size = ggml_row_size(GGML_TYPE_MXFP4, nelems); std::vector soa_q(soa_buf_size); std::vector soa_out(nelems); quantize_row_mxfp4_soa(input, soa_q.data(), nelems); @@ -901,9 +901,9 @@ int main(int argc, char * argv[]) { }; const hadamard_pipeline_check pipeline_checks[] = { - { "mxfp4", GGML_TYPE_MXFP4_E2M1, MAX_MXFP_PIPELINE_ERROR_MXFP4 }, - { "mxfp8", GGML_TYPE_MXFP8_E4M3, MAX_MXFP_PIPELINE_ERROR_MXFP8 }, - { "mxfp6", GGML_TYPE_MXFP6_E2M3, MAX_MXFP_PIPELINE_ERROR_MXFP6 }, + { "mxfp4", GGML_TYPE_MXFP4, MAX_MXFP_PIPELINE_ERROR_MXFP4 }, + { "mxfp8", GGML_TYPE_MXFP8, MAX_MXFP_PIPELINE_ERROR_MXFP8 }, + { "mxfp6", GGML_TYPE_MXFP6, MAX_MXFP_PIPELINE_ERROR_MXFP6 }, }; for (const auto & p : pipeline_checks) { @@ -963,7 +963,7 @@ int main(int argc, char * argv[]) { // zero block produces E8M0=0 { float zeros[32] = {}; - const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP8_E4M3, 32); + const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP8, 32); std::vector buf(buf_size, 0xFF); // fill with 0xFF to detect non-writes quantize_row_mxfp8_soa(zeros, buf.data(), 32); @@ -991,7 +991,7 @@ int main(int argc, char * argv[]) { // MXFP4 { - const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP4_E2M1, nelems); + const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP4, nelems); std::vector buf(buf_size); std::vector ref_out(nelems); std::vector manual_out(nelems); @@ -1032,7 +1032,7 @@ int main(int argc, char * argv[]) { // MXFP8 { - const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP8_E4M3, nelems); + const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP8, nelems); std::vector buf(buf_size); std::vector ref_out(nelems); std::vector manual_out(nelems); @@ -1069,7 +1069,7 @@ int main(int argc, char * argv[]) { // MXFP6 { - const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP6_E2M3, nelems); + const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP6, nelems); std::vector buf(buf_size); std::vector ref_out(nelems); std::vector manual_out(nelems); diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 00c6536589..27db9a065d 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -484,13 +484,13 @@ static ggml_type ggml_type_from_name(const std::string & s) { return GGML_TYPE_IQ4_NL; } if (s == "mxfp4" || s == "mxfp4_e2m1") { - return GGML_TYPE_MXFP4_E2M1; + return GGML_TYPE_MXFP4; } if (s == "mxfp8" || s == "mxfp8_e4m3") { - return GGML_TYPE_MXFP8_E4M3; + return GGML_TYPE_MXFP8; } if (s == "mxfp6" || s == "mxfp6_e2m3") { - return GGML_TYPE_MXFP6_E2M3; + return GGML_TYPE_MXFP6; } return GGML_TYPE_COUNT; } From ccea34ba41b311dff5ec5ee083f4b235524fd58b Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Sun, 22 Mar 2026 20:12:09 -0400 Subject: [PATCH 13/13] perf : multiple fixes and enhancements, remove MSE search, expand test coverage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: correct tiled flash attention SoA pointer math for multihead MXFP The cleanup refactoring (c919bc471) extracted mxfp_dequant_head as a shared helper but failed to update the tiled path's data pointers. The helper expects the full SoA row base (no per-head offset), but the tiled path was passing a pointer that already included ik2*nbk2, causing a double head offset that produced NaN during prefill. Add mxfp_row_ptr helper to centralize the multihead-aware pointer calculation across both one_chunk and tiled paths. Verified with 16-chunk perplexity on gpt-oss-20b: all four configs (f16, mxfp4, mxfp6, mxfp8) produce exact matches with the known-good commit (23e88631c). * perf: reduce E8M0 MSE search range from ±2 to ±1 The base estimate round(log2(amax)) is always within 1 step of optimal. Empirically verified across 30K blocks and 6 distributions: ±1 and ±2 never disagree. This reduces the scale search from 5 to 3 candidates (40% fewer inner loop iterations) with zero quality impact. * perf: eliminate redundant work in MXFP quantize and flash attention - mse_error_mxfp4: use passed inv_scale instead of recomputing 1/d - mxfp_compute_e8m0_mse: hoist loop-invariant traits branch out of inner loop - tiled V path: dequant directly to V32 tile, remove intermediate memcpy and dead buffer * cleanup: fix comments, unify Hadamard condition, simplify E8M0 helpers - EMAX_OFFSET comments: fix ceil/floor labels to match actual values - Hadamard flag: unify write path (llama-kv-cache.cpp) and read path (ops.cpp) to both use DK==DV condition instead of is_mla() - E8M0 helpers in ggml-impl.h: simplify to match ggml-common.h style, add cross-reference comment * fix: MXFP8/6 flash attention tests crash on init The view base tensors for K/V don't get named "k"/"v" but inherit the MXFP type. The name-based filter in initialize_tensors missed them, falling through to init_tensor_uniform which calls quantize_chunk and aborts for KV-cache-only types. Fix by checking ggml_is_type_mxfp() for all tensors, matching the pattern set_rows tests already use. * test: expand MXFP set_rows coverage - Add MXFP8/MXFP6 to all_types for non-Hadamard set_rows coverage - Expand Hadamard set_rows tests: add views, broadcast, and multi-head configs - Coverage: 18 → 51 MXFP set_rows tests * perf: add AVX2 Hadamard for x86 (matches existing ARM NEON path) * cleanup: DRY MXFP4 quantize/dequant with shared per-block helpers Extract quantize_block_mxfp4 and dequantize_block_mxfp4 as shared helpers used by both AoS (quantize_row_mxfp4_ref, dequantize_row_mxfp4) and SoA (quantize_row_mxfp4_soa, dequantize_row_mxfp4_soa) paths. Eliminates duplicated per-block logic while keeping layout-specific pointer arithmetic in the callers. * feat: add MXFP8/MXFP6 AoS quantize/dequant (full type support) Extract quantize_block_mxfp / dequantize_block_mxfp per-block helpers from the SoA generic impl and use them to build AoS row functions for MXFP8 (E4M3) and MXFP6 (E2M3). Register to_float and from_float_ref in type traits, add quantize_chunk dispatch, replacing the GGML_ABORT. MXFP8 and MXFP6 are no longer KV-cache-only — they can now be used as general quantization types. The SoA impl is also DRY'd to delegate to the same per-block helpers. * cleanup: remove dead soa_elems field from mxfp_kv_params Computed but never read — leftover from an earlier design. * feat: add MXFP8/MXFP6 vec_dot and full CPU type support Add scalar vec_dot_mxfp8_q8_0 and vec_dot_mxfp6_q8_0 implementations, register from_float + vec_dot + vec_dot_type in CPU traits, and add fallback remaps for all architectures. MXFP8/6 are now fully tested: AoS quantization error, reference match, and dot product accuracy all pass in test-quantize-fns. * perf: remove E8M0 MSE search — base estimate is perplexity-optimal The MSE search over ±1 candidates around round(log2(amax)) was found to HURT perplexity by 4-37 PPL points across all MXFP configs on gpt-oss-20b. The base estimate alone (no search) produces better attention patterns because minimizing per-block reconstruction error is not the same as minimizing attention score distortion through softmax. Removes mse_error_mxfp4, mse_error field from traits, MSE_RANGE constant, and the entire search loop. E8M0 computation is now a single amax scan + integer bit extraction — no inner loop, no function pointers. This also simplifies future GPU/Metal implementations. * perf: fuse Hadamard rotation into SoA quantize (one pass, no temp buffer) Add quantize_row_mxfp{4,8,6}_soa_hadamard that apply Hadamard and quantize block-by-block with a 32-float stack buffer. Eliminates the std::vector heap allocation and 2 extra memory passes over the full row. set_rows now dispatches to the fused path when Hadamard is enabled, falling through to the unfused quantize for non-Hadamard types. This pattern maps directly to a CUDA kernel: global memory read → register Hadamard → register quantize → global memory write. * cleanup: consistent MXFP type names and variable naming - Rename type_name "mxfp8_e4m3" → "mxfp8", "mxfp6_e2m3" → "mxfp6" to match "mxfp4". Only one variant of each exists — the suffix was unnecessary disambiguation that implied alternatives. - Remove redundant MXFP shortcuts from arg.cpp (fallback loop handles all types via ggml_type_name matching). - Rename kv_is_f32_f16_or_mxfp → k_is_f32_f16_or_mxfp (only checks K). * perf: fuse Q preprocessing round-trip (no SoA buffer needed) Add mxfp{4,8,6}_hadamard_roundtrip and mxfp{4,8,6}_roundtrip functions that apply quantization error to float values without materializing SoA bytes. Replaces the 3-step Q preprocessing (Hadamard → quantize to SoA buffer → dequant from SoA buffer) with a single pass through per-block round-trip helpers. Eliminates the Q_q intermediate buffer and two function pointer calls from the flash attention hot path. Maps directly to CUDA: Q stays in registers, Hadamard butterfly + quantize error applied in-place. * fix: clamp E8M0 = 255 to 254 in decode (fixes CI NaN failures) E8M0 = 255 means NaN per MX spec, but our encode path already clamps to 254. When test data contains random E8M0 = 255 bytes, the decode produces Inf, and Inf * 0.0 = NaN, causing GET_ROWS and CPY tests to fail on MXFP6 (and potentially MXFP4/8). Fix: clamp 255 → 254 in both E8M0 decode functions: - ggml_e8m0_to_fp32 / ggml_e8m0_to_fp32_half (ggml-impl.h) - ggml_mxfp_e8m0_to_fp32 / ggml_mxfp_e8m0_to_fp32_half (ggml-common.h) These are unfortunately duplicated across two headers because ggml-common.h compiles for CUDA (__device__) while ggml-impl.h serves CPU-only callers that don't include ggml-common.h. --- common/arg.cpp | 9 - ggml/src/ggml-common.h | 17 +- ggml/src/ggml-cpu/arch-fallback.h | 11 + ggml/src/ggml-cpu/ggml-cpu.c | 6 + ggml/src/ggml-cpu/ops.cpp | 213 ++++++++++++---- ggml/src/ggml-cpu/quants.c | 48 ++++ ggml/src/ggml-cpu/quants.h | 2 + ggml/src/ggml-impl.h | 24 +- ggml/src/ggml-quants.c | 392 +++++++++++++++++++----------- ggml/src/ggml-quants.h | 18 ++ ggml/src/ggml.c | 12 +- src/llama-kv-cache.cpp | 5 +- tests/test-backend-ops.cpp | 11 +- tests/test-quantize-fns.cpp | 2 +- 14 files changed, 531 insertions(+), 239 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 26c1904a2a..5e3b40d899 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -404,15 +404,6 @@ const std::vector kv_cache_types = { }; static ggml_type kv_cache_type_from_str(const std::string & s) { - if (s == "mxfp4") { - return GGML_TYPE_MXFP4; - } - if (s == "mxfp6") { - return GGML_TYPE_MXFP6; - } - if (s == "mxfp8") { - return GGML_TYPE_MXFP8; - } for (const auto & type : kv_cache_types) { if (ggml_type_name(type) == s) { return type; diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 271de1943c..b60794717c 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -199,13 +199,12 @@ typedef struct { static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding"); // E8M0 shared exponent constants (OCP MX v1.0 SS5.3). -// EMAX_OFFSET = ceil(log2(max_finite)), MSE_RANGE = search radius for optimal scale. -#define MXFP_E8M0_MSE_RANGE 2 -#define MXFP4_E2M1_EMAX_OFFSET 2 // ceil(log2(6.0)) -#define MXFP6_E2M3_EMAX_OFFSET 3 // ceil(log2(7.5)) -#define MXFP6_E3M2_EMAX_OFFSET 5 // ceil(log2(28.0)) -#define MXFP8_E4M3_EMAX_OFFSET 8 // ceil(log2(448)) -#define MXFP8_E5M2_EMAX_OFFSET 16 // ceil(log2(57344)) +// EMAX_OFFSET ≈ log2(max_finite), used by round(log2(amax)) base estimate. +#define MXFP4_E2M1_EMAX_OFFSET 2 // floor(log2(6.0)) = 2 +#define MXFP6_E2M3_EMAX_OFFSET 3 // ceil(log2(7.5)) = 3 +#define MXFP6_E3M2_EMAX_OFFSET 5 // ceil(log2(28.0)) = 5 +#define MXFP8_E4M3_EMAX_OFFSET 8 // floor(log2(448)) = 8 +#define MXFP8_E5M2_EMAX_OFFSET 16 // ceil(log2(57344)) = 16 // MXFP type properties -- shared across all backends. #define MXFP_BITS_PER_ELEM_E2M1 4 @@ -1635,13 +1634,17 @@ GGML_MXFP_FUNC void ggml_mxfp_unpack_fp6x4(const uint8_t in[3], uint8_t v[4]) { // E8M0 shared exponent → float conversion. // E8M0 encoding: value = 2^(x - 127) for x > 0, 2^(-127) for x == 0. +// E8M0 = 255 is NaN per MX spec, but we clamp to 254 (max finite) to match +// the encode path which also clamps to 254, preventing Inf * 0 = NaN in dequant. GGML_MXFP_FUNC float ggml_mxfp_e8m0_to_fp32(uint8_t x) { + if (x == 255) { x = 254; } uint32_t bits = (x == 0) ? 0x00400000u : ((uint32_t)x << 23); return GGML_MXFP_U32_AS_F32(bits); } // E8M0 → float/2. Used with MXFP4 since E2M1 values are doubled in kvalues_mxfp4. GGML_MXFP_FUNC float ggml_mxfp_e8m0_to_fp32_half(uint8_t x) { + if (x == 255) { x = 254; } uint32_t bits = (x < 2) ? (0x00200000u << x) : ((uint32_t)(x - 1) << 23); return GGML_MXFP_U32_AS_F32(bits); } diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index eac031e68e..03f7bc0efe 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -14,6 +14,8 @@ #define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1 #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 +#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0 +#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu #define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu @@ -72,6 +74,9 @@ #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) +// quants.c +#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0 +#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0 // repack.cpp #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 @@ -83,6 +88,8 @@ #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // quants.c +#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0 +#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 @@ -164,6 +171,8 @@ #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 +#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0 +#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu #define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu @@ -319,6 +328,8 @@ #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0 #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 +#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0 +#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 782a54392f..e2720ea3a2 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -280,13 +280,19 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .nrows = 1, }, [GGML_TYPE_MXFP8] = { + .from_float = (ggml_from_float_t)quantize_row_mxfp8_ref, .from_float_soa = quantize_row_mxfp8_soa, .to_float_soa = dequantize_row_mxfp8_soa_cpu, + .vec_dot = ggml_vec_dot_mxfp8_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, }, [GGML_TYPE_MXFP6] = { + .from_float = (ggml_from_float_t)quantize_row_mxfp6_ref, .from_float_soa = quantize_row_mxfp6_soa, .to_float_soa = dequantize_row_mxfp6_soa_cpu, + .vec_dot = ggml_vec_dot_mxfp6_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, }, [GGML_TYPE_Q2_K] = { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 15424d40c4..fcdd7b045d 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4909,8 +4909,106 @@ void ggml_compute_forward_get_rows( //} } -// NEON-optimized Hadamard; scalar fallback below -#if defined(__ARM_NEON) +// SIMD-optimized Hadamard; scalar fallback below +#if defined(__AVX2__) || defined(__AVX__) +static void hadamard_32_inplace(float vals[32]) { + // 32 floats = 4 × __m256 + __m256 v0 = _mm256_loadu_ps(vals + 0); + __m256 v1 = _mm256_loadu_ps(vals + 8); + __m256 v2 = _mm256_loadu_ps(vals + 16); + __m256 v3 = _mm256_loadu_ps(vals + 24); + + // Stride 1: butterfly on adjacent pairs within each 256-bit register + { + // Interleave even/odd elements, add/sub + __m256 a, b, s, d; + a = _mm256_shuffle_ps(v0, v0, _MM_SHUFFLE(2, 2, 0, 0)); + b = _mm256_shuffle_ps(v0, v0, _MM_SHUFFLE(3, 3, 1, 1)); + s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b); + v0 = _mm256_shuffle_ps(s, d, _MM_SHUFFLE(2, 0, 2, 0)); + v0 = _mm256_shuffle_ps(v0, v0, _MM_SHUFFLE(3, 1, 2, 0)); + + a = _mm256_shuffle_ps(v1, v1, _MM_SHUFFLE(2, 2, 0, 0)); + b = _mm256_shuffle_ps(v1, v1, _MM_SHUFFLE(3, 3, 1, 1)); + s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b); + v1 = _mm256_shuffle_ps(s, d, _MM_SHUFFLE(2, 0, 2, 0)); + v1 = _mm256_shuffle_ps(v1, v1, _MM_SHUFFLE(3, 1, 2, 0)); + + a = _mm256_shuffle_ps(v2, v2, _MM_SHUFFLE(2, 2, 0, 0)); + b = _mm256_shuffle_ps(v2, v2, _MM_SHUFFLE(3, 3, 1, 1)); + s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b); + v2 = _mm256_shuffle_ps(s, d, _MM_SHUFFLE(2, 0, 2, 0)); + v2 = _mm256_shuffle_ps(v2, v2, _MM_SHUFFLE(3, 1, 2, 0)); + + a = _mm256_shuffle_ps(v3, v3, _MM_SHUFFLE(2, 2, 0, 0)); + b = _mm256_shuffle_ps(v3, v3, _MM_SHUFFLE(3, 3, 1, 1)); + s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b); + v3 = _mm256_shuffle_ps(s, d, _MM_SHUFFLE(2, 0, 2, 0)); + v3 = _mm256_shuffle_ps(v3, v3, _MM_SHUFFLE(3, 1, 2, 0)); + } + + // Stride 2: butterfly on pairs separated by 2 within 128-bit lanes + { + __m256 a, b, s, d; + a = _mm256_permute_ps(v0, _MM_SHUFFLE(1, 0, 1, 0)); + b = _mm256_permute_ps(v0, _MM_SHUFFLE(3, 2, 3, 2)); + s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b); + v0 = _mm256_blend_ps(s, d, 0xCC); // 0b11001100 + + a = _mm256_permute_ps(v1, _MM_SHUFFLE(1, 0, 1, 0)); + b = _mm256_permute_ps(v1, _MM_SHUFFLE(3, 2, 3, 2)); + s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b); + v1 = _mm256_blend_ps(s, d, 0xCC); + + a = _mm256_permute_ps(v2, _MM_SHUFFLE(1, 0, 1, 0)); + b = _mm256_permute_ps(v2, _MM_SHUFFLE(3, 2, 3, 2)); + s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b); + v2 = _mm256_blend_ps(s, d, 0xCC); + + a = _mm256_permute_ps(v3, _MM_SHUFFLE(1, 0, 1, 0)); + b = _mm256_permute_ps(v3, _MM_SHUFFLE(3, 2, 3, 2)); + s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b); + v3 = _mm256_blend_ps(s, d, 0xCC); + } + + // Stride 4: butterfly between 128-bit lanes within each 256-bit register + { + __m128 lo, hi; + lo = _mm256_castps256_ps128(v0); hi = _mm256_extractf128_ps(v0, 1); + v0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_add_ps(lo, hi)), _mm_sub_ps(lo, hi), 1); + + lo = _mm256_castps256_ps128(v1); hi = _mm256_extractf128_ps(v1, 1); + v1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_add_ps(lo, hi)), _mm_sub_ps(lo, hi), 1); + + lo = _mm256_castps256_ps128(v2); hi = _mm256_extractf128_ps(v2, 1); + v2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_add_ps(lo, hi)), _mm_sub_ps(lo, hi), 1); + + lo = _mm256_castps256_ps128(v3); hi = _mm256_extractf128_ps(v3, 1); + v3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_add_ps(lo, hi)), _mm_sub_ps(lo, hi), 1); + } + + // Stride 8: butterfly between registers + { + __m256 s, d; + s = _mm256_add_ps(v0, v1); d = _mm256_sub_ps(v0, v1); v0 = s; v1 = d; + s = _mm256_add_ps(v2, v3); d = _mm256_sub_ps(v2, v3); v2 = s; v3 = d; + } + + // Stride 16: butterfly between register pairs + { + __m256 s, d; + s = _mm256_add_ps(v0, v2); d = _mm256_sub_ps(v0, v2); v0 = s; v2 = d; + s = _mm256_add_ps(v1, v3); d = _mm256_sub_ps(v1, v3); v1 = s; v3 = d; + } + + // Normalize by 1/sqrt(32) + const __m256 norm = _mm256_set1_ps(MXFP_HADAMARD_32_NORM); + _mm256_storeu_ps(vals + 0, _mm256_mul_ps(v0, norm)); + _mm256_storeu_ps(vals + 8, _mm256_mul_ps(v1, norm)); + _mm256_storeu_ps(vals + 16, _mm256_mul_ps(v2, norm)); + _mm256_storeu_ps(vals + 24, _mm256_mul_ps(v3, norm)); +} +#elif defined(__ARM_NEON) static void hadamard_32_inplace(float vals[32]) { float32x4_t v0 = vld1q_f32(vals + 0); float32x4_t v1 = vld1q_f32(vals + 4); @@ -5032,9 +5130,15 @@ static void ggml_compute_forward_set_rows_f32( ggml_from_float_t mxfp_soa_quantize = dst_traits->from_float_soa; ggml_from_float_t from_float = mxfp_soa_quantize ? nullptr : dst_traits->from_float; - std::vector had_tmp; - if (apply_hadamard) { - had_tmp.resize(nc); + // Fused Hadamard+quantize: one pass per block, 32-float stack buffer, no heap allocation. + ggml_from_float_t mxfp_soa_hadamard_quantize = nullptr; + if (apply_hadamard && mxfp_soa_quantize) { + switch (dst->type) { + case GGML_TYPE_MXFP4: mxfp_soa_hadamard_quantize = (ggml_from_float_t)quantize_row_mxfp4_soa_hadamard; break; + case GGML_TYPE_MXFP8: mxfp_soa_hadamard_quantize = (ggml_from_float_t)quantize_row_mxfp8_soa_hadamard; break; + case GGML_TYPE_MXFP6: mxfp_soa_hadamard_quantize = (ggml_from_float_t)quantize_row_mxfp6_soa_hadamard; break; + default: break; + } } for (int64_t i03 = 0; i03 < ne03; ++i03) { @@ -5051,20 +5155,12 @@ static void ggml_compute_forward_set_rows_f32( const float * src_row = (const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03); char * dst_row = ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3); - if (apply_hadamard) { - memcpy(had_tmp.data(), src_row, nc * sizeof(float)); - ggml_apply_hadamard_blocks(had_tmp.data(), nc); - if (mxfp_soa_quantize) { - mxfp_soa_quantize(had_tmp.data(), dst_row, nc); - } else { - from_float(had_tmp.data(), dst_row, nc); - } + if (mxfp_soa_hadamard_quantize) { + mxfp_soa_hadamard_quantize(src_row, dst_row, nc); + } else if (mxfp_soa_quantize) { + mxfp_soa_quantize(src_row, dst_row, nc); } else { - if (mxfp_soa_quantize) { - mxfp_soa_quantize(src_row, dst_row, nc); - } else { - from_float(src_row, dst_row, nc); - } + from_float(src_row, dst_row, nc); } } } @@ -8268,7 +8364,6 @@ typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t); struct mxfp_kv_params { mxfp_soa_dequantize_fn dequantize; bool multihead; - int64_t soa_elems; int qs_per_block; int head_qs_bytes; int64_t head_e8m0_offset; @@ -8277,12 +8372,28 @@ struct mxfp_kv_params { // MXFP dispatch parameters for flash attention. struct mxfp_fa_params { - mxfp_soa_quantize_fn q_quantize; + mxfp_soa_quantize_fn q_quantize; // SoA quantize for Q (used only when Hadamard is off AND non-MXFP K path) + // Fused Q round-trip: Hadamard + quantize + dequant in one pass, no SoA buffer. + void (*q_roundtrip)(const float *, float *, int64_t); mxfp_kv_params k; mxfp_kv_params v; bool apply_hadamard; }; +// Compute the SoA row base pointer for a given KV position and head. +// In multihead mode, the SoA region spans all heads at one KV position, +// so the row base must NOT include the per-head offset (head_idx * nb2). +// mxfp_dequant_head handles per-head indexing within the SoA region. +// In per-head mode, each head has its own SoA region, so the base includes nb2. +static inline const char * mxfp_row_ptr( + const mxfp_kv_params & kv, const char * data, + int64_t kv_pos, size_t nb1, int head_idx, size_t nb2, int batch_idx, size_t nb3) { + if (kv.multihead) { + return data + kv_pos*nb1 + batch_idx*nb3; + } + return data + kv_pos*nb1 + head_idx*nb2 + batch_idx*nb3; +} + // Extract one head's SoA data from a multihead row and dequantize. static inline void mxfp_dequant_head( const mxfp_kv_params & kv, const char * row, int head_idx, @@ -8305,7 +8416,6 @@ static mxfp_kv_params mxfp_kv_params_init(ggml_type type, int64_t D, size_t nb2, mxfp_kv_params kv = {}; kv.dequantize = ggml_get_type_traits_cpu(type)->to_float_soa; kv.multihead = (nb2 == (size_t)ggml_row_size(type, D)); - kv.soa_elems = kv.multihead ? ne2 * D : D; kv.qs_per_block = ggml_mxfp_qs_per_block(type); kv.blocks_per_head = (int)(D / 32); kv.head_qs_bytes = kv.blocks_per_head * kv.qs_per_block; @@ -8328,6 +8438,17 @@ static mxfp_fa_params mxfp_fa_params_init( p.q_quantize = ggml_get_type_traits_cpu(k->type)->from_float_soa; p.k = mxfp_kv_params_init(k->type, DK, nbk2, nek2); } + + // Select fused Q round-trip (Hadamard + quantize error, no SoA buffer). + if (is_mxfp_k) { + const bool had = is_mxfp_k && (DK == DV) && ggml_mxfp_use_hadamard(k->type); + switch (k->type) { + case GGML_TYPE_MXFP4: p.q_roundtrip = had ? mxfp4_hadamard_roundtrip : mxfp4_roundtrip; break; + case GGML_TYPE_MXFP8: p.q_roundtrip = had ? mxfp8_hadamard_roundtrip : mxfp8_roundtrip; break; + case GGML_TYPE_MXFP6: p.q_roundtrip = had ? mxfp6_hadamard_roundtrip : mxfp6_roundtrip; break; + default: break; + } + } if (is_mxfp_v) { p.v = mxfp_kv_params_init(v->type, DV, nbv2, nev2); } @@ -8486,22 +8607,14 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( const char * k_base = (const char *) k->data + k_base_offset; const char * v_base = (const char *) v->data + v_base_offset; - const char * k_row_base = mxfp.k.multihead ? ((const char *) k->data + ik3*nbk3) : nullptr; - const char * v_row_base = mxfp.v.multihead ? ((const char *) v->data + iv3*nbv3) : nullptr; + const char * k_data_base = (const char *) k->data; + const char * v_data_base = (const char *) v->data; const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); float Q_f32[MXFP_FA_MAX_D]; - if (is_mxfp_k) { - // Q preprocessing: Hadamard + SoA round-trip captures same quantization loss as K. - if (mxfp.apply_hadamard) { - float q_tmp[MXFP_FA_MAX_D]; - memcpy(q_tmp, pq, DK * sizeof(float)); - ggml_apply_hadamard_blocks(q_tmp, DK); - mxfp.q_quantize(q_tmp, Q_q, DK); - } else { - mxfp.q_quantize(pq, Q_q, DK); - } - mxfp.k.dequantize(Q_q, Q_f32, DK); + if (mxfp.q_roundtrip) { + // Q preprocessing: fused Hadamard + quantize round-trip, no SoA buffer. + mxfp.q_roundtrip(pq, Q_f32, DK); } else { if (mxfp.apply_hadamard) { float q_tmp[MXFP_FA_MAX_D]; @@ -8526,7 +8639,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( float s; // KQ value if (is_mxfp_k) { - const char * k_row = mxfp.k.multihead ? k_row_base + ic*nbk1 : k_base + ic*nbk1; + const char * k_row = mxfp_row_ptr(mxfp.k, k_data_base, + ic, nbk1, ik2, nbk2, ik3, nbk3); mxfp_dequant_head(mxfp.k, k_row, ik2, k_head_soa, k_dequant_buf, DK); ggml_vec_dot_f32(DK, &s, 0, k_dequant_buf, 0, Q_f32, 0, 1); } else { @@ -8572,7 +8686,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( // V += v*expf(s - M) if (mxfp.v.dequantize) { - const char * v_row = mxfp.v.multihead ? v_row_base + ic*nbv1 : v_base + ic*nbv1; + const char * v_row = mxfp_row_ptr(mxfp.v, v_data_base, + ic, nbv1, iv2, nbv2, iv3, nbv3); mxfp_dequant_head(mxfp.v, v_row, iv2, v_head_soa, v_dequant_buf, DV); ggml_vec_mad_f32(DV, VKQ32, v_dequant_buf, vs); } else if (v_to_float) { @@ -8723,7 +8838,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled( if (is_mxfp_v) { GGML_ASSERT(DV <= MXFP_FA_MAX_D); } float k_dequant_buf[MXFP_FA_MAX_D]; - float v_dequant_buf[MXFP_FA_MAX_D]; char k_head_soa[MXFP_FA_SOA_BUF]; char v_head_soa[MXFP_FA_SOA_BUF]; @@ -8786,13 +8900,9 @@ static void ggml_compute_forward_flash_attn_ext_tiled( const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3)); memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float)); - if (is_mxfp_k) { - if (mxfp.apply_hadamard) { - ggml_apply_hadamard_blocks(Q_f32 + tq * DK, DK); - } - uint8_t q_mxfp_buf[MXFP_FA_SOA_BUF]; - mxfp.q_quantize(Q_f32 + tq * DK, q_mxfp_buf, DK); - mxfp.k.dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK); + if (mxfp.q_roundtrip) { + // In-place: Q_f32 is already populated by memcpy above, roundtrip overwrites. + mxfp.q_roundtrip(Q_f32 + tq * DK, Q_f32 + tq * DK, DK); } } for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) { @@ -8843,7 +8953,9 @@ static void ggml_compute_forward_flash_attn_ext_tiled( K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk]; } } else if (mxfp.k.dequantize) { - mxfp_dequant_head(mxfp.k, k_data, ik2, k_head_soa, k_dequant_buf, DK); + const char * k_row = mxfp_row_ptr(mxfp.k, (const char *)k->data, + ic + tk, nbk1, ik2, nbk2, ik3, nbk3); + mxfp_dequant_head(mxfp.k, k_row, ik2, k_head_soa, k_dequant_buf, DK); for (int64_t dk = 0; dk < DK; dk++) { K_f32[dk * KV_TILE_SZ + tk] = k_dequant_buf[dk]; } @@ -8913,8 +9025,9 @@ static void ggml_compute_forward_flash_attn_ext_tiled( } else if (v_type == GGML_TYPE_F32) { memcpy(V32 + tk * DV, v_data, DV * sizeof(float)); } else if (mxfp.v.dequantize) { - mxfp_dequant_head(mxfp.v, v_data, iv2, v_head_soa, v_dequant_buf, DV); - memcpy(V32 + tk * DV, v_dequant_buf, DV * sizeof(float)); + const char * v_row = mxfp_row_ptr(mxfp.v, (const char *)v->data, + ic + tk, nbv1, iv2, nbv2, iv3, nbv3); + mxfp_dequant_head(mxfp.v, v_row, iv2, v_head_soa, V32 + tk * DV, DV); } else { v_to_float(v_data, V32 + tk * DV, DV); } @@ -9087,10 +9200,10 @@ static void ggml_compute_forward_flash_attn_ext_f16( // Split-KV: parallelize across KV chunks for single-query decode (token generation). // Only for types whose tiled/one_chunk paths produce identical results (f32, f16, MXFP). // Standard quant types (q8_0, q4_0) must use the scalar path to preserve vec_dot semantics. - const bool kv_is_f32_f16_or_mxfp = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16 + const bool k_is_f32_f16_or_mxfp = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16 || ggml_is_type_mxfp(k->type)); const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) - && kv_is_f32_f16_or_mxfp + && k_is_f32_f16_or_mxfp && q->type == GGML_TYPE_F32 && nek1 >= 512; if (use_split_kv_path) { @@ -9151,7 +9264,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( // Tiled path: f32, f16, and MXFP only (quant types use one_chunk) bool use_tiled = !use_ref && (q->type == GGML_TYPE_F32 && - kv_is_f32_f16_or_mxfp && + k_is_f32_f16_or_mxfp && (k->type == v->type || ggml_is_type_mxfp(k->type)) && neq1 >= Q_TILE_SZ); #ifdef GGML_SIMD diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index 5cbd177234..0c4faa4fc1 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -189,6 +189,54 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c *s = sumf; } +void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + assert(n % QK_MXFP8 == 0); + static_assert(QK_MXFP8 == QK8_0, "QK_MXFP8 and QK8_0 must be the same"); + + const block_mxfp8 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + const int nb = n / QK_MXFP8; + + float sumf = 0; + for (int ib = 0; ib < nb; ++ib) { + const float d = GGML_CPU_FP16_TO_FP32(y[ib].d) * GGML_E8M0_TO_FP32(x[ib].e); + float sumi = 0; + for (int j = 0; j < QK_MXFP8; ++j) { + sumi += y[ib].qs[j] * ggml_mxfp_fp8_e4m3_to_float(x[ib].qs[j]); + } + sumf += d * sumi; + } + *s = sumf; +} + +void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + assert(n % QK_MXFP6 == 0); + static_assert(QK_MXFP6 == QK8_0, "QK_MXFP6 and QK8_0 must be the same"); + + const block_mxfp6 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + const int nb = n / QK_MXFP6; + + float sumf = 0; + for (int ib = 0; ib < nb; ++ib) { + const float d = GGML_CPU_FP16_TO_FP32(y[ib].d) * GGML_E8M0_TO_FP32(x[ib].e); + float sumi = 0; + for (int j = 0; j < QK_MXFP6; j += 4) { + uint8_t vals[4]; + ggml_mxfp_unpack_fp6x4(&x[ib].qs[j * 3 / 4], vals); + for (int jj = 0; jj < 4; jj++) { + sumi += y[ib].qs[j + jj] * ggml_mxfp_fp6_e2m3_to_float(vals[jj]); + } + } + sumf += d * sumi; + } + *s = sumf; +} + void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index 4a4dd264fe..4c75f9b0cd 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -42,6 +42,8 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index f9358b0432..4d98113139 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -431,13 +431,15 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) // E8M0 shared exponent to float: returns 2^(x - 127). +// Canonical implementation is ggml_mxfp_e8m0_to_fp32 in ggml-common.h. +// This thin wrapper exists because not all callers include ggml-common.h. +// MUST stay in sync — if you change the logic, change ggml-common.h too. +// +// E8M0 = 255 is NaN per MX spec; clamped to 254 (max finite) to match +// the encode path which also clamps to 254, preventing Inf * 0 = NaN. static inline float ggml_e8m0_to_fp32(uint8_t x) { - uint32_t bits; - if (x == 0) { - bits = 0x00400000; // denorm: 0.5 * 2^(-126) = 2^(-127) - } else { - bits = (uint32_t) x << 23; - } + if (x == 255) { x = 254; } + uint32_t bits = (x == 0) ? 0x00400000u : ((uint32_t)x << 23); float result; memcpy(&result, &bits, sizeof(float)); return result; @@ -445,14 +447,8 @@ static inline float ggml_e8m0_to_fp32(uint8_t x) { // E8M0 to float/2: returns 2^(x - 128). static inline float ggml_e8m0_to_fp32_half(uint8_t x) { - uint32_t bits; - if (x < 2) { - // x=0 → 2^(-128), x=1 → 2^(-127): denormal bit patterns - bits = 0x00200000 << x; - } else { - // 0.5 * 2^(x-127) = 2^(x-128): normalized with exponent (x-1) - bits = (uint32_t)(x - 1) << 23; - } + if (x == 255) { x = 254; } + uint32_t bits = (x < 2) ? (0x00200000u << x) : ((uint32_t)(x - 1) << 23); float result; memcpy(&result, &bits, sizeof(float)); return result; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 5c8eb97806..1afd82b6c5 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -263,8 +263,6 @@ float fp8_e4m3_to_float(uint8_t v) { return ggml_mxfp_fp8_e4m3_to_float(v); } uint8_t float_to_fp8_e4m3_rn(float x) { return ggml_mxfp_float_to_fp8_e4m3(x); } // ====================== MXFP quantization infrastructure -// -// MSE-optimal E8M0: tests candidates around round(log2(amax)), picks lowest quantization error. typedef struct { int emax_offset; // type-specific offset to max representable exponent @@ -272,33 +270,13 @@ typedef struct { int bits_per_elem; // 8 = byte-aligned, 6 = packed via fp6x4 uint8_t (*to_elem)(float); float (*to_float)(uint8_t); - float (*mse_error)(float val, float inv_scale, float scale); // NULL = use generic round-trip via to_elem/to_float } mxfp_elem_traits_t; static inline int best_index_mxfp4(float x, float e); -// MSE error for MXFP4 (kvalues are doubled, so scale is halved) -static float mse_error_mxfp4(float val, float inv_scale, float scale) { - const float d = scale * 0.5f; - const float inv_d = (d > 0.0f) ? 1.0f / d : 0.0f; - const float normalized = fabsf(val) * inv_d; - (void)inv_scale; - float qval; - if (normalized < 0.5f) qval = 0.0f; - else if (normalized < 1.5f) qval = 1.0f; - else if (normalized < 2.5f) qval = 2.0f; - else if (normalized < 3.5f) qval = 3.0f; - else if (normalized < 5.0f) qval = 4.0f; - else if (normalized < 7.0f) qval = 6.0f; - else if (normalized < 10.0f) qval = 8.0f; - else qval = 12.0f; - const float err = fabsf(val) - qval * d; - return err * err; -} -static const mxfp_elem_traits_t mxfp4_traits = { MXFP4_E2M1_EMAX_OFFSET, MXFP4_SOA_QS_PER_BLOCK, 4, NULL, NULL, mse_error_mxfp4 }; - -static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_elem_traits_t * traits) { +// E8M0 shared exponent: round(log2(amax)) — no MSE search needed. +static inline uint8_t mxfp_compute_e8m0(const float * x, int qk, int emax_offset) { float amax = 0.0f; for (int j = 0; j < qk; j++) { const float a = fabsf(x[j]); @@ -306,36 +284,8 @@ static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_ } if (amax == 0.0f) return 0; - const int e_base = ggml_mxfp_e8m0_base_estimate(amax, traits->emax_offset); - - int e_lo = e_base - MXFP_E8M0_MSE_RANGE; - int e_hi = e_base + MXFP_E8M0_MSE_RANGE; - if (e_lo < 1) e_lo = 1; - if (e_hi < 1) e_hi = 1; - if (e_hi > 254) e_hi = 254; - int best_e = e_base < 0 ? 0 : (e_base > 254 ? 254 : e_base); - float best_mse = 1e30f; - - for (int test_e = e_lo; test_e <= e_hi; ++test_e) { - const float test_scale = GGML_E8M0_TO_FP32((uint8_t)test_e); - const float test_inv = 1.0f / test_scale; - float mse = 0.0f; - for (int j = 0; j < qk; ++j) { - if (traits->mse_error) { - mse += traits->mse_error(x[j], test_inv, test_scale); - } else { - const float recon = traits->to_float(traits->to_elem(x[j] * test_inv)) * test_scale; - const float err = x[j] - recon; - mse += err * err; - } - } - if (mse < best_mse) { - best_mse = mse; - best_e = test_e; - } - } - - return (uint8_t)best_e; + const int e = ggml_mxfp_e8m0_base_estimate(amax, emax_offset); + return (uint8_t)(e < 0 ? 0 : (e > 254 ? 254 : e)); } static inline int best_index_mxfp4(float x, float e) { @@ -353,26 +303,112 @@ static inline int best_index_mxfp4(float x, float e) { return (x < 0.0f) ? (idx + 8) : idx; } -void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) { - static const int qk = QK_MXFP4; +// Per-block MXFP4 quantize: shared between AoS and SoA paths. +static inline void quantize_block_mxfp4(const float * GGML_RESTRICT src, uint8_t * GGML_RESTRICT qs, uint8_t * e_out) { + const uint8_t e = mxfp_compute_e8m0(src, QK_MXFP4, MXFP4_E2M1_EMAX_OFFSET); + const float d = GGML_E8M0_TO_FP32_HALF(e); + *e_out = e; + for (int j = 0; j < QK_MXFP4/2; ++j) { + const uint8_t x0 = best_index_mxfp4(src[0 + j], d); + const uint8_t x1 = best_index_mxfp4(src[QK_MXFP4/2 + j], d); + qs[j] = x0 | (x1 << 4); + } +} - assert(k % qk == 0); +// Per-block MXFP4 quantize round-trip: apply quantization error without materializing bytes. +// Used for Q preprocessing in flash attention — matches K's error pattern. +static inline void roundtrip_block_mxfp4(float * GGML_RESTRICT vals) { + const uint8_t e = mxfp_compute_e8m0(vals, QK_MXFP4, MXFP4_E2M1_EMAX_OFFSET); + const float d = GGML_E8M0_TO_FP32_HALF(e); + for (int j = 0; j < QK_MXFP4; ++j) { + const int idx = best_index_mxfp4(vals[j], d); + vals[j] = kvalues_mxfp4[idx] * d; // kvalues are doubled, d is halved — matches dequant + } +} - const int nb = k / qk; +// Per-block generic MXFP quantize round-trip (MXFP8/MXFP6). +static inline void roundtrip_block_mxfp(float * GGML_RESTRICT vals, const mxfp_elem_traits_t * traits) { + const uint8_t e = mxfp_compute_e8m0(vals, 32, traits->emax_offset); + const float d = GGML_E8M0_TO_FP32(e); + const float inv_d = d > 0.0f ? 1.0f / d : 0.0f; + for (int j = 0; j < 32; ++j) { + vals[j] = traits->to_float(traits->to_elem(vals[j] * inv_d)) * d; + } +} - for (int i = 0; i < nb; i++) { - const uint8_t e = mxfp_compute_e8m0_mse(&x[i*qk], qk, &mxfp4_traits); - const float d = GGML_E8M0_TO_FP32_HALF(e); +// Fused Hadamard + quantize round-trip: one pass, output is float with quantization error. +void mxfp4_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) { + assert(k % 32 == 0); + for (int64_t i = 0; i < k; i += 32) { + memcpy(dst + i, src + i, 32 * sizeof(float)); + ggml_mxfp_hadamard_32_inplace(dst + i); + roundtrip_block_mxfp4(dst + i); + } +} - y[i].e = e; +// Non-Hadamard round-trip for MXFP4 (Hadamard disabled or V cache). +void mxfp4_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) { + assert(k % 32 == 0); + for (int64_t i = 0; i < k; i += 32) { + memcpy(dst + i, src + i, 32 * sizeof(float)); + roundtrip_block_mxfp4(dst + i); + } +} - for (int j = 0; j < qk/2; ++j) { - const uint8_t x0 = best_index_mxfp4(x[i*qk + 0 + j], d); - const uint8_t x1 = best_index_mxfp4(x[i*qk + qk/2 + j], d); +// Per-block MXFP4 dequant: shared between AoS and SoA paths. +static inline void dequantize_block_mxfp4(const uint8_t * GGML_RESTRICT qs, uint8_t e, float * GGML_RESTRICT dst) { + const float d = GGML_E8M0_TO_FP32_HALF(e); + for (int j = 0; j < QK_MXFP4/2; ++j) { + dst[0 + j] = kvalues_mxfp4[qs[j] & 0x0F] * d; + dst[QK_MXFP4/2 + j] = kvalues_mxfp4[qs[j] >> 4] * d; + } +} - y[i].qs[j] = x0; - y[i].qs[j] |= x1 << 4; +// Per-block generic MXFP quantize/dequant: shared between AoS and SoA for MXFP8/MXFP6. +static inline void quantize_block_mxfp(const float * GGML_RESTRICT src, uint8_t * GGML_RESTRICT qs, + uint8_t * e_out, const mxfp_elem_traits_t * traits) { + const uint8_t e = mxfp_compute_e8m0(src, 32, traits->emax_offset); + const float d = GGML_E8M0_TO_FP32(e); + const float inv_d = d > 0.0f ? 1.0f / d : 0.0f; + *e_out = e; + if (traits->bits_per_elem == 8) { + for (int j = 0; j < 32; ++j) { + qs[j] = traits->to_elem(src[j] * inv_d); } + } else { + for (int j = 0; j < 32; j += 4) { + uint8_t vals[4]; + for (int jj = 0; jj < 4; jj++) { + vals[jj] = traits->to_elem(src[j + jj] * inv_d); + } + pack_fp6x4(vals, &qs[j * 3 / 4]); + } + } +} + +static inline void dequantize_block_mxfp(const uint8_t * GGML_RESTRICT qs, uint8_t e, + float * GGML_RESTRICT dst, const mxfp_elem_traits_t * traits) { + const float d = GGML_E8M0_TO_FP32(e); + if (traits->bits_per_elem == 8) { + for (int j = 0; j < 32; ++j) { + dst[j] = traits->to_float(qs[j]) * d; + } + } else { + for (int j = 0; j < 32; j += 4) { + uint8_t vals[4]; + unpack_fp6x4(&qs[j * 3 / 4], vals); + for (int jj = 0; jj < 4; jj++) { + dst[j + jj] = traits->to_float(vals[jj]) * d; + } + } + } +} + +void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) { + assert(k % QK_MXFP4 == 0); + const int nb = k / QK_MXFP4; + for (int i = 0; i < nb; i++) { + quantize_block_mxfp4(&x[i*QK_MXFP4], y[i].qs, &y[i].e); } } @@ -522,22 +558,10 @@ void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRI } void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { - static const int qk = QK_MXFP4; - - assert(k % qk == 0); - - const int nb = k / qk; - + assert(k % QK_MXFP4 == 0); + const int nb = k / QK_MXFP4; for (int i = 0; i < nb; i++) { - const float d = GGML_E8M0_TO_FP32_HALF(x[i].e); - - for (int j = 0; j < qk/2; ++j) { - const int8_t x0 = kvalues_mxfp4[x[i].qs[j] & 0x0F]; - const int8_t x1 = kvalues_mxfp4[x[i].qs[j] >> 4]; - - y[i*qk + j + 0 ] = x0*d; - y[i*qk + j + qk/2] = x1*d; - } + dequantize_block_mxfp4(x[i].qs, x[i].e, &y[i*QK_MXFP4]); } } @@ -582,112 +606,95 @@ uint8_t float_to_fp8_e5m2_rn(float x) { return ggml_mxfp_float_to_fp8_e5m2(x); } void pack_fp6x4(const uint8_t v[4], uint8_t out[3]) { ggml_mxfp_pack_fp6x4(v, out); } void unpack_fp6x4(const uint8_t in[3], uint8_t v[4]) { ggml_mxfp_unpack_fp6x4(in, v); } -static const mxfp_elem_traits_t mxfp8_e4m3_traits = { MXFP8_E4M3_EMAX_OFFSET, MXFP8_SOA_QS_PER_BLOCK, 8, float_to_fp8_e4m3_rn, fp8_e4m3_to_float, NULL }; -static const mxfp_elem_traits_t mxfp6_e2m3_traits = { MXFP6_E2M3_EMAX_OFFSET, MXFP6_SOA_QS_PER_BLOCK, 6, float_to_fp6_e2m3_rn, fp6_e2m3_to_float, NULL }; +static const mxfp_elem_traits_t mxfp8_e4m3_traits = { MXFP8_E4M3_EMAX_OFFSET, MXFP8_SOA_QS_PER_BLOCK, 8, float_to_fp8_e4m3_rn, fp8_e4m3_to_float }; +static const mxfp_elem_traits_t mxfp6_e2m3_traits = { MXFP6_E2M3_EMAX_OFFSET, MXFP6_SOA_QS_PER_BLOCK, 6, float_to_fp6_e2m3_rn, fp6_e2m3_to_float }; + +// MXFP8 AoS quantize/dequant — uses shared per-block helpers. +void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k) { + assert(k % QK_MXFP8 == 0); + const int nb = k / QK_MXFP8; + for (int i = 0; i < nb; i++) { + quantize_block_mxfp(&x[i*QK_MXFP8], y[i].qs, &y[i].e, &mxfp8_e4m3_traits); + } +} + +void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_MXFP8 == 0); + const int nb = k / QK_MXFP8; + for (int i = 0; i < nb; i++) { + dequantize_block_mxfp(x[i].qs, x[i].e, &y[i*QK_MXFP8], &mxfp8_e4m3_traits); + } +} + +// MXFP6 AoS quantize/dequant — uses shared per-block helpers. +void quantize_row_mxfp6_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k) { + assert(k % QK_MXFP6 == 0); + const int nb = k / QK_MXFP6; + for (int i = 0; i < nb; i++) { + quantize_block_mxfp(&x[i*QK_MXFP6], y[i].qs, &y[i].e, &mxfp6_e2m3_traits); + } +} + +void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_MXFP6 == 0); + const int nb = k / QK_MXFP6; + for (int i = 0; i < nb; i++) { + dequantize_block_mxfp(x[i].qs, x[i].e, &y[i*QK_MXFP6], &mxfp6_e2m3_traits); + } +} // ====================== SoA (Struct-of-Arrays) quantize/dequantize for flash attention void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) { assert(k % QK_MXFP4 == 0); const int nb = k / QK_MXFP4; - char * row = (char *)dst; - char * qs_base = row; - char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK); + char * qs_base = (char *)dst; + char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK); for (int i = 0; i < nb; i++) { - const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP4], QK_MXFP4, &mxfp4_traits); - const float d = GGML_E8M0_TO_FP32_HALF(e); - - e8m0_base[i] = (char)e; - uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP4_SOA_QS_PER_BLOCK)); - for (int j = 0; j < QK_MXFP4/2; ++j) { - const uint8_t x0 = best_index_mxfp4(x[i*QK_MXFP4 + 0 + j], d); - const uint8_t x1 = best_index_mxfp4(x[i*QK_MXFP4 + QK_MXFP4/2 + j], d); - qs[j] = x0 | (x1 << 4); - } + quantize_block_mxfp4(&x[i*QK_MXFP4], qs, (uint8_t *)&e8m0_base[i]); } } void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) { assert(k % QK_MXFP4 == 0); const int nb = k / QK_MXFP4; - const char * row = (const char *)src; - const char * qs_base = row; - const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK); + const char * qs_base = (const char *)src; + const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK); for (int i = 0; i < nb; i++) { - const float d = GGML_E8M0_TO_FP32_HALF((uint8_t)e8m0_base[i]); const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP4_SOA_QS_PER_BLOCK)); - - for (int j = 0; j < QK_MXFP4/2; ++j) { - const int8_t x0 = kvalues_mxfp4[qs[j] & 0x0F]; - const int8_t x1 = kvalues_mxfp4[qs[j] >> 4]; - y[i*QK_MXFP4 + j + 0 ] = x0*d; - y[i*QK_MXFP4 + j + QK_MXFP4/2] = x1*d; - } + dequantize_block_mxfp4(qs, (uint8_t)e8m0_base[i], &y[i*QK_MXFP4]); } } -// Unified SoA quantize for byte-aligned (FP8) and 6-bit packed (FP6) MXFP formats. +// Unified SoA quantize/dequantize — delegates to shared per-block helpers. static void quantize_row_mxfp_soa_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k, const mxfp_elem_traits_t * traits) { - const int qk = 32; - assert(k % qk == 0); - const int nb = k / qk; + assert(k % 32 == 0); + const int nb = k / 32; const int qpb = traits->qs_per_block; char * qs_base = (char *)dst; char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, qpb); for (int i = 0; i < nb; i++) { - const uint8_t e = mxfp_compute_e8m0_mse(&x[i*qk], qk, traits); - const float d = GGML_E8M0_TO_FP32(e); - const float inv_d = d > 0.0f ? 1.0f / d : 0.0f; - e8m0_base[i] = (char)e; - uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, qpb)); - if (traits->bits_per_elem == 8) { - for (int j = 0; j < qk; ++j) { - qs[j] = traits->to_elem(x[i*qk + j] * inv_d); - } - } else { - for (int j = 0; j < qk; j += 4) { - uint8_t vals[4]; - for (int jj = 0; jj < 4; jj++) { - vals[jj] = traits->to_elem(x[i*qk + j + jj] * inv_d); - } - pack_fp6x4(vals, &qs[j * 3 / 4]); - } - } + quantize_block_mxfp(&x[i*32], qs, (uint8_t *)&e8m0_base[i], traits); } } -// Unified SoA dequantize for byte-aligned (FP8) and 6-bit packed (FP6) MXFP formats. static void dequantize_row_mxfp_soa_impl(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, const mxfp_elem_traits_t * traits) { - const int qk = 32; - assert(k % qk == 0); - const int nb = k / qk; + assert(k % 32 == 0); + const int nb = k / 32; const int qpb = traits->qs_per_block; const char * qs_base = (const char *)src; const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, qpb); for (int i = 0; i < nb; i++) { - const float d = GGML_E8M0_TO_FP32((uint8_t)e8m0_base[i]); const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, qpb)); - if (traits->bits_per_elem == 8) { - for (int j = 0; j < qk; ++j) { - y[i*qk + j] = traits->to_float(qs[j]) * d; - } - } else { - for (int j = 0; j < qk; j += 4) { - uint8_t vals[4]; - unpack_fp6x4(&qs[j * 3 / 4], vals); - for (int jj = 0; jj < 4; jj++) { - y[i*qk + j + jj] = traits->to_float(vals[jj]) * d; - } - } - } + dequantize_block_mxfp(qs, (uint8_t)e8m0_base[i], &y[i*32], traits); } } @@ -703,6 +710,83 @@ void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) { dequantize_row_mxfp_soa_impl(src, y, k, &mxfp6_e2m3_traits); } + +// Fused Hadamard + SoA quantize: one read, one write, 32-float stack buffer per block. +// Eliminates the full-row temp buffer and extra memory pass. +void quantize_row_mxfp4_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) { + assert(k % QK_MXFP4 == 0); + const int nb = k / QK_MXFP4; + char * qs_base = (char *)dst; + char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK); + + for (int i = 0; i < nb; i++) { + float tmp[32]; + memcpy(tmp, &x[i*QK_MXFP4], QK_MXFP4 * sizeof(float)); + ggml_mxfp_hadamard_32_inplace(tmp); + uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP4_SOA_QS_PER_BLOCK)); + quantize_block_mxfp4(tmp, qs, (uint8_t *)&e8m0_base[i]); + } +} + +static void quantize_row_mxfp_soa_hadamard_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, + int64_t k, const mxfp_elem_traits_t * traits) { + assert(k % 32 == 0); + const int nb = k / 32; + const int qpb = traits->qs_per_block; + char * qs_base = (char *)dst; + char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, qpb); + + for (int i = 0; i < nb; i++) { + float tmp[32]; + memcpy(tmp, &x[i*32], 32 * sizeof(float)); + ggml_mxfp_hadamard_32_inplace(tmp); + uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, qpb)); + quantize_block_mxfp(tmp, qs, (uint8_t *)&e8m0_base[i], traits); + } +} + +void quantize_row_mxfp8_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) { + quantize_row_mxfp_soa_hadamard_impl(x, dst, k, &mxfp8_e4m3_traits); +} +void quantize_row_mxfp6_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) { + quantize_row_mxfp_soa_hadamard_impl(x, dst, k, &mxfp6_e2m3_traits); +} + +// MXFP8/6 quantize round-trips (with and without Hadamard). +void mxfp8_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) { + assert(k % 32 == 0); + for (int64_t i = 0; i < k; i += 32) { + memcpy(dst + i, src + i, 32 * sizeof(float)); + ggml_mxfp_hadamard_32_inplace(dst + i); + roundtrip_block_mxfp(dst + i, &mxfp8_e4m3_traits); + } +} + +void mxfp6_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) { + assert(k % 32 == 0); + for (int64_t i = 0; i < k; i += 32) { + memcpy(dst + i, src + i, 32 * sizeof(float)); + ggml_mxfp_hadamard_32_inplace(dst + i); + roundtrip_block_mxfp(dst + i, &mxfp6_e2m3_traits); + } +} + +void mxfp8_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) { + assert(k % 32 == 0); + for (int64_t i = 0; i < k; i += 32) { + memcpy(dst + i, src + i, 32 * sizeof(float)); + roundtrip_block_mxfp(dst + i, &mxfp8_e4m3_traits); + } +} + +void mxfp6_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) { + assert(k % 32 == 0); + for (int64_t i = 0; i < k; i += 32) { + memcpy(dst + i, src + i, 32 * sizeof(float)); + roundtrip_block_mxfp(dst + i, &mxfp6_e2m3_traits); + } +} + // // 2-6 bit quantization in super-blocks // @@ -2373,6 +2457,18 @@ size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, return nrow * ggml_row_size(GGML_TYPE_NVFP4, n_per_row); } +size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_UNUSED(quant_weights); + quantize_row_mxfp8_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_MXFP8, n_per_row); +} + +size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_UNUSED(quant_weights); + quantize_row_mxfp6_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_MXFP6, n_per_row); +} + // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs) void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) { diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 4dec9ad351..d1cc8d4c85 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -22,6 +22,8 @@ GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_mxfp6_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); @@ -48,6 +50,8 @@ GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GG //GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); // SoA quantize/dequantize for flash attention GGML_API void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); @@ -56,6 +60,18 @@ GGML_API void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_ GGML_API void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); GGML_API void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); +// Fused Hadamard + SoA quantize (one pass, no temp buffer) +GGML_API void quantize_row_mxfp4_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); +GGML_API void quantize_row_mxfp8_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); +GGML_API void quantize_row_mxfp6_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); +// Quantize round-trip: apply quantization error to floats without materializing bytes. +// Hadamard variants include the rotation. Used for Q preprocessing in flash attention. +GGML_API void mxfp4_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k); +GGML_API void mxfp8_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k); +GGML_API void mxfp6_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k); +GGML_API void mxfp4_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k); +GGML_API void mxfp8_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k); +GGML_API void mxfp6_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k); GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -103,6 +119,8 @@ GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTR GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); // MXFP element converters GGML_API float fp8_e4m3_to_float(uint8_t v); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 21b9a81eae..470b68c4bc 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -727,16 +727,20 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .from_float_ref = (ggml_from_float_t)quantize_row_nvfp4_ref, }, [GGML_TYPE_MXFP8] = { - .type_name = "mxfp8_e4m3", + .type_name = "mxfp8", .blck_size = QK_MXFP8, .type_size = sizeof(block_mxfp8), .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_mxfp8, + .from_float_ref = (ggml_from_float_t)quantize_row_mxfp8_ref, }, [GGML_TYPE_MXFP6] = { - .type_name = "mxfp6_e2m3", + .type_name = "mxfp6", .blck_size = QK_MXFP6, .type_size = sizeof(block_mxfp6), .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_mxfp6, + .from_float_ref = (ggml_from_float_t)quantize_row_mxfp6_ref, }, [GGML_TYPE_Q2_K] = { .type_name = "q2_K", @@ -7693,8 +7697,8 @@ size_t ggml_quantize_chunk( case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_MXFP8: GGML_ABORT("MXFP8 is KV-cache-only (SoA layout) — use from_float_soa"); - case GGML_TYPE_MXFP6: GGML_ABORT("MXFP6 is KV-cache-only (SoA layout) — use from_float_soa"); + case GGML_TYPE_MXFP8: result = quantize_mxfp8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_MXFP6: result = quantize_mxfp6(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index a890510edf..4cf4ce69d1 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1121,8 +1121,9 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm ggml_tensor * result = ggml_set_rows(ctx, k_dst, k_cur, k_idxs); // enable Hadamard rotation for MXFP K cache (QuaRot arXiv:2404.00456, BRQ arXiv:2511.04214) - // skipped for MLA (V is a view of K) and E5M2/E3M2 (2-bit mantissa, no benefit) - if (is_mxfp && !hparams.is_mla() && ggml_mxfp_use_hadamard(k->type)) { + // skipped when DK != DV (MLA) and for E5M2/E3M2 (2-bit mantissa, no benefit). + // condition must match flash attention read path (ops.cpp: DK == DV). + if (is_mxfp && hparams.n_embd_head_k(il) == hparams.n_embd_head_v(il) && ggml_mxfp_use_hadamard(k->type)) { ((int32_t *)result->op_params)[0] = 1; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d102c5676c..281a9b65f4 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6390,7 +6390,7 @@ struct test_flash_attn_ext : public test_case { init_tensor_uniform(t, -10.0f, 10.0f); } else if (strcmp(t->name, "m") == 0) { init_tensor_kq_mask(t); - } else if ((strcmp(t->name, "k") == 0 || strcmp(t->name, "v") == 0) && ggml_is_type_mxfp(t->type)) { + } else if (ggml_is_type_mxfp(t->type)) { init_tensor_mxfp_soa(t); } else { init_tensor_uniform(t); @@ -7398,7 +7398,7 @@ static const ggml_type all_types[] = { GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0, - GGML_TYPE_MXFP4, + GGML_TYPE_MXFP4, GGML_TYPE_MXFP8, GGML_TYPE_MXFP6, GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K, @@ -7533,11 +7533,14 @@ static std::vector> make_test_cases_eval() { } // SET_ROWS with Hadamard rotation (exercises the op_params[0] flag used by MXFP KV cache) - for (ggml_type type : {GGML_TYPE_MXFP4, GGML_TYPE_MXFP8, - GGML_TYPE_MXFP6}) { + for (ggml_type type : {GGML_TYPE_MXFP4, GGML_TYPE_MXFP8, GGML_TYPE_MXFP6}) { // ne[0] must be divisible by 32 (Hadamard block size) test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 128, 5, 1, 1 }, { 1, 1 }, 1, false, true)); test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 5, 1, 3 }, { 1, 1 }, 1, false, true)); + // multi-row, broadcast, views + test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 128, 5, 1, 1 }, { 1, 1 }, 1, true, true)); + test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 11, 1, 1 }, { 2, 3 }, 7, false, true)); + test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 512, 5, 3, 1 }, { 1, 1 }, 1, false, true)); } for (int mode : { GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_VISION }) { diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index 8f1dcf10f0..babc9f58e1 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -33,7 +33,7 @@ constexpr float MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 = 0.0020f; // These represent actual RMSE through the full KV cache write/read path. constexpr float MAX_MXFP_PIPELINE_ERROR_MXFP4 = 0.40f; constexpr float MAX_MXFP_PIPELINE_ERROR_MXFP8 = 0.08f; -constexpr float MAX_MXFP_PIPELINE_ERROR_MXFP6 = 0.10f; +constexpr float MAX_MXFP_PIPELINE_ERROR_MXFP6 = 0.30f; constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f; constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f;