test : add testing and fixes

* cleanup : hoist mxfp soa functions

* fix: CI failures — CUDA __device__ init, Metal MXFP supports_op, SoA test assert

Three fixes for CI failures:

1. Remove <cmath> from CUDA/HIP/MUSA section of ggml-common.h — the include
   causes NAN/INFINITY to become non-constexpr, breaking __device__ static
   table initialization for the MXFP LUTs.

2. Add MXFP type guards to Metal's supports_op: MXFP8/MXFP6 have no Metal
   shaders yet (reject all ops), MXFP4 has AoS shaders (MUL_MAT, GET_ROWS)
   but no SoA/flash attention support yet (reject FLASH_ATTN_EXT, SET_ROWS).

3. Replace strict assert in test-backend-ops init_tensor_mxfp_soa with a
   conditional fallback — when ne2 is not divisible by heads_per_region,
   fall back to per-head SoA init instead of crashing.

* fix : correct guard for mxfp cpu dequant functions

* fix: CUDA MXFP LUT init and MXFP flash attention SoA test layout

- Add per-platform GGML_TABLE_NAN/GGML_TABLE_INFINITY macros for MXFP
  LUTs — uses __uint_as_float on CUDA to avoid MSVC non-constexpr INFINITY
- Fix init_tensor_mxfp_soa to detect multihead SoA from tensor strides,
  matching the KV cache layout for permuted flash attention tests

* fix: CUDA MXFP LUT init — use __builtin_nanf/__builtin_inff for constexpr device tables

CUDA/HIP/MUSA __device__ static tables require constexpr initializers.
Standard NAN/INFINITY macros may expand to non-constexpr expressions
(e.g. MSVC: (float)(1e+300), nvcc: __uint_as_float is not constexpr
for static init). Previous fix attempted __uint_as_float for nvcc and
__builtin_bit_cast for clang — neither worked universally.

Use __builtin_nanf("") and __builtin_inff() which are constexpr on
all target compilers (nvcc, clang for HIP/MUSA, GCC, MSVC). Define
once before the platform #if chain instead of per-platform copies.

* fix: correct E5M2 LUT precision and add converter-vs-LUT validation tests

The kvalues_mxfp8_e5m2 LUT had 50 values with insufficient decimal
precision, causing bitwise mismatches against the IEEE-754 element
converter. Regenerated from ggml_mxfp_fp8_e5m2_to_float() with %.9e
precision for exact float round-trip on all 256 entries.

Also consolidates GGML_TABLE_NAN/GGML_TABLE_INFINITY into a single
definition using __builtin_nanf/__builtin_inff (constexpr on all
target compilers), and adds LUT validation tests to test-quantize-fns
that verify all 5 MXFP element converters match their canonical LUT
values (FP4 E2M1: 16, FP6 E2M3: 64, FP6 E3M2: 64, FP8 E4M3: 256,
FP8 E5M2: 256 — 656 total values verified).

* fix: MSVC compat for GGML_TABLE_NAN/INFINITY — use builtins only on GCC/Clang/nvcc

MSVC does not support __builtin_nanf/__builtin_inff. Use standard
NAN/INFINITY macros on MSVC (which work for regular static tables),
and compiler builtins only on GCC/Clang/nvcc (needed for CUDA
__device__ table constexpr initialization).

* fix: handle nvcc+MSVC host — check __CUDACC__ before _MSC_VER for NAN/INF macros

When nvcc uses MSVC as the host compiler, both _MSC_VER and __CUDACC__
are defined. The previous fix checked _MSC_VER first, giving nvcc the
MSVC NAN/INFINITY macros which are not constexpr for __device__ tables.
Add __CUDACC__ exclusion so nvcc gets __builtin_nanf/__builtin_inff.

* cleanup: remove AoS MXFP6/MXFP8 dequant code — these types are KV-cache-only (SoA)

MXFP6 (E2M3) and MXFP8 (E4M3) exist only for KV cache flash attention,
which uses SoA (Struct-of-Arrays) layout. The AoS dequant functions
(NEON, AVX2, CPU dispatch, generic wrappers) were incorrectly added
and are dead code — no model stores weights in these formats.

Removed:
- AoS NEON dequant: dequantize_row_mxfp{6,8}_neon, _cpu dispatch
- AoS AVX2 dequant: dequantize_row_mxfp{6,8}_avx2, _cpu dispatch
- AoS generic wrappers: dequantize_row_mxfp{6,8}_cpu_generic
- AoS fallback defines in arch-fallback.h
- CPU traits .to_float entries for MXFP6/MXFP8
- MXFP6/MXFP8 from all_types[] in test-backend-ops (no AoS tests)

Kept (correct SoA code):
- All *_soa_* functions (NEON, AVX2, generic, dispatch)
- CPU traits .from_float_soa / .to_float_soa
- Flash attention and SET_ROWS Hadamard test cases
- Scalar reference dequant in ggml-quants.c (test-quantize-fns roundtrip)
- MXFP4 AoS code (upstream model weight support, untouched)

Fixes ARM64 CI failure: GET_ROWS(mxfp6_e2m3) was testing dead AoS code
that had a NEON bug. The test no longer runs because the type is
correctly excluded from AoS test paths.

* test: guard all MXFP types must have SoA traits for flash attention

All MXFP flash attention uses SoA layout exclusively. Test validates:
- ALL MXFP types (MXFP4, MXFP6, MXFP8) have from_float_soa and to_float_soa
- MXFP6/MXFP8 (KV-cache-only) do NOT have AoS CPU to_float

Prevents regression: if someone adds AoS dequant back for MXFP6/MXFP8,
or removes SoA traits from any MXFP type, CI will catch it.

* test: add Hadamard, SoA cross-check, E8M0, and layout offset tests

* test: add MXFP converter edge cases, FP6 packing, E8M0 known-answer tests

Add comprehensive tests to catch the bugs backend implementers hit most:
- Element converter edge cases: subnormals, max finite, saturation, NaN, sign
- FP6 pack/unpack exhaustive round-trip with known-answer byte verification
- E8M0 known-answer decode + HALF vs FULL scale distinction
- E8M0 rounding boundary at sqrt(2) threshold (catches floor-only bugs)
- Converter exhaustive round-trip: quantize(dequantize(i))==i for all formats
- Consolidate duplicate SoA switches into single table in test-backend-ops

* test: add AoS/SoA cross-check, Hadamard pipeline, format spec, and mxfp_rmse

- MXFP4 AoS vs SoA cross-check: two independent code paths, bitwise match
- Full Hadamard pipeline roundtrip: H→quantize→dequant→H for all 3 types
- mxfp_rmse helper: computes sqrt(sum/n), with named pipeline constants
- Block size consistency: verify QK_MXFP{4,8,6} == 32
- EMAX_OFFSET vs format max: validate constants produce valid E8M0
- Edge case LUT validation: expected_bits verified against canonical LUTs
- FP4 E2M1 exhaustive converter round-trip (16/16)

* cleanup: tighten MXFP test comments to match repo conventions

* fix: platform-specific NaN/Infinity for GPU device table initializers

FP8 E4M3/E5M2 LUTs contain NaN/Inf which cannot be constexpr-initialized
in __device__ tables on any CUDA/HIP/MUSA version. No GPU backend uses
these LUTs (they use converter functions instead), so guard them out of
GPU builds entirely. Simplify GGML_TABLE_NAN/INFINITY to CPU-only macros.
This commit is contained in:
Tim Burke 2026-03-22 01:07:55 -04:00 committed by GitHub
parent dd263ff567
commit ad2fa9035a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1032 additions and 282 deletions

View File

@ -467,6 +467,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
GGML_FTYPE_MOSTLY_MXFP4_E2M1 = 25, // except 1d tensors
GGML_FTYPE_MOSTLY_MXFP4 = GGML_FTYPE_MOSTLY_MXFP4_E2M1, // compat alias
GGML_FTYPE_MOSTLY_NVFP4 = 26, // except 1d tensors
};

View File

