This commit is contained in:
Tim Burke 2026-03-24 05:35:01 +02:00 committed by GitHub
commit aba3778ca8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 3208 additions and 192 deletions

View File

@ -398,6 +398,9 @@ const std::vector<ggml_type> kv_cache_types = {
GGML_TYPE_IQ4_NL,
GGML_TYPE_Q5_0,
GGML_TYPE_Q5_1,
GGML_TYPE_MXFP4,
GGML_TYPE_MXFP8,
GGML_TYPE_MXFP6,
};
static ggml_type kv_cache_type_from_str(const std::string & s) {

View File

@ -115,9 +115,12 @@ extern "C" {
struct ggml_type_traits_cpu {
ggml_from_float_t from_float;
ggml_to_float_t to_float;
ggml_from_float_t from_float_soa; // SoA quantize (MXFP flash attention layout)
ggml_to_float_t to_float_soa; // SoA dequant (MXFP flash attention layout)
ggml_vec_dot_t vec_dot;
enum ggml_type vec_dot_type;
int64_t nrows; // number of rows to process simultaneously
int64_t nrows;
};
GGML_BACKEND_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type);

View File

@ -426,9 +426,14 @@ extern "C" {
// GGML_TYPE_IQ4_NL_4_4 = 36,
// GGML_TYPE_IQ4_NL_4_8 = 37,
// GGML_TYPE_IQ4_NL_8_8 = 38,
GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale)
GGML_TYPE_COUNT = 41,
GGML_TYPE_MXFP4_E2M1 = 39, // MX FP4 E2M1
GGML_TYPE_MXFP4 = GGML_TYPE_MXFP4_E2M1, // compat alias
GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale)
GGML_TYPE_MXFP8_E4M3 = 41, // MX FP8 E4M3
GGML_TYPE_MXFP8 = GGML_TYPE_MXFP8_E4M3, // compat alias
GGML_TYPE_MXFP6_E2M3 = 42, // MX FP6 E2M3
GGML_TYPE_MXFP6 = GGML_TYPE_MXFP6_E2M3, // compat alias
GGML_TYPE_COUNT = 43,
};
// precision
@ -463,7 +468,8 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors
GGML_FTYPE_MOSTLY_MXFP4_E2M1 = 25, // except 1d tensors
GGML_FTYPE_MOSTLY_MXFP4 = GGML_FTYPE_MOSTLY_MXFP4_E2M1, // compat alias
GGML_FTYPE_MOSTLY_NVFP4 = 26, // except 1d tensors
};
@ -748,6 +754,9 @@ extern "C" {
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
GGML_API bool ggml_is_quantized(enum ggml_type type);
GGML_API bool ggml_is_type_mxfp(enum ggml_type type);
GGML_API bool ggml_mxfp_use_hadamard(enum ggml_type type);
GGML_API int ggml_mxfp_qs_per_block(enum ggml_type type); // quantized bytes per 32-element block (SoA qs region)
// TODO: temporary until model loading of ggml examples is refactored
GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);

View File

@ -71,6 +71,8 @@ typedef sycl::half2 ggml_half2;
#define GGML_COMMON_DECL
#endif
#define MXFP_HADAMARD_32_NORM 0.17677669529663689f // 1/sqrt(32)
#if defined(GGML_COMMON_DECL)
#ifndef __cplusplus
@ -105,6 +107,12 @@ typedef sycl::half2 ggml_half2;
#define QI_NVFP4 (QK_NVFP4 / (4 * QR_NVFP4))
#define QR_NVFP4 2
#define QI_MXFP8 (QK_MXFP8 / (4 * QR_MXFP8))
#define QR_MXFP8 1
#define QI_MXFP6 (QK_MXFP6 / (4 * QR_MXFP6))
#define QR_MXFP6 1
#define QI5_0 (QK5_0 / (4 * QR5_0))
#define QR5_0 2
@ -190,6 +198,103 @@ typedef struct {
} block_q4_1;
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
// E8M0 shared exponent constants (OCP MX v1.0 SS5.3).
// EMAX_OFFSET ≈ log2(max_finite), used by round(log2(amax)) base estimate.
#define MXFP4_E2M1_EMAX_OFFSET 2 // floor(log2(6.0)) = 2
#define MXFP6_E2M3_EMAX_OFFSET 3 // ceil(log2(7.5)) = 3
#define MXFP6_E3M2_EMAX_OFFSET 5 // ceil(log2(28.0)) = 5
#define MXFP8_E4M3_EMAX_OFFSET 8 // floor(log2(448)) = 8
#define MXFP8_E5M2_EMAX_OFFSET 16 // ceil(log2(57344)) = 16
// MXFP type properties -- shared across all backends.
#define MXFP_BITS_PER_ELEM_E2M1 4
#define MXFP_BITS_PER_ELEM_E4M3 8
#define MXFP_BITS_PER_ELEM_E5M2 8
#define MXFP_BITS_PER_ELEM_E2M3 6
#define MXFP_BITS_PER_ELEM_E3M2 6
#define MXFP_QS_PER_BLOCK_E2M1 16 // 32 * 4 / 8
#define MXFP_QS_PER_BLOCK_E4M3 32 // 32 * 8 / 8
#define MXFP_QS_PER_BLOCK_E5M2 32
#define MXFP_QS_PER_BLOCK_E2M3 24 // 32 * 6 / 8
#define MXFP_QS_PER_BLOCK_E3M2 24
#define MXFP_USE_HADAMARD_E2M1 1
#define MXFP_USE_HADAMARD_E4M3 1
#define MXFP_USE_HADAMARD_E5M2 0
#define MXFP_USE_HADAMARD_E2M3 1
#define MXFP_USE_HADAMARD_E3M2 0
// SIMD dequant constants for IEEE-754 bit reconstruction of FP8/FP6 elements.
// For a format with sign(1), exp(E), mant(M), bias(B):
// EXP_MASK = (1<<E)-1 MANT_MASK = (1<<M)-1 EXP_SHIFT = M
// IEEE_EXP_OFF = 127-B MANT_SHIFT = 23-M SUB_SCALE = 2^(1-B-M)
// Used by x86 AVX2 and ARM NEON vectorized dequant in dot product, AoS dequant, SoA dequant.
#define MXFP8_E4M3_EXP_MASK 0xF
#define MXFP8_E4M3_MANT_MASK 0x7
#define MXFP8_E4M3_EXP_SHIFT 3
#define MXFP8_E4M3_IEEE_EXP_OFF 120
#define MXFP8_E4M3_MANT_SHIFT 20
#define MXFP8_E4M3_SUB_SCALE (1.0f/512.0f)
#define MXFP8_E5M2_EXP_MASK 0x1F
#define MXFP8_E5M2_MANT_MASK 0x3
#define MXFP8_E5M2_EXP_SHIFT 2
#define MXFP8_E5M2_IEEE_EXP_OFF 112
#define MXFP8_E5M2_MANT_SHIFT 21
#define MXFP8_E5M2_SUB_SCALE (1.0f/65536.0f)
#define MXFP6_E2M3_EXP_MASK 0x3
#define MXFP6_E2M3_MANT_MASK 0x7
#define MXFP6_E2M3_EXP_SHIFT 3
#define MXFP6_E2M3_IEEE_EXP_OFF 126
#define MXFP6_E2M3_MANT_SHIFT 20
#define MXFP6_E2M3_SUB_SCALE (1.0f/8.0f)
#define MXFP6_E3M2_EXP_MASK 0x7
#define MXFP6_E3M2_MANT_MASK 0x3
#define MXFP6_E3M2_EXP_SHIFT 2
#define MXFP6_E3M2_IEEE_EXP_OFF 124
#define MXFP6_E3M2_MANT_SHIFT 21
#define MXFP6_E3M2_SUB_SCALE (1.0f/16.0f)
// MXFP dequant traits for IEEE-754 bit reconstruction (FP8/FP6).
typedef struct {
int exp_mask;
int mant_mask;
int exp_shift;
int ieee_exp_off;
int mant_shift;
float sub_scale;
int sign_mask; // 0x80 for 8-bit, 0x20 for 6-bit
int sign_shift; // 24 for 8-bit, 26 for 6-bit
int qs_per_block;
int emax_offset;
} mxfp_dequant_traits_t;
#if defined(GGML_COMMON_IMPL)
static const mxfp_dequant_traits_t MXFP_TRAITS_E4M3 = {
MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT,
MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE,
0x80, 24, MXFP_QS_PER_BLOCK_E4M3, MXFP8_E4M3_EMAX_OFFSET
};
static const mxfp_dequant_traits_t MXFP_TRAITS_E5M2 = {
MXFP8_E5M2_EXP_MASK, MXFP8_E5M2_MANT_MASK, MXFP8_E5M2_EXP_SHIFT,
MXFP8_E5M2_IEEE_EXP_OFF, MXFP8_E5M2_MANT_SHIFT, MXFP8_E5M2_SUB_SCALE,
0x80, 24, MXFP_QS_PER_BLOCK_E5M2, MXFP8_E5M2_EMAX_OFFSET
};
static const mxfp_dequant_traits_t MXFP_TRAITS_E2M3 = {
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE,
0x20, 26, MXFP_QS_PER_BLOCK_E2M3, MXFP6_E2M3_EMAX_OFFSET
};
static const mxfp_dequant_traits_t MXFP_TRAITS_E3M2 = {
MXFP6_E3M2_EXP_MASK, MXFP6_E3M2_MANT_MASK, MXFP6_E3M2_EXP_SHIFT,
MXFP6_E3M2_IEEE_EXP_OFF, MXFP6_E3M2_MANT_SHIFT, MXFP6_E3M2_SUB_SCALE,
0x20, 26, MXFP_QS_PER_BLOCK_E3M2, MXFP6_E3M2_EMAX_OFFSET
};
#endif // GGML_COMMON_IMPL
#define QK_MXFP4 32
typedef struct {
uint8_t e; // E8M0
@ -205,6 +310,29 @@ typedef struct {
} block_nvfp4;
static_assert(sizeof(block_nvfp4) == sizeof(uint8_t)*(QK_NVFP4/QK_NVFP4_SUB) + QK_NVFP4/2, "wrong nvfp4 block size/padding");
#define QK_MXFP8 32
typedef struct {
uint8_t e; // E8M0 shared exponent
uint8_t qs[QK_MXFP8]; // 32 FP8 values (1 byte each), used for E4M3 and E5M2
} block_mxfp8;
static_assert(sizeof(block_mxfp8) == sizeof(uint8_t) + QK_MXFP8, "wrong mxfp8 block size/padding");
#define QK_MXFP6 32
typedef struct {
uint8_t e; // E8M0 shared exponent
uint8_t qs[QK_MXFP6 * 6 / 8]; // 24 bytes: 32 six-bit values tightly packed, used for E2M3 and E3M2
} block_mxfp6;
static_assert(sizeof(block_mxfp6) == sizeof(uint8_t) + QK_MXFP6 * 6 / 8, "wrong mxfp6 block size/padding");
// SoA layout for MXFP KV cache: [qs blocks][e8m0 scales]
#define MXFP4_SOA_QS_PER_BLOCK MXFP_QS_PER_BLOCK_E2M1
#define MXFP8_SOA_QS_PER_BLOCK MXFP_QS_PER_BLOCK_E4M3
#define MXFP6_SOA_QS_PER_BLOCK MXFP_QS_PER_BLOCK_E2M3
// SoA offset helpers
#define MXFP_SOA_QS_OFFSET(block_idx, qs_per_block) ((block_idx) * (qs_per_block))
#define MXFP_SOA_E8M0_OFFSET(nblocks, qs_per_block) ((nblocks) * (qs_per_block))
#define QK5_0 32
typedef struct {
ggml_half d; // delta
@ -445,18 +573,47 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_
#ifndef GGML_COMMON_IMPL
// NaN/Infinity for FP8 LUT initializers (CPU-only, guarded out of GPU builds).
#if defined(_MSC_VER) && !defined(__clang__)
#include <math.h>
#define GGML_TABLE_NAN NAN
#define GGML_TABLE_INFINITY INFINITY
#else
#define GGML_TABLE_NAN __builtin_nanf("")
#define GGML_TABLE_INFINITY __builtin_inff()
#endif
#if defined(GGML_COMMON_IMPL_C)
#include <stdint.h>
#include <string.h>
#include <math.h>
#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
#define GGML_TABLE_END() };
#define GGML_MXFP_FUNC static inline
static inline uint32_t ggml_mxfp_f32_as_u32_(float f) { uint32_t u; memcpy(&u, &f, sizeof(u)); return u; }
static inline float ggml_mxfp_u32_as_f32_(uint32_t u) { float f; memcpy(&f, &u, sizeof(f)); return f; }
#define GGML_MXFP_F32_AS_U32(f) ggml_mxfp_f32_as_u32_(f)
#define GGML_MXFP_U32_AS_F32(u) ggml_mxfp_u32_as_f32_(u)
#define GGML_MXFP_LDEXPF(x, n) ldexpf(x, n)
#define GGML_MXFP_THREAD
#define GGML_MXFP_UNROLL
#define GGML_COMMON_IMPL
#elif defined(GGML_COMMON_IMPL_CPP)
#include <cstdint>
#include <cstring>
#include <cmath>
#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
#define GGML_TABLE_END() };
#define GGML_MXFP_FUNC static inline
static inline uint32_t ggml_mxfp_f32_as_u32_(float f) { uint32_t u; memcpy(&u, &f, sizeof(u)); return u; }
static inline float ggml_mxfp_u32_as_f32_(uint32_t u) { float f; memcpy(&f, &u, sizeof(f)); return f; }
#define GGML_MXFP_F32_AS_U32(f) ggml_mxfp_f32_as_u32_(f)
#define GGML_MXFP_U32_AS_F32(u) ggml_mxfp_u32_as_f32_(u)
#define GGML_MXFP_LDEXPF(x, n) ldexpf(x, n)
#define GGML_MXFP_THREAD
#define GGML_MXFP_UNROLL
#define GGML_COMMON_IMPL
#elif defined(GGML_COMMON_IMPL_METAL)
@ -464,21 +621,43 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_
#define GGML_TABLE_BEGIN(type, name, size) static const constant type name[size] = {
#define GGML_TABLE_END() };
#define GGML_MXFP_FUNC static inline
#define GGML_MXFP_F32_AS_U32(f) as_type<uint32_t>(f)
#define GGML_MXFP_U32_AS_F32(u) as_type<float>(u)
#define GGML_MXFP_LDEXPF(x, n) metal::ldexp(x, n)
#define GGML_MXFP_THREAD thread
#define GGML_MXFP_UNROLL _Pragma("unroll")
#define GGML_COMMON_IMPL
#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) || defined(GGML_COMMON_IMPL_MUSA)
#include <cstdint>
#include <cstring>
#define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = {
#define GGML_TABLE_END() };
#define GGML_MXFP_FUNC static __device__ __forceinline__
#define GGML_MXFP_F32_AS_U32(f) __float_as_uint(f)
#define GGML_MXFP_U32_AS_F32(u) __uint_as_float(u)
#define GGML_MXFP_LDEXPF(x, n) ldexpf(x, n)
#define GGML_MXFP_THREAD
#define GGML_MXFP_UNROLL _Pragma("unroll")
#define GGML_COMMON_IMPL
#elif defined(GGML_COMMON_IMPL_SYCL)
#include <cstdint>
#include <cstring>
#include <cmath>
#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
#define GGML_TABLE_END() };
#define GGML_MXFP_FUNC static inline
static inline uint32_t ggml_mxfp_f32_as_u32_(float f) { uint32_t u; memcpy(&u, &f, sizeof(u)); return u; }
static inline float ggml_mxfp_u32_as_f32_(uint32_t u) { float f; memcpy(&f, &u, sizeof(f)); return f; }
#define GGML_MXFP_F32_AS_U32(f) ggml_mxfp_f32_as_u32_(f)
#define GGML_MXFP_U32_AS_F32(u) ggml_mxfp_u32_as_f32_(u)
#define GGML_MXFP_LDEXPF(x, n) ldexpf(x, n)
#define GGML_MXFP_THREAD
#define GGML_MXFP_UNROLL
#define GGML_COMMON_IMPL
#endif
@ -1100,12 +1279,410 @@ GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
GGML_TABLE_END()
// e2m1 values (doubled)
// Canonical E2M1 values (true FP4 magnitudes).
// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
GGML_TABLE_BEGIN(float, kvalues_mxfp4_float, 16)
0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f,
-0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f,
GGML_TABLE_END()
// E2M1 values doubled (for integer arithmetic with half-scale).
GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
GGML_TABLE_END()
// FP6 E2M3 dequantization LUT: 6-bit value -> float.
GGML_TABLE_BEGIN(float, kvalues_mxfp6_e2m3, 64)
0.0f, 0.125f, 0.25f, 0.375f, 0.5f, 0.625f, 0.75f, 0.875f,
1.0f, 1.125f, 1.25f, 1.375f, 1.5f, 1.625f, 1.75f, 1.875f,
2.0f, 2.25f, 2.5f, 2.75f, 3.0f, 3.25f, 3.5f, 3.75f,
4.0f, 4.5f, 5.0f, 5.5f, 6.0f, 6.5f, 7.0f, 7.5f,
-0.0f, -0.125f, -0.25f, -0.375f, -0.5f, -0.625f, -0.75f, -0.875f,
-1.0f, -1.125f, -1.25f, -1.375f, -1.5f, -1.625f, -1.75f, -1.875f,
-2.0f, -2.25f, -2.5f, -2.75f, -3.0f, -3.25f, -3.5f, -3.75f,
-4.0f, -4.5f, -5.0f, -5.5f, -6.0f, -6.5f, -7.0f, -7.5f,
GGML_TABLE_END()
// FP6 E3M2 dequantization LUT: 6-bit value -> float. No NaN/Inf.
GGML_TABLE_BEGIN(float, kvalues_mxfp6_e3m2, 64)
0.0f, 0.0625f, 0.125f, 0.1875f, 0.25f, 0.3125f, 0.375f, 0.4375f,
0.5f, 0.625f, 0.75f, 0.875f, 1.0f, 1.25f, 1.5f, 1.75f,
2.0f, 2.5f, 3.0f, 3.5f, 4.0f, 5.0f, 6.0f, 7.0f,
8.0f, 10.0f, 12.0f, 14.0f, 16.0f, 20.0f, 24.0f, 28.0f,
-0.0f, -0.0625f, -0.125f,-0.1875f, -0.25f,-0.3125f, -0.375f,-0.4375f,
-0.5f, -0.625f, -0.75f, -0.875f, -1.0f, -1.25f, -1.5f, -1.75f,
-2.0f, -2.5f, -3.0f, -3.5f, -4.0f, -5.0f, -6.0f, -7.0f,
-8.0f, -10.0f, -12.0f, -14.0f, -16.0f, -20.0f, -24.0f, -28.0f,
GGML_TABLE_END()
// FP8 E4M3/E5M2 LUTs contain NaN/Inf which cannot be constexpr-initialized in
// __device__ tables. GPU backends use the converter functions instead.
#if !defined(GGML_COMMON_DECL_CUDA) && !defined(GGML_COMMON_DECL_HIP) && !defined(GGML_COMMON_DECL_MUSA)
// FP8 E4M3 dequantization LUT: byte -> float. Entry 127 = 448 (max finite), 255 = NaN.
GGML_TABLE_BEGIN(float, kvalues_mxfp8_e4m3, 256)
0.0f, 0.001953125f, 0.00390625f, 0.005859375f, 0.0078125f, 0.009765625f, 0.01171875f, 0.013671875f,
0.015625f, 0.017578125f, 0.01953125f, 0.021484375f, 0.0234375f, 0.025390625f, 0.02734375f, 0.029296875f,
0.03125f, 0.03515625f, 0.0390625f, 0.04296875f, 0.046875f, 0.05078125f, 0.0546875f, 0.05859375f,
0.0625f, 0.0703125f, 0.078125f, 0.0859375f, 0.09375f, 0.1015625f, 0.109375f, 0.1171875f,
0.125f, 0.140625f, 0.15625f, 0.171875f, 0.1875f, 0.203125f, 0.21875f, 0.234375f,
0.25f, 0.28125f, 0.3125f, 0.34375f, 0.375f, 0.40625f, 0.4375f, 0.46875f,
0.5f, 0.5625f, 0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f, 0.9375f,
1.0f, 1.125f, 1.25f, 1.375f, 1.5f, 1.625f, 1.75f, 1.875f,
2.0f, 2.25f, 2.5f, 2.75f, 3.0f, 3.25f, 3.5f, 3.75f,
4.0f, 4.5f, 5.0f, 5.5f, 6.0f, 6.5f, 7.0f, 7.5f,
8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
16.0f, 18.0f, 20.0f, 22.0f, 24.0f, 26.0f, 28.0f, 30.0f,
32.0f, 36.0f, 40.0f, 44.0f, 48.0f, 52.0f, 56.0f, 60.0f,
64.0f, 72.0f, 80.0f, 88.0f, 96.0f, 104.0f, 112.0f, 120.0f,
128.0f, 144.0f, 160.0f, 176.0f, 192.0f, 208.0f, 224.0f, 240.0f,
256.0f, 288.0f, 320.0f, 352.0f, 384.0f, 416.0f, 448.0f, GGML_TABLE_NAN,
-0.0f,-0.001953125f, -0.00390625f,-0.005859375f, -0.0078125f,-0.009765625f, -0.01171875f,-0.013671875f,
-0.015625f,-0.017578125f, -0.01953125f,-0.021484375f, -0.0234375f,-0.025390625f, -0.02734375f,-0.029296875f,
-0.03125f, -0.03515625f, -0.0390625f, -0.04296875f, -0.046875f, -0.05078125f, -0.0546875f, -0.05859375f,
-0.0625f, -0.0703125f, -0.078125f, -0.0859375f, -0.09375f, -0.1015625f, -0.109375f, -0.1171875f,
-0.125f, -0.140625f, -0.15625f, -0.171875f, -0.1875f, -0.203125f, -0.21875f, -0.234375f,
-0.25f, -0.28125f, -0.3125f, -0.34375f, -0.375f, -0.40625f, -0.4375f, -0.46875f,
-0.5f, -0.5625f, -0.625f, -0.6875f, -0.75f, -0.8125f, -0.875f, -0.9375f,
-1.0f, -1.125f, -1.25f, -1.375f, -1.5f, -1.625f, -1.75f, -1.875f,
-2.0f, -2.25f, -2.5f, -2.75f, -3.0f, -3.25f, -3.5f, -3.75f,
-4.0f, -4.5f, -5.0f, -5.5f, -6.0f, -6.5f, -7.0f, -7.5f,
-8.0f, -9.0f, -10.0f, -11.0f, -12.0f, -13.0f, -14.0f, -15.0f,
-16.0f, -18.0f, -20.0f, -22.0f, -24.0f, -26.0f, -28.0f, -30.0f,
-32.0f, -36.0f, -40.0f, -44.0f, -48.0f, -52.0f, -56.0f, -60.0f,
-64.0f, -72.0f, -80.0f, -88.0f, -96.0f, -104.0f, -112.0f, -120.0f,
-128.0f, -144.0f, -160.0f, -176.0f, -192.0f, -208.0f, -224.0f, -240.0f,
-256.0f, -288.0f, -320.0f, -352.0f, -384.0f, -416.0f, -448.0f, GGML_TABLE_NAN,
GGML_TABLE_END()
// FP8 E5M2 dequantization LUT: byte -> float. Entries 124-127 = {Inf, NaN, NaN, NaN}.
// Generated from ggml_mxfp_fp8_e5m2_to_float() with %.9e precision for exact float round-trip.
GGML_TABLE_BEGIN(float, kvalues_mxfp8_e5m2, 256)
0.000000000e+00f, 1.525878906e-05f, 3.051757812e-05f, 4.577636719e-05f, 6.103515625e-05f, 7.629394531e-05f, 9.155273438e-05f, 1.068115234e-04f,
1.220703125e-04f, 1.525878906e-04f, 1.831054688e-04f, 2.136230469e-04f, 2.441406250e-04f, 3.051757812e-04f, 3.662109375e-04f, 4.272460938e-04f,
4.882812500e-04f, 6.103515625e-04f, 7.324218750e-04f, 8.544921875e-04f, 9.765625000e-04f, 1.220703125e-03f, 1.464843750e-03f, 1.708984375e-03f,
1.953125000e-03f, 2.441406250e-03f, 2.929687500e-03f, 3.417968750e-03f, 3.906250000e-03f, 4.882812500e-03f, 5.859375000e-03f, 6.835937500e-03f,
7.812500000e-03f, 9.765625000e-03f, 1.171875000e-02f, 1.367187500e-02f, 1.562500000e-02f, 1.953125000e-02f, 2.343750000e-02f, 2.734375000e-02f,
3.125000000e-02f, 3.906250000e-02f, 4.687500000e-02f, 5.468750000e-02f, 6.250000000e-02f, 7.812500000e-02f, 9.375000000e-02f, 1.093750000e-01f,
1.250000000e-01f, 1.562500000e-01f, 1.875000000e-01f, 2.187500000e-01f, 2.500000000e-01f, 3.125000000e-01f, 3.750000000e-01f, 4.375000000e-01f,
5.000000000e-01f, 6.250000000e-01f, 7.500000000e-01f, 8.750000000e-01f, 1.000000000e+00f, 1.250000000e+00f, 1.500000000e+00f, 1.750000000e+00f,
2.000000000e+00f, 2.500000000e+00f, 3.000000000e+00f, 3.500000000e+00f, 4.000000000e+00f, 5.000000000e+00f, 6.000000000e+00f, 7.000000000e+00f,
8.000000000e+00f, 1.000000000e+01f, 1.200000000e+01f, 1.400000000e+01f, 1.600000000e+01f, 2.000000000e+01f, 2.400000000e+01f, 2.800000000e+01f,
3.200000000e+01f, 4.000000000e+01f, 4.800000000e+01f, 5.600000000e+01f, 6.400000000e+01f, 8.000000000e+01f, 9.600000000e+01f, 1.120000000e+02f,
1.280000000e+02f, 1.600000000e+02f, 1.920000000e+02f, 2.240000000e+02f, 2.560000000e+02f, 3.200000000e+02f, 3.840000000e+02f, 4.480000000e+02f,
5.120000000e+02f, 6.400000000e+02f, 7.680000000e+02f, 8.960000000e+02f, 1.024000000e+03f, 1.280000000e+03f, 1.536000000e+03f, 1.792000000e+03f,
2.048000000e+03f, 2.560000000e+03f, 3.072000000e+03f, 3.584000000e+03f, 4.096000000e+03f, 5.120000000e+03f, 6.144000000e+03f, 7.168000000e+03f,
8.192000000e+03f, 1.024000000e+04f, 1.228800000e+04f, 1.433600000e+04f, 1.638400000e+04f, 2.048000000e+04f, 2.457600000e+04f, 2.867200000e+04f,
3.276800000e+04f, 4.096000000e+04f, 4.915200000e+04f, 5.734400000e+04f, GGML_TABLE_INFINITY, GGML_TABLE_NAN, GGML_TABLE_NAN, GGML_TABLE_NAN,
-0.000000000e+00f,-1.525878906e-05f,-3.051757812e-05f,-4.577636719e-05f,-6.103515625e-05f,-7.629394531e-05f,-9.155273438e-05f,-1.068115234e-04f,
-1.220703125e-04f,-1.525878906e-04f,-1.831054688e-04f,-2.136230469e-04f,-2.441406250e-04f,-3.051757812e-04f,-3.662109375e-04f,-4.272460938e-04f,
-4.882812500e-04f,-6.103515625e-04f,-7.324218750e-04f,-8.544921875e-04f,-9.765625000e-04f,-1.220703125e-03f,-1.464843750e-03f,-1.708984375e-03f,
-1.953125000e-03f,-2.441406250e-03f,-2.929687500e-03f,-3.417968750e-03f,-3.906250000e-03f,-4.882812500e-03f,-5.859375000e-03f,-6.835937500e-03f,
-7.812500000e-03f,-9.765625000e-03f,-1.171875000e-02f,-1.367187500e-02f,-1.562500000e-02f,-1.953125000e-02f,-2.343750000e-02f,-2.734375000e-02f,
-3.125000000e-02f,-3.906250000e-02f,-4.687500000e-02f,-5.468750000e-02f,-6.250000000e-02f,-7.812500000e-02f,-9.375000000e-02f,-1.093750000e-01f,
-1.250000000e-01f,-1.562500000e-01f,-1.875000000e-01f,-2.187500000e-01f,-2.500000000e-01f,-3.125000000e-01f,-3.750000000e-01f,-4.375000000e-01f,
-5.000000000e-01f,-6.250000000e-01f,-7.500000000e-01f,-8.750000000e-01f,-1.000000000e+00f,-1.250000000e+00f,-1.500000000e+00f,-1.750000000e+00f,
-2.000000000e+00f,-2.500000000e+00f,-3.000000000e+00f,-3.500000000e+00f,-4.000000000e+00f,-5.000000000e+00f,-6.000000000e+00f,-7.000000000e+00f,
-8.000000000e+00f,-1.000000000e+01f,-1.200000000e+01f,-1.400000000e+01f,-1.600000000e+01f,-2.000000000e+01f,-2.400000000e+01f,-2.800000000e+01f,
-3.200000000e+01f,-4.000000000e+01f,-4.800000000e+01f,-5.600000000e+01f,-6.400000000e+01f,-8.000000000e+01f,-9.600000000e+01f,-1.120000000e+02f,
-1.280000000e+02f,-1.600000000e+02f,-1.920000000e+02f,-2.240000000e+02f,-2.560000000e+02f,-3.200000000e+02f,-3.840000000e+02f,-4.480000000e+02f,
-5.120000000e+02f,-6.400000000e+02f,-7.680000000e+02f,-8.960000000e+02f,-1.024000000e+03f,-1.280000000e+03f,-1.536000000e+03f,-1.792000000e+03f,
-2.048000000e+03f,-2.560000000e+03f,-3.072000000e+03f,-3.584000000e+03f,-4.096000000e+03f,-5.120000000e+03f,-6.144000000e+03f,-7.168000000e+03f,
-8.192000000e+03f,-1.024000000e+04f,-1.228800000e+04f,-1.433600000e+04f,-1.638400000e+04f,-2.048000000e+04f,-2.457600000e+04f,-2.867200000e+04f,
-3.276800000e+04f,-4.096000000e+04f,-4.915200000e+04f,-5.734400000e+04f, -GGML_TABLE_INFINITY, GGML_TABLE_NAN, GGML_TABLE_NAN, GGML_TABLE_NAN,
GGML_TABLE_END()
#endif // !CUDA && !HIP && !MUSA
// MXFP element converters -- portable IEEE-754 bit manipulation.
#if defined(GGML_MXFP_FUNC)
// FP4 E2M1: [S(1) | E(2) | M(1)], max normal = 6.0
GGML_MXFP_FUNC float ggml_mxfp_fp4_e2m1_to_float(uint8_t v) {
const float sign = (v & 0x8) ? -1.0f : 1.0f;
const int exp = (v >> 1) & 0x3;
const int mant = v & 0x1;
if (exp == 0) return sign * (float)mant * 0.5f;
return sign * (1.0f + mant * 0.5f) * (float)(1 << (exp - 1));
}
GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp4_e2m1(float x) {
uint8_t sign = 0;
if (x < 0) { sign = 0x8; x = -x; }
if (x == 0) return sign;
if (x >= 6.0f) return sign | 0x7; // max finite
if (x < 0.25f) return sign | 0x0; // 0
else if (x < 0.75f) return sign | 0x1; // 0.5
else if (x < 1.25f) return sign | 0x2; // 1.0
else if (x < 1.75f) return sign | 0x3; // 1.5
else if (x < 2.5f) return sign | 0x4; // 2.0
else if (x < 3.5f) return sign | 0x5; // 3.0
else if (x < 5.0f) return sign | 0x6; // 4.0
else return sign | 0x7; // 6.0
}
// FP6 E2M3: [S(1) | E(2) | M(3)], max normal = 7.5
GGML_MXFP_FUNC float ggml_mxfp_fp6_e2m3_to_float(uint8_t v) {
const float sign = (v & 0x20) ? -1.0f : 1.0f;
const int exp = (v >> 3) & 0x3;
const int mant = v & 0x7;
if (exp == 0) return sign * (float)mant * 0.125f;
return sign * (1.0f + mant * 0.125f) * (float)(1 << (exp - 1));
}
GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp6_e2m3(float x) {
uint8_t sign = 0;
if (x < 0) { sign = 0x20; x = -x; }
if (x == 0) return sign;
if (x >= 7.5f) return sign | 0x1F; // max finite
uint32_t bits = GGML_MXFP_F32_AS_U32(x);
int f32_exp = (int)((bits >> 23) & 0xFF) - 127;
if (f32_exp < 0) {
// Subnormal in E2M3: mant * 2^(-3)
float scaled = x * 8.0f;
int mant = (int)(scaled + 0.5f);
if (mant > 7) return sign | 0x08; // smallest normal
return sign | (uint8_t)mant;
}
if (f32_exp > 2) f32_exp = 2;
float mantf = (x / (float)(1 << f32_exp)) - 1.0f;
int mant = (int)(mantf * 8.0f + 0.5f);
if (mant > 7) { mant = 0; f32_exp++; }
if (f32_exp > 2) return sign | 0x1F;
return sign | (uint8_t)(((f32_exp + 1) << 3) | mant);
}
// FP6 E3M2: [S(1) | E(3) | M(2)], max normal = 28.0, no NaN/Inf
GGML_MXFP_FUNC float ggml_mxfp_fp6_e3m2_to_float(uint8_t v) {
const float sign = (v & 0x20) ? -1.0f : 1.0f;
const int exp = (v >> 2) & 0x7;
const int mant = v & 0x3;
if (exp == 0) return sign * (float)mant * 0.0625f; // 2^(-4)
// MX E3M2 has no NaN/Inf — exp=7 is a valid normal value (max finite = 28.0).
return sign * GGML_MXFP_LDEXPF(1.0f + mant * 0.25f, exp - 3);
}
GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp6_e3m2(float x) {
uint8_t sign = 0;
if (x < 0) { sign = 0x20; x = -x; }
if (x == 0) return sign;
if (x >= 28.0f) return sign | 0x1F; // max finite
uint32_t bits = GGML_MXFP_F32_AS_U32(x);
int f32_exp = (int)((bits >> 23) & 0xFF) - 127;
int biased_exp = f32_exp + 3;
if (biased_exp <= 0) {
// Subnormal in E3M2: mant * 2^(-4)
float scaled = x * 16.0f;
int mant = (int)(scaled + 0.5f);
if (mant > 3) return sign | 0x04; // smallest normal
return sign | (uint8_t)mant;
}
if (biased_exp > 7) return sign | 0x1F;
float pow2 = (f32_exp >= 0) ? (float)(1 << f32_exp) : 1.0f / (float)(1 << (-f32_exp));
float mantf = (x / pow2) - 1.0f;
int mant = (int)(mantf * 4.0f + 0.5f);
if (mant > 3) { mant = 0; biased_exp++; }
if (biased_exp > 7) return sign | 0x1F;
return sign | (uint8_t)((biased_exp << 2) | mant);
}
// FP8 E4M3: [S(1) | E(4) | M(3)], bias=7, max finite=448
GGML_MXFP_FUNC float ggml_mxfp_fp8_e4m3_to_float(uint8_t v) {
uint32_t sign = ((uint32_t)(v & 0x80)) << 24;
uint32_t exp = (v >> 3) & 0xF;
uint32_t mant = v & 0x7;
if (exp == 0) {
if (mant == 0) return GGML_MXFP_U32_AS_F32(sign);
// Subnormal: mant * 2^(1-7) * 2^(-3) = mant * 2^(-9)
float val = (float)mant * (1.0f / 512.0f);
uint32_t vb = GGML_MXFP_F32_AS_U32(val);
vb = (vb & 0x7FFFFFFFu) | sign;
return GGML_MXFP_U32_AS_F32(vb);
}
if (exp == 15 && mant == 7) {
return GGML_MXFP_U32_AS_F32(sign | 0x7FC00000u);
}
// Normal: (-1)^S * 2^(E-7) * (1 + M/8) → F32 exp = E-7+127 = E+120
return GGML_MXFP_U32_AS_F32(sign | ((exp + 120) << 23) | (mant << 20));
}
GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp8_e4m3(float x) {
uint32_t bits = GGML_MXFP_F32_AS_U32(x);
uint8_t sign = (bits >> 24) & 0x80;
bits &= 0x7FFFFFFFu;
if (bits == 0) return sign;
uint32_t f32_exp = (bits >> 23) & 0xFF;
uint32_t f32_mant = bits & 0x7FFFFF;
int e4m3_exp = (int)f32_exp - 120;
if (e4m3_exp <= 0) {
// Subnormal in E4M3
int shift = 1 - e4m3_exp;
uint32_t full_mant = (1u << 23) | f32_mant;
int total_shift = 20 + shift;
if (total_shift >= 32) return sign;
uint32_t mant3 = full_mant >> total_shift;
if (total_shift > 0 && total_shift < 32) {
uint32_t round_bit = (full_mant >> (total_shift - 1)) & 1;
uint32_t sticky = (total_shift > 1) ? (full_mant & ((1u << (total_shift - 1)) - 1)) : 0;
if (round_bit && (sticky || (mant3 & 1))) mant3++;
}
if (mant3 > 7) return sign | 0x08;
return sign | (uint8_t)mant3;
}
uint32_t round_bit = (f32_mant >> 19) & 1;
uint32_t sticky = f32_mant & ((1u << 19) - 1);
uint32_t mant3 = f32_mant >> 20;
if (round_bit && (sticky || (mant3 & 1))) {
mant3++;
if (mant3 > 7) { mant3 = 0; e4m3_exp++; }
}
if (e4m3_exp > 15 || (e4m3_exp == 15 && mant3 >= 7)) return sign | 0x7E; // max finite
return sign | (uint8_t)((e4m3_exp << 3) | mant3);
}
// FP8 E5M2: [S(1) | E(5) | M(2)], bias=15, max finite=57344
GGML_MXFP_FUNC float ggml_mxfp_fp8_e5m2_to_float(uint8_t v) {
uint32_t sign = ((uint32_t)(v & 0x80)) << 24;
uint32_t exp = (v >> 2) & 0x1F;
uint32_t mant = v & 0x3;
if (exp == 0) {
if (mant == 0) return GGML_MXFP_U32_AS_F32(sign);
// Subnormal: mant * 2^(1-15) * 2^(-2) = mant/4 * 2^(-14)
float val = (float)mant * 0.25f * (1.0f / 16384.0f);
uint32_t vb = GGML_MXFP_F32_AS_U32(val);
vb = (vb & 0x7FFFFFFFu) | sign;
return GGML_MXFP_U32_AS_F32(vb);
}
if (exp == 31) {
return GGML_MXFP_U32_AS_F32(sign | 0x7F800000u | (mant ? 0x400000u : 0));
}
// Normal: F32 exp = E-15+127 = E+112
return GGML_MXFP_U32_AS_F32(sign | ((exp + 112) << 23) | (mant << 21));
}
GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp8_e5m2(float x) {
uint32_t bits = GGML_MXFP_F32_AS_U32(x);
uint8_t sign = (bits >> 24) & 0x80;
bits &= 0x7FFFFFFFu;
if (bits == 0) return sign;
uint32_t f32_exp = (bits >> 23) & 0xFF;
uint32_t f32_mant = bits & 0x7FFFFF;
int e5m2_exp = (int)f32_exp - 112;
if (e5m2_exp <= 0) {
int shift = 1 - e5m2_exp;
uint32_t full_mant = (1u << 23) | f32_mant;
int total_shift = 21 + shift;
if (total_shift >= 32) return sign;
uint32_t mant2 = full_mant >> total_shift;
if (total_shift > 0 && total_shift < 32) {
uint32_t round_bit = (full_mant >> (total_shift - 1)) & 1;
uint32_t sticky = (total_shift > 1) ? (full_mant & ((1u << (total_shift - 1)) - 1)) : 0;
if (round_bit && (sticky || (mant2 & 1))) mant2++;
}
if (mant2 > 3) return sign | 0x04;
return sign | (uint8_t)mant2;
}
uint32_t round_bit = (f32_mant >> 20) & 1;
uint32_t sticky = f32_mant & ((1u << 20) - 1);
uint32_t mant2 = f32_mant >> 21;
if (round_bit && (sticky || (mant2 & 1))) {
mant2++;
if (mant2 > 3) { mant2 = 0; e5m2_exp++; }
}
if (e5m2_exp >= 31) return sign | 0x7B; // max finite
return sign | (uint8_t)((e5m2_exp << 2) | mant2);
}
// FP6 packing/unpacking
// Pack 4 six-bit values into 3 bytes
GGML_MXFP_FUNC void ggml_mxfp_pack_fp6x4(const uint8_t v[4], uint8_t out[3]) {
uint32_t packed = (v[0] & 0x3F) | ((v[1] & 0x3F) << 6) |
((v[2] & 0x3F) << 12) | ((v[3] & 0x3F) << 18);
out[0] = (uint8_t)(packed);
out[1] = (uint8_t)(packed >> 8);
out[2] = (uint8_t)(packed >> 16);
}
// Unpack 3 bytes into 4 six-bit values
GGML_MXFP_FUNC void ggml_mxfp_unpack_fp6x4(const uint8_t in[3], uint8_t v[4]) {
uint32_t packed = (uint32_t)in[0] | ((uint32_t)in[1] << 8) | ((uint32_t)in[2] << 16);
v[0] = packed & 0x3F;
v[1] = (packed >> 6) & 0x3F;
v[2] = (packed >> 12) & 0x3F;
v[3] = (packed >> 18) & 0x3F;
}
// E8M0 shared exponent → float conversion.
// E8M0 encoding: value = 2^(x - 127) for x > 0, 2^(-127) for x == 0.
// E8M0 = 255 is NaN per MX spec, but we clamp to 254 (max finite) to match
// the encode path which also clamps to 254, preventing Inf * 0 = NaN in dequant.
GGML_MXFP_FUNC float ggml_mxfp_e8m0_to_fp32(uint8_t x) {
if (x == 255) { x = 254; }
uint32_t bits = (x == 0) ? 0x00400000u : ((uint32_t)x << 23);
return GGML_MXFP_U32_AS_F32(bits);
}
// E8M0 → float/2. Used with MXFP4 since E2M1 values are doubled in kvalues_mxfp4.
GGML_MXFP_FUNC float ggml_mxfp_e8m0_to_fp32_half(uint8_t x) {
if (x == 255) { x = 254; }
uint32_t bits = (x < 2) ? (0x00200000u << x) : ((uint32_t)(x - 1) << 23);
return GGML_MXFP_U32_AS_F32(bits);
}
// E8M0 base exponent estimate: round(log2(amax)) - emax_offset + 127.
// Uses integer bit extraction — no log2f() SFU dependency.
// Caller must ensure amax > 0 and finite. Returns unclamped e_base.
GGML_MXFP_FUNC int ggml_mxfp_e8m0_base_estimate(float amax, int emax_offset) {
uint32_t amax_bits = GGML_MXFP_F32_AS_U32(amax);
const int floor_log2 = (int)((amax_bits >> 23) & 0xFF) - 127;
// Round: add 1 if mantissa >= sqrt(2)-1 (0x3504F3 in 23-bit IEEE mantissa).
const int round_log2 = floor_log2 + ((amax_bits & 0x7FFFFF) >= 0x3504F3 ? 1 : 0);
return round_log2 - emax_offset + 127;
}
// Block-32 Walsh-Hadamard Transform, normalized by 1/sqrt(32).
GGML_MXFP_FUNC void ggml_mxfp_hadamard_32_inplace(GGML_MXFP_THREAD float * vals) {
GGML_MXFP_UNROLL
for (int stride = 1; stride < 32; stride *= 2) {
GGML_MXFP_UNROLL
for (int i = 0; i < 32; i += 2 * stride) {
GGML_MXFP_UNROLL
for (int j = 0; j < stride; ++j) {
const float a = vals[i + j];
const float b = vals[i + j + stride];
vals[i + j] = a + b;
vals[i + j + stride] = a - b;
}
}
}
GGML_MXFP_UNROLL
for (int i = 0; i < 32; ++i) {
vals[i] *= MXFP_HADAMARD_32_NORM;
}
}
#endif // GGML_MXFP_FUNC
#define NGRID_IQ1S 2048
#define IQ1S_DELTA 0.125f
#define IQ1M_DELTA 0.125f

View File

@ -2,7 +2,6 @@
#pragma once
// Rename `_generic` functions if no native implementation is available.
// This effectively selects the generic implementation.
#if defined(GGML_CPU_GENERIC)
// quants.c
@ -15,7 +14,12 @@
#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0
#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
@ -70,6 +74,9 @@
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
// quants.c
#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0
#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
@ -81,6 +88,8 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
// quants.c
#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0
#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
@ -112,6 +121,9 @@
// quants.c
#define quantize_row_q8_K_generic quantize_row_q8_K
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
@ -159,7 +171,12 @@
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0
#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@ -200,6 +217,9 @@
#elif defined(__riscv)
// quants.c
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
// repack.cpp
#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
@ -240,6 +260,9 @@
// quants.c
#define quantize_row_q8_K_generic quantize_row_q8_K
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
@ -290,6 +313,9 @@
#elif defined(__wasm__)
// quants.c
#define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K
@ -302,6 +328,8 @@
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0
#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4

View File

@ -4134,3 +4134,223 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
#endif
}
// MXFP FP8/FP6 NEON helpers
// Separate FP8/FP6 functions because NEON vshlq_n_u32 requires compile-time constants.
#if defined(__ARM_NEON)
#define mxfp_neon_traits_t mxfp_dequant_traits_t
// Dequantize 4 FP8 values to floats.
static inline float32x4_t mxfp8_dequant_neon(
const uint32x4_t v_raw,
const uint32x4_t v_exp_mask, const uint32x4_t v_mant_mask,
const uint32x4_t v_ieee_off, const float32x4_t v_sub_sc,
const int32x4_t v_neg_exp_shift, const int32x4_t v_mant_shift) {
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x80));
const uint32x4_t exp = vandq_u32(vshlq_u32(v_raw, v_neg_exp_shift), v_exp_mask);
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask);
const uint32x4_t ieee = vorrq_u32(
vorrq_u32(vshlq_n_u32(sign, 24),
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)),
vshlq_u32(mant, v_mant_shift));
const float32x4_t normal = vreinterpretq_f32_u32(ieee);
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc);
const float32x4_t sub_val = vreinterpretq_f32_u32(
vorrq_u32(vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 24)));
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0));
return vbslq_f32(is_sub, sub_val, normal);
}
// Dequantize 4 FP6 values to floats.
static inline float32x4_t mxfp6_dequant_neon(
const uint32x4_t v_raw,
const uint32x4_t v_exp_mask, const uint32x4_t v_mant_mask,
const uint32x4_t v_ieee_off, const float32x4_t v_sub_sc,
const int32x4_t v_neg_exp_shift, const int32x4_t v_mant_shift) {
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x20));
const uint32x4_t exp = vandq_u32(vshlq_u32(v_raw, v_neg_exp_shift), v_exp_mask);
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask);
const uint32x4_t ieee = vorrq_u32(
vorrq_u32(vshlq_n_u32(sign, 26),
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)),
vshlq_u32(mant, v_mant_shift));
const float32x4_t normal = vreinterpretq_f32_u32(ieee);
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc);
const float32x4_t sub_val = vreinterpretq_f32_u32(
vorrq_u32(vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 26)));
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0));
return vbslq_f32(is_sub, sub_val, normal);
}
// Unpack 4 tightly-packed 6-bit values from 3 bytes, widen to uint32x4_t.
static inline uint32x4_t unpack_fp6x4_neon(const uint8_t * p) {
uint8_t u[4];
ggml_mxfp_unpack_fp6x4(p, u);
const uint8x8_t raw8 = vcreate_u8(
(uint64_t)u[0] | ((uint64_t)u[1] << 8) |
((uint64_t)u[2] << 16) | ((uint64_t)u[3] << 24));
return vmovl_u16(vget_low_u16(vmovl_u8(raw8)));
}
// Widen 8 raw bytes to two uint32x4_t halves.
static inline void widen_u8x8_to_u32x4x2(const uint8_t * src,
uint32x4_t * lo, uint32x4_t * hi) {
const uint8x8_t raw8 = vld1_u8(src);
const uint16x8_t raw16 = vmovl_u8(raw8);
*lo = vmovl_u16(vget_low_u16(raw16));
*hi = vmovl_u16(vget_high_u16(raw16));
}
// Widen 8 Q8_0 int8 values to two float32x4_t halves.
static inline void widen_s8x8_to_f32x4x2(const int8_t * src,
float32x4_t * lo, float32x4_t * hi) {
const int8x8_t q8 = vld1_s8(src);
const int16x8_t q16 = vmovl_s8(q8);
*lo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(q16)));
*hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16)));
}
// MXFP SoA dequant (flash attention)
static void dequantize_row_mxfp8_soa_neon(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
const mxfp_neon_traits_t * t) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
const char * qs_base = (const char *)src;
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK);
const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale);
const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift);
const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift);
for (int ib = 0; ib < nb; ++ib) {
const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP8_SOA_QS_PER_BLOCK));
for (int j = 0; j < 32; j += 8) {
uint32x4_t v_lo, v_hi;
widen_u8x8_to_u32x4x2(qs + j, &v_lo, &v_hi);
const float32x4_t val_lo = mxfp8_dequant_neon(v_lo,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
const float32x4_t val_hi = mxfp8_dequant_neon(v_hi,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
vst1q_f32(y + ib * QK_MXFP8 + j, vmulq_f32(val_lo, v_scale));
vst1q_f32(y + ib * QK_MXFP8 + j + 4, vmulq_f32(val_hi, v_scale));
}
}
}
static void dequantize_row_mxfp6_soa_neon(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
const mxfp_neon_traits_t * t) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
const char * qs_base = (const char *)src;
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK);
const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale);
const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift);
const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift);
for (int ib = 0; ib < nb; ++ib) {
const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP6_SOA_QS_PER_BLOCK));
for (int j = 0; j < 32; j += 4) {
const uint32x4_t v_raw = unpack_fp6x4_neon(qs + (j * 3 / 4));
const float32x4_t val = mxfp6_dequant_neon(v_raw,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
vst1q_f32(y + ib * QK_MXFP6 + j, vmulq_f32(val, v_scale));
}
}
}
// MXFP4 SoA dequant — LUT-based, no IEEE reconstruction needed.
static void dequantize_row_mxfp4_soa_neon(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) {
assert(k % QK_MXFP4 == 0);
const int nb = k / QK_MXFP4;
const char * qs_base = (const char *)src;
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK);
const int8x16_t values = vld1q_s8(kvalues_mxfp4);
const uint8x16_t m4b = vdupq_n_u8(0x0f);
for (int i = 0; i < nb; i++) {
const float d = GGML_E8M0_TO_FP32_HALF((uint8_t)e8m0_base[i]);
const float32x4_t v_scale = vdupq_n_f32(d);
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP4_SOA_QS_PER_BLOCK));
const uint8x16_t q4bits = vld1q_u8(qs);
const int8x16_t lo = ggml_vqtbl1q_s8(values, vandq_u8(q4bits, m4b));
const int8x16_t hi = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits, 4));
float * out_lo = y + i * QK_MXFP4;
float * out_hi = y + i * QK_MXFP4 + QK_MXFP4/2;
{
const int16x8_t lo16_0 = vmovl_s8(vget_low_s8(lo));
const int16x8_t lo16_1 = vmovl_s8(vget_high_s8(lo));
vst1q_f32(out_lo + 0, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(lo16_0))), v_scale));
vst1q_f32(out_lo + 4, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(lo16_0))), v_scale));
vst1q_f32(out_lo + 8, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(lo16_1))), v_scale));
vst1q_f32(out_lo + 12, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(lo16_1))), v_scale));
}
{
const int16x8_t hi16_0 = vmovl_s8(vget_low_s8(hi));
const int16x8_t hi16_1 = vmovl_s8(vget_high_s8(hi));
vst1q_f32(out_hi + 0, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(hi16_0))), v_scale));
vst1q_f32(out_hi + 4, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(hi16_0))), v_scale));
vst1q_f32(out_hi + 8, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(hi16_1))), v_scale));
vst1q_f32(out_hi + 12, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(hi16_1))), v_scale));
}
}
}
#endif // __ARM_NEON
// Public dispatch functions
void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__ARM_NEON)
dequantize_row_mxfp4_soa_neon(x, y, k);
#else
dequantize_row_mxfp4_soa_cpu_generic(x, y, k);
#endif
}
void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__ARM_NEON)
dequantize_row_mxfp8_soa_neon(x, y, k, &MXFP_TRAITS_E4M3);
#else
dequantize_row_mxfp8_soa_cpu_generic(x, y, k);
#endif
}
void dequantize_row_mxfp6_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__ARM_NEON)
dequantize_row_mxfp6_soa_neon(x, y, k, &MXFP_TRAITS_E2M3);
#else
dequantize_row_mxfp6_soa_cpu_generic(x, y, k);
#endif
}