@ -574,11 +574,20 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_
#ifndef GGML_COMMON_IMPL
// NaN/Infinity for FP8 LUT initializers (CPU-only, guarded out of GPU builds).
#if defined(_MSC_VER) && !defined(__clang__)
#include <math.h>
#define GGML_TABLE_NAN NAN
#define GGML_TABLE_INFINITY INFINITY
#else
#define GGML_TABLE_NAN __builtin_nanf("")
#define GGML_TABLE_INFINITY __builtin_inff()
#endif
#if defined(GGML_COMMON_IMPL_C)
#include <stdint.h>
#include <string.h>
#include <math.h>
#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
#define GGML_TABLE_END() };
#define GGML_MXFP_FUNC static inline
@ -636,7 +645,6 @@ static inline float ggml_mxfp_u32_as_f32_(uint32_t u) { float f; memcpy(&f, &
#define GGML_COMMON_IMPL
#elif defined(GGML_COMMON_IMPL_SYCL)
#include <cstdint>
#include <cstring>
#include <cmath>
@ -1308,6 +1316,10 @@ GGML_TABLE_BEGIN(float, kvalues_mxfp6_e3m2, 64)
-8.0f, -10.0f, -12.0f, -14.0f, -16.0f, -20.0f, -24.0f, -28.0f,
GGML_TABLE_END()
// FP8 E4M3/E5M2 LUTs contain NaN/Inf which cannot be constexpr-initialized in
// __device__ tables. GPU backends use the converter functions instead.
#if !defined(GGML_COMMON_DECL_CUDA) && !defined(GGML_COMMON_DECL_HIP) && !defined(GGML_COMMON_DECL_MUSA)
// FP8 E4M3 dequantization LUT: byte -> float. Entry 127 = 448 (max finite), 255 = NaN.
GGML_TABLE_BEGIN(float, kvalues_mxfp8_e4m3, 256)
0.0f, 0.001953125f, 0.00390625f, 0.005859375f, 0.0078125f, 0.009765625f, 0.01171875f, 0.013671875f,
@ -1325,7 +1337,7 @@ GGML_TABLE_BEGIN(float, kvalues_mxfp8_e4m3, 256)
32.0f, 36.0f, 40.0f, 44.0f, 48.0f, 52.0f, 56.0f, 60.0f,
64.0f, 72.0f, 80.0f, 88.0f, 96.0f, 104.0f, 112.0f, 120.0f,
128.0f, 144.0f, 160.0f, 176.0f, 192.0f, 208.0f, 224.0f, 240.0f,
256.0f, 288.0f, 320.0f, 352.0f, 384.0f, 416.0f, 448.0f, NAN,
256.0f, 288.0f, 320.0f, 352.0f, 384.0f, 416.0f, 448.0f, GGML_TABLE_NAN,
-0.0f,-0.001953125f, -0.00390625f,-0.005859375f, -0.0078125f,-0.009765625f, -0.01171875f,-0.013671875f,
-0.015625f,-0.017578125f, -0.01953125f,-0.021484375f, -0.0234375f,-0.025390625f, -0.02734375f,-0.029296875f,
-0.03125f, -0.03515625f, -0.0390625f, -0.04296875f, -0.046875f, -0.05078125f, -0.0546875f, -0.05859375f,
@ -1341,45 +1353,48 @@ GGML_TABLE_BEGIN(float, kvalues_mxfp8_e4m3, 256)
-32.0f, -36.0f, -40.0f, -44.0f, -48.0f, -52.0f, -56.0f, -60.0f,
-64.0f, -72.0f, -80.0f, -88.0f, -96.0f, -104.0f, -112.0f, -120.0f,
-128.0f, -144.0f, -160.0f, -176.0f, -192.0f, -208.0f, -224.0f, -240.0f,
-256.0f, -288.0f, -320.0f, -352.0f, -384.0f, -416.0f, -448.0f, NAN,
-256.0f, -288.0f, -320.0f, -352.0f, -384.0f, -416.0f, -448.0f, GGML_TABLE_NAN,
GGML_TABLE_END()
// FP8 E5M2 dequantization LUT: byte -> float. Entries 124-127 = {Inf, NaN, NaN, NaN}.
// Generated from ggml_mxfp_fp8_e5m2_to_float() with %.9e precision for exact float round-trip.
GGML_TABLE_BEGIN(float, kvalues_mxfp8_e5m2, 256)
0.0f, 1.525879e-05f, 3.051758e-05f, 4.577637e-05f, 6.103516e-05f, 7.629395e-05f, 9.155273e-05f, 1.068115e-04f,
1.220703e-04f, 1.525879e-04f, 1.831055e-04f, 2.136230e-04f, 2.441406e-04f, 3.051758e-04f, 3.662109e-04f, 4.272461e-04f,
4.882812e-04f, 6.103516e-04f, 7.324219e-04f, 8.544922e-04f, 9.765625e-04f, 1.220703e-03f, 1.464844e-03f, 1.708984e-03f,
1.953125e-03f, 2.441406e-03f, 2.929688e-03f, 3.417969e-03f, 3.906250e-03f, 4.882812e-03f, 5.859375e-03f, 6.835938e-03f,
7.812500e-03f, 9.765625e-03f, 1.171875e-02f, 1.367188e-02f, 1.562500e-02f, 1.953125e-02f, 2.343750e-02f, 2.734375e-02f,
3.125000e-02f, 3.906250e-02f, 4.687500e-02f, 5.468750e-02f, 6.250000e-02f, 7.812500e-02f, 9.375000e-02f, 1.093750e-01f,
0.125f, 0.15625f, 0.1875f, 0.21875f, 0.25f, 0.3125f, 0.375f, 0.4375f,
0.5f, 0.625f, 0.75f, 0.875f, 1.0f, 1.25f, 1.5f, 1.75f,
2.0f, 2.5f, 3.0f, 3.5f, 4.0f, 5.0f, 6.0f, 7.0f,
8.0f, 10.0f, 12.0f, 14.0f, 16.0f, 20.0f, 24.0f, 28.0f,
32.0f, 40.0f, 48.0f, 56.0f, 64.0f, 80.0f, 96.0f, 112.0f,
128.0f, 160.0f, 192.0f, 224.0f, 256.0f, 320.0f, 384.0f, 448.0f,
512.0f, 640.0f, 768.0f, 896.0f, 1024.0f, 1280.0f, 1536.0f, 1792.0f,
2048.0f, 2560.0f, 3072.0f, 3584.0f, 4096.0f, 5120.0f, 6144.0f, 7168.0f,
8192.0f, 10240.0f, 12288.0f, 14336.0f, 16384.0f, 20480.0f, 24576.0f, 28672.0f,
32768.0f, 40960.0f, 49152.0f, 57344.0f, INFINITY, NAN, NAN, NAN,
-0.0f,-1.525879e-05f,-3.051758e-05f,-4.577637e-05f,-6.103516e-05f,-7.629395e-05f,-9.155273e-05f,-1.068115e-04f,
-1.220703e-04f,-1.525879e-04f,-1.831055e-04f,-2.136230e-04f,-2.441406e-04f,-3.051758e-04f,-3.662109e-04f,-4.272461e-04f,
-4.882812e-04f,-6.103516e-04f,-7.324219e-04f,-8.544922e-04f,-9.765625e-04f,-1.220703e-03f,-1.464844e-03f,-1.708984e-03f,
-1.953125e-03f,-2.441406e-03f,-2.929688e-03f,-3.417969e-03f,-3.906250e-03f,-4.882812e-03f,-5.859375e-03f,-6.835938e-03f,
-7.812500e-03f,-9.765625e-03f,-1.171875e-02f,-1.367188e-02f,-1.562500e-02f,-1.953125e-02f,-2.343750e-02f,-2.734375e-02f,
-3.125000e-02f,-3.906250e-02f,-4.687500e-02f,-5.468750e-02f,-6.250000e-02f,-7.812500e-02f,-9.375000e-02f,-1.093750e-01f,
-0.125f, -0.15625f, -0.1875f, -0.21875f, -0.25f, -0.3125f, -0.375f, -0.4375f,
-0.5f, -0.625f, -0.75f, -0.875f, -1.0f, -1.25f, -1.5f, -1.75f,
-2.0f, -2.5f, -3.0f, -3.5f, -4.0f, -5.0f, -6.0f, -7.0f,
-8.0f, -10.0f, -12.0f, -14.0f, -16.0f, -20.0f, -24.0f, -28.0f,
-32.0f, -40.0f, -48.0f, -56.0f, -64.0f, -80.0f, -96.0f, -112.0f,
-128.0f, -160.0f, -192.0f, -224.0f, -256.0f, -320.0f, -384.0f, -448.0f,
-512.0f, -640.0f, -768.0f, -896.0f, -1024.0f, -1280.0f, -1536.0f, -1792.0f,
-2048.0f, -2560.0f, -3072.0f, -3584.0f, -4096.0f, -5120.0f, -6144.0f, -7168.0f,
-8192.0f, -10240.0f, -12288.0f, -14336.0f, -16384.0f, -20480.0f, -24576.0f, -28672.0f,
-32768.0f, -40960.0f, -49152.0f, -57344.0f, -INFINITY, NAN, NAN, NAN,
0.000000000e+00f, 1.525878906e-05f, 3.051757812e-05f, 4.577636719e-05f, 6.103515625e-05f, 7.629394531e-05f, 9.155273438e-05f, 1.068115234e-04f,
1.220703125e-04f, 1.525878906e-04f, 1.831054688e-04f, 2.136230469e-04f, 2.441406250e-04f, 3.051757812e-04f, 3.662109375e-04f, 4.272460938e-04f,
4.882812500e-04f, 6.103515625e-04f, 7.324218750e-04f, 8.544921875e-04f, 9.765625000e-04f, 1.220703125e-03f, 1.464843750e-03f, 1.708984375e-03f,
1.953125000e-03f, 2.441406250e-03f, 2.929687500e-03f, 3.417968750e-03f, 3.906250000e-03f, 4.882812500e-03f, 5.859375000e-03f, 6.835937500e-03f,
7.812500000e-03f, 9.765625000e-03f, 1.171875000e-02f, 1.367187500e-02f, 1.562500000e-02f, 1.953125000e-02f, 2.343750000e-02f, 2.734375000e-02f,
3.125000000e-02f, 3.906250000e-02f, 4.687500000e-02f, 5.468750000e-02f, 6.250000000e-02f, 7.812500000e-02f, 9.375000000e-02f, 1.093750000e-01f,
1.250000000e-01f, 1.562500000e-01f, 1.875000000e-01f, 2.187500000e-01f, 2.500000000e-01f, 3.125000000e-01f, 3.750000000e-01f, 4.375000000e-01f,
5.000000000e-01f, 6.250000000e-01f, 7.500000000e-01f, 8.750000000e-01f, 1.000000000e+00f, 1.250000000e+00f, 1.500000000e+00f, 1.750000000e+00f,
2.000000000e+00f, 2.500000000e+00f, 3.000000000e+00f, 3.500000000e+00f, 4.000000000e+00f, 5.000000000e+00f, 6.000000000e+00f, 7.000000000e+00f,
8.000000000e+00f, 1.000000000e+01f, 1.200000000e+01f, 1.400000000e+01f, 1.600000000e+01f, 2.000000000e+01f, 2.400000000e+01f, 2.800000000e+01f,
3.200000000e+01f, 4.000000000e+01f, 4.800000000e+01f, 5.600000000e+01f, 6.400000000e+01f, 8.000000000e+01f, 9.600000000e+01f, 1.120000000e+02f,
1.280000000e+02f, 1.600000000e+02f, 1.920000000e+02f, 2.240000000e+02f, 2.560000000e+02f, 3.200000000e+02f, 3.840000000e+02f, 4.480000000e+02f,
5.120000000e+02f, 6.400000000e+02f, 7.680000000e+02f, 8.960000000e+02f, 1.024000000e+03f, 1.280000000e+03f, 1.536000000e+03f, 1.792000000e+03f,
2.048000000e+03f, 2.560000000e+03f, 3.072000000e+03f, 3.584000000e+03f, 4.096000000e+03f, 5.120000000e+03f, 6.144000000e+03f, 7.168000000e+03f,
8.192000000e+03f, 1.024000000e+04f, 1.228800000e+04f, 1.433600000e+04f, 1.638400000e+04f, 2.048000000e+04f, 2.457600000e+04f, 2.867200000e+04f,
3.276800000e+04f, 4.096000000e+04f, 4.915200000e+04f, 5.734400000e+04f, GGML_TABLE_INFINITY, GGML_TABLE_NAN, GGML_TABLE_NAN, GGML_TABLE_NAN,
-0.000000000e+00f,-1.525878906e-05f,-3.051757812e-05f,-4.577636719e-05f,-6.103515625e-05f,-7.629394531e-05f,-9.155273438e-05f,-1.068115234e-04f,
-1.220703125e-04f,-1.525878906e-04f,-1.831054688e-04f,-2.136230469e-04f,-2.441406250e-04f,-3.051757812e-04f,-3.662109375e-04f,-4.272460938e-04f,
-4.882812500e-04f,-6.103515625e-04f,-7.324218750e-04f,-8.544921875e-04f,-9.765625000e-04f,-1.220703125e-03f,-1.464843750e-03f,-1.708984375e-03f,
-1.953125000e-03f,-2.441406250e-03f,-2.929687500e-03f,-3.417968750e-03f,-3.906250000e-03f,-4.882812500e-03f,-5.859375000e-03f,-6.835937500e-03f,
-7.812500000e-03f,-9.765625000e-03f,-1.171875000e-02f,-1.367187500e-02f,-1.562500000e-02f,-1.953125000e-02f,-2.343750000e-02f,-2.734375000e-02f,
-3.125000000e-02f,-3.906250000e-02f,-4.687500000e-02f,-5.468750000e-02f,-6.250000000e-02f,-7.812500000e-02f,-9.375000000e-02f,-1.093750000e-01f,
-1.250000000e-01f,-1.562500000e-01f,-1.875000000e-01f,-2.187500000e-01f,-2.500000000e-01f,-3.125000000e-01f,-3.750000000e-01f,-4.375000000e-01f,
-5.000000000e-01f,-6.250000000e-01f,-7.500000000e-01f,-8.750000000e-01f,-1.000000000e+00f,-1.250000000e+00f,-1.500000000e+00f,-1.750000000e+00f,
-2.000000000e+00f,-2.500000000e+00f,-3.000000000e+00f,-3.500000000e+00f,-4.000000000e+00f,-5.000000000e+00f,-6.000000000e+00f,-7.000000000e+00f,
-8.000000000e+00f,-1.000000000e+01f,-1.200000000e+01f,-1.400000000e+01f,-1.600000000e+01f,-2.000000000e+01f,-2.400000000e+01f,-2.800000000e+01f,
-3.200000000e+01f,-4.000000000e+01f,-4.800000000e+01f,-5.600000000e+01f,-6.400000000e+01f,-8.000000000e+01f,-9.600000000e+01f,-1.120000000e+02f,
-1.280000000e+02f,-1.600000000e+02f,-1.920000000e+02f,-2.240000000e+02f,-2.560000000e+02f,-3.200000000e+02f,-3.840000000e+02f,-4.480000000e+02f,
-5.120000000e+02f,-6.400000000e+02f,-7.680000000e+02f,-8.960000000e+02f,-1.024000000e+03f,-1.280000000e+03f,-1.536000000e+03f,-1.792000000e+03f,
-2.048000000e+03f,-2.560000000e+03f,-3.072000000e+03f,-3.584000000e+03f,-4.096000000e+03f,-5.120000000e+03f,-6.144000000e+03f,-7.168000000e+03f,
-8.192000000e+03f,-1.024000000e+04f,-1.228800000e+04f,-1.433600000e+04f,-1.638400000e+04f,-2.048000000e+04f,-2.457600000e+04f,-2.867200000e+04f,
-3.276800000e+04f,-4.096000000e+04f,-4.915200000e+04f,-5.734400000e+04f, -GGML_TABLE_INFINITY, GGML_TABLE_NAN, GGML_TABLE_NAN, GGML_TABLE_NAN,
GGML_TABLE_END()
#endif // !CUDA && !HIP && !MUSA
// MXFP element converters -- portable IEEE-754 bit manipulation.
#if defined(GGML_MXFP_FUNC)

View File

@ -343,12 +343,8 @@
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
#endif
// MXFP dequantize has no arch-specific (SIMD) implementations except on arm and x86.
// All other targets use the scalar generic as the public cpu function.
#if !defined(__aarch64__) && !defined(__arm__) && !defined(_M_ARM) && !defined(_M_ARM64) && \
!defined(__x86_64__) && !defined(__i386__) && !defined(_M_IX86) && !defined(_M_X64)
#define dequantize_row_mxfp8_cpu_generic dequantize_row_mxfp8_cpu
#define dequantize_row_mxfp6_cpu_generic dequantize_row_mxfp6_cpu
// MXFP dequantize fallbacks (same GGML_CPU_GENERIC guard as above)
#if defined(GGML_CPU_GENERIC)
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu

View File

@ -4311,68 +4311,6 @@ static void ggml_vec_dot_mxfp6_q8_0_neon(
*s = vaddvq_f32(vaddq_f32(acc0, acc1));
}
// MXFP FP8/FP6 dequantize_row (AoS)
static void dequantize_row_mxfp8_neon(
const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k,
const mxfp_neon_traits_t * t) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
const block_mxfp8 * GGML_RESTRICT x = vx;
const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale);
const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift);
const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift);
for (int ib = 0; ib < nb; ++ib) {
const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32(x[ib].e));
for (int j = 0; j < 32; j += 8) {
uint32x4_t v_lo, v_hi;
widen_u8x8_to_u32x4x2(x[ib].qs + j, &v_lo, &v_hi);
const float32x4_t val_lo = mxfp8_dequant_neon(v_lo,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
const float32x4_t val_hi = mxfp8_dequant_neon(v_hi,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
vst1q_f32(y + ib * QK_MXFP8 + j, vmulq_f32(val_lo, v_scale));
vst1q_f32(y + ib * QK_MXFP8 + j + 4, vmulq_f32(val_hi, v_scale));
}
}
}
static void dequantize_row_mxfp6_neon(
const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k,
const mxfp_neon_traits_t * t) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale);
const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift);
const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift);
for (int ib = 0; ib < nb; ++ib) {
const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib;
const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32(xb->e));
for (int j = 0; j < 32; j += 4) {
const uint32x4_t v_raw = unpack_fp6x4_neon(xb->qs + (j * 3 / 4));
const float32x4_t val = mxfp6_dequant_neon(v_raw,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
vst1q_f32(y + ib * QK_MXFP6 + j, vmulq_f32(val, v_scale));
}
}
}
// MXFP SoA dequant (flash attention)
static void dequantize_row_mxfp8_soa_neon(
@ -4506,22 +4444,6 @@ void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
#endif
}
void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__ARM_NEON)
dequantize_row_mxfp8_neon(x, y, k, &MXFP_TRAITS_E4M3);
#else
dequantize_row_mxfp8_cpu_generic(x, y, k);
#endif
}
void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__ARM_NEON)
dequantize_row_mxfp6_neon(x, y, k, &MXFP_TRAITS_E2M3);
#else
dequantize_row_mxfp6_cpu_generic(x, y, k);
#endif
}
void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__ARM_NEON)
dequantize_row_mxfp4_soa_neon(x, y, k);