View File

@ -3818,3 +3818,169 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
// MXFP FP8/FP6 AVX2 helpers
#if defined(__AVX2__)
#define mxfp_avx2_traits_t mxfp_dequant_traits_t
// Dequantize 8 FP8/FP6 values to floats.
static inline __m256 mxfp_dequant_avx2(
const __m256i v_raw,
const __m256i v_exp_mask, const __m256i v_mant_mask,
const __m256i v_ieee_off, const __m256 v_sub_sc,
const __m256i v_sign_mask, const __m256i v_zero,
int exp_shift, int sign_shift, int mant_shift) {
const __m256i sign = _mm256_and_si256(v_raw, v_sign_mask);
const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask);
const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask);
const __m256i ieee = _mm256_or_si256(
_mm256_or_si256(_mm256_slli_epi32(sign, sign_shift),
_mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)),
_mm256_slli_epi32(mant, mant_shift));
const __m256 normal = _mm256_castsi256_ps(ieee);
const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc);
const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256(
_mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, sign_shift)));
const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero));
return _mm256_blendv_ps(normal, sub_val, is_sub);
}
// Unpack 8 FP6 values (two groups of 4) from packed qs data at offset j.
static inline __m256i unpack_fp6x8_avx2(const uint8_t * qs, int j) {
uint8_t unpacked[8];
ggml_mxfp_unpack_fp6x4(qs + (j * 3 / 4), unpacked);
ggml_mxfp_unpack_fp6x4(qs + ((j + 4) * 3 / 4), unpacked + 4);
return _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)unpacked));
}
// MXFP SoA dequant (flash attention)
static void dequantize_row_mxfp8_soa_avx2(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
const mxfp_avx2_traits_t * t) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
const char * qs_base = (const char *)src;
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK);
const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale);
const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask);
const __m256i v_zero = _mm256_setzero_si256();
for (int ib = 0; ib < nb; ++ib) {
const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP8_SOA_QS_PER_BLOCK));
for (int j = 0; j < 32; j += 8) {
const __m256i v_raw = _mm256_cvtepu8_epi32(
_mm_loadl_epi64((const __m128i *)(qs + j)));
const __m256 val = mxfp_dequant_avx2(v_raw,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc,
v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift);
_mm256_storeu_ps(y + ib * QK_MXFP8 + j, _mm256_mul_ps(val, v_scale));
}
}
}
static void dequantize_row_mxfp6_soa_avx2(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
const mxfp_avx2_traits_t * t) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
const char * qs_base = (const char *)src;
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK);
const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale);
const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask);
const __m256i v_zero = _mm256_setzero_si256();
for (int ib = 0; ib < nb; ++ib) {
const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP6_SOA_QS_PER_BLOCK));
for (int j = 0; j < 32; j += 8) {
const __m256i v_raw = unpack_fp6x8_avx2(qs, j);
const __m256 val = mxfp_dequant_avx2(v_raw,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc,
v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift);
_mm256_storeu_ps(y + ib * QK_MXFP6 + j, _mm256_mul_ps(val, v_scale));
}
}
}
// MXFP4 SoA dequant — LUT-based, no IEEE reconstruction needed.
static void dequantize_row_mxfp4_soa_avx2(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) {
assert(k % QK_MXFP4 == 0);
const int nb = k / QK_MXFP4;
const char * qs_base = (const char *)src;
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK);
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
for (int i = 0; i < nb; i++) {
const float d = GGML_E8M0_TO_FP32_HALF((uint8_t)e8m0_base[i]);
const __m256 v_scale = _mm256_set1_ps(d);
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP4_SOA_QS_PER_BLOCK));
const __m128i q4bits = _mm_loadu_si128((const __m128i *)qs);
const __m128i lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits, m4b));
const __m128i hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4b));
const __m256i lo32_0 = _mm256_cvtepi8_epi32(lo);
const __m256i lo32_1 = _mm256_cvtepi8_epi32(_mm_srli_si128(lo, 8));
_mm256_storeu_ps(y + i * QK_MXFP4 + 0, _mm256_mul_ps(_mm256_cvtepi32_ps(lo32_0), v_scale));
_mm256_storeu_ps(y + i * QK_MXFP4 + 8, _mm256_mul_ps(_mm256_cvtepi32_ps(lo32_1), v_scale));
const __m256i hi32_0 = _mm256_cvtepi8_epi32(hi);
const __m256i hi32_1 = _mm256_cvtepi8_epi32(_mm_srli_si128(hi, 8));
_mm256_storeu_ps(y + i * QK_MXFP4 + 16, _mm256_mul_ps(_mm256_cvtepi32_ps(hi32_0), v_scale));
_mm256_storeu_ps(y + i * QK_MXFP4 + 24, _mm256_mul_ps(_mm256_cvtepi32_ps(hi32_1), v_scale));
}
}
#endif // __AVX2__
// Public dispatch functions
void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__AVX2__)
dequantize_row_mxfp4_soa_avx2(x, y, k);
#else
dequantize_row_mxfp4_soa_cpu_generic(x, y, k);
#endif
}
void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__AVX2__)
dequantize_row_mxfp8_soa_avx2(x, y, k, &MXFP_TRAITS_E4M3);
#else
dequantize_row_mxfp8_soa_cpu_generic(x, y, k);
#endif
}
void dequantize_row_mxfp6_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__AVX2__)
dequantize_row_mxfp6_soa_avx2(x, y, k, &MXFP_TRAITS_E2M3);
#else
dequantize_row_mxfp6_soa_cpu_generic(x, y, k);
#endif
}

View File

@ -7,6 +7,7 @@
#include "ggml-cpu-impl.h"
#include "ggml-impl.h"
#include "quants.h"
#include "ggml-quants.h"
#include "ggml-threading.h"
#include "unary-ops.h"
#include "binary-ops.h"
@ -266,6 +267,8 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
},
[GGML_TYPE_MXFP4] = {
.from_float = quantize_row_mxfp4,
.from_float_soa = quantize_row_mxfp4_soa,
.to_float_soa = dequantize_row_mxfp4_soa_cpu,
.vec_dot = ggml_vec_dot_mxfp4_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
@ -276,6 +279,22 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
[GGML_TYPE_MXFP8] = {
.from_float = (ggml_from_float_t)quantize_row_mxfp8_ref,
.from_float_soa = quantize_row_mxfp8_soa,
.to_float_soa = dequantize_row_mxfp8_soa_cpu,
.vec_dot = ggml_vec_dot_mxfp8_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
[GGML_TYPE_MXFP6] = {
.from_float = (ggml_from_float_t)quantize_row_mxfp6_ref,
.from_float_soa = quantize_row_mxfp6_soa,
.to_float_soa = dequantize_row_mxfp6_soa_cpu,
.vec_dot = ggml_vec_dot_mxfp6_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
[GGML_TYPE_Q2_K] = {
.from_float = quantize_row_q2_K,
.vec_dot = ggml_vec_dot_q2_K_q8_K,

View File

@ -2,6 +2,8 @@
#include "ggml-cpu.h"
#include "ggml-impl.h"
#include "ggml-quants.h"
#include "quants.h"
#include "binary-ops.h"
#include "simd-gemm.h"
#include "ggml.h"
@ -11,6 +13,7 @@
#include <algorithm>
#include <cfloat>
#include <cmath>
#include <cstring>
// ggml_compute_forward_dup
@ -671,6 +674,8 @@ void ggml_compute_forward_add(
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
case GGML_TYPE_MXFP8:
case GGML_TYPE_MXFP6:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@ -1121,6 +1126,8 @@ void ggml_compute_forward_add1(
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
case GGML_TYPE_MXFP8:
case GGML_TYPE_MXFP6:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@ -1250,6 +1257,8 @@ void ggml_compute_forward_acc(
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
case GGML_TYPE_MXFP8:
case GGML_TYPE_MXFP6:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@ -4338,6 +4347,8 @@ void ggml_compute_forward_out_prod(
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
case GGML_TYPE_MXFP8:
case GGML_TYPE_MXFP6:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@ -4614,6 +4625,8 @@ void ggml_compute_forward_set(
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
case GGML_TYPE_MXFP8:
case GGML_TYPE_MXFP6:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@ -4837,6 +4850,8 @@ void ggml_compute_forward_get_rows(
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
case GGML_TYPE_MXFP8:
case GGML_TYPE_MXFP6:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@ -4894,6 +4909,191 @@ void ggml_compute_forward_get_rows(
//}
}
// SIMD-optimized Hadamard; scalar fallback below
#if defined(__AVX2__) || defined(__AVX__)
static void hadamard_32_inplace(float vals[32]) {
// 32 floats = 4 × __m256
__m256 v0 = _mm256_loadu_ps(vals + 0);
__m256 v1 = _mm256_loadu_ps(vals + 8);
__m256 v2 = _mm256_loadu_ps(vals + 16);
__m256 v3 = _mm256_loadu_ps(vals + 24);
// Stride 1: butterfly on adjacent pairs within each 256-bit register
{
// Interleave even/odd elements, add/sub
__m256 a, b, s, d;
a = _mm256_shuffle_ps(v0, v0, _MM_SHUFFLE(2, 2, 0, 0));
b = _mm256_shuffle_ps(v0, v0, _MM_SHUFFLE(3, 3, 1, 1));
s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b);
v0 = _mm256_shuffle_ps(s, d, _MM_SHUFFLE(2, 0, 2, 0));
v0 = _mm256_shuffle_ps(v0, v0, _MM_SHUFFLE(3, 1, 2, 0));
a = _mm256_shuffle_ps(v1, v1, _MM_SHUFFLE(2, 2, 0, 0));
b = _mm256_shuffle_ps(v1, v1, _MM_SHUFFLE(3, 3, 1, 1));
s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b);
v1 = _mm256_shuffle_ps(s, d, _MM_SHUFFLE(2, 0, 2, 0));
v1 = _mm256_shuffle_ps(v1, v1, _MM_SHUFFLE(3, 1, 2, 0));
a = _mm256_shuffle_ps(v2, v2, _MM_SHUFFLE(2, 2, 0, 0));
b = _mm256_shuffle_ps(v2, v2, _MM_SHUFFLE(3, 3, 1, 1));
s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b);
v2 = _mm256_shuffle_ps(s, d, _MM_SHUFFLE(2, 0, 2, 0));
v2 = _mm256_shuffle_ps(v2, v2, _MM_SHUFFLE(3, 1, 2, 0));
a = _mm256_shuffle_ps(v3, v3, _MM_SHUFFLE(2, 2, 0, 0));
b = _mm256_shuffle_ps(v3, v3, _MM_SHUFFLE(3, 3, 1, 1));
s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b);
v3 = _mm256_shuffle_ps(s, d, _MM_SHUFFLE(2, 0, 2, 0));
v3 = _mm256_shuffle_ps(v3, v3, _MM_SHUFFLE(3, 1, 2, 0));
}
// Stride 2: butterfly on pairs separated by 2 within 128-bit lanes
{
__m256 a, b, s, d;
a = _mm256_permute_ps(v0, _MM_SHUFFLE(1, 0, 1, 0));
b = _mm256_permute_ps(v0, _MM_SHUFFLE(3, 2, 3, 2));
s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b);
v0 = _mm256_blend_ps(s, d, 0xCC); // 0b11001100
a = _mm256_permute_ps(v1, _MM_SHUFFLE(1, 0, 1, 0));
b = _mm256_permute_ps(v1, _MM_SHUFFLE(3, 2, 3, 2));
s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b);
v1 = _mm256_blend_ps(s, d, 0xCC);
a = _mm256_permute_ps(v2, _MM_SHUFFLE(1, 0, 1, 0));
b = _mm256_permute_ps(v2, _MM_SHUFFLE(3, 2, 3, 2));
s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b);
v2 = _mm256_blend_ps(s, d, 0xCC);
a = _mm256_permute_ps(v3, _MM_SHUFFLE(1, 0, 1, 0));
b = _mm256_permute_ps(v3, _MM_SHUFFLE(3, 2, 3, 2));
s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b);
v3 = _mm256_blend_ps(s, d, 0xCC);
}
// Stride 4: butterfly between 128-bit lanes within each 256-bit register
{
__m128 lo, hi;
lo = _mm256_castps256_ps128(v0); hi = _mm256_extractf128_ps(v0, 1);
v0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_add_ps(lo, hi)), _mm_sub_ps(lo, hi), 1);
lo = _mm256_castps256_ps128(v1); hi = _mm256_extractf128_ps(v1, 1);
v1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_add_ps(lo, hi)), _mm_sub_ps(lo, hi), 1);
lo = _mm256_castps256_ps128(v2); hi = _mm256_extractf128_ps(v2, 1);
v2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_add_ps(lo, hi)), _mm_sub_ps(lo, hi), 1);
lo = _mm256_castps256_ps128(v3); hi = _mm256_extractf128_ps(v3, 1);
v3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_add_ps(lo, hi)), _mm_sub_ps(lo, hi), 1);
}
// Stride 8: butterfly between registers
{
__m256 s, d;
s = _mm256_add_ps(v0, v1); d = _mm256_sub_ps(v0, v1); v0 = s; v1 = d;
s = _mm256_add_ps(v2, v3); d = _mm256_sub_ps(v2, v3); v2 = s; v3 = d;
}
// Stride 16: butterfly between register pairs
{
__m256 s, d;
s = _mm256_add_ps(v0, v2); d = _mm256_sub_ps(v0, v2); v0 = s; v2 = d;
s = _mm256_add_ps(v1, v3); d = _mm256_sub_ps(v1, v3); v1 = s; v3 = d;
}
// Normalize by 1/sqrt(32)
const __m256 norm = _mm256_set1_ps(MXFP_HADAMARD_32_NORM);
_mm256_storeu_ps(vals + 0, _mm256_mul_ps(v0, norm));
_mm256_storeu_ps(vals + 8, _mm256_mul_ps(v1, norm));
_mm256_storeu_ps(vals + 16, _mm256_mul_ps(v2, norm));
_mm256_storeu_ps(vals + 24, _mm256_mul_ps(v3, norm));
}
#elif defined(__ARM_NEON)
static void hadamard_32_inplace(float vals[32]) {
float32x4_t v0 = vld1q_f32(vals + 0);
float32x4_t v1 = vld1q_f32(vals + 4);
float32x4_t v2 = vld1q_f32(vals + 8);
float32x4_t v3 = vld1q_f32(vals + 12);
float32x4_t v4 = vld1q_f32(vals + 16);
float32x4_t v5 = vld1q_f32(vals + 20);
float32x4_t v6 = vld1q_f32(vals + 24);
float32x4_t v7 = vld1q_f32(vals + 28);
#define HADAMARD_S1(v) do { \
float32x2_t lo = vget_low_f32(v); \
float32x2_t hi = vget_high_f32(v); \
float32x2x2_t t = vtrn_f32(lo, hi); \
float32x2_t sum = vadd_f32(t.val[0], t.val[1]); \
float32x2_t dif = vsub_f32(t.val[0], t.val[1]); \
float32x2x2_t r = vtrn_f32(sum, dif); \
(v) = vcombine_f32(r.val[0], r.val[1]); \
} while (0)
HADAMARD_S1(v0); HADAMARD_S1(v1); HADAMARD_S1(v2); HADAMARD_S1(v3);
HADAMARD_S1(v4); HADAMARD_S1(v5); HADAMARD_S1(v6); HADAMARD_S1(v7);
#undef HADAMARD_S1
#define HADAMARD_S2(v) do { \
float32x2_t lo = vget_low_f32(v); \
float32x2_t hi = vget_high_f32(v); \
(v) = vcombine_f32(vadd_f32(lo, hi), vsub_f32(lo, hi)); \
} while (0)
HADAMARD_S2(v0); HADAMARD_S2(v1); HADAMARD_S2(v2); HADAMARD_S2(v3);
HADAMARD_S2(v4); HADAMARD_S2(v5); HADAMARD_S2(v6); HADAMARD_S2(v7);
#undef HADAMARD_S2
#define HADAMARD_S4(a, b) do { \
float32x4_t s = vaddq_f32(a, b); \
float32x4_t d = vsubq_f32(a, b); \
(a) = s; (b) = d; \
} while (0)
HADAMARD_S4(v0, v1); HADAMARD_S4(v2, v3);
HADAMARD_S4(v4, v5); HADAMARD_S4(v6, v7);
#undef HADAMARD_S4
{ float32x4_t s, d;
s = vaddq_f32(v0, v2); d = vsubq_f32(v0, v2); v0 = s; v2 = d;
s = vaddq_f32(v1, v3); d = vsubq_f32(v1, v3); v1 = s; v3 = d;
s = vaddq_f32(v4, v6); d = vsubq_f32(v4, v6); v4 = s; v6 = d;
s = vaddq_f32(v5, v7); d = vsubq_f32(v5, v7); v5 = s; v7 = d;
}
{ float32x4_t s, d;
s = vaddq_f32(v0, v4); d = vsubq_f32(v0, v4); v0 = s; v4 = d;
s = vaddq_f32(v1, v5); d = vsubq_f32(v1, v5); v1 = s; v5 = d;
s = vaddq_f32(v2, v6); d = vsubq_f32(v2, v6); v2 = s; v6 = d;
s = vaddq_f32(v3, v7); d = vsubq_f32(v3, v7); v3 = s; v7 = d;
}
const float32x4_t norm = vdupq_n_f32(MXFP_HADAMARD_32_NORM);
vst1q_f32(vals + 0, vmulq_f32(v0, norm));
vst1q_f32(vals + 4, vmulq_f32(v1, norm));
vst1q_f32(vals + 8, vmulq_f32(v2, norm));
vst1q_f32(vals + 12, vmulq_f32(v3, norm));
vst1q_f32(vals + 16, vmulq_f32(v4, norm));
vst1q_f32(vals + 20, vmulq_f32(v5, norm));
vst1q_f32(vals + 24, vmulq_f32(v6, norm));
vst1q_f32(vals + 28, vmulq_f32(v7, norm));
}
#else
static void hadamard_32_inplace(float vals[32]) {
ggml_hadamard_32_inplace(vals);
}
#endif
static void ggml_apply_hadamard_blocks(float * data, int64_t n) {
GGML_ASSERT(n % 32 == 0);
for (int64_t i = 0; i < n; i += 32) {
hadamard_32_inplace(data + i);
}
}
// Prefer SIMD-optimized CPU dequant, fall back to scalar reference.
static inline ggml_to_float_t ggml_get_to_float_fn(ggml_type type) {
ggml_to_float_t fn = ggml_get_type_traits_cpu(type)->to_float;
if (!fn) { fn = ggml_get_type_traits(type)->to_float; }
return fn;
}
template<typename idx_t>
static void ggml_compute_forward_set_rows_f32(
const ggml_compute_params * params,
@ -4924,7 +5124,22 @@ static void ggml_compute_forward_set_rows_f32(
const int64_t ir0 = dr*ith;
const int64_t ir1 = std::min(ir0 + dr, nr);
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
const int32_t apply_hadamard = ((const int32_t *)dst->op_params)[0];
const struct ggml_type_traits_cpu * dst_traits = ggml_get_type_traits_cpu(dst->type);
ggml_from_float_t mxfp_soa_quantize = dst_traits->from_float_soa;
ggml_from_float_t from_float = mxfp_soa_quantize ? nullptr : dst_traits->from_float;
// Fused Hadamard+quantize: one pass per block, 32-float stack buffer, no heap allocation.
ggml_from_float_t mxfp_soa_hadamard_quantize = nullptr;
if (apply_hadamard && mxfp_soa_quantize) {
switch (dst->type) {
case GGML_TYPE_MXFP4: mxfp_soa_hadamard_quantize = (ggml_from_float_t)quantize_row_mxfp4_soa_hadamard; break;
case GGML_TYPE_MXFP8: mxfp_soa_hadamard_quantize = (ggml_from_float_t)quantize_row_mxfp8_soa_hadamard; break;
case GGML_TYPE_MXFP6: mxfp_soa_hadamard_quantize = (ggml_from_float_t)quantize_row_mxfp6_soa_hadamard; break;
default: break;
}
}
for (int64_t i03 = 0; i03 < ne03; ++i03) {
for (int64_t i02 = 0; i02 < ne02; ++i02) {
@ -4937,9 +5152,16 @@ static void ggml_compute_forward_set_rows_f32(
GGML_ASSERT(i1 >= 0 && i1 < ne1);
from_float(
(const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc);
const float * src_row = (const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03);
char * dst_row = ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3);
if (mxfp_soa_hadamard_quantize) {
mxfp_soa_hadamard_quantize(src_row, dst_row, nc);
} else if (mxfp_soa_quantize) {
mxfp_soa_quantize(src_row, dst_row, nc);
} else {
from_float(src_row, dst_row, nc);
}
}
}
}
@ -5562,6 +5784,8 @@ void ggml_compute_forward_clamp(
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
case GGML_TYPE_MXFP8:
case GGML_TYPE_MXFP6:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@ -8127,6 +8351,115 @@ void ggml_compute_forward_top_k(
}
}
// Max head dimension for stack-allocated MXFP buffers.
static constexpr int64_t MXFP_FA_MAX_D = 1024;
// SoA buffer size for MXFP_FA_MAX_D with MXFP8 (worst case: 1024 + 32 e8m0 = 1056, rounded up).
static constexpr int MXFP_FA_SOA_BUF = 1088;
// SoA function pointer types for MXFP flash attention paths.
typedef void (*mxfp_soa_quantize_fn)(const float *, void *, int64_t);
typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t);
// Per-KV-type MXFP parameters (shared between K and V).
struct mxfp_kv_params {
mxfp_soa_dequantize_fn dequantize;
bool multihead;
int qs_per_block;
int head_qs_bytes;
int64_t head_e8m0_offset;
int blocks_per_head;
};
// MXFP dispatch parameters for flash attention.
struct mxfp_fa_params {
mxfp_soa_quantize_fn q_quantize; // SoA quantize for Q (used only when Hadamard is off AND non-MXFP K path)
// Fused Q round-trip: Hadamard + quantize + dequant in one pass, no SoA buffer.
void (*q_roundtrip)(const float *, float *, int64_t);
mxfp_kv_params k;
mxfp_kv_params v;
bool apply_hadamard;
};
// Compute the SoA row base pointer for a given KV position and head.
// In multihead mode, the SoA region spans all heads at one KV position,
// so the row base must NOT include the per-head offset (head_idx * nb2).
// mxfp_dequant_head handles per-head indexing within the SoA region.
// In per-head mode, each head has its own SoA region, so the base includes nb2.
static inline const char * mxfp_row_ptr(
const mxfp_kv_params & kv, const char * data,
int64_t kv_pos, size_t nb1, int head_idx, size_t nb2, int batch_idx, size_t nb3) {
if (kv.multihead) {
return data + kv_pos*nb1 + batch_idx*nb3;
}
return data + kv_pos*nb1 + head_idx*nb2 + batch_idx*nb3;
}
// Extract one head's SoA data from a multihead row and dequantize.
static inline void mxfp_dequant_head(
const mxfp_kv_params & kv, const char * row, int head_idx,
char * soa_buf, float * out, int64_t D) {
if (kv.multihead) {
const int qs_off = head_idx * kv.head_qs_bytes;
const int e8m0_off = (int)kv.head_e8m0_offset + head_idx * kv.blocks_per_head;
memcpy(soa_buf, row + qs_off, kv.head_qs_bytes);
memcpy(soa_buf + kv.head_qs_bytes, row + e8m0_off, kv.blocks_per_head);
kv.dequantize(soa_buf, out, D);
} else {
kv.dequantize(row, out, D);
}
}
// Initialize per-KV-type params from tensor metadata.
// Multihead detection: nb2 == row_size(D) means heads are contiguous within
// one KV-position stride, so SoA spans all heads. Otherwise SoA is per-head.
static mxfp_kv_params mxfp_kv_params_init(ggml_type type, int64_t D, size_t nb2, int64_t ne2) {
mxfp_kv_params kv = {};
kv.dequantize = ggml_get_type_traits_cpu(type)->to_float_soa;
kv.multihead = (nb2 == (size_t)ggml_row_size(type, D));
kv.qs_per_block = ggml_mxfp_qs_per_block(type);
kv.blocks_per_head = (int)(D / 32);
kv.head_qs_bytes = kv.blocks_per_head * kv.qs_per_block;
const int64_t total_blocks = kv.multihead ? ne2 * kv.blocks_per_head : kv.blocks_per_head;
kv.head_e8m0_offset = total_blocks * kv.qs_per_block;
return kv;
}
static mxfp_fa_params mxfp_fa_params_init(
const ggml_tensor * k, const ggml_tensor * v,
int64_t DK, int64_t DV,
size_t nbk2, size_t nbv2,
int64_t nek2, int64_t nev2) {
mxfp_fa_params p = {};
const bool is_mxfp_k = ggml_is_type_mxfp(k->type);
const bool is_mxfp_v = ggml_is_type_mxfp(v->type);
if (is_mxfp_k) {
p.q_quantize = ggml_get_type_traits_cpu(k->type)->from_float_soa;
p.k = mxfp_kv_params_init(k->type, DK, nbk2, nek2);
}
// Select fused Q round-trip (Hadamard + quantize error, no SoA buffer).
if (is_mxfp_k) {
const bool had = is_mxfp_k && (DK == DV) && ggml_mxfp_use_hadamard(k->type);
switch (k->type) {
case GGML_TYPE_MXFP4: p.q_roundtrip = had ? mxfp4_hadamard_roundtrip : mxfp4_roundtrip; break;
case GGML_TYPE_MXFP8: p.q_roundtrip = had ? mxfp8_hadamard_roundtrip : mxfp8_roundtrip; break;
case GGML_TYPE_MXFP6: p.q_roundtrip = had ? mxfp6_hadamard_roundtrip : mxfp6_roundtrip; break;
default: break;
}
}
if (is_mxfp_v) {
p.v = mxfp_kv_params_init(v->type, DV, nbv2, nev2);
}
// Hadamard rotation must match K rotation.
// Skipped for MLA (DK != DV, V is a view of K).
p.apply_hadamard = is_mxfp_k && (DK == DV) && ggml_mxfp_use_hadamard(k->type);
return p;
}
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
const ggml_compute_params * params,
ggml_tensor * dst,
@ -8201,21 +8534,53 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
const bool is_mxfp_k = ggml_is_type_mxfp(k->type);
const bool is_mxfp_v = ggml_is_type_mxfp(v->type);
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
const mxfp_fa_params mxfp = mxfp_fa_params_init(k, v, DK, DV, nbk2, nbv2, nek2, nev2);
ggml_from_float_t q_to_vec_dot = nullptr;
ggml_vec_dot_t kq_vec_dot = nullptr;
ggml_to_float_t v_to_float = nullptr;
if (!is_mxfp_k) {
ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
}
if (!is_mxfp_v) {
v_to_float = ggml_get_to_float_fn(v->type);
}
GGML_ASSERT((is_mxfp_k || q_to_vec_dot) && "fattn: unsupported K-type");
GGML_ASSERT((v->type == GGML_TYPE_F32 || is_mxfp_v || v_to_float) && "fattn: unsupported V-type");
int ith = params->ith;
if (is_mxfp_k) { GGML_ASSERT(DK <= MXFP_FA_MAX_D); }
if (is_mxfp_v) { GGML_ASSERT(DV <= MXFP_FA_MAX_D); }
float k_dequant_buf[MXFP_FA_MAX_D];
float v_dequant_buf[MXFP_FA_MAX_D];
char k_head_soa[MXFP_FA_SOA_BUF]; // max: DK=1024 MXFP8 -> 1056 bytes, rounded up
char v_head_soa[MXFP_FA_SOA_BUF];
float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32);
float * V32 = (VKQ32 + 1*DV);
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV);
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV);
const bool v_is_f16 = (v->type == GGML_TYPE_F16);
const bool use_softcap = (logit_softcap != 0.0f);
const int64_t neq2_x_neq1 = neq2 * neq1;
for (int ir = ir0; ir < ir1; ++ir) {
// q indices
const int iq3 = ir/(neq2*neq1);
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
const int iq3 = ir / neq2_x_neq1;
const int iq2 = (ir - iq3*neq2_x_neq1) / neq1;
const int iq1 = (ir - iq3*neq2_x_neq1 - iq2*neq1);
const uint32_t h = iq2; // head index
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
@ -8223,12 +8588,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
float S = 0.0f; // sum
float M = -INFINITY; // maximum KQ value
float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16
if (v->type == GGML_TYPE_F16) {
if (v_is_f16) {
memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
} else {
memset(VKQ32, 0, DV*sizeof(float));
@ -8236,16 +8596,35 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
// k indices
// k/v head indices — constant for this query row
const int ik3 = iq3 / rk3;
const int ik2 = iq2 / rk2;
// v indices
const int iv3 = iq3 / rv3;
const int iv2 = iq2 / rv2;
const size_t k_base_offset = ik2*nbk2 + ik3*nbk3;
const size_t v_base_offset = iv2*nbv2 + iv3*nbv3;
const char * k_base = (const char *) k->data + k_base_offset;
const char * v_base = (const char *) v->data + v_base_offset;
const char * k_data_base = (const char *) k->data;
const char * v_data_base = (const char *) v->data;
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
q_to_vec_dot(pq, Q_q, DK);
float Q_f32[MXFP_FA_MAX_D];
if (mxfp.q_roundtrip) {
// Q preprocessing: fused Hadamard + quantize round-trip, no SoA buffer.
mxfp.q_roundtrip(pq, Q_f32, DK);
} else {
if (mxfp.apply_hadamard) {
float q_tmp[MXFP_FA_MAX_D];
memcpy(q_tmp, pq, DK * sizeof(float));
ggml_apply_hadamard_blocks(q_tmp, DK);
q_to_vec_dot(q_tmp, Q_q, DK);
} else {
q_to_vec_dot(pq, Q_q, DK);
}
}
// online softmax / attention
// loop over n_kv and n_head_kv
@ -8259,12 +8638,18 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
float s; // KQ value
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
if (is_mxfp_k) {
const char * k_row = mxfp_row_ptr(mxfp.k, k_data_base,
ic, nbk1, ik2, nbk2, ik3, nbk3);
mxfp_dequant_head(mxfp.k, k_row, ik2, k_head_soa, k_dequant_buf, DK);
ggml_vec_dot_f32(DK, &s, 0, k_dequant_buf, 0, Q_f32, 0, 1);
} else {
kq_vec_dot(DK, &s, 0, k_base + ic*nbk1, 0, Q_q, 0, 1);
}
s = s*scale; // scale KQ value
if (logit_softcap != 0.0f) {
if (use_softcap) {
s = logit_softcap*tanhf(s);
}
@ -8275,15 +8660,11 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
float vs = 1.0f; // post-softmax KQ value, expf(s - M)
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
if (v->type == GGML_TYPE_F16) {
if (v_is_f16) {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
M = s;
ms = expf(Mold - M);
// V = V*expf(Mold - M)
ggml_vec_scale_f16(DV, VKQ16, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
@ -8291,14 +8672,12 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
}
// V += v*expf(s - M)
ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) (v_base + ic*nbv1), vs);
} else {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
M = s;
ms = expf(Mold - M);
// V = V*expf(Mold - M)
ggml_vec_scale_f32(DV, VKQ32, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
@ -8306,12 +8685,17 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
}
// V += v*expf(s - M)
if (v_to_float) {
v_to_float(v_data, V32, DV);
if (mxfp.v.dequantize) {
const char * v_row = mxfp_row_ptr(mxfp.v, v_data_base,
ic, nbv1, iv2, nbv2, iv3, nbv3);
mxfp_dequant_head(mxfp.v, v_row, iv2, v_head_soa, v_dequant_buf, DV);
ggml_vec_mad_f32(DV, VKQ32, v_dequant_buf, vs);
} else if (v_to_float) {
v_to_float(v_base + ic*nbv1, V32, DV);
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
} else {
// V is F32
ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
ggml_vec_mad_f32(DV, VKQ32, (const float *) (v_base + ic*nbv1), vs);
}
}
@ -8408,9 +8792,17 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(k->type == v->type);
const ggml_type kv_type = k->type;
const ggml_type k_type = k->type;
const ggml_type v_type = v->type;
const bool is_mxfp_k = ggml_is_type_mxfp(k_type);
const bool is_mxfp_v = ggml_is_type_mxfp(v_type);
const mxfp_fa_params mxfp = mxfp_fa_params_init(k, v, DK, DV, nbk2, nbv2, nek2, nev2);
// Non-MXFP dequant functions
ggml_to_float_t k_to_float = is_mxfp_k ? nullptr : ggml_get_to_float_fn(k_type);
ggml_to_float_t v_to_float = is_mxfp_v ? nullptr : ggml_get_to_float_fn(v_type);
// broadcast factors
const int64_t rk2 = neq2/nek2;
@ -8442,6 +8834,14 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
if (is_mxfp_k) { GGML_ASSERT(DK <= MXFP_FA_MAX_D); }
if (is_mxfp_v) { GGML_ASSERT(DV <= MXFP_FA_MAX_D); }
float k_dequant_buf[MXFP_FA_MAX_D];
char k_head_soa[MXFP_FA_SOA_BUF];
char v_head_soa[MXFP_FA_SOA_BUF];
int ir = ir0;
while (ir < ir1) {
// q indices for the start of this tile
@ -8499,6 +8899,11 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
for (int tq = 0; tq < tile_rows; tq++) {
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));
if (mxfp.q_roundtrip) {
// In-place: Q_f32 is already populated by memcpy above, roundtrip overwrites.
mxfp.q_roundtrip(Q_f32 + tq * DK, Q_f32 + tq * DK, DK);
}
}
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
memset(Q_f32 + tq * DK, 0, DK * sizeof(float));
@ -8537,16 +8942,29 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
// Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
for (int tk = 0; tk < kv_tile; tk++) {
const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
if (kv_type == GGML_TYPE_F16) {
if (k_type == GGML_TYPE_F16) {
const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
for (int64_t dk = 0; dk < DK; dk++) {
K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]);
}
} else {
} else if (k_type == GGML_TYPE_F32) {
const float * k_f32_src = (const float *)k_data;
for (int64_t dk = 0; dk < DK; dk++) {
K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
}
} else if (mxfp.k.dequantize) {
const char * k_row = mxfp_row_ptr(mxfp.k, (const char *)k->data,
ic + tk, nbk1, ik2, nbk2, ik3, nbk3);
mxfp_dequant_head(mxfp.k, k_row, ik2, k_head_soa, k_dequant_buf, DK);
for (int64_t dk = 0; dk < DK; dk++) {
K_f32[dk * KV_TILE_SZ + tk] = k_dequant_buf[dk];
}
} else {
float k_tmp[MXFP_FA_MAX_D];
k_to_float(k_data, k_tmp, DK);
for (int64_t dk = 0; dk < DK; dk++) {
K_f32[dk * KV_TILE_SZ + tk] = k_tmp[dk];
}
}
}
memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
@ -8602,10 +9020,16 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
// Pack V tile to contiguous F32, zero-padded
for (int tk = 0; tk < kv_tile; tk++) {
const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
if (kv_type == GGML_TYPE_F16) {
if (v_type == GGML_TYPE_F16) {
ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
} else {
} else if (v_type == GGML_TYPE_F32) {
memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
} else if (mxfp.v.dequantize) {
const char * v_row = mxfp_row_ptr(mxfp.v, (const char *)v->data,
ic + tk, nbv1, iv2, nbv2, iv3, nbv3);
mxfp_dequant_head(mxfp.v, v_row, iv2, v_head_soa, V32 + tk * DV, DV);
} else {
v_to_float(v_data, V32 + tk * DV, DV);
}
}
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
@ -8773,8 +9197,14 @@ static void ggml_compute_forward_flash_attn_ext_f16(
// When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking)
const bool use_ref = params->use_ref;
const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512;
// Split-KV: parallelize across KV chunks for single-query decode (token generation).
// Only for types whose tiled/one_chunk paths produce identical results (f32, f16, MXFP).
// Standard quant types (q8_0, q4_0) must use the scalar path to preserve vec_dot semantics.
const bool k_is_f32_f16_or_mxfp = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16
|| ggml_is_type_mxfp(k->type));
const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1)
&& k_is_f32_f16_or_mxfp
&& q->type == GGML_TYPE_F32 && nek1 >= 512;
if (use_split_kv_path) {
const int64_t chunk_size = (nek1 + nth - 1) / nth;
@ -8831,10 +9261,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int64_t dr = (nr + nchunk - 1) / nchunk;
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
// Tiled path: f32, f16, and MXFP only (quant types use one_chunk)
bool use_tiled = !use_ref &&
(q->type == GGML_TYPE_F32 &&
kv_is_f32_or_f16 &&
k->type == v->type &&
k_is_f32_f16_or_mxfp &&
(k->type == v->type || ggml_is_type_mxfp(k->type)) &&
neq1 >= Q_TILE_SZ);
#ifdef GGML_SIMD
use_tiled &= (DV % GGML_F32_EPR == 0);

View File

@ -189,6 +189,54 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
*s = sumf;
}
void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs);
assert(n % QK_MXFP8 == 0);
static_assert(QK_MXFP8 == QK8_0, "QK_MXFP8 and QK8_0 must be the same");
const block_mxfp8 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;
const int nb = n / QK_MXFP8;
float sumf = 0;
for (int ib = 0; ib < nb; ++ib) {
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d) * GGML_E8M0_TO_FP32(x[ib].e);
float sumi = 0;
for (int j = 0; j < QK_MXFP8; ++j) {
sumi += y[ib].qs[j] * ggml_mxfp_fp8_e4m3_to_float(x[ib].qs[j]);
}
sumf += d * sumi;
}
*s = sumf;
}
void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs);
assert(n % QK_MXFP6 == 0);
static_assert(QK_MXFP6 == QK8_0, "QK_MXFP6 and QK8_0 must be the same");
const block_mxfp6 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;
const int nb = n / QK_MXFP6;
float sumf = 0;
for (int ib = 0; ib < nb; ++ib) {
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d) * GGML_E8M0_TO_FP32(x[ib].e);
float sumi = 0;
for (int j = 0; j < QK_MXFP6; j += 4) {
uint8_t vals[4];
ggml_mxfp_unpack_fp6x4(&x[ib].qs[j * 3 / 4], vals);
for (int jj = 0; jj < 4; jj++) {
sumi += y[ib].qs[j + jj] * ggml_mxfp_fp6_e2m3_to_float(vals[jj]);
}
}
sumf += d * sumi;
}
*s = sumf;
}
void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
@ -256,6 +304,16 @@ void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
*s = sumf;
}
// Generic SoA dequant wrappers — arch-specific SIMD versions override via fallback.h.
void dequantize_row_mxfp4_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp4_soa(x, y, k);
}
void dequantize_row_mxfp8_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp8_soa(x, y, k);
}
void dequantize_row_mxfp6_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp6_soa(x, y, k);
}
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;