View File

@ -3950,67 +3950,6 @@ static void ggml_vec_dot_mxfp6_q8_0_avx2(
*s = hsum_float_8(acc);
}
// MXFP FP8/FP6 dequantize_row (AoS)
static void dequantize_row_mxfp8_avx2(
const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k,
const mxfp_avx2_traits_t * t) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
const block_mxfp8 * GGML_RESTRICT x = vx;
const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale);
const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask);
const __m256i v_zero = _mm256_setzero_si256();
for (int ib = 0; ib < nb; ++ib) {
const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32(x[ib].e));
for (int j = 0; j < 32; j += 8) {
const __m256i v_raw = _mm256_cvtepu8_epi32(
_mm_loadl_epi64((const __m128i *)(x[ib].qs + j)));
const __m256 val = mxfp_dequant_avx2(v_raw,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc,
v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift);
_mm256_storeu_ps(y + ib * QK_MXFP8 + j, _mm256_mul_ps(val, v_scale));
}
}
}
static void dequantize_row_mxfp6_avx2(
const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k,
const mxfp_avx2_traits_t * t) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale);
const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask);
const __m256i v_zero = _mm256_setzero_si256();
for (int ib = 0; ib < nb; ++ib) {
const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib;
const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32(xb->e));
for (int j = 0; j < 32; j += 8) {
const __m256i v_raw = unpack_fp6x8_avx2(xb->qs, j);
const __m256 val = mxfp_dequant_avx2(v_raw,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc,
v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift);
_mm256_storeu_ps(y + ib * QK_MXFP6 + j, _mm256_mul_ps(val, v_scale));
}
}
}
// MXFP SoA dequant (flash attention)
static void dequantize_row_mxfp8_soa_avx2(
@ -4133,22 +4072,6 @@ void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
#endif
}
void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__AVX2__)
dequantize_row_mxfp8_avx2(x, y, k, &MXFP_TRAITS_E4M3);
#else
dequantize_row_mxfp8_cpu_generic(x, y, k);
#endif
}
void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__AVX2__)
dequantize_row_mxfp6_avx2(x, y, k, &MXFP_TRAITS_E2M3);
#else
dequantize_row_mxfp6_cpu_generic(x, y, k);
#endif
}
void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__AVX2__)
dequantize_row_mxfp4_soa_avx2(x, y, k);

View File