View File

@ -21,7 +21,6 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
@ -43,8 +42,9 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@ -76,6 +76,14 @@ void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c
void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
// SoA dequant (SIMD-dispatched, CPU backend)
void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp6_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp4_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp8_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp6_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);

View File

@ -430,59 +430,25 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
// E8M0 shared exponent to float: returns 2^(x - 127).
// Canonical implementation is ggml_mxfp_e8m0_to_fp32 in ggml-common.h.
// This thin wrapper exists because not all callers include ggml-common.h.
// MUST stay in sync — if you change the logic, change ggml-common.h too.
//
// E8M0 = 255 is NaN per MX spec; clamped to 254 (max finite) to match
// the encode path which also clamps to 254, preventing Inf * 0 = NaN.
static inline float ggml_e8m0_to_fp32(uint8_t x) {
uint32_t bits; // Stores the raw bit representation of the float
// Handle special case for minimum exponent (denormalized float)
if (x == 0) {
// Bit pattern for 2^(-127):
// - Sign bit: 0 (positive)
// - Exponent: 0 (denormalized number)
// - Mantissa: 0x400000 (0.5 in fractional form)
// Value = 0.5 * 2^(-126) = 2^(-127)
bits = 0x00400000;
}
// note: disabled as we don't need to handle NaNs
//// Handle special case for NaN (all bits set)
//else if (x == 0xFF) {
// // Standard quiet NaN pattern:
// // - Sign bit: 0
// // - Exponent: all 1s (0xFF)
// // - Mantissa: 0x400000 (quiet NaN flag)
// bits = 0x7FC00000;
//}
// Normalized values (most common case)
else {
// Construct normalized float by shifting exponent into position:
// - Exponent field: 8 bits (positions 30-23)
// - Mantissa: 0 (implicit leading 1)
// Value = 2^(x - 127)
bits = (uint32_t) x << 23;
}
float result; // Final float value
// Safely reinterpret bit pattern as float without type-punning issues
if (x == 255) { x = 254; }
uint32_t bits = (x == 0) ? 0x00400000u : ((uint32_t)x << 23);
float result;
memcpy(&result, &bits, sizeof(float));
return result;
}
// Equal to ggml_e8m0_to_fp32/2
// Useful with MXFP4 quantization since the E0M2 values are doubled
// E8M0 to float/2: returns 2^(x - 128).
static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
uint32_t bits;
// For x < 2: use precomputed denormal patterns
if (x < 2) {
// 0x00200000 = 2^(-128), 0x00400000 = 2^(-127)
bits = 0x00200000 << x;
}
// For x >= 2: normalized exponent adjustment
else {
// 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1)
bits = (uint32_t)(x - 1) << 23;
}
// Note: NaNs are not handled here
if (x == 255) { x = 254; }
uint32_t bits = (x < 2) ? (0x00200000u << x) : ((uint32_t)(x - 1) << 23);
float result;
memcpy(&result, &bits, sizeof(float));
return result;
@ -491,23 +457,26 @@ static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
#define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)
#define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)
// UE4M3: unsigned, 4 exp bits (bias=7), 3 mantissa bits
// Returns value * 0.5 to match kvalues_mxfp4 convention (kvalues = 2 * E2M1_float)
// UE4M3 (unsigned E4M3): 4 exponent bits (bias 7), 3 mantissa bits.
// Returns value * 0.5 to match kvalues_mxfp4 convention (kvalues = 2 * E2M1_float).
static inline float ggml_ue4m3_to_fp32(uint8_t x) {
if (x == 0 || x == 0x7F) {
return 0.0f;
return 0.0f; // zero and NaN → 0
}
int exp = (x >> 3) & 0xF;
int man = x & 0x7;
float raw;
if (exp == 0) {
// subnormal: value = man * 2^(1 - bias - mantissa_bits) = man * 2^(-9)
raw = ldexpf((float) man, -9);
} else {
// normalized: value = (1 + man/8) * 2^(exp - 7)
raw = ldexpf(1.0f + (float) man / 8.0f, exp - 7);
}
return raw * 0.5f;
}
// Float32 to UE4M3 with round-to-nearest.
static inline uint8_t ggml_fp32_to_ue4m3(float x) {
if (!(x > 0.0f)) {
return 0;
@ -521,7 +490,7 @@ static inline uint8_t ggml_fp32_to_ue4m3(float x) {
int fp32_man = (bits >> 20) & 0x7;
int ue4m3_exp = fp32_exp + 7;
if (ue4m3_exp <= 0) {
// subnormal: value = man * 2^-9, man = round(x * 2^9)
// subnormal: value = man * 2^(-9), so man = round(x * 512)
int man = (int) (x * 512.0f + 0.5f);
if (man > 7) {
man = 7;

View File

@ -1010,6 +1010,19 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
}
}
// MXFP8/MXFP6: no Metal shaders yet reject for all ops.
// MXFP4: has AoS shaders (MUL_MAT, GET_ROWS) but no SoA/flash attention support yet.
for (size_t i = 0, n = 3; i < n; ++i) {
if (op->src[i] != NULL && ggml_is_type_mxfp(op->src[i]->type)) {
if (op->src[i]->type != GGML_TYPE_MXFP4) {
return false;
}
if (op->op == GGML_OP_FLASH_ATTN_EXT || op->op == GGML_OP_SET_ROWS) {
return false;
}
}
}
switch (op->op) {
case GGML_OP_SCALE:
case GGML_OP_FILL:

View File

@ -257,50 +257,158 @@ void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_REST
}
}
// ====================== MXFP element conversions (wrappers around ggml-common.h)
float fp8_e4m3_to_float(uint8_t v) { return ggml_mxfp_fp8_e4m3_to_float(v); }
uint8_t float_to_fp8_e4m3_rn(float x) { return ggml_mxfp_float_to_fp8_e4m3(x); }
// ====================== MXFP quantization infrastructure
typedef struct {
int emax_offset; // type-specific offset to max representable exponent
int qs_per_block; // quantized scalar bytes per 32-element block
int bits_per_elem; // 8 = byte-aligned, 6 = packed via fp6x4
uint8_t (*to_elem)(float);
float (*to_float)(uint8_t);
} mxfp_elem_traits_t;
static inline int best_index_mxfp4(float x, float e);
// E8M0 shared exponent: round(log2(amax)) — no MSE search needed.
static inline uint8_t mxfp_compute_e8m0(const float * x, int qk, int emax_offset) {
float amax = 0.0f;
for (int j = 0; j < qk; j++) {
const float a = fabsf(x[j]);
if (a > amax) amax = a;
}
if (amax == 0.0f) return 0;
const int e = ggml_mxfp_e8m0_base_estimate(amax, emax_offset);
return (uint8_t)(e < 0 ? 0 : (e > 254 ? 254 : e));
}
static inline int best_index_mxfp4(float x, float e) {
int best_index = 0;
float best_err = fabsf(kvalues_mxfp4[0]*e - x);
for (int i = 1; i < 16; i++) {
float err = fabsf(kvalues_mxfp4[i]*e - x);
if (err < best_err) {
best_index = i;
best_err = err;
const float inv_e = (e > 0.0f) ? 1.0f / e : 0.0f;
const float normalized = fabsf(x) * inv_e;
int idx;
if (normalized < 0.5f) idx = 0;
else if (normalized < 1.5f) idx = 1;
else if (normalized < 2.5f) idx = 2;
else if (normalized < 3.5f) idx = 3;
else if (normalized < 5.0f) idx = 4;
else if (normalized < 7.0f) idx = 5;
else if (normalized < 10.0f) idx = 6;
else idx = 7;
return (x < 0.0f) ? (idx + 8) : idx;
}
// Per-block MXFP4 quantize: shared between AoS and SoA paths.
static inline void quantize_block_mxfp4(const float * GGML_RESTRICT src, uint8_t * GGML_RESTRICT qs, uint8_t * e_out) {
const uint8_t e = mxfp_compute_e8m0(src, QK_MXFP4, MXFP4_E2M1_EMAX_OFFSET);
const float d = GGML_E8M0_TO_FP32_HALF(e);
*e_out = e;
for (int j = 0; j < QK_MXFP4/2; ++j) {
const uint8_t x0 = best_index_mxfp4(src[0 + j], d);
const uint8_t x1 = best_index_mxfp4(src[QK_MXFP4/2 + j], d);
qs[j] = x0 | (x1 << 4);
}
}
// Per-block MXFP4 quantize round-trip: apply quantization error without materializing bytes.
// Used for Q preprocessing in flash attention — matches K's error pattern.
static inline void roundtrip_block_mxfp4(float * GGML_RESTRICT vals) {
const uint8_t e = mxfp_compute_e8m0(vals, QK_MXFP4, MXFP4_E2M1_EMAX_OFFSET);
const float d = GGML_E8M0_TO_FP32_HALF(e);
for (int j = 0; j < QK_MXFP4; ++j) {
const int idx = best_index_mxfp4(vals[j], d);
vals[j] = kvalues_mxfp4[idx] * d; // kvalues are doubled, d is halved — matches dequant
}
}
// Per-block generic MXFP quantize round-trip (MXFP8/MXFP6).
static inline void roundtrip_block_mxfp(float * GGML_RESTRICT vals, const mxfp_elem_traits_t * traits) {
const uint8_t e = mxfp_compute_e8m0(vals, 32, traits->emax_offset);
const float d = GGML_E8M0_TO_FP32(e);
const float inv_d = d > 0.0f ? 1.0f / d : 0.0f;
for (int j = 0; j < 32; ++j) {
vals[j] = traits->to_float(traits->to_elem(vals[j] * inv_d)) * d;
}
}
// Fused Hadamard + quantize round-trip: one pass, output is float with quantization error.
void mxfp4_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) {
assert(k % 32 == 0);
for (int64_t i = 0; i < k; i += 32) {
memcpy(dst + i, src + i, 32 * sizeof(float));
ggml_mxfp_hadamard_32_inplace(dst + i);
roundtrip_block_mxfp4(dst + i);
}
}
// Non-Hadamard round-trip for MXFP4 (Hadamard disabled or V cache).
void mxfp4_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) {
assert(k % 32 == 0);
for (int64_t i = 0; i < k; i += 32) {
memcpy(dst + i, src + i, 32 * sizeof(float));
roundtrip_block_mxfp4(dst + i);
}
}
// Per-block MXFP4 dequant: shared between AoS and SoA paths.
static inline void dequantize_block_mxfp4(const uint8_t * GGML_RESTRICT qs, uint8_t e, float * GGML_RESTRICT dst) {
const float d = GGML_E8M0_TO_FP32_HALF(e);
for (int j = 0; j < QK_MXFP4/2; ++j) {
dst[0 + j] = kvalues_mxfp4[qs[j] & 0x0F] * d;
dst[QK_MXFP4/2 + j] = kvalues_mxfp4[qs[j] >> 4] * d;
}
}
// Per-block generic MXFP quantize/dequant: shared between AoS and SoA for MXFP8/MXFP6.
static inline void quantize_block_mxfp(const float * GGML_RESTRICT src, uint8_t * GGML_RESTRICT qs,
uint8_t * e_out, const mxfp_elem_traits_t * traits) {
const uint8_t e = mxfp_compute_e8m0(src, 32, traits->emax_offset);
const float d = GGML_E8M0_TO_FP32(e);
const float inv_d = d > 0.0f ? 1.0f / d : 0.0f;
*e_out = e;
if (traits->bits_per_elem == 8) {
for (int j = 0; j < 32; ++j) {
qs[j] = traits->to_elem(src[j] * inv_d);
}
} else {
for (int j = 0; j < 32; j += 4) {
uint8_t vals[4];
for (int jj = 0; jj < 4; jj++) {
vals[jj] = traits->to_elem(src[j + jj] * inv_d);
}
pack_fp6x4(vals, &qs[j * 3 / 4]);
}
}
}
static inline void dequantize_block_mxfp(const uint8_t * GGML_RESTRICT qs, uint8_t e,
float * GGML_RESTRICT dst, const mxfp_elem_traits_t * traits) {
const float d = GGML_E8M0_TO_FP32(e);
if (traits->bits_per_elem == 8) {
for (int j = 0; j < 32; ++j) {
dst[j] = traits->to_float(qs[j]) * d;
}
} else {
for (int j = 0; j < 32; j += 4) {
uint8_t vals[4];
unpack_fp6x4(&qs[j * 3 / 4], vals);
for (int jj = 0; jj < 4; jj++) {
dst[j + jj] = traits->to_float(vals[jj]) * d;
}
}
}
return best_index;
}
void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {
static const int qk = QK_MXFP4;
assert(k % qk == 0);
const int nb = k / qk;
assert(k % QK_MXFP4 == 0);
const int nb = k / QK_MXFP4;
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max
for (int j = 0; j < qk; j++) {
const float v = x[i*qk + j];
if (amax < fabsf(v)) {
amax = fabsf(v);
}
}
const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0;
const float d = GGML_E8M0_TO_FP32_HALF(e);
y[i].e = e;
for (int j = 0; j < qk/2; ++j) {
const uint8_t x0 = best_index_mxfp4(x[i*qk + 0 + j], d);
const uint8_t x1 = best_index_mxfp4(x[i*qk + qk/2 + j], d);
y[i].qs[j] = x0;
y[i].qs[j] |= x1 << 4;
}
quantize_block_mxfp4(&x[i*QK_MXFP4], y[i].qs, &y[i].e);
}
}
@ -450,22 +558,10 @@ void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRI
}
void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
static const int qk = QK_MXFP4;
assert(k % qk == 0);
const int nb = k / qk;
assert(k % QK_MXFP4 == 0);
const int nb = k / QK_MXFP4;
for (int i = 0; i < nb; i++) {
const float d = GGML_E8M0_TO_FP32_HALF(x[i].e);
for (int j = 0; j < qk/2; ++j) {
const int8_t x0 = kvalues_mxfp4[x[i].qs[j] & 0x0F];
const int8_t x1 = kvalues_mxfp4[x[i].qs[j] >> 4];
y[i*qk + j + 0 ] = x0*d;
y[i*qk + j + qk/2] = x1*d;
}
dequantize_block_mxfp4(x[i].qs, x[i].e, &y[i*QK_MXFP4]);
}
}
@ -494,6 +590,203 @@ void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_REST
}
}
// ====================== Hadamard rotation
void ggml_hadamard_32_inplace(float vals[32]) {
ggml_mxfp_hadamard_32_inplace(vals);
}
float fp6_e2m3_to_float(uint8_t v) { return ggml_mxfp_fp6_e2m3_to_float(v); }
uint8_t float_to_fp6_e2m3_rn(float x) { return ggml_mxfp_float_to_fp6_e2m3(x); }
float fp6_e3m2_to_float(uint8_t v) { return ggml_mxfp_fp6_e3m2_to_float(v); }
uint8_t float_to_fp6_e3m2_rn(float x) { return ggml_mxfp_float_to_fp6_e3m2(x); }
float fp8_e5m2_to_float(uint8_t v) { return ggml_mxfp_fp8_e5m2_to_float(v); }
uint8_t float_to_fp8_e5m2_rn(float x) { return ggml_mxfp_float_to_fp8_e5m2(x); }
void pack_fp6x4(const uint8_t v[4], uint8_t out[3]) { ggml_mxfp_pack_fp6x4(v, out); }
void unpack_fp6x4(const uint8_t in[3], uint8_t v[4]) { ggml_mxfp_unpack_fp6x4(in, v); }
static const mxfp_elem_traits_t mxfp8_e4m3_traits = { MXFP8_E4M3_EMAX_OFFSET, MXFP8_SOA_QS_PER_BLOCK, 8, float_to_fp8_e4m3_rn, fp8_e4m3_to_float };
static const mxfp_elem_traits_t mxfp6_e2m3_traits = { MXFP6_E2M3_EMAX_OFFSET, MXFP6_SOA_QS_PER_BLOCK, 6, float_to_fp6_e2m3_rn, fp6_e2m3_to_float };
// MXFP8 AoS quantize/dequant — uses shared per-block helpers.
void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
for (int i = 0; i < nb; i++) {
quantize_block_mxfp(&x[i*QK_MXFP8], y[i].qs, &y[i].e, &mxfp8_e4m3_traits);
}
}
void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
for (int i = 0; i < nb; i++) {
dequantize_block_mxfp(x[i].qs, x[i].e, &y[i*QK_MXFP8], &mxfp8_e4m3_traits);
}
}
// MXFP6 AoS quantize/dequant — uses shared per-block helpers.
void quantize_row_mxfp6_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
for (int i = 0; i < nb; i++) {
quantize_block_mxfp(&x[i*QK_MXFP6], y[i].qs, &y[i].e, &mxfp6_e2m3_traits);
}
}
void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
for (int i = 0; i < nb; i++) {
dequantize_block_mxfp(x[i].qs, x[i].e, &y[i*QK_MXFP6], &mxfp6_e2m3_traits);
}
}
// ====================== SoA (Struct-of-Arrays) quantize/dequantize for flash attention
void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) {
assert(k % QK_MXFP4 == 0);
const int nb = k / QK_MXFP4;
char * qs_base = (char *)dst;
char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK);
for (int i = 0; i < nb; i++) {
uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP4_SOA_QS_PER_BLOCK));
quantize_block_mxfp4(&x[i*QK_MXFP4], qs, (uint8_t *)&e8m0_base[i]);
}
}
void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) {
assert(k % QK_MXFP4 == 0);
const int nb = k / QK_MXFP4;
const char * qs_base = (const char *)src;
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK);
for (int i = 0; i < nb; i++) {
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP4_SOA_QS_PER_BLOCK));
dequantize_block_mxfp4(qs, (uint8_t)e8m0_base[i], &y[i*QK_MXFP4]);
}
}
// Unified SoA quantize/dequantize — delegates to shared per-block helpers.
static void quantize_row_mxfp_soa_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst,
int64_t k, const mxfp_elem_traits_t * traits) {
assert(k % 32 == 0);
const int nb = k / 32;
const int qpb = traits->qs_per_block;
char * qs_base = (char *)dst;
char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, qpb);
for (int i = 0; i < nb; i++) {
uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, qpb));
quantize_block_mxfp(&x[i*32], qs, (uint8_t *)&e8m0_base[i], traits);
}
}
static void dequantize_row_mxfp_soa_impl(const void * GGML_RESTRICT src, float * GGML_RESTRICT y,
int64_t k, const mxfp_elem_traits_t * traits) {
assert(k % 32 == 0);
const int nb = k / 32;
const int qpb = traits->qs_per_block;
const char * qs_base = (const char *)src;
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, qpb);
for (int i = 0; i < nb; i++) {
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, qpb));
dequantize_block_mxfp(qs, (uint8_t)e8m0_base[i], &y[i*32], traits);
}
}
void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) {
quantize_row_mxfp_soa_impl(x, dst, k, &mxfp8_e4m3_traits);
}
void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp_soa_impl(src, y, k, &mxfp8_e4m3_traits);
}
void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) {
quantize_row_mxfp_soa_impl(x, dst, k, &mxfp6_e2m3_traits);
}
void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp_soa_impl(src, y, k, &mxfp6_e2m3_traits);
}
// Fused Hadamard + SoA quantize: one read, one write, 32-float stack buffer per block.
// Eliminates the full-row temp buffer and extra memory pass.
void quantize_row_mxfp4_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) {
assert(k % QK_MXFP4 == 0);
const int nb = k / QK_MXFP4;
char * qs_base = (char *)dst;
char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK);
for (int i = 0; i < nb; i++) {
float tmp[32];
memcpy(tmp, &x[i*QK_MXFP4], QK_MXFP4 * sizeof(float));
ggml_mxfp_hadamard_32_inplace(tmp);
uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP4_SOA_QS_PER_BLOCK));
quantize_block_mxfp4(tmp, qs, (uint8_t *)&e8m0_base[i]);
}
}
static void quantize_row_mxfp_soa_hadamard_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst,
int64_t k, const mxfp_elem_traits_t * traits) {
assert(k % 32 == 0);
const int nb = k / 32;
const int qpb = traits->qs_per_block;
char * qs_base = (char *)dst;
char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, qpb);
for (int i = 0; i < nb; i++) {
float tmp[32];
memcpy(tmp, &x[i*32], 32 * sizeof(float));
ggml_mxfp_hadamard_32_inplace(tmp);
uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, qpb));
quantize_block_mxfp(tmp, qs, (uint8_t *)&e8m0_base[i], traits);
}
}
void quantize_row_mxfp8_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) {
quantize_row_mxfp_soa_hadamard_impl(x, dst, k, &mxfp8_e4m3_traits);
}
void quantize_row_mxfp6_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) {
quantize_row_mxfp_soa_hadamard_impl(x, dst, k, &mxfp6_e2m3_traits);
}
// MXFP8/6 quantize round-trips (with and without Hadamard).
void mxfp8_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) {
assert(k % 32 == 0);
for (int64_t i = 0; i < k; i += 32) {
memcpy(dst + i, src + i, 32 * sizeof(float));
ggml_mxfp_hadamard_32_inplace(dst + i);
roundtrip_block_mxfp(dst + i, &mxfp8_e4m3_traits);
}
}
void mxfp6_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) {
assert(k % 32 == 0);
for (int64_t i = 0; i < k; i += 32) {
memcpy(dst + i, src + i, 32 * sizeof(float));
ggml_mxfp_hadamard_32_inplace(dst + i);
roundtrip_block_mxfp(dst + i, &mxfp6_e2m3_traits);
}
}
void mxfp8_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) {
assert(k % 32 == 0);
for (int64_t i = 0; i < k; i += 32) {
memcpy(dst + i, src + i, 32 * sizeof(float));
roundtrip_block_mxfp(dst + i, &mxfp8_e4m3_traits);
}
}
void mxfp6_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) {
assert(k % 32 == 0);
for (int64_t i = 0; i < k; i += 32) {
memcpy(dst + i, src + i, 32 * sizeof(float));
roundtrip_block_mxfp(dst + i, &mxfp6_e2m3_traits);
}
}
//
// 2-6 bit quantization in super-blocks
//
@ -2164,6 +2457,18 @@ size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
return nrow * ggml_row_size(GGML_TYPE_NVFP4, n_per_row);
}
size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
GGML_UNUSED(quant_weights);
quantize_row_mxfp8_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_MXFP8, n_per_row);
}
size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
GGML_UNUSED(quant_weights);
quantize_row_mxfp6_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_MXFP6, n_per_row);
}
// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {
@ -5310,6 +5615,14 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
{
VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb);
} break;
case GGML_TYPE_MXFP8:
{
VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp8, data, nb);
} break;
case GGML_TYPE_MXFP6:
{
VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp6, data, nb);
} break;
case GGML_TYPE_NVFP4:
{
// UE4M3 scales are uint8_t — all byte values are valid

View File

@ -22,8 +22,9 @@ GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 *
GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp6_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
@ -49,7 +50,28 @@ GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GG
//GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
// SoA quantize/dequantize for flash attention
GGML_API void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
GGML_API void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
GGML_API void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
GGML_API void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
// Fused Hadamard + SoA quantize (one pass, no temp buffer)
GGML_API void quantize_row_mxfp4_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
GGML_API void quantize_row_mxfp8_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
GGML_API void quantize_row_mxfp6_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
// Quantize round-trip: apply quantization error to floats without materializing bytes.
// Hadamard variants include the rotation. Used for Q preprocessing in flash attention.
GGML_API void mxfp4_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k);
GGML_API void mxfp8_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k);
GGML_API void mxfp6_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k);
GGML_API void mxfp4_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k);
GGML_API void mxfp8_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k);
GGML_API void mxfp6_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k);
GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
@ -97,7 +119,30 @@ GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTR
GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
// MXFP element converters
GGML_API float fp8_e4m3_to_float(uint8_t v);
GGML_API uint8_t float_to_fp8_e4m3_rn(float x);
GGML_API float fp8_e5m2_to_float(uint8_t v);
GGML_API uint8_t float_to_fp8_e5m2_rn(float x);
// no NaN/Inf in FP6 — all bit patterns are valid numbers
GGML_API float fp6_e2m3_to_float(uint8_t v);
GGML_API uint8_t float_to_fp6_e2m3_rn(float x);
// no NaN/Inf — exp=7 is a valid normal value (unlike IEEE-754)
GGML_API float fp6_e3m2_to_float(uint8_t v);
GGML_API uint8_t float_to_fp6_e3m2_rn(float x);
// Pack/unpack 4 six-bit values into 3 bytes
GGML_API void pack_fp6x4(const uint8_t v[4], uint8_t out[3]);
GGML_API void unpack_fp6x4(const uint8_t in[3], uint8_t v[4]);
// Block-32 Walsh-Hadamard transform, normalized by 1/sqrt(32)
GGML_API void ggml_hadamard_32_inplace(float vals[32]);
GGML_API void iq2xs_init_impl(enum ggml_type type);
GGML_API void iq2xs_free_impl(enum ggml_type type);