@ -7,6 +7,7 @@
#include "ggml-cpu-impl.h"
#include "ggml-impl.h"
#include "quants.h"
#include "ggml-quants.h"
#include "ggml-threading.h"
#include "unary-ops.h"
#include "binary-ops.h"
@ -280,7 +281,6 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
},
[GGML_TYPE_MXFP8_E4M3] = {
.from_float = quantize_row_mxfp8,
.to_float = dequantize_row_mxfp8_cpu,
.from_float_soa = quantize_row_mxfp8_soa,
.to_float_soa = dequantize_row_mxfp8_soa_cpu,
.vec_dot = ggml_vec_dot_mxfp8_q8_0,
@ -289,7 +289,6 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
},
[GGML_TYPE_MXFP6_E2M3] = {
.from_float = quantize_row_mxfp6,
.to_float = dequantize_row_mxfp6_cpu,
.from_float_soa = quantize_row_mxfp6_soa,
.to_float_soa = dequantize_row_mxfp6_soa_cpu,
.vec_dot = ggml_vec_dot_mxfp6_q8_0,

View File

@ -8313,20 +8313,8 @@ static mxfp_fa_params mxfp_fa_params_init(
p.v_multihead = is_mxfp_v && (nbv2 == (size_t)ggml_row_size(v->type, DV));
p.v_soa_elems = is_mxfp_v ? (p.v_multihead ? nev2 * DV : DV) : 0;
// Per-head SoA addressing for multihead mode.
// Precompute byte offsets so the hot loop can skip per-head pointer math.
// qs_per_block values from centralized MXFP_QS_PER_BLOCK_* defines in ggml-common.h.
auto mxfp_qs_per_block = [](ggml_type type) -> int {
switch (type) {
case GGML_TYPE_MXFP4_E2M1: return MXFP4_SOA_QS_PER_BLOCK;
case GGML_TYPE_MXFP8_E4M3: return MXFP8_SOA_QS_PER_BLOCK;
case GGML_TYPE_MXFP6_E2M3: return MXFP6_SOA_QS_PER_BLOCK;
default: return 0;
}
};
if (is_mxfp_k) {
p.k_qs_per_block = mxfp_qs_per_block(k->type);
p.k_qs_per_block = ggml_mxfp_qs_per_block(k->type);
p.k_blocks_per_head = (int)(DK / 32);
p.k_head_qs_bytes = p.k_blocks_per_head * p.k_qs_per_block;
const int64_t k_total_blocks = p.k_multihead ? nek2 * p.k_blocks_per_head : p.k_blocks_per_head;
@ -8334,7 +8322,7 @@ static mxfp_fa_params mxfp_fa_params_init(
}
if (is_mxfp_v) {
p.v_qs_per_block = mxfp_qs_per_block(v->type);
p.v_qs_per_block = ggml_mxfp_qs_per_block(v->type);
p.v_blocks_per_head = (int)(DV / 32);
p.v_head_qs_bytes = p.v_blocks_per_head * p.v_qs_per_block;
const int64_t v_total_blocks = p.v_multihead ? nev2 * p.v_blocks_per_head : p.v_blocks_per_head;

View File

@ -309,13 +309,7 @@ void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
(ggml_to_float_t)dequantize_row_mxfp6);
}
// Generic dequant wrappers — arch-specific SIMD versions override via fallback.h.
void dequantize_row_mxfp8_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp8(x, y, k);
}
void dequantize_row_mxfp6_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp6(x, y, k);
}
// Generic SoA dequant wrappers — arch-specific SIMD versions override via fallback.h.
void dequantize_row_mxfp4_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp4_soa(x, y, k);
}

View File

@ -24,10 +24,6 @@ void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, i
void quantize_row_mxfp8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_mxfp6(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
// Dequantization (SIMD-optimized, arch-dispatched)
void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
@ -87,13 +83,7 @@ void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void dequantize_row_mxfp8_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp6_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
// SoA quantize/dequant for MXFP flash attention
void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
// SoA dequant (SIMD-dispatched, CPU backend)
void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp6_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);

View File