View File

@ -726,6 +726,22 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
.to_float = (ggml_to_float_t) dequantize_row_nvfp4,
.from_float_ref = (ggml_from_float_t)quantize_row_nvfp4_ref,
},
[GGML_TYPE_MXFP8] = {
.type_name = "mxfp8",
.blck_size = QK_MXFP8,
.type_size = sizeof(block_mxfp8),
.is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_mxfp8,
.from_float_ref = (ggml_from_float_t)quantize_row_mxfp8_ref,
},
[GGML_TYPE_MXFP6] = {
.type_name = "mxfp6",
.blck_size = QK_MXFP6,
.type_size = sizeof(block_mxfp6),
.is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_mxfp6,
.from_float_ref = (ggml_from_float_t)quantize_row_mxfp6_ref,
},
[GGML_TYPE_Q2_K] = {
.type_name = "q2_K",
.blck_size = QK_K,
@ -1312,6 +1328,30 @@ bool ggml_is_quantized(enum ggml_type type) {
return type_traits[type].is_quantized;
}
bool ggml_is_type_mxfp(enum ggml_type type) {
return type == GGML_TYPE_MXFP4 ||
type == GGML_TYPE_MXFP8 ||
type == GGML_TYPE_MXFP6;
}
bool ggml_mxfp_use_hadamard(enum ggml_type type) {
switch (type) {
case GGML_TYPE_MXFP4: return MXFP_USE_HADAMARD_E2M1;
case GGML_TYPE_MXFP8: return MXFP_USE_HADAMARD_E4M3;
case GGML_TYPE_MXFP6: return MXFP_USE_HADAMARD_E2M3;
default: return false;
}
}
int ggml_mxfp_qs_per_block(enum ggml_type type) {
switch (type) {
case GGML_TYPE_MXFP4: return MXFP_QS_PER_BLOCK_E2M1;
case GGML_TYPE_MXFP8: return MXFP_QS_PER_BLOCK_E4M3;
case GGML_TYPE_MXFP6: return MXFP_QS_PER_BLOCK_E2M3;
default: return 0;
}
}
const char * ggml_op_name(enum ggml_op op) {
return GGML_OP_NAME[op];
}
@ -1387,7 +1427,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
case GGML_FTYPE_MOSTLY_MXFP4: wtype = GGML_TYPE_MXFP4; break;
case GGML_FTYPE_MOSTLY_MXFP4_E2M1: wtype = GGML_TYPE_MXFP4_E2M1; break;
case GGML_FTYPE_MOSTLY_NVFP4: wtype = GGML_TYPE_NVFP4; break;
case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
@ -7655,8 +7695,10 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_MXFP8: result = quantize_mxfp8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_MXFP6: result = quantize_mxfp6(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;

View File

@ -135,7 +135,18 @@ llama_kv_cache::llama_kv_cache(
const bool has_k = true;
const bool has_v = !is_mla;
ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream) : nullptr;
// MXFP: align block count to 16 for cp.async
uint32_t n_embd_k_alloc = n_embd_k_gqa;
const bool is_mxfp_k = ggml_is_type_mxfp(type_k);
if (is_mxfp_k) {
const int qk = (int)ggml_blck_size(type_k);
GGML_ASSERT(n_embd_k_gqa % qk == 0 && "MXFP K cache requires n_embd_k_gqa divisible by block size");
const int blocks = (int)n_embd_k_gqa / qk;
const int blocks_aligned = (blocks + 15) & ~15;
n_embd_k_alloc = (uint32_t)(blocks_aligned * qk);
}
ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_alloc, kv_size, n_stream) : nullptr;
ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream) : nullptr;
has_k && ggml_format_name(k, "cache_k_l%d", il);
@ -1025,19 +1036,15 @@ ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_k
auto * k = layers[ikv].k;
const uint64_t kv_size = get_size();
const uint64_t n_embd_k_gqa = k->ne[0];
assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il));
// note: for MXFP types, k->ne[0] may be padded for block alignment; use nb[] for strides
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
return ggml_view_4d(ctx, k,
hparams.n_embd_head_k(il), hparams.n_head_kv(il), n_kv, ns,
ggml_row_size(k->type, hparams.n_embd_head_k(il)),
ggml_row_size(k->type, n_embd_k_gqa),
ggml_row_size(k->type, n_embd_k_gqa*kv_size),
ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
k->nb[1],
k->nb[2],
k->nb[2]*sinfo.s0);
}
ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
@ -1092,19 +1099,35 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm
k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_tokens, k_cur->nb[2], 0);
const int64_t n_stream = k->ne[2];
const int64_t kv_size = get_size();
if (n_stream > 1) {
const int64_t kv_size = get_size();
assert(n_embd_gqa == k->ne[0]);
assert(kv_size == k->ne[1]);
assert(kv_size == k->ne[1]);
// merge the buffer across all streams because the idxs are global
k = ggml_reshape_2d(ctx, k, n_embd_gqa, kv_size*n_stream);
// note: use view_2d to preserve nb[1] (includes MXFP alignment padding)
k = ggml_view_2d(ctx, k, k->ne[0], kv_size*n_stream, k->nb[1], 0);
}
const bool is_mxfp = ggml_is_type_mxfp(k->type);
// for MXFP: ne[0] may be padded, narrow view to n_embd_gqa while keeping row stride
ggml_tensor * k_dst = k;
if (is_mxfp) {
k_dst = ggml_view_2d(ctx, k, n_embd_gqa, k->ne[1], k->nb[1], 0);
}
// store the current K values into the cache
return ggml_set_rows(ctx, k, k_cur, k_idxs);
ggml_tensor * result = ggml_set_rows(ctx, k_dst, k_cur, k_idxs);
// enable Hadamard rotation for MXFP K cache (QuaRot arXiv:2404.00456, BRQ arXiv:2511.04214)
// skipped when DK != DV (MLA) and for E5M2/E3M2 (2-bit mantissa, no benefit).
// condition must match flash attention read path (ops.cpp: DK == DV).
if (is_mxfp && hparams.n_embd_head_k(il) == hparams.n_embd_head_v(il) && ggml_mxfp_use_hadamard(k->type)) {
((int32_t *)result->op_params)[0] = 1;
}
return result;
}
ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {

View File

@ -252,6 +252,7 @@ if (NOT GGML_BACKEND_DL)
# these tests use the backends directly and cannot be built with dynamic loading
llama_build_and_test(test-barrier.cpp)
llama_build_and_test(test-quantize-fns.cpp)
target_include_directories(test-quantize-fns PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src)
llama_build_and_test(test-quantize-perf.cpp)
llama_build_and_test(test-rope.cpp)
endif()

View File

@ -150,6 +150,87 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
}
}
// MXFP SoA quantization functions
extern "C" {
void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
}
typedef void (*mxfp_soa_quantize_fn)(const float *, void *, int64_t);
typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t);
struct mxfp_soa_fns {
ggml_type type;
mxfp_soa_quantize_fn quantize;
mxfp_soa_dequantize_fn dequantize;
};
static const mxfp_soa_fns mxfp_soa_table[] = {
{ GGML_TYPE_MXFP4, quantize_row_mxfp4_soa, dequantize_row_mxfp4_soa },
{ GGML_TYPE_MXFP8, quantize_row_mxfp8_soa, dequantize_row_mxfp8_soa },
{ GGML_TYPE_MXFP6, quantize_row_mxfp6_soa, dequantize_row_mxfp6_soa },
};
static const mxfp_soa_fns * get_mxfp_soa(ggml_type type) {
for (const auto & e : mxfp_soa_table) {
if (e.type == type) return &e;
}
return nullptr;
}
// init MXFP tensor with SoA layout
static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
GGML_ASSERT(ggml_is_type_mxfp(tensor->type));
const auto * soa = get_mxfp_soa(tensor->type);
GGML_ASSERT(soa && "unsupported MXFP type for SoA init");
const int64_t DK = tensor->ne[0];
const size_t row_sz = ggml_row_size(tensor->type, DK);
// multihead: heads packed contiguously
const bool multihead = (tensor->nb[2] == row_sz) && (tensor->ne[2] > 1);
std::default_random_engine gen(42);
std::uniform_real_distribution<float> dist(min, max);
std::vector<uint8_t> buf(ggml_nbytes(tensor), 0);
if (multihead) {
// all heads at one position share one SoA region
const int64_t n_heads = tensor->ne[2];
const int64_t soa_elems = n_heads * DK;
std::vector<float> region(soa_elems);
for (int64_t i3 = 0; i3 < tensor->ne[3]; i3++) {
for (int64_t i1 = 0; i1 < tensor->ne[1]; i1++) {
size_t offset = i3*tensor->nb[3] + i1*tensor->nb[1];
for (int64_t j = 0; j < soa_elems; j++) { region[j] = dist(gen); }
soa->quantize(region.data(), buf.data() + offset, soa_elems);
}
}
} else {
// per-head SoA: each head independently packed
std::vector<float> region(DK);
for (int64_t i3 = 0; i3 < tensor->ne[3]; i3++) {
for (int64_t i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int64_t i1 = 0; i1 < tensor->ne[1]; i1++) {
size_t offset = i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1];
for (int64_t j = 0; j < DK; j++) { region[j] = dist(gen); }
soa->quantize(region.data(), buf.data() + offset, DK);
}
}
}
}
ggml_backend_tensor_set(tensor, buf.data(), 0, buf.size());
}
// generate an F16 mask where certain blocks are randomly masked with -INF value
static void init_tensor_kq_mask(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
GGML_ASSERT(tensor->type == GGML_TYPE_F16);
@ -239,11 +320,27 @@ static std::vector<float> tensor_to_float(const ggml_tensor * t) {
size_t bs = ggml_blck_size(t->type);
std::vector<float> vq(ggml_blck_size(t->type));
bool quantized = ggml_is_quantized(t->type);
const bool is_mxfp = ggml_is_type_mxfp(t->type);
mxfp_soa_dequantize_fn mxfp_dequant_soa = nullptr;
std::vector<float> mxfp_row_f32;
if (is_mxfp) {
const auto * soa_fns = get_mxfp_soa(t->type);
GGML_ASSERT(soa_fns && "unsupported MXFP type in tensor_to_float");
mxfp_dequant_soa = soa_fns->dequantize;
mxfp_row_f32.resize(t->ne[0]);
}
// access elements by index to avoid gaps in views
for (int64_t i3 = 0; i3 < t->ne[3]; i3++) {
for (int64_t i2 = 0; i2 < t->ne[2]; i2++) {
for (int64_t i1 = 0; i1 < t->ne[1]; i1++) {
if (is_mxfp) {
size_t row_off = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1];
mxfp_dequant_soa(&buf[row_off], mxfp_row_f32.data(), t->ne[0]);
tv.insert(tv.end(), mxfp_row_f32.begin(), mxfp_row_f32.end());
continue;
}
for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) {
size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0];
if (t->type == GGML_TYPE_F16) {
@ -2309,8 +2406,12 @@ struct test_set_rows : public test_case {
const std::array<int, 2> nr23; // broadcast only dims 2 and 3
const int r; // rows to set
const bool v; // view (non-contiguous src1)
const bool hadamard; // apply Walsh-Hadamard rotation before quantization
std::string vars() override {
if (hadamard) {
return VARS_TO_STR6(type, type_idx, ne, nr23, r, v) + ",hadamard=1";
}
return VARS_TO_STR6(type, type_idx, ne, nr23, r, v);
}
@ -2318,8 +2419,8 @@ struct test_set_rows : public test_case {
ggml_type type_idx,
std::array<int64_t, 4> ne,
std::array<int, 2> nr23,
int r, bool v = false)
: type(type), type_idx(type_idx), ne(ne), nr23(nr23), r(r), v(v) {}
int r, bool v = false, bool hadamard = false)
: type(type), type_idx(type_idx), ne(ne), nr23(nr23), r(r), v(v), hadamard(hadamard) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
@ -2338,6 +2439,11 @@ struct test_set_rows : public test_case {
}
ggml_tensor * out = ggml_set_rows(ctx, dst, src, row_idxs);
if (hadamard) {
((int32_t *)out->op_params)[0] = 1;
}
ggml_set_name(out, "out");
return out;
@ -2351,6 +2457,10 @@ struct test_set_rows : public test_case {
}
init_set_rows_row_ids(t, ne[1]);
} else if (ggml_is_type_mxfp(t->type)) {
// MXFP dst tensors must use SoA layout — set_rows writes SoA,
// and tensor_to_float reads back assuming SoA for MXFP types.
init_tensor_mxfp_soa(t);
} else {
init_tensor_uniform(t);
}
@ -6180,9 +6290,14 @@ struct test_flash_attn_ext : public test_case {
const ggml_prec prec;
const ggml_type type_KV;
const ggml_type type_V; // V type, defaults to type_KV for same-type K/V
std::array<int32_t, 4> permute;
std::string vars() override {
if (type_V != type_KV) {
return VARS_TO_STR13(hsk, hsv, nh, nr23, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, permute)
+ ",type_V=" + ggml_type_name(type_V);
}
return VARS_TO_STR13(hsk, hsv, nh, nr23, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, permute);
}
@ -6199,12 +6314,14 @@ struct test_flash_attn_ext : public test_case {
test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, std::array<int64_t, 2> nr23 = {1, 1}, int64_t kv = 96, int64_t nb = 8,
bool mask = true, bool sinks = false, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
: hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), sinks(sinks), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3},
ggml_type type_V_override = GGML_TYPE_COUNT)
: hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), sinks(sinks), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV),
type_V(type_V_override == GGML_TYPE_COUNT ? type_KV : type_V_override), permute(permute) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_KV));
const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_V));
auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, bool is_view) -> ggml_tensor * {
int64_t ne[4] = {ne0, ne1, ne2, ne3};
@ -6242,7 +6359,7 @@ struct test_flash_attn_ext : public test_case {
// - https://github.com/ggml-org/llama.cpp/pull/18986
v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], 0);
} else {
v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache
v = create_permuted(type_V, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache
}
ggml_set_name(v, "v");
@ -6273,6 +6390,8 @@ struct test_flash_attn_ext : public test_case {
init_tensor_uniform(t, -10.0f, 10.0f);
} else if (strcmp(t->name, "m") == 0) {
init_tensor_kq_mask(t);
} else if (ggml_is_type_mxfp(t->type)) {
init_tensor_mxfp_soa(t);
} else {
init_tensor_uniform(t);
}
@ -7279,7 +7398,7 @@ static const ggml_type all_types[] = {
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
GGML_TYPE_Q8_0,
GGML_TYPE_MXFP4,
GGML_TYPE_MXFP4, GGML_TYPE_MXFP8, GGML_TYPE_MXFP6,
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
GGML_TYPE_Q6_K,
@ -7295,7 +7414,7 @@ static const ggml_type base_types[] = {
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1, // for I8MM tests
GGML_TYPE_Q4_K,
GGML_TYPE_MXFP4, // TODO: or "other"
GGML_TYPE_MXFP4,
GGML_TYPE_IQ2_XXS
};
@ -7413,6 +7532,17 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
// SET_ROWS with Hadamard rotation (exercises the op_params[0] flag used by MXFP KV cache)
for (ggml_type type : {GGML_TYPE_MXFP4, GGML_TYPE_MXFP8, GGML_TYPE_MXFP6}) {
// ne[0] must be divisible by 32 (Hadamard block size)
test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 128, 5, 1, 1 }, { 1, 1 }, 1, false, true));
test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 5, 1, 3 }, { 1, 1 }, 1, false, true));
// multi-row, broadcast, views
test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 128, 5, 1, 1 }, { 1, 1 }, 1, true, true));
test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 11, 1, 1 }, { 2, 3 }, 7, false, true));
test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 512, 5, 3, 1 }, { 1, 1 }, 1, false, true));
}
for (int mode : { GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_VISION }) {
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
for (int ne2 : {1, 8, 512}) {
@ -8603,8 +8733,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (int nb : { 1, 3, 32, 75, }) {
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72) continue;
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0,
GGML_TYPE_MXFP4, GGML_TYPE_MXFP8, GGML_TYPE_MXFP6,
}) {
// Non-F16 types: test at D=64, D=72, and D=128.
if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72 && hsk != 128) continue;
// MXFP types require D % 32 == 0, skip D=72.
if (ggml_is_type_mxfp(type_KV) && hsk == 72) continue;
test_cases.emplace_back(new test_flash_attn_ext(
hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));
// run fewer test cases permuted
@ -8626,6 +8761,25 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
// MXFP-specific K/V type combinations (mixed and same-type)
// Mixed: mxfp8 K + mxfp4 V, mxfp6 K + mxfp4 V (our recommended configs)
for (ggml_type type_K : {GGML_TYPE_MXFP8, GGML_TYPE_MXFP6}) {
for (ggml_type type_V : {GGML_TYPE_MXFP4}) {
if (type_K == type_V) continue;
for (int nb : {1, 3, 32}) {
test_cases.emplace_back(new test_flash_attn_ext(
128, 128, 4, {1, 1}, 512, nb, true, false, 0.0f, 0.0f, GGML_PREC_F32, type_K, {0, 1, 2, 3}, type_V));
}
}
}
// Same-type: mxfp8/mxfp8, mxfp6/mxfp6
for (ggml_type type_KV : {GGML_TYPE_MXFP8, GGML_TYPE_MXFP6}) {
for (int nb : {1, 3, 32}) {
test_cases.emplace_back(new test_flash_attn_ext(
128, 128, 4, {1, 1}, 512, nb, true, false, 0.0f, 0.0f, GGML_PREC_F32, type_KV, {0, 1, 2, 3}, type_KV));
}
}
test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, { 10, 5, 4, 3}));
test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, {30000, 1, 1, 1}));
test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, { 10, 5, 4, 3}));