@ -1010,6 +1010,19 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
}
}
// MXFP8/MXFP6: no Metal shaders yet reject for all ops.
// MXFP4: has AoS shaders (MUL_MAT, GET_ROWS) but no SoA/flash attention support yet.
for (size_t i = 0, n = 3; i < n; ++i) {
if (op->src[i] != NULL && ggml_is_type_mxfp(op->src[i]->type)) {
if (op->src[i]->type != GGML_TYPE_MXFP4_E2M1) {
return false;
}
if (op->op == GGML_OP_FLASH_ATTN_EXT || op->op == GGML_OP_SET_ROWS) {
return false;
}
}
}
switch (op->op) {
case GGML_OP_SCALE:
case GGML_OP_FILL:

View File

@ -711,7 +711,7 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
.from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref,
},
[GGML_TYPE_MXFP4_E2M1] = {
.type_name = "mxfp4_e2m1",
.type_name = "mxfp4",
.blck_size = QK_MXFP4,
.type_size = sizeof(block_mxfp4),
.is_quantized = true,

View File

@ -252,6 +252,7 @@ if (NOT GGML_BACKEND_DL)
# these tests use the backends directly and cannot be built with dynamic loading
llama_build_and_test(test-barrier.cpp)
llama_build_and_test(test-quantize-fns.cpp)
target_include_directories(test-quantize-fns PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src)
llama_build_and_test(test-quantize-perf.cpp)
llama_build_and_test(test-rope.cpp)
endif()

View File

@ -18,7 +18,6 @@
#include <ggml.h>
#include <ggml-alloc.h>
#include <ggml-backend.h>
#include <ggml-cpu.h>
#include <ggml-cpp.h>
#include <algorithm>
@ -151,59 +150,79 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
}
}
// MXFP SoA quantization functions
extern "C" {
void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
}
typedef void (*mxfp_soa_quantize_fn)(const float *, void *, int64_t);
typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t);
// Initialize an MXFP tensor with SoA layout (soa_bytes = region width, 0 = one row).
static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f, size_t soa_bytes = 0) {
struct mxfp_soa_fns {
ggml_type type;
mxfp_soa_quantize_fn quantize;
mxfp_soa_dequantize_fn dequantize;
};
static const mxfp_soa_fns mxfp_soa_table[] = {
{ GGML_TYPE_MXFP4_E2M1, quantize_row_mxfp4_soa, dequantize_row_mxfp4_soa },
{ GGML_TYPE_MXFP8_E4M3, quantize_row_mxfp8_soa, dequantize_row_mxfp8_soa },
{ GGML_TYPE_MXFP6_E2M3, quantize_row_mxfp6_soa, dequantize_row_mxfp6_soa },
};
static const mxfp_soa_fns * get_mxfp_soa(ggml_type type) {
for (const auto & e : mxfp_soa_table) {
if (e.type == type) return &e;
}
return nullptr;
}
// init MXFP tensor with SoA layout
static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
GGML_ASSERT(ggml_is_type_mxfp(tensor->type));
const auto * traits = ggml_get_type_traits_cpu(tensor->type);
GGML_ASSERT(traits->from_float_soa && "MXFP type missing SoA quantize in traits");
auto quantize_soa = traits->from_float_soa;
const auto * soa = get_mxfp_soa(tensor->type);
GGML_ASSERT(soa && "unsupported MXFP type for SoA init");
const int qk = (int)ggml_blck_size(tensor->type);
const size_t block_size = ggml_type_size(tensor->type);
const size_t head_row_sz = ggml_row_size(tensor->type, tensor->ne[0]);
if (soa_bytes == 0) { soa_bytes = head_row_sz; }
GGML_ASSERT(soa_bytes % block_size == 0 && "soa_bytes must be a multiple of block_size");
const int64_t soa_elems = (int64_t)(soa_bytes / block_size) * qk;
const int64_t DK = tensor->ne[0];
const size_t row_sz = ggml_row_size(tensor->type, DK);
// multihead: heads packed contiguously
const bool multihead = (tensor->nb[2] == row_sz) && (tensor->ne[2] > 1);
std::default_random_engine gen(42);
std::uniform_real_distribution<float> dist(min, max);
std::vector<float> region_f32(soa_elems);
const size_t nb1 = tensor->nb[1];
const size_t nb2 = tensor->nb[2];
const size_t nb3 = tensor->nb[3];
const int64_t ne1 = tensor->ne[1];
const int64_t ne2 = tensor->ne[2];
const int64_t ne3 = tensor->ne[3];
const int64_t heads_per_region = (int64_t)(soa_bytes / head_row_sz);
GGML_ASSERT(soa_bytes % head_row_sz == 0 && "soa_bytes must be a multiple of head_row_sz");
std::vector<uint8_t> buf(ggml_nbytes(tensor), 0);
if (heads_per_region > 1) {
// Multi-head SoA:
for (int64_t i3 = 0; i3 < ne3; i3++) {
const int64_t n_groups = ne2 / heads_per_region;
for (int64_t ig = 0; ig < n_groups; ig++) {
for (int64_t i1 = 0; i1 < ne1; i1++) {
size_t offset = i3*nb3 + ig*heads_per_region*nb2 + i1*nb1;
for (int64_t j = 0; j < soa_elems; j++) { region_f32[j] = dist(gen); }
quantize_soa(region_f32.data(), buf.data() + offset, soa_elems);
}
if (multihead) {
// all heads at one position share one SoA region
const int64_t n_heads = tensor->ne[2];
const int64_t soa_elems = n_heads * DK;
std::vector<float> region(soa_elems);
for (int64_t i3 = 0; i3 < tensor->ne[3]; i3++) {
for (int64_t i1 = 0; i1 < tensor->ne[1]; i1++) {
size_t offset = i3*tensor->nb[3] + i1*tensor->nb[1];
for (int64_t j = 0; j < soa_elems; j++) { region[j] = dist(gen); }
soa->quantize(region.data(), buf.data() + offset, soa_elems);
}
}
} else {
// Per-head SoA:
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
for (int64_t i1 = 0; i1 < ne1; i1++) {
size_t offset = i3*nb3 + i2*nb2 + i1*nb1;
for (int64_t j = 0; j < soa_elems; j++) { region_f32[j] = dist(gen); }
quantize_soa(region_f32.data(), buf.data() + offset, soa_elems);
// per-head SoA: each head independently packed
std::vector<float> region(DK);
for (int64_t i3 = 0; i3 < tensor->ne[3]; i3++) {
for (int64_t i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int64_t i1 = 0; i1 < tensor->ne[1]; i1++) {
size_t offset = i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1];
for (int64_t j = 0; j < DK; j++) { region[j] = dist(gen); }
soa->quantize(region.data(), buf.data() + offset, DK);
}
}
}
@ -304,9 +323,12 @@ static std::vector<float> tensor_to_float(const ggml_tensor * t) {
const bool is_mxfp = ggml_is_type_mxfp(t->type);
mxfp_soa_dequantize_fn mxfp_dequant_soa = nullptr;
std::vector<float> mxfp_row_f32;
if (is_mxfp) {
mxfp_dequant_soa = (mxfp_soa_dequantize_fn) ggml_get_type_traits_cpu(t->type)->to_float_soa;
GGML_ASSERT(mxfp_dequant_soa && "MXFP type missing SoA dequant in traits");
const auto * soa_fns = get_mxfp_soa(t->type);
GGML_ASSERT(soa_fns && "unsupported MXFP type in tensor_to_float");
mxfp_dequant_soa = soa_fns->dequantize;
mxfp_row_f32.resize(t->ne[0]);
}
// access elements by index to avoid gaps in views
@ -315,9 +337,8 @@ static std::vector<float> tensor_to_float(const ggml_tensor * t) {
for (int64_t i1 = 0; i1 < t->ne[1]; i1++) {
if (is_mxfp) {
size_t row_off = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1];
std::vector<float> row_f32(t->ne[0]);
mxfp_dequant_soa(&buf[row_off], row_f32.data(), t->ne[0]);
tv.insert(tv.end(), row_f32.begin(), row_f32.end());
mxfp_dequant_soa(&buf[row_off], mxfp_row_f32.data(), t->ne[0]);
tv.insert(tv.end(), mxfp_row_f32.begin(), mxfp_row_f32.end());
continue;
}
for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) {
@ -6370,8 +6391,7 @@ struct test_flash_attn_ext : public test_case {
} else if (strcmp(t->name, "m") == 0) {
init_tensor_kq_mask(t);
} else if ((strcmp(t->name, "k") == 0 || strcmp(t->name, "v") == 0) && ggml_is_type_mxfp(t->type)) {
// MXFP K/V use SoA layout; nb[1] spans all heads in one KV-position stride
init_tensor_mxfp_soa(t, -1.0f, 1.0f, t->nb[1]);
init_tensor_mxfp_soa(t);
} else {
init_tensor_uniform(t);
}
@ -7378,8 +7398,7 @@ static const ggml_type all_types[] = {
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
GGML_TYPE_Q8_0,
GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP8_E4M3,
GGML_TYPE_MXFP6_E2M3,
GGML_TYPE_MXFP4_E2M1,
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
GGML_TYPE_Q6_K,

View File

@ -2,6 +2,11 @@
#include "ggml.h"
#include "ggml-cpu.h"
#include "ggml-quants.h"
#define GGML_COMMON_DECL_CPP
#define GGML_COMMON_IMPL_CPP
#include "ggml-common.h"
#undef NDEBUG
#include <assert.h>
@ -24,6 +29,12 @@ constexpr float MAX_QUANTIZATION_TOTAL_ERROR_FP4 = 0.0030f;
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 = 0.0070f;
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 = 0.0040f;
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 = 0.0020f;
// MXFP Hadamard pipeline thresholds (mxfp_rmse, which computes sqrt(sum/n)).
// These represent actual RMSE through the full KV cache write/read path.
constexpr float MAX_MXFP_PIPELINE_ERROR_MXFP4 = 0.40f;
constexpr float MAX_MXFP_PIPELINE_ERROR_MXFP8 = 0.08f;
constexpr float MAX_MXFP_PIPELINE_ERROR_MXFP6 = 0.10f;
constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f;
constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f;
constexpr float MAX_DOT_PRODUCT_ERROR_FP4 = 0.03f;
@ -50,6 +61,16 @@ static float array_rmse(const float * a1, const float * a2, size_t n) {
return sqrtf(sum) / n;
}
// MXFP RMSE: sqrt(sum/n), used with MAX_MXFP_PIPELINE_ERROR_* thresholds
static float mxfp_rmse(const float * a1, const float * a2, size_t n) {
double sum = 0;
for (size_t i = 0; i < n; i++) {
double diff = a1[i] - a2[i];
sum += diff * diff;
}
return sqrtf((float)(sum / n));
}
// Total quantization error on test data
static float total_quantization_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data) {
std::vector<uint8_t> tmp_q(2*test_size);
@ -192,7 +213,7 @@ int main(int argc, char * argv[]) {
}
}
// MXFP SoA roundtrip: test from_float_soa → to_float_soa through the traits system
// MXFP SoA roundtrip via traits
for (int i = 0; i < GGML_TYPE_COUNT; i++) {
ggml_type type = (ggml_type) i;
const auto * qfns_cpu = ggml_get_type_traits_cpu(type);
@ -220,6 +241,874 @@ int main(int argc, char * argv[]) {
}
}
// MXFP traits: SoA required, MXFP6/MXFP8 are KV-cache-only (no AoS dequant)
{
const ggml_type all_mxfp_types[] = { GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3 };
for (ggml_type type : all_mxfp_types) {
const auto * cpu = ggml_get_type_traits_cpu(type);
failed = !(cpu->from_float_soa && cpu->to_float_soa);
num_failed += failed;
if (failed || verbose) {
printf("%5s SoA traits present: %s\n", ggml_type_name(type), RESULT_STR[failed]);
}
}
// KV-cache-only types: no AoS dequant
const ggml_type kv_only_types[] = { GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3 };
for (ggml_type type : kv_only_types) {
const auto * cpu = ggml_get_type_traits_cpu(type);
failed = (cpu->to_float != nullptr);
num_failed += failed;
if (failed || verbose) {
printf("%5s AoS CPU to_float absent: %s\n", ggml_type_name(type), RESULT_STR[failed]);
}
}
}
// Hadamard self-inverse: H(H(x)) == x
{
float original[32], transformed[32];
for (int i = 0; i < 32; i++) {
original[i] = 0.1f + 2.0f * cosf(i + 0.5f);
transformed[i] = original[i];
}
ggml_hadamard_32_inplace(transformed);
ggml_hadamard_32_inplace(transformed); // apply twice = identity
float max_err = 0.0f;
for (int i = 0; i < 32; i++) {
float err = fabsf(transformed[i] - original[i]);
if (err > max_err) max_err = err;
}
// floating-point rounding tolerance
failed = !(max_err < 1e-5f);
num_failed += failed;
if (failed || verbose) {
printf("hadamard H(H(x))==x roundtrip: %s (max_err=%.2e)\n", RESULT_STR[failed], max_err);
}
}
// SoA SIMD vs scalar dequant
{
struct soa_cross_check {
ggml_type type;
void (*ref_dequant)(const void *, float *, int64_t);
};
const soa_cross_check checks[] = {
{ GGML_TYPE_MXFP4_E2M1, dequantize_row_mxfp4_soa },
{ GGML_TYPE_MXFP8_E4M3, dequantize_row_mxfp8_soa },
{ GGML_TYPE_MXFP6_E2M3, dequantize_row_mxfp6_soa },
};
for (const auto & c : checks) {
const auto * cpu = ggml_get_type_traits_cpu(c.type);
if (!cpu->from_float_soa || !cpu->to_float_soa) continue;
const size_t buf_size = ggml_row_size(c.type, test_size);
std::vector<uint8_t> tmp_q(buf_size);
std::vector<float> out_ref(test_size);
std::vector<float> out_simd(test_size);
// Quantize with SoA
cpu->from_float_soa(test_data.data(), tmp_q.data(), test_size);
// Dequant with scalar reference
c.ref_dequant(tmp_q.data(), out_ref.data(), test_size);
// Dequant with CPU/SIMD path
cpu->to_float_soa(tmp_q.data(), out_simd.data(), test_size);
// Compare bitwise
int mismatches = 0;
for (size_t j = 0; j < test_size; j++) {
uint32_t a, b;
memcpy(&a, &out_ref[j], 4);
memcpy(&b, &out_simd[j], 4);
if (a != b) mismatches++;
}
failed = (mismatches > 0);
num_failed += failed;
if (failed || verbose) {
printf("%5s SoA SIMD vs scalar ref: %s (%zu/%zu match)\n",
ggml_type_name(c.type), RESULT_STR[failed],
test_size - mismatches, test_size);
}
}
}
// element converters vs canonical LUT values
{
struct lut_test {
const char * name;
const float * lut;
int count;
float (*converter)(uint8_t);
};
const lut_test lut_tests[] = {
{ "fp8_e4m3", kvalues_mxfp8_e4m3, 256, fp8_e4m3_to_float },
{ "fp8_e5m2", kvalues_mxfp8_e5m2, 256, fp8_e5m2_to_float },
{ "fp6_e2m3", kvalues_mxfp6_e2m3, 64, fp6_e2m3_to_float },
{ "fp6_e3m2", kvalues_mxfp6_e3m2, 64, fp6_e3m2_to_float },
};
for (const auto & t : lut_tests) {
int mismatches = 0;
for (int i = 0; i < t.count; i++) {
const float converter_val = t.converter((uint8_t)i);
const float lut_val = t.lut[i];
// both NaN = match
if (isnan(converter_val) && isnan(lut_val)) continue;
if (converter_val != lut_val) {
if (mismatches == 0 || verbose) {
printf(" %s LUT mismatch at [%d]: converter=%.8g, lut=%.8g\n",
t.name, i, converter_val, lut_val);
}
mismatches++;
}
}
failed = (mismatches > 0);
num_failed += failed;
if (failed || verbose) {
printf("%5s converter vs LUT: %s (%d/%d values match)\n",
t.name, RESULT_STR[failed], t.count - mismatches, t.count);
}
}
// FP4 E2M1
{
int mismatches = 0;
for (int i = 0; i < 16; i++) {
const float converter_val = ggml_mxfp_fp4_e2m1_to_float((uint8_t)i);
const float lut_val = kvalues_mxfp4_float[i];
if (converter_val != lut_val) {
if (mismatches == 0 || verbose) {
printf(" fp4_e2m1 LUT mismatch at [%d]: converter=%.8g, lut=%.8g\n",
i, converter_val, lut_val);
}
mismatches++;
}
}
failed = (mismatches > 0);
num_failed += failed;
if (failed || verbose) {
printf("fp4_e2m1 converter vs LUT: %s (%d/16 values match)\n",
RESULT_STR[failed], 16 - mismatches);
}
}
}
// element converter edge cases (expected values validated against LUTs)
{
struct conv_check {
const char * name;
float input;
uint8_t expected_bits;
bool is_saturation; // true = input overflows, expected_bits is max finite
const float * lut; // canonical LUT to validate expected_bits against (NULL for FP4)
float (*to_float)(uint8_t);
uint8_t (*to_quant)(float);
};
const conv_check checks[] = {
// FP4 E2M1 -[S(1)|E(2)|M(1)], bias=0
{ "fp4 zero", 0.0f, 0x00, false, nullptr, nullptr, nullptr },
{ "fp4 sub 0.5", 0.5f, 0x01, false, nullptr, nullptr, nullptr },
{ "fp4 norm 1.0", 1.0f, 0x02, false, nullptr, nullptr, nullptr },
{ "fp4 max 6.0", 6.0f, 0x07, false, nullptr, nullptr, nullptr },
{ "fp4 neg -3.0", -3.0f, 0x0D, false, nullptr, nullptr, nullptr },
{ "fp4 sat 100", 100.0f, 0x07, true, nullptr, nullptr, nullptr },
// FP8 E4M3 -[S(1)|E(4)|M(3)], bias=7
{ "e4m3 zero", 0.0f, 0x00, false, kvalues_mxfp8_e4m3, fp8_e4m3_to_float, float_to_fp8_e4m3_rn },
{ "e4m3 sub", 1.f/512, 0x01, false, kvalues_mxfp8_e4m3, fp8_e4m3_to_float, float_to_fp8_e4m3_rn },
{ "e4m3 max 448", 448.0f, 0x7E, false, kvalues_mxfp8_e4m3, fp8_e4m3_to_float, float_to_fp8_e4m3_rn },
{ "e4m3 sat 500", 500.0f, 0x7E, true, kvalues_mxfp8_e4m3, fp8_e4m3_to_float, float_to_fp8_e4m3_rn },
{ "e4m3 neg -1", -1.0f, 0xB8, false, kvalues_mxfp8_e4m3, fp8_e4m3_to_float, float_to_fp8_e4m3_rn },
// FP6 E2M3 -[S(1)|E(2)|M(3)], no NaN/Inf
{ "e2m3 zero", 0.0f, 0x00, false, kvalues_mxfp6_e2m3, fp6_e2m3_to_float, float_to_fp6_e2m3_rn },
{ "e2m3 sub", 0.125f, 0x01, false, kvalues_mxfp6_e2m3, fp6_e2m3_to_float, float_to_fp6_e2m3_rn },
{ "e2m3 max 7.5", 7.5f, 0x1F, false, kvalues_mxfp6_e2m3, fp6_e2m3_to_float, float_to_fp6_e2m3_rn },
{ "e2m3 sat 100", 100.0f, 0x1F, true, kvalues_mxfp6_e2m3, fp6_e2m3_to_float, float_to_fp6_e2m3_rn },
// FP6 E3M2 -[S(1)|E(3)|M(2)], no NaN/Inf, exp=7 is NORMAL
{ "e3m2 zero", 0.0f, 0x00, false, kvalues_mxfp6_e3m2, fp6_e3m2_to_float, float_to_fp6_e3m2_rn },
{ "e3m2 sub", 0.0625f, 0x01, false, kvalues_mxfp6_e3m2, fp6_e3m2_to_float, float_to_fp6_e3m2_rn },
{ "e3m2 max 28.0", 28.0f, 0x1F, false, kvalues_mxfp6_e3m2, fp6_e3m2_to_float, float_to_fp6_e3m2_rn },
{ "e3m2 exp7 16", 16.0f, 0x1C, false, kvalues_mxfp6_e3m2, fp6_e3m2_to_float, float_to_fp6_e3m2_rn },
// FP8 E5M2 -[S(1)|E(5)|M(2)], bias=15
{ "e5m2 zero", 0.0f, 0x00, false, kvalues_mxfp8_e5m2, fp8_e5m2_to_float, float_to_fp8_e5m2_rn },
{ "e5m2 max", 57344.f, 0x7B, false, kvalues_mxfp8_e5m2, fp8_e5m2_to_float, float_to_fp8_e5m2_rn },
};
int conv_bad = 0;
// validate expected_bits against LUTs
for (const auto & c : checks) {
if (c.lut && !c.is_saturation) {
float lut_val = c.lut[c.expected_bits];
if (c.input != lut_val && !(c.input == 0.0f && lut_val == 0.0f)) {
printf(" TEST BUG %s: expected_bits=0x%02X → LUT=%.8g, but input=%.8g\n",
c.name, c.expected_bits, lut_val, c.input);
conv_bad++;
}
} else if (!c.lut && !c.is_saturation) {
float lut_val = kvalues_mxfp4_float[c.expected_bits];
if (c.input != lut_val && !(c.input == 0.0f && lut_val == 0.0f)) {
printf(" TEST BUG %s: expected_bits=0x%02X → LUT=%.8g, but input=%.8g\n",
c.name, c.expected_bits, lut_val, c.input);
conv_bad++;
}
}
}
// Now test the quantize direction
for (const auto & c : checks) {
uint8_t got;
if (c.to_quant) {
got = c.to_quant(c.input);
} else {
got = ggml_mxfp_float_to_fp4_e2m1(c.input);
}
if (got != c.expected_bits) {
if (conv_bad == 0 || verbose) {
printf(" %s: quantize(%.6g) = 0x%02X, expected 0x%02X\n",
c.name, c.input, got, c.expected_bits);
}
conv_bad++;
}
}
// FP8 E4M3: 0x7F must dequantize to NaN
{
float nan_val = fp8_e4m3_to_float(0x7F);
if (!isnan(nan_val)) {
if (conv_bad == 0 || verbose) {
printf(" e4m3 0x7F dequant: expected NaN, got %.6g\n", nan_val);
}
conv_bad++;
}
}
// FP6 E3M2: exp=7 must dequant to valid float (NOT Inf/NaN)
{
float exp7_val = fp6_e3m2_to_float(0x1F); // max: exp=7, mant=3 → 28.0
if (isnan(exp7_val) || exp7_val != 28.0f) {
if (conv_bad == 0 || verbose) {
printf(" e3m2 0x1F dequant: expected 28.0, got %.6g\n", exp7_val);
}
conv_bad++;
}
}
failed = (conv_bad > 0);
num_failed += failed;
if (failed || verbose) {
printf(" element converter edge cases: %s (%d/%d passed)\n",
RESULT_STR[failed],
(int)(sizeof(checks)/sizeof(checks[0])) + 2 - conv_bad,
(int)(sizeof(checks)/sizeof(checks[0])) + 2);
}
}
// FP6 pack/unpack round-trip
{
int pack_bad = 0;
// Test all 64 possible 6-bit values in each of the 4 positions
for (int pos = 0; pos < 4; pos++) {
for (int val = 0; val < 64; val++) {
uint8_t in[4] = {0, 0, 0, 0};
in[pos] = (uint8_t)val;
uint8_t packed[3], out[4];
pack_fp6x4(in, packed);
unpack_fp6x4(packed, out);
if (out[pos] != (uint8_t)val) {
if (pack_bad == 0 || verbose) {
printf(" fp6 pack roundtrip: pos=%d val=0x%02X → got 0x%02X\n",
pos, val, out[pos]);
}
pack_bad++;
}
// no crosstalk
for (int k = 0; k < 4; k++) {
if (k != pos && out[k] != 0) {
if (pack_bad == 0 || verbose) {
printf(" fp6 pack crosstalk: pos=%d val=0x%02X leaked to pos=%d (0x%02X)\n",
pos, val, k, out[k]);
}
pack_bad++;
}
}
}
}
// known-answer: [0x3F, 0x00, 0x3F, 0x00] -> {0x3F, 0xF0, 0x03}
{
uint8_t in[4] = {0x3F, 0x00, 0x3F, 0x00};
uint8_t packed[3];
pack_fp6x4(in, packed);
uint8_t expected[3] = {0x3F, 0xF0, 0x03};
if (packed[0] != expected[0] || packed[1] != expected[1] || packed[2] != expected[2]) {
if (pack_bad == 0 || verbose) {
printf(" fp6 known-answer: packed [%02X,%02X,%02X] expected [%02X,%02X,%02X]\n",
packed[0], packed[1], packed[2], expected[0], expected[1], expected[2]);
}
pack_bad++;
}
}
failed = (pack_bad > 0);
num_failed += failed;
if (failed || verbose) {
printf(" fp6 pack/unpack round-trip: %s\n", RESULT_STR[failed]);
}
}
// E8M0 known-answer decode + HALF vs FULL (MXFP4 uses HALF, MXFP6/8 use FULL)
{
int e8m0_bad = 0;
// Known-answer E8M0 decodes
struct { uint8_t e; float expected; } e8m0_known[] = {
{ 127, 1.0f }, // 2^(127-127) = 2^0 = 1.0
{ 128, 2.0f }, // 2^(128-127) = 2^1 = 2.0
{ 126, 0.5f }, // 2^(126-127) = 2^(-1) = 0.5
{ 254, 1.70141183e+38f }, // 2^127 (max representable)
{ 1, 1.17549435e-38f }, // 2^(-126) (min normal)
};
for (const auto & t : e8m0_known) {
float got = ggml_mxfp_e8m0_to_fp32(t.e);
if (got != t.expected) {
if (e8m0_bad == 0 || verbose) {
printf(" E8M0 decode e=%d: got %.8g, expected %.8g\n", t.e, got, t.expected);
}
e8m0_bad++;
}
}
// HALF must be exactly half of FULL for all valid exponents
for (int e = 2; e < 255; e++) {
float full = ggml_mxfp_e8m0_to_fp32((uint8_t)e);
float half = ggml_mxfp_e8m0_to_fp32_half((uint8_t)e);
if (half != full * 0.5f) {
if (e8m0_bad == 0 || verbose) {
printf(" E8M0 HALF!=FULL/2 at e=%d: half=%.8g, full/2=%.8g\n", e, half, full * 0.5f);
}
e8m0_bad++;
break; // one failure is enough to flag the pattern
}
}
failed = (e8m0_bad > 0);
num_failed += failed;
if (failed || verbose) {
printf(" E8M0 known-answer + HALF/FULL: %s\n", RESULT_STR[failed]);
}
}
// E8M0 rounding at sqrt(2) threshold
{
int round_bad = 0;
// amax=1.0: floor_log2=0, mantissa=0 → no round → e_base = 0 - 0 + 127 = 127
{
int e = ggml_mxfp_e8m0_base_estimate(1.0f, 0);
if (e != 127) {
printf(" E8M0 round: amax=1.0 → e=%d, expected 127\n", e);
round_bad++;
}
}
// amax=2.0: floor_log2=1, mantissa=0 → no round → e_base = 1 + 127 = 128
{
int e = ggml_mxfp_e8m0_base_estimate(2.0f, 0);
if (e != 128) {
printf(" E8M0 round: amax=2.0 → e=%d, expected 128\n", e);
round_bad++;
}
}
// amax just below sqrt(2): mantissa < 0x3504F3 → floor only → e=127
{
// 1.41421 has IEEE mantissa just below 0x3504F3
float below = 1.4142f;
int e = ggml_mxfp_e8m0_base_estimate(below, 0);
if (e != 127) {
printf(" E8M0 round: amax=%.6f → e=%d, expected 127 (no round)\n", below, e);
round_bad++;
}
}
// amax at sqrt(2): mantissa >= 0x3504F3 → rounds up → e=128
{
float at_sqrt2 = 1.41422f;
int e = ggml_mxfp_e8m0_base_estimate(at_sqrt2, 0);
if (e != 128) {
printf(" E8M0 round: amax=%.6f → e=%d, expected 128 (rounds up)\n", at_sqrt2, e);
round_bad++;
}
}
// Verify emax_offset shifts the result
{
int e_no_off = ggml_mxfp_e8m0_base_estimate(448.0f, 0);
int e_e4m3 = ggml_mxfp_e8m0_base_estimate(448.0f, MXFP8_E4M3_EMAX_OFFSET);
if (e_no_off - e_e4m3 != MXFP8_E4M3_EMAX_OFFSET) {
printf(" E8M0 emax_offset: diff=%d, expected %d\n",
e_no_off - e_e4m3, MXFP8_E4M3_EMAX_OFFSET);
round_bad++;
}
}
failed = (round_bad > 0);
num_failed += failed;
if (failed || verbose) {
printf(" E8M0 rounding boundary: %s\n", RESULT_STR[failed]);
}
}
// Element converter exhaustive round-trip: quantize(dequantize(i)) == i for all valid bit patterns.
// Catches asymmetries between the to_float and to_quant paths.
{
struct rt_test {
const char * name;
int count;
float (*to_float)(uint8_t);
uint8_t (*to_quant)(float);
uint8_t nan_bits; // bit pattern for NaN (0 = no NaN in format)
};
const rt_test rt_tests[] = {
{ "fp8_e4m3", 256, fp8_e4m3_to_float, float_to_fp8_e4m3_rn, 0x7F },
{ "fp8_e5m2", 256, fp8_e5m2_to_float, float_to_fp8_e5m2_rn, 0 },
{ "fp6_e2m3", 64, fp6_e2m3_to_float, float_to_fp6_e2m3_rn, 0 },
{ "fp6_e3m2", 64, fp6_e3m2_to_float, float_to_fp6_e3m2_rn, 0 },
};
for (const auto & t : rt_tests) {
int rt_bad = 0;
for (int i = 0; i < t.count; i++) {
if ((uint8_t)i == t.nan_bits) continue; // skip NaN -quantize(NaN) is implementation-defined
float f = t.to_float((uint8_t)i);
if (isnan(f) || isinf(f)) continue; // E5M2 Inf/NaN
uint8_t back = t.to_quant(f);
// Negative zero may round-trip to positive zero -both are valid
if (back != (uint8_t)i && !(f == 0.0f && t.to_float(back) == 0.0f)) {
if (rt_bad == 0 || verbose) {
printf(" %s roundtrip: 0x%02X → %.6g → 0x%02X\n",
t.name, i, f, back);
}
rt_bad++;
}
}
failed = (rt_bad > 0);
num_failed += failed;
if (failed || verbose) {
printf("%5s converter round-trip: %s (%d/%d survived)\n",
t.name, RESULT_STR[failed], t.count - rt_bad, t.count);
}
}
// FP4 E2M1: uses static inline converters (not GGML_API wrappers), only 16 values
{
int rt_bad = 0;
for (int i = 0; i < 16; i++) {
float f = ggml_mxfp_fp4_e2m1_to_float((uint8_t)i);
uint8_t back = ggml_mxfp_float_to_fp4_e2m1(f);
if (back != (uint8_t)i && !(f == 0.0f && ggml_mxfp_fp4_e2m1_to_float(back) == 0.0f)) {
if (rt_bad == 0 || verbose) {
printf(" fp4_e2m1 roundtrip: 0x%02X → %.6g → 0x%02X\n", i, f, back);
}
rt_bad++;
}
}
failed = (rt_bad > 0);
num_failed += failed;
if (failed || verbose) {
printf("fp4_e2m1 converter round-trip: %s (%d/16 survived)\n",
RESULT_STR[failed], 16 - rt_bad);
}
}
}
// E8M0 scale computation: verify base exponent is reasonable for various amax values
{
const float test_amax[] = { 0.001f, 0.1f, 1.0f, 6.0f, 100.0f, 448.0f, 10000.0f };
int bad = 0;
for (float amax : test_amax) {
// ggml_mxfp_e8m0_base_estimate returns unclamped e_base
int e_base = ggml_mxfp_e8m0_base_estimate(amax, 0);
if (e_base < 1 || e_base > 254) {
if (bad == 0 || verbose) {
printf(" E8M0 bad e_base=%d for amax=%.4f\n", e_base, amax);
}
bad++;
continue;
}
float scale = ggml_mxfp_e8m0_to_fp32((uint8_t)e_base);
// Scale should be within 2x of amax (rough sanity check)
float ratio = amax / scale;
if (ratio < 0.25f || ratio > 4.0f) {
if (bad == 0 || verbose) {
printf(" E8M0 scale=%.6g for amax=%.4f, ratio=%.4f (expected ~1)\n",
scale, amax, ratio);
}
bad++;
}
}
failed = (bad > 0);
num_failed += failed;
if (failed || verbose) {
printf(" E8M0 scale sanity check: %s (%d/%d passed)\n",
RESULT_STR[failed], (int)(sizeof(test_amax)/sizeof(test_amax[0])) - bad,
(int)(sizeof(test_amax)/sizeof(test_amax[0])));
}
}
// SoA layout: verify offset macros produce correct byte positions
{
const struct { ggml_type type; int qs_per_block; } soa_types[] = {
{ GGML_TYPE_MXFP4_E2M1, MXFP4_SOA_QS_PER_BLOCK },
{ GGML_TYPE_MXFP8_E4M3, MXFP8_SOA_QS_PER_BLOCK },
{ GGML_TYPE_MXFP6_E2M3, MXFP6_SOA_QS_PER_BLOCK },
};
for (const auto & st : soa_types) {
for (int nblocks : { 1, 4, 8, 32 }) {
size_t expected_e8m0_off = (size_t)nblocks * st.qs_per_block;
size_t actual_e8m0_off = MXFP_SOA_E8M0_OFFSET(nblocks, st.qs_per_block);
size_t total = actual_e8m0_off + nblocks; // e8m0 region = 1 byte per block
size_t row_size = ggml_row_size(st.type, nblocks * 32);
bool offset_ok = (actual_e8m0_off == expected_e8m0_off);
bool size_ok = (total == row_size);
if (!offset_ok || !size_ok) {
failed = true;
num_failed++;
if (verbose) {
printf(" %s SoA layout nblocks=%d: e8m0_off=%zu (expected %zu), total=%zu (row_size=%zu)\n",
ggml_type_name(st.type), nblocks, actual_e8m0_off, expected_e8m0_off, total, row_size);
}
}
}
}
if (verbose) {
printf(" SoA layout offset check: %s\n", RESULT_STR[0]); // only prints failures above
}
}
// block size consistency
{
failed = !(QK_MXFP4 == 32 && QK_MXFP8 == 32 && QK_MXFP6 == 32);
num_failed += failed;
if (failed || verbose) {
printf(" MXFP block size == 32: %s (QK4=%d, QK8=%d, QK6=%d)\n",
RESULT_STR[failed], QK_MXFP4, QK_MXFP8, QK_MXFP6);
}
}
// EMAX_OFFSET produces valid E8M0 for each format's max finite value
{
struct emax_check {
const char * name;
int emax_offset;
float max_finite; // from LUT / converter
};
const emax_check emax_checks[] = {
{ "fp4_e2m1", MXFP4_E2M1_EMAX_OFFSET, 6.0f },
{ "fp6_e2m3", MXFP6_E2M3_EMAX_OFFSET, 7.5f },
{ "fp6_e3m2", MXFP6_E3M2_EMAX_OFFSET, 28.0f },
{ "fp8_e4m3", MXFP8_E4M3_EMAX_OFFSET, 448.0f },
{ "fp8_e5m2", MXFP8_E5M2_EMAX_OFFSET, 57344.0f },
};
int emax_bad = 0;
for (const auto & e : emax_checks) {
// When amax == max_finite, the base estimate must produce a valid E8M0 (1..254)
int e_base = ggml_mxfp_e8m0_base_estimate(e.max_finite, e.emax_offset);
if (e_base < 1 || e_base > 254) {
if (emax_bad == 0 || verbose) {
printf(" %s emax_offset=%d: max_finite=%.1f gives e_base=%d (out of range)\n",
e.name, e.emax_offset, e.max_finite, e_base);
}
emax_bad++;
}
}
failed = (emax_bad > 0);
num_failed += failed;
if (failed || verbose) {
printf(" EMAX_OFFSET vs format max: %s\n", RESULT_STR[failed]);
}
}
// MXFP4 AoS vs SoA: two independent code paths, same result
{
const int nelems = 64; // 2 blocks
float input[64];
for (int i = 0; i < 64; i++) {
input[i] = 0.5f + 2.0f * sinf(i * 0.7f + 0.3f);
}
// Quantize and dequant via AoS (block_mxfp4 structs)
std::vector<block_mxfp4> aos_q(nelems / QK_MXFP4);
std::vector<float> aos_out(nelems);
quantize_row_mxfp4_ref(input, aos_q.data(), nelems);
dequantize_row_mxfp4(aos_q.data(), aos_out.data(), nelems);
// Quantize and dequant via SoA
const size_t soa_buf_size = ggml_row_size(GGML_TYPE_MXFP4_E2M1, nelems);
std::vector<uint8_t> soa_q(soa_buf_size);
std::vector<float> soa_out(nelems);
quantize_row_mxfp4_soa(input, soa_q.data(), nelems);
dequantize_row_mxfp4_soa(soa_q.data(), soa_out.data(), nelems);
// Compare: both paths should produce identical results
int mismatches = 0;
for (int i = 0; i < nelems; i++) {
uint32_t a, b;
memcpy(&a, &aos_out[i], 4);
memcpy(&b, &soa_out[i], 4);
if (a != b) {
if (mismatches == 0 || verbose) {
printf(" mxfp4 AoS/SoA mismatch at [%d]: AoS=%.8g, SoA=%.8g\n",
i, aos_out[i], soa_out[i]);
}
mismatches++;
}
}
failed = (mismatches > 0);
num_failed += failed;
if (failed || verbose) {
printf("mxfp4 AoS vs SoA cross-check: %s (%d/%d match)\n",
RESULT_STR[failed], nelems - mismatches, nelems);
}
}
// Hadamard + quantize + dequant + Hadamard roundtrip (KV cache write/read path)
{
struct hadamard_pipeline_check {
const char * name;
ggml_type type;
float max_err;
};
const hadamard_pipeline_check pipeline_checks[] = {
{ "mxfp4", GGML_TYPE_MXFP4_E2M1, MAX_MXFP_PIPELINE_ERROR_MXFP4 },
{ "mxfp8", GGML_TYPE_MXFP8_E4M3, MAX_MXFP_PIPELINE_ERROR_MXFP8 },
{ "mxfp6", GGML_TYPE_MXFP6_E2M3, MAX_MXFP_PIPELINE_ERROR_MXFP6 },
};
for (const auto & p : pipeline_checks) {
const auto * cpu = ggml_get_type_traits_cpu(p.type);
std::vector<float> original(test_size);
std::vector<float> rotated(test_size);
std::vector<float> recovered(test_size);
generate_data(2.0, test_size, original.data());
// Write path: Hadamard each block, then quantize
memcpy(rotated.data(), original.data(), test_size * sizeof(float));
for (size_t b = 0; b < test_size / 32; b++) {
ggml_hadamard_32_inplace(&rotated[b * 32]);
}
const size_t buf_size = ggml_row_size(p.type, test_size);
std::vector<uint8_t> qbuf(buf_size);
cpu->from_float_soa(rotated.data(), qbuf.data(), test_size);
// Read path: dequant, then Hadamard each block (self-inverse)
cpu->to_float_soa(qbuf.data(), recovered.data(), test_size);
for (size_t b = 0; b < test_size / 32; b++) {
ggml_hadamard_32_inplace(&recovered[b * 32]);
}
float err = mxfp_rmse(original.data(), recovered.data(), test_size);
failed = !(err < p.max_err);
num_failed += failed;
if (failed || verbose) {
printf("%5s Hadamard pipeline roundtrip: %s (err=%.6f, max=%.6f)\n",
p.name, RESULT_STR[failed], err, p.max_err);
}
}
}
// Hadamard known output: H([1,0,...,0]) = [1/sqrt(32), ...]
{
float unit[32] = {};
unit[0] = 1.0f;
ggml_hadamard_32_inplace(unit);
const float expected = MXFP_HADAMARD_32_NORM; // 1/sqrt(32)
float max_err = 0.0f;
for (int i = 0; i < 32; i++) {
float err = fabsf(unit[i] - expected);
if (err > max_err) max_err = err;
}
failed = !(max_err < 1e-7f);
num_failed += failed;
if (failed || verbose) {
printf("hadamard unit vector: %s (max_err=%.2e, expected %.8f)\n",
RESULT_STR[failed], max_err, expected);
}
}
// zero block produces E8M0=0
{
float zeros[32] = {};
const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP8_E4M3, 32);
std::vector<uint8_t> buf(buf_size, 0xFF); // fill with 0xFF to detect non-writes
quantize_row_mxfp8_soa(zeros, buf.data(), 32);
// E8M0 scale is at offset MXFP8_SOA_QS_PER_BLOCK (32) for 1 block
uint8_t e8m0 = buf[MXFP8_SOA_QS_PER_BLOCK];
failed = (e8m0 != 0);
num_failed += failed;
if (failed || verbose) {
printf(" zero block E8M0: %s (e8m0=%d, expected 0)\n",
RESULT_STR[failed], e8m0);
}
}
// SoA format spec: quantize, manually walk raw bytes, compare against reference dequant
{
// 2 blocks, asymmetric data
const int nblocks = 2;
const int nelems = nblocks * 32;
float input[64];
for (int i = 0; i < 64; i++) {
// Block 0: small values, Block 1: large values -different E8M0 scales
input[i] = (i < 32) ? 0.1f * sinf(i + 0.5f) : 3.0f * cosf(i + 0.5f);
}
// MXFP4
{
const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP4_E2M1, nelems);
std::vector<uint8_t> buf(buf_size);
std::vector<float> ref_out(nelems);
std::vector<float> manual_out(nelems);
quantize_row_mxfp4_soa(input, buf.data(), nelems);
dequantize_row_mxfp4_soa(buf.data(), ref_out.data(), nelems);
// manual dequant from raw bytes
const uint8_t * qs = buf.data();
const uint8_t * e8m0 = buf.data() + MXFP_SOA_E8M0_OFFSET(nblocks, MXFP4_SOA_QS_PER_BLOCK);
for (int b = 0; b < nblocks; b++) {
const float d = ggml_mxfp_e8m0_to_fp32_half(e8m0[b]);
const uint8_t * block_qs = qs + MXFP_SOA_QS_OFFSET(b, MXFP4_SOA_QS_PER_BLOCK);
for (int j = 0; j < 16; j++) {
// low nibble = first half, high nibble = second half
int8_t v_lo = kvalues_mxfp4[block_qs[j] & 0x0F];
int8_t v_hi = kvalues_mxfp4[block_qs[j] >> 4];
manual_out[b*32 + j] = v_lo * d;
manual_out[b*32 + j + 16] = v_hi * d;
}
}
int mismatches = 0;
for (int i = 0; i < nelems; i++) {
uint32_t a, b;
memcpy(&a, &ref_out[i], 4);
memcpy(&b, &manual_out[i], 4);
if (a != b) mismatches++;
}
failed = (mismatches > 0);
num_failed += failed;
if (failed || verbose) {
printf("mxfp4 SoA format spec: %s (%d/%d match)\n",
RESULT_STR[failed], nelems - mismatches, nelems);
}
}
// MXFP8
{
const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP8_E4M3, nelems);
std::vector<uint8_t> buf(buf_size);
std::vector<float> ref_out(nelems);
std::vector<float> manual_out(nelems);
quantize_row_mxfp8_soa(input, buf.data(), nelems);
dequantize_row_mxfp8_soa(buf.data(), ref_out.data(), nelems);
const uint8_t * qs = buf.data();
const uint8_t * e8m0 = buf.data() + MXFP_SOA_E8M0_OFFSET(nblocks, MXFP8_SOA_QS_PER_BLOCK);
for (int b = 0; b < nblocks; b++) {
const float d = ggml_mxfp_e8m0_to_fp32(e8m0[b]);
const uint8_t * block_qs = qs + MXFP_SOA_QS_OFFSET(b, MXFP8_SOA_QS_PER_BLOCK);
for (int j = 0; j < 32; j++) {
// one byte per element
manual_out[b*32 + j] = fp8_e4m3_to_float(block_qs[j]) * d;
}
}
int mismatches = 0;
for (int i = 0; i < nelems; i++) {
uint32_t a, b;
memcpy(&a, &ref_out[i], 4);
memcpy(&b, &manual_out[i], 4);
if (a != b) mismatches++;
}
failed = (mismatches > 0);
num_failed += failed;
if (failed || verbose) {
printf("mxfp8 SoA format spec: %s (%d/%d match)\n",
RESULT_STR[failed], nelems - mismatches, nelems);
}
}
// MXFP6
{
const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP6_E2M3, nelems);
std::vector<uint8_t> buf(buf_size);
std::vector<float> ref_out(nelems);
std::vector<float> manual_out(nelems);
quantize_row_mxfp6_soa(input, buf.data(), nelems);
dequantize_row_mxfp6_soa(buf.data(), ref_out.data(), nelems);
const uint8_t * qs = buf.data();
const uint8_t * e8m0 = buf.data() + MXFP_SOA_E8M0_OFFSET(nblocks, MXFP6_SOA_QS_PER_BLOCK);
for (int b = 0; b < nblocks; b++) {
const float d = ggml_mxfp_e8m0_to_fp32(e8m0[b]);
const uint8_t * block_qs = qs + MXFP_SOA_QS_OFFSET(b, MXFP6_SOA_QS_PER_BLOCK);
for (int j = 0; j < 32; j += 4) {
// 4 elements packed into 3 bytes
uint8_t vals[4];
unpack_fp6x4(&block_qs[j * 3 / 4], vals);
for (int k = 0; k < 4; k++) {
manual_out[b*32 + j + k] = fp6_e2m3_to_float(vals[k]) * d;
}
}
}
int mismatches = 0;
for (int i = 0; i < nelems; i++) {
uint32_t a, b;
memcpy(&a, &ref_out[i], 4);
memcpy(&b, &manual_out[i], 4);
if (a != b) mismatches++;
}
failed = (mismatches > 0);
num_failed += failed;
if (failed || verbose) {
printf("mxfp6 SoA format spec: %s (%d/%d match)\n",
RESULT_STR[failed], nelems - mismatches, nelems);
}
}
}
if (num_failed || verbose) {
printf("%d tests failed\n", num_failed);
}