View File

@ -2,6 +2,11 @@
#include "ggml.h"
#include "ggml-cpu.h"
#include "ggml-quants.h"
#define GGML_COMMON_DECL_CPP
#define GGML_COMMON_IMPL_CPP
#include "ggml-common.h"
#undef NDEBUG
#include <assert.h>
@ -21,9 +26,19 @@ constexpr float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f;
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f;
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS = 0.0050f;
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_FP4 = 0.0030f;
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 = 0.0070f;
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 = 0.0040f;
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 = 0.0020f;
// MXFP Hadamard pipeline thresholds (mxfp_rmse, which computes sqrt(sum/n)).
// These represent actual RMSE through the full KV cache write/read path.
constexpr float MAX_MXFP_PIPELINE_ERROR_MXFP4 = 0.40f;
constexpr float MAX_MXFP_PIPELINE_ERROR_MXFP8 = 0.08f;
constexpr float MAX_MXFP_PIPELINE_ERROR_MXFP6 = 0.30f;
constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f;
constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f;
constexpr float MAX_DOT_PRODUCT_ERROR_FP4 = 0.03f;
constexpr float MAX_DOT_PRODUCT_ERROR_MXFP = 0.04f;
constexpr float MAX_DOT_PRODUCT_ERROR_TERNARY = 0.15f;
static const char* RESULT_STR[] = {"ok", "FAILED"};
@ -46,6 +61,16 @@ static float array_rmse(const float * a1, const float * a2, size_t n) {
return sqrtf(sum) / n;
}
// MXFP RMSE: sqrt(sum/n), used with MAX_MXFP_PIPELINE_ERROR_* thresholds
static float mxfp_rmse(const float * a1, const float * a2, size_t n) {
double sum = 0;
for (size_t i = 0; i < n; i++) {
double diff = a1[i] - a2[i];
sum += diff * diff;
}
return sqrtf((float)(sum / n));
}
// Total quantization error on test data
static float total_quantization_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data) {
std::vector<uint8_t> tmp_q(2*test_size);
@ -152,7 +177,10 @@ int main(int argc, char * argv[]) {
type == GGML_TYPE_Q3_K ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
type == GGML_TYPE_IQ3_S ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS :
type == GGML_TYPE_NVFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_FP4 : MAX_QUANTIZATION_TOTAL_ERROR;
type == GGML_TYPE_NVFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_FP4 :
type == GGML_TYPE_MXFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 :
type == GGML_TYPE_MXFP6 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 :
type == GGML_TYPE_MXFP8 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR;
failed = !(total_error < max_quantization_error);
num_failed += failed;
if (failed || verbose) {
@ -174,6 +202,8 @@ int main(int argc, char * argv[]) {
? MAX_DOT_PRODUCT_ERROR_TERNARY
: type == GGML_TYPE_NVFP4
? MAX_DOT_PRODUCT_ERROR_FP4
: type == GGML_TYPE_MXFP4 || type == GGML_TYPE_MXFP6 || type == GGML_TYPE_MXFP8
? MAX_DOT_PRODUCT_ERROR_MXFP
: MAX_DOT_PRODUCT_ERROR;
failed = !(vec_dot_error < max_allowed_error);
num_failed += failed;
@ -183,6 +213,902 @@ int main(int argc, char * argv[]) {
}
}
// MXFP SoA roundtrip via traits
for (int i = 0; i < GGML_TYPE_COUNT; i++) {
ggml_type type = (ggml_type) i;
const auto * qfns_cpu = ggml_get_type_traits_cpu(type);
if (!qfns_cpu->from_float_soa || !qfns_cpu->to_float_soa) {
continue;
}
const size_t buf_size = ggml_row_size(type, test_size);
std::vector<uint8_t> tmp_q(buf_size);
std::vector<float> tmp_out(test_size);
qfns_cpu->from_float_soa(test_data.data(), tmp_q.data(), test_size);
qfns_cpu->to_float_soa(tmp_q.data(), tmp_out.data(), test_size);
const float soa_error = array_rmse(test_data.data(), tmp_out.data(), test_size);
const float max_soa_error =
type == GGML_TYPE_MXFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 :
type == GGML_TYPE_MXFP6 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 :
type == GGML_TYPE_MXFP8 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR;
failed = !(soa_error < max_soa_error);
num_failed += failed;
if (failed || verbose) {
printf("%5s SoA quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], soa_error);
}
}
// MXFP traits: SoA required, MXFP6/MXFP8 are KV-cache-only (no AoS dequant)
{
const ggml_type all_mxfp_types[] = { GGML_TYPE_MXFP4, GGML_TYPE_MXFP8, GGML_TYPE_MXFP6 };
for (ggml_type type : all_mxfp_types) {
const auto * cpu = ggml_get_type_traits_cpu(type);
failed = !(cpu->from_float_soa && cpu->to_float_soa);
num_failed += failed;
if (failed || verbose) {
printf("%5s SoA traits present: %s\n", ggml_type_name(type), RESULT_STR[failed]);
}
}
// KV-cache-only types: no AoS dequant
const ggml_type kv_only_types[] = { GGML_TYPE_MXFP8, GGML_TYPE_MXFP6 };
for (ggml_type type : kv_only_types) {
const auto * cpu = ggml_get_type_traits_cpu(type);
failed = (cpu->to_float != nullptr);
num_failed += failed;
if (failed || verbose) {
printf("%5s AoS CPU to_float absent: %s\n", ggml_type_name(type), RESULT_STR[failed]);
}
}
}
// Hadamard self-inverse: H(H(x)) == x
{
float original[32], transformed[32];
for (int i = 0; i < 32; i++) {
original[i] = 0.1f + 2.0f * cosf(i + 0.5f);
transformed[i] = original[i];
}
ggml_hadamard_32_inplace(transformed);
ggml_hadamard_32_inplace(transformed); // apply twice = identity
float max_err = 0.0f;
for (int i = 0; i < 32; i++) {
float err = fabsf(transformed[i] - original[i]);
if (err > max_err) max_err = err;
}
// floating-point rounding tolerance
failed = !(max_err < 1e-5f);
num_failed += failed;
if (failed || verbose) {
printf("hadamard H(H(x))==x roundtrip: %s (max_err=%.2e)\n", RESULT_STR[failed], max_err);
}
}
// SoA SIMD vs scalar dequant
{
struct soa_cross_check {
ggml_type type;
void (*ref_dequant)(const void *, float *, int64_t);
};
const soa_cross_check checks[] = {
{ GGML_TYPE_MXFP4, dequantize_row_mxfp4_soa },
{ GGML_TYPE_MXFP8, dequantize_row_mxfp8_soa },
{ GGML_TYPE_MXFP6, dequantize_row_mxfp6_soa },
};
for (const auto & c : checks) {
const auto * cpu = ggml_get_type_traits_cpu(c.type);
if (!cpu->from_float_soa || !cpu->to_float_soa) continue;
const size_t buf_size = ggml_row_size(c.type, test_size);
std::vector<uint8_t> tmp_q(buf_size);
std::vector<float> out_ref(test_size);
std::vector<float> out_simd(test_size);
// Quantize with SoA
cpu->from_float_soa(test_data.data(), tmp_q.data(), test_size);
// Dequant with scalar reference
c.ref_dequant(tmp_q.data(), out_ref.data(), test_size);
// Dequant with CPU/SIMD path
cpu->to_float_soa(tmp_q.data(), out_simd.data(), test_size);
// Compare bitwise
int mismatches = 0;
for (size_t j = 0; j < test_size; j++) {
uint32_t a, b;
memcpy(&a, &out_ref[j], 4);
memcpy(&b, &out_simd[j], 4);
if (a != b) mismatches++;
}
failed = (mismatches > 0);
num_failed += failed;
if (failed || verbose) {
printf("%5s SoA SIMD vs scalar ref: %s (%zu/%zu match)\n",
ggml_type_name(c.type), RESULT_STR[failed],
test_size - mismatches, test_size);
}
}
}
// element converters vs canonical LUT values
{
struct lut_test {
const char * name;
const float * lut;
int count;
float (*converter)(uint8_t);
};
const lut_test lut_tests[] = {
{ "fp8_e4m3", kvalues_mxfp8_e4m3, 256, fp8_e4m3_to_float },
{ "fp8_e5m2", kvalues_mxfp8_e5m2, 256, fp8_e5m2_to_float },
{ "fp6_e2m3", kvalues_mxfp6_e2m3, 64, fp6_e2m3_to_float },
{ "fp6_e3m2", kvalues_mxfp6_e3m2, 64, fp6_e3m2_to_float },
};
for (const auto & t : lut_tests) {
int mismatches = 0;
for (int i = 0; i < t.count; i++) {
const float converter_val = t.converter((uint8_t)i);
const float lut_val = t.lut[i];
// both NaN = match
if (isnan(converter_val) && isnan(lut_val)) continue;
if (converter_val != lut_val) {
if (mismatches == 0 || verbose) {
printf(" %s LUT mismatch at [%d]: converter=%.8g, lut=%.8g\n",
t.name, i, converter_val, lut_val);
}
mismatches++;
}
}
failed = (mismatches > 0);
num_failed += failed;
if (failed || verbose) {
printf("%5s converter vs LUT: %s (%d/%d values match)\n",
t.name, RESULT_STR[failed], t.count - mismatches, t.count);
}
}
// FP4 E2M1
{
int mismatches = 0;
for (int i = 0; i < 16; i++) {
const float converter_val = ggml_mxfp_fp4_e2m1_to_float((uint8_t)i);
const float lut_val = kvalues_mxfp4_float[i];
if (converter_val != lut_val) {
if (mismatches == 0 || verbose) {
printf(" fp4_e2m1 LUT mismatch at [%d]: converter=%.8g, lut=%.8g\n",
i, converter_val, lut_val);
}
mismatches++;
}
}
failed = (mismatches > 0);
num_failed += failed;
if (failed || verbose) {
printf("fp4_e2m1 converter vs LUT: %s (%d/16 values match)\n",
RESULT_STR[failed], 16 - mismatches);
}
}
}
// element converter edge cases (expected values validated against LUTs)
{
struct conv_check {
const char * name;
float input;
uint8_t expected_bits;
bool is_saturation; // true = input overflows, expected_bits is max finite
const float * lut; // canonical LUT to validate expected_bits against (NULL for FP4)
float (*to_float)(uint8_t);
uint8_t (*to_quant)(float);
};
const conv_check checks[] = {
// FP4 E2M1 -[S(1)|E(2)|M(1)], bias=0
{ "fp4 zero", 0.0f, 0x00, false, nullptr, nullptr, nullptr },
{ "fp4 sub 0.5", 0.5f, 0x01, false, nullptr, nullptr, nullptr },
{ "fp4 norm 1.0", 1.0f, 0x02, false, nullptr, nullptr, nullptr },
{ "fp4 max 6.0", 6.0f, 0x07, false, nullptr, nullptr, nullptr },
{ "fp4 neg -3.0", -3.0f, 0x0D, false, nullptr, nullptr, nullptr },
{ "fp4 sat 100", 100.0f, 0x07, true, nullptr, nullptr, nullptr },
// FP8 E4M3 -[S(1)|E(4)|M(3)], bias=7
{ "e4m3 zero", 0.0f, 0x00, false, kvalues_mxfp8_e4m3, fp8_e4m3_to_float, float_to_fp8_e4m3_rn },
{ "e4m3 sub", 1.f/512, 0x01, false, kvalues_mxfp8_e4m3, fp8_e4m3_to_float, float_to_fp8_e4m3_rn },
{ "e4m3 max 448", 448.0f, 0x7E, false, kvalues_mxfp8_e4m3, fp8_e4m3_to_float, float_to_fp8_e4m3_rn },
{ "e4m3 sat 500", 500.0f, 0x7E, true, kvalues_mxfp8_e4m3, fp8_e4m3_to_float, float_to_fp8_e4m3_rn },
{ "e4m3 neg -1", -1.0f, 0xB8, false, kvalues_mxfp8_e4m3, fp8_e4m3_to_float, float_to_fp8_e4m3_rn },
// FP6 E2M3 -[S(1)|E(2)|M(3)], no NaN/Inf
{ "e2m3 zero", 0.0f, 0x00, false, kvalues_mxfp6_e2m3, fp6_e2m3_to_float, float_to_fp6_e2m3_rn },
{ "e2m3 sub", 0.125f, 0x01, false, kvalues_mxfp6_e2m3, fp6_e2m3_to_float, float_to_fp6_e2m3_rn },
{ "e2m3 max 7.5", 7.5f, 0x1F, false, kvalues_mxfp6_e2m3, fp6_e2m3_to_float, float_to_fp6_e2m3_rn },
{ "e2m3 sat 100", 100.0f, 0x1F, true, kvalues_mxfp6_e2m3, fp6_e2m3_to_float, float_to_fp6_e2m3_rn },
// FP6 E3M2 -[S(1)|E(3)|M(2)], no NaN/Inf, exp=7 is NORMAL
{ "e3m2 zero", 0.0f, 0x00, false, kvalues_mxfp6_e3m2, fp6_e3m2_to_float, float_to_fp6_e3m2_rn },
{ "e3m2 sub", 0.0625f, 0x01, false, kvalues_mxfp6_e3m2, fp6_e3m2_to_float, float_to_fp6_e3m2_rn },
{ "e3m2 max 28.0", 28.0f, 0x1F, false, kvalues_mxfp6_e3m2, fp6_e3m2_to_float, float_to_fp6_e3m2_rn },
{ "e3m2 exp7 16", 16.0f, 0x1C, false, kvalues_mxfp6_e3m2, fp6_e3m2_to_float, float_to_fp6_e3m2_rn },
// FP8 E5M2 -[S(1)|E(5)|M(2)], bias=15
{ "e5m2 zero", 0.0f, 0x00, false, kvalues_mxfp8_e5m2, fp8_e5m2_to_float, float_to_fp8_e5m2_rn },
{ "e5m2 max", 57344.f, 0x7B, false, kvalues_mxfp8_e5m2, fp8_e5m2_to_float, float_to_fp8_e5m2_rn },
};
int conv_bad = 0;
// validate expected_bits against LUTs
for (const auto & c : checks) {
if (c.lut && !c.is_saturation) {
float lut_val = c.lut[c.expected_bits];
if (c.input != lut_val && !(c.input == 0.0f && lut_val == 0.0f)) {
printf(" TEST BUG %s: expected_bits=0x%02X → LUT=%.8g, but input=%.8g\n",
c.name, c.expected_bits, lut_val, c.input);
conv_bad++;
}
} else if (!c.lut && !c.is_saturation) {
float lut_val = kvalues_mxfp4_float[c.expected_bits];
if (c.input != lut_val && !(c.input == 0.0f && lut_val == 0.0f)) {
printf(" TEST BUG %s: expected_bits=0x%02X → LUT=%.8g, but input=%.8g\n",
c.name, c.expected_bits, lut_val, c.input);
conv_bad++;
}
}
}
// Now test the quantize direction
for (const auto & c : checks) {
uint8_t got;
if (c.to_quant) {
got = c.to_quant(c.input);
} else {
got = ggml_mxfp_float_to_fp4_e2m1(c.input);
}
if (got != c.expected_bits) {
if (conv_bad == 0 || verbose) {
printf(" %s: quantize(%.6g) = 0x%02X, expected 0x%02X\n",
c.name, c.input, got, c.expected_bits);
}
conv_bad++;
}
}
// FP8 E4M3: 0x7F must dequantize to NaN
{
float nan_val = fp8_e4m3_to_float(0x7F);
if (!isnan(nan_val)) {
if (conv_bad == 0 || verbose) {
printf(" e4m3 0x7F dequant: expected NaN, got %.6g\n", nan_val);
}
conv_bad++;
}
}
// FP6 E3M2: exp=7 must dequant to valid float (NOT Inf/NaN)
{
float exp7_val = fp6_e3m2_to_float(0x1F); // max: exp=7, mant=3 → 28.0
if (isnan(exp7_val) || exp7_val != 28.0f) {
if (conv_bad == 0 || verbose) {
printf(" e3m2 0x1F dequant: expected 28.0, got %.6g\n", exp7_val);
}
conv_bad++;
}
}
failed = (conv_bad > 0);
num_failed += failed;
if (failed || verbose) {
printf(" element converter edge cases: %s (%d/%d passed)\n",
RESULT_STR[failed],
(int)(sizeof(checks)/sizeof(checks[0])) + 2 - conv_bad,
(int)(sizeof(checks)/sizeof(checks[0])) + 2);
}
}
// FP6 pack/unpack round-trip
{
int pack_bad = 0;
// Test all 64 possible 6-bit values in each of the 4 positions
for (int pos = 0; pos < 4; pos++) {
for (int val = 0; val < 64; val++) {
uint8_t in[4] = {0, 0, 0, 0};
in[pos] = (uint8_t)val;
uint8_t packed[3], out[4];
pack_fp6x4(in, packed);
unpack_fp6x4(packed, out);
if (out[pos] != (uint8_t)val) {
if (pack_bad == 0 || verbose) {
printf(" fp6 pack roundtrip: pos=%d val=0x%02X → got 0x%02X\n",
pos, val, out[pos]);
}
pack_bad++;
}
// no crosstalk
for (int k = 0; k < 4; k++) {
if (k != pos && out[k] != 0) {
if (pack_bad == 0 || verbose) {
printf(" fp6 pack crosstalk: pos=%d val=0x%02X leaked to pos=%d (0x%02X)\n",
pos, val, k, out[k]);
}
pack_bad++;
}
}
}
}
// known-answer: [0x3F, 0x00, 0x3F, 0x00] -> {0x3F, 0xF0, 0x03}
{
uint8_t in[4] = {0x3F, 0x00, 0x3F, 0x00};
uint8_t packed[3];
pack_fp6x4(in, packed);
uint8_t expected[3] = {0x3F, 0xF0, 0x03};
if (packed[0] != expected[0] || packed[1] != expected[1] || packed[2] != expected[2]) {
if (pack_bad == 0 || verbose) {
printf(" fp6 known-answer: packed [%02X,%02X,%02X] expected [%02X,%02X,%02X]\n",
packed[0], packed[1], packed[2], expected[0], expected[1], expected[2]);
}
pack_bad++;
}
}
failed = (pack_bad > 0);
num_failed += failed;
if (failed || verbose) {
printf(" fp6 pack/unpack round-trip: %s\n", RESULT_STR[failed]);
}
}
// E8M0 known-answer decode + HALF vs FULL (MXFP4 uses HALF, MXFP6/8 use FULL)
{
int e8m0_bad = 0;
// Known-answer E8M0 decodes
struct { uint8_t e; float expected; } e8m0_known[] = {
{ 127, 1.0f }, // 2^(127-127) = 2^0 = 1.0
{ 128, 2.0f }, // 2^(128-127) = 2^1 = 2.0
{ 126, 0.5f }, // 2^(126-127) = 2^(-1) = 0.5
{ 254, 1.70141183e+38f }, // 2^127 (max representable)
{ 1, 1.17549435e-38f }, // 2^(-126) (min normal)
};
for (const auto & t : e8m0_known) {
float got = ggml_mxfp_e8m0_to_fp32(t.e);
if (got != t.expected) {
if (e8m0_bad == 0 || verbose) {
printf(" E8M0 decode e=%d: got %.8g, expected %.8g\n", t.e, got, t.expected);
}
e8m0_bad++;
}
}
// HALF must be exactly half of FULL for all valid exponents
for (int e = 2; e < 255; e++) {
float full = ggml_mxfp_e8m0_to_fp32((uint8_t)e);
float half = ggml_mxfp_e8m0_to_fp32_half((uint8_t)e);
if (half != full * 0.5f) {
if (e8m0_bad == 0 || verbose) {
printf(" E8M0 HALF!=FULL/2 at e=%d: half=%.8g, full/2=%.8g\n", e, half, full * 0.5f);
}
e8m0_bad++;
break; // one failure is enough to flag the pattern
}
}
failed = (e8m0_bad > 0);
num_failed += failed;
if (failed || verbose) {
printf(" E8M0 known-answer + HALF/FULL: %s\n", RESULT_STR[failed]);
}
}
// E8M0 rounding at sqrt(2) threshold
{
int round_bad = 0;
// amax=1.0: floor_log2=0, mantissa=0 → no round → e_base = 0 - 0 + 127 = 127
{
int e = ggml_mxfp_e8m0_base_estimate(1.0f, 0);
if (e != 127) {
printf(" E8M0 round: amax=1.0 → e=%d, expected 127\n", e);
round_bad++;
}
}
// amax=2.0: floor_log2=1, mantissa=0 → no round → e_base = 1 + 127 = 128
{
int e = ggml_mxfp_e8m0_base_estimate(2.0f, 0);
if (e != 128) {
printf(" E8M0 round: amax=2.0 → e=%d, expected 128\n", e);
round_bad++;
}
}
// amax just below sqrt(2): mantissa < 0x3504F3 → floor only → e=127
{
// 1.41421 has IEEE mantissa just below 0x3504F3
float below = 1.4142f;
int e = ggml_mxfp_e8m0_base_estimate(below, 0);
if (e != 127) {
printf(" E8M0 round: amax=%.6f → e=%d, expected 127 (no round)\n", below, e);
round_bad++;
}
}
// amax at sqrt(2): mantissa >= 0x3504F3 → rounds up → e=128
{
float at_sqrt2 = 1.41422f;
int e = ggml_mxfp_e8m0_base_estimate(at_sqrt2, 0);
if (e != 128) {
printf(" E8M0 round: amax=%.6f → e=%d, expected 128 (rounds up)\n", at_sqrt2, e);
round_bad++;
}
}
// Verify emax_offset shifts the result
{
int e_no_off = ggml_mxfp_e8m0_base_estimate(448.0f, 0);
int e_e4m3 = ggml_mxfp_e8m0_base_estimate(448.0f, MXFP8_E4M3_EMAX_OFFSET);
if (e_no_off - e_e4m3 != MXFP8_E4M3_EMAX_OFFSET) {
printf(" E8M0 emax_offset: diff=%d, expected %d\n",
e_no_off - e_e4m3, MXFP8_E4M3_EMAX_OFFSET);
round_bad++;
}
}
failed = (round_bad > 0);
num_failed += failed;
if (failed || verbose) {
printf(" E8M0 rounding boundary: %s\n", RESULT_STR[failed]);
}
}
// Element converter exhaustive round-trip: quantize(dequantize(i)) == i for all valid bit patterns.
// Catches asymmetries between the to_float and to_quant paths.
{
struct rt_test {
const char * name;
int count;
float (*to_float)(uint8_t);
uint8_t (*to_quant)(float);
uint8_t nan_bits; // bit pattern for NaN (0 = no NaN in format)
};
const rt_test rt_tests[] = {
{ "fp8_e4m3", 256, fp8_e4m3_to_float, float_to_fp8_e4m3_rn, 0x7F },
{ "fp8_e5m2", 256, fp8_e5m2_to_float, float_to_fp8_e5m2_rn, 0 },
{ "fp6_e2m3", 64, fp6_e2m3_to_float, float_to_fp6_e2m3_rn, 0 },
{ "fp6_e3m2", 64, fp6_e3m2_to_float, float_to_fp6_e3m2_rn, 0 },
};
for (const auto & t : rt_tests) {
int rt_bad = 0;
for (int i = 0; i < t.count; i++) {
if ((uint8_t)i == t.nan_bits) continue; // skip NaN -quantize(NaN) is implementation-defined
float f = t.to_float((uint8_t)i);
if (isnan(f) || isinf(f)) continue; // E5M2 Inf/NaN
uint8_t back = t.to_quant(f);
// Negative zero may round-trip to positive zero -both are valid
if (back != (uint8_t)i && !(f == 0.0f && t.to_float(back) == 0.0f)) {
if (rt_bad == 0 || verbose) {
printf(" %s roundtrip: 0x%02X → %.6g → 0x%02X\n",
t.name, i, f, back);
}
rt_bad++;
}
}
failed = (rt_bad > 0);
num_failed += failed;
if (failed || verbose) {
printf("%5s converter round-trip: %s (%d/%d survived)\n",
t.name, RESULT_STR[failed], t.count - rt_bad, t.count);
}
}
// FP4 E2M1: uses static inline converters (not GGML_API wrappers), only 16 values
{
int rt_bad = 0;
for (int i = 0; i < 16; i++) {
float f = ggml_mxfp_fp4_e2m1_to_float((uint8_t)i);
uint8_t back = ggml_mxfp_float_to_fp4_e2m1(f);
if (back != (uint8_t)i && !(f == 0.0f && ggml_mxfp_fp4_e2m1_to_float(back) == 0.0f)) {
if (rt_bad == 0 || verbose) {
printf(" fp4_e2m1 roundtrip: 0x%02X → %.6g → 0x%02X\n", i, f, back);
}
rt_bad++;
}
}
failed = (rt_bad > 0);
num_failed += failed;
if (failed || verbose) {
printf("fp4_e2m1 converter round-trip: %s (%d/16 survived)\n",
RESULT_STR[failed], 16 - rt_bad);
}
}
}
// E8M0 scale computation: verify base exponent is reasonable for various amax values
{
const float test_amax[] = { 0.001f, 0.1f, 1.0f, 6.0f, 100.0f, 448.0f, 10000.0f };
int bad = 0;
for (float amax : test_amax) {
// ggml_mxfp_e8m0_base_estimate returns unclamped e_base
int e_base = ggml_mxfp_e8m0_base_estimate(amax, 0);
if (e_base < 1 || e_base > 254) {
if (bad == 0 || verbose) {
printf(" E8M0 bad e_base=%d for amax=%.4f\n", e_base, amax);
}
bad++;
continue;
}
float scale = ggml_mxfp_e8m0_to_fp32((uint8_t)e_base);
// Scale should be within 2x of amax (rough sanity check)
float ratio = amax / scale;
if (ratio < 0.25f || ratio > 4.0f) {
if (bad == 0 || verbose) {
printf(" E8M0 scale=%.6g for amax=%.4f, ratio=%.4f (expected ~1)\n",
scale, amax, ratio);
}
bad++;
}
}
failed = (bad > 0);
num_failed += failed;
if (failed || verbose) {
printf(" E8M0 scale sanity check: %s (%d/%d passed)\n",
RESULT_STR[failed], (int)(sizeof(test_amax)/sizeof(test_amax[0])) - bad,
(int)(sizeof(test_amax)/sizeof(test_amax[0])));
}
}
// SoA layout: verify offset macros produce correct byte positions
{
const struct { ggml_type type; int qs_per_block; } soa_types[] = {
{ GGML_TYPE_MXFP4, MXFP4_SOA_QS_PER_BLOCK },
{ GGML_TYPE_MXFP8, MXFP8_SOA_QS_PER_BLOCK },
{ GGML_TYPE_MXFP6, MXFP6_SOA_QS_PER_BLOCK },
};
for (const auto & st : soa_types) {
for (int nblocks : { 1, 4, 8, 32 }) {
size_t expected_e8m0_off = (size_t)nblocks * st.qs_per_block;
size_t actual_e8m0_off = MXFP_SOA_E8M0_OFFSET(nblocks, st.qs_per_block);
size_t total = actual_e8m0_off + nblocks; // e8m0 region = 1 byte per block
size_t row_size = ggml_row_size(st.type, nblocks * 32);
bool offset_ok = (actual_e8m0_off == expected_e8m0_off);
bool size_ok = (total == row_size);
if (!offset_ok || !size_ok) {
failed = true;
num_failed++;
if (verbose) {
printf(" %s SoA layout nblocks=%d: e8m0_off=%zu (expected %zu), total=%zu (row_size=%zu)\n",
ggml_type_name(st.type), nblocks, actual_e8m0_off, expected_e8m0_off, total, row_size);
}
}
}
}
if (verbose) {
printf(" SoA layout offset check: %s\n", RESULT_STR[0]); // only prints failures above
}
}
// block size consistency
{
failed = !(QK_MXFP4 == 32 && QK_MXFP8 == 32 && QK_MXFP6 == 32);
num_failed += failed;
if (failed || verbose) {
printf(" MXFP block size == 32: %s (QK4=%d, QK8=%d, QK6=%d)\n",
RESULT_STR[failed], QK_MXFP4, QK_MXFP8, QK_MXFP6);
}
}
// EMAX_OFFSET produces valid E8M0 for each format's max finite value
{
struct emax_check {
const char * name;
int emax_offset;
float max_finite; // from LUT / converter
};
const emax_check emax_checks[] = {
{ "fp4_e2m1", MXFP4_E2M1_EMAX_OFFSET, 6.0f },
{ "fp6_e2m3", MXFP6_E2M3_EMAX_OFFSET, 7.5f },
{ "fp6_e3m2", MXFP6_E3M2_EMAX_OFFSET, 28.0f },
{ "fp8_e4m3", MXFP8_E4M3_EMAX_OFFSET, 448.0f },
{ "fp8_e5m2", MXFP8_E5M2_EMAX_OFFSET, 57344.0f },
};
int emax_bad = 0;
for (const auto & e : emax_checks) {
// When amax == max_finite, the base estimate must produce a valid E8M0 (1..254)
int e_base = ggml_mxfp_e8m0_base_estimate(e.max_finite, e.emax_offset);
if (e_base < 1 || e_base > 254) {
if (emax_bad == 0 || verbose) {
printf(" %s emax_offset=%d: max_finite=%.1f gives e_base=%d (out of range)\n",
e.name, e.emax_offset, e.max_finite, e_base);
}
emax_bad++;
}
}
failed = (emax_bad > 0);
num_failed += failed;
if (failed || verbose) {
printf(" EMAX_OFFSET vs format max: %s\n", RESULT_STR[failed]);
}
}
// MXFP4 AoS vs SoA: two independent code paths, same result
{
const int nelems = 64; // 2 blocks
float input[64];
for (int i = 0; i < 64; i++) {
input[i] = 0.5f + 2.0f * sinf(i * 0.7f + 0.3f);
}
// Quantize and dequant via AoS (block_mxfp4 structs)
std::vector<block_mxfp4> aos_q(nelems / QK_MXFP4);
std::vector<float> aos_out(nelems);
quantize_row_mxfp4_ref(input, aos_q.data(), nelems);
dequantize_row_mxfp4(aos_q.data(), aos_out.data(), nelems);
// Quantize and dequant via SoA
const size_t soa_buf_size = ggml_row_size(GGML_TYPE_MXFP4, nelems);
std::vector<uint8_t> soa_q(soa_buf_size);
std::vector<float> soa_out(nelems);
quantize_row_mxfp4_soa(input, soa_q.data(), nelems);
dequantize_row_mxfp4_soa(soa_q.data(), soa_out.data(), nelems);
// Compare: both paths should produce identical results
int mismatches = 0;
for (int i = 0; i < nelems; i++) {
uint32_t a, b;
memcpy(&a, &aos_out[i], 4);
memcpy(&b, &soa_out[i], 4);
if (a != b) {
if (mismatches == 0 || verbose) {
printf(" mxfp4 AoS/SoA mismatch at [%d]: AoS=%.8g, SoA=%.8g\n",
i, aos_out[i], soa_out[i]);
}
mismatches++;
}
}
failed = (mismatches > 0);
num_failed += failed;
if (failed || verbose) {
printf("mxfp4 AoS vs SoA cross-check: %s (%d/%d match)\n",
RESULT_STR[failed], nelems - mismatches, nelems);
}
}
// Hadamard + quantize + dequant + Hadamard roundtrip (KV cache write/read path)
{
struct hadamard_pipeline_check {
const char * name;
ggml_type type;
float max_err;
};
const hadamard_pipeline_check pipeline_checks[] = {
{ "mxfp4", GGML_TYPE_MXFP4, MAX_MXFP_PIPELINE_ERROR_MXFP4 },
{ "mxfp8", GGML_TYPE_MXFP8, MAX_MXFP_PIPELINE_ERROR_MXFP8 },
{ "mxfp6", GGML_TYPE_MXFP6, MAX_MXFP_PIPELINE_ERROR_MXFP6 },
};
for (const auto & p : pipeline_checks) {
const auto * cpu = ggml_get_type_traits_cpu(p.type);
std::vector<float> original(test_size);
std::vector<float> rotated(test_size);
std::vector<float> recovered(test_size);
generate_data(2.0, test_size, original.data());
// Write path: Hadamard each block, then quantize
memcpy(rotated.data(), original.data(), test_size * sizeof(float));
for (size_t b = 0; b < test_size / 32; b++) {
ggml_hadamard_32_inplace(&rotated[b * 32]);
}
const size_t buf_size = ggml_row_size(p.type, test_size);
std::vector<uint8_t> qbuf(buf_size);
cpu->from_float_soa(rotated.data(), qbuf.data(), test_size);
// Read path: dequant, then Hadamard each block (self-inverse)
cpu->to_float_soa(qbuf.data(), recovered.data(), test_size);
for (size_t b = 0; b < test_size / 32; b++) {
ggml_hadamard_32_inplace(&recovered[b * 32]);
}
float err = mxfp_rmse(original.data(), recovered.data(), test_size);
failed = !(err < p.max_err);
num_failed += failed;
if (failed || verbose) {
printf("%5s Hadamard pipeline roundtrip: %s (err=%.6f, max=%.6f)\n",
p.name, RESULT_STR[failed], err, p.max_err);
}
}
}
// Hadamard known output: H([1,0,...,0]) = [1/sqrt(32), ...]
{
float unit[32] = {};
unit[0] = 1.0f;
ggml_hadamard_32_inplace(unit);
const float expected = MXFP_HADAMARD_32_NORM; // 1/sqrt(32)
float max_err = 0.0f;
for (int i = 0; i < 32; i++) {
float err = fabsf(unit[i] - expected);
if (err > max_err) max_err = err;
}
failed = !(max_err < 1e-7f);
num_failed += failed;
if (failed || verbose) {
printf("hadamard unit vector: %s (max_err=%.2e, expected %.8f)\n",
RESULT_STR[failed], max_err, expected);
}
}
// zero block produces E8M0=0
{
float zeros[32] = {};
const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP8, 32);
std::vector<uint8_t> buf(buf_size, 0xFF); // fill with 0xFF to detect non-writes
quantize_row_mxfp8_soa(zeros, buf.data(), 32);
// E8M0 scale is at offset MXFP8_SOA_QS_PER_BLOCK (32) for 1 block
uint8_t e8m0 = buf[MXFP8_SOA_QS_PER_BLOCK];
failed = (e8m0 != 0);
num_failed += failed;
if (failed || verbose) {
printf(" zero block E8M0: %s (e8m0=%d, expected 0)\n",
RESULT_STR[failed], e8m0);
}
}
// SoA format spec: quantize, manually walk raw bytes, compare against reference dequant
{
// 2 blocks, asymmetric data
const int nblocks = 2;
const int nelems = nblocks * 32;
float input[64];
for (int i = 0; i < 64; i++) {
// Block 0: small values, Block 1: large values -different E8M0 scales
input[i] = (i < 32) ? 0.1f * sinf(i + 0.5f) : 3.0f * cosf(i + 0.5f);
}
// MXFP4
{
const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP4, nelems);
std::vector<uint8_t> buf(buf_size);
std::vector<float> ref_out(nelems);
std::vector<float> manual_out(nelems);
quantize_row_mxfp4_soa(input, buf.data(), nelems);
dequantize_row_mxfp4_soa(buf.data(), ref_out.data(), nelems);
// manual dequant from raw bytes
const uint8_t * qs = buf.data();
const uint8_t * e8m0 = buf.data() + MXFP_SOA_E8M0_OFFSET(nblocks, MXFP4_SOA_QS_PER_BLOCK);
for (int b = 0; b < nblocks; b++) {
const float d = ggml_mxfp_e8m0_to_fp32_half(e8m0[b]);
const uint8_t * block_qs = qs + MXFP_SOA_QS_OFFSET(b, MXFP4_SOA_QS_PER_BLOCK);
for (int j = 0; j < 16; j++) {
// low nibble = first half, high nibble = second half
int8_t v_lo = kvalues_mxfp4[block_qs[j] & 0x0F];
int8_t v_hi = kvalues_mxfp4[block_qs[j] >> 4];
manual_out[b*32 + j] = v_lo * d;
manual_out[b*32 + j + 16] = v_hi * d;
}
}
int mismatches = 0;
for (int i = 0; i < nelems; i++) {
uint32_t a, b;
memcpy(&a, &ref_out[i], 4);
memcpy(&b, &manual_out[i], 4);
if (a != b) mismatches++;
}
failed = (mismatches > 0);
num_failed += failed;
if (failed || verbose) {
printf("mxfp4 SoA format spec: %s (%d/%d match)\n",
RESULT_STR[failed], nelems - mismatches, nelems);
}
}
// MXFP8
{
const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP8, nelems);
std::vector<uint8_t> buf(buf_size);
std::vector<float> ref_out(nelems);
std::vector<float> manual_out(nelems);
quantize_row_mxfp8_soa(input, buf.data(), nelems);
dequantize_row_mxfp8_soa(buf.data(), ref_out.data(), nelems);
const uint8_t * qs = buf.data();
const uint8_t * e8m0 = buf.data() + MXFP_SOA_E8M0_OFFSET(nblocks, MXFP8_SOA_QS_PER_BLOCK);
for (int b = 0; b < nblocks; b++) {
const float d = ggml_mxfp_e8m0_to_fp32(e8m0[b]);
const uint8_t * block_qs = qs + MXFP_SOA_QS_OFFSET(b, MXFP8_SOA_QS_PER_BLOCK);
for (int j = 0; j < 32; j++) {
// one byte per element
manual_out[b*32 + j] = fp8_e4m3_to_float(block_qs[j]) * d;
}
}
int mismatches = 0;
for (int i = 0; i < nelems; i++) {
uint32_t a, b;
memcpy(&a, &ref_out[i], 4);
memcpy(&b, &manual_out[i], 4);
if (a != b) mismatches++;
}
failed = (mismatches > 0);
num_failed += failed;
if (failed || verbose) {
printf("mxfp8 SoA format spec: %s (%d/%d match)\n",
RESULT_STR[failed], nelems - mismatches, nelems);
}
}
// MXFP6
{
const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP6, nelems);
std::vector<uint8_t> buf(buf_size);
std::vector<float> ref_out(nelems);
std::vector<float> manual_out(nelems);
quantize_row_mxfp6_soa(input, buf.data(), nelems);
dequantize_row_mxfp6_soa(buf.data(), ref_out.data(), nelems);
const uint8_t * qs = buf.data();
const uint8_t * e8m0 = buf.data() + MXFP_SOA_E8M0_OFFSET(nblocks, MXFP6_SOA_QS_PER_BLOCK);
for (int b = 0; b < nblocks; b++) {
const float d = ggml_mxfp_e8m0_to_fp32(e8m0[b]);
const uint8_t * block_qs = qs + MXFP_SOA_QS_OFFSET(b, MXFP6_SOA_QS_PER_BLOCK);
for (int j = 0; j < 32; j += 4) {
// 4 elements packed into 3 bytes
uint8_t vals[4];
unpack_fp6x4(&block_qs[j * 3 / 4], vals);
for (int k = 0; k < 4; k++) {
manual_out[b*32 + j + k] = fp6_e2m3_to_float(vals[k]) * d;
}
}
}
int mismatches = 0;
for (int i = 0; i < nelems; i++) {
uint32_t a, b;
memcpy(&a, &ref_out[i], 4);
memcpy(&b, &manual_out[i], 4);
if (a != b) mismatches++;
}
failed = (mismatches > 0);
num_failed += failed;
if (failed || verbose) {
printf("mxfp6 SoA format spec: %s (%d/%d match)\n",
RESULT_STR[failed], nelems - mismatches, nelems);
}
}
}
if (num_failed || verbose) {
printf("%d tests failed\n", num_failed);
}

View File

@ -483,7 +483,15 @@ static ggml_type ggml_type_from_name(const std::string & s) {
if (s == "iq4_nl") {
return GGML_TYPE_IQ4_NL;
}
if (s == "mxfp4" || s == "mxfp4_e2m1") {
return GGML_TYPE_MXFP4;
}
if (s == "mxfp8" || s == "mxfp8_e4m3") {
return GGML_TYPE_MXFP8;
}
if (s == "mxfp6" || s == "mxfp6_e2m3") {
return GGML_TYPE_MXFP6;
}
return GGML_TYPE_COUNT;
}