ggml: MXFP flash attention with SoA layout (CPU scalar reference)

Add MXFP KV cache quantization for flash attention using Struct-of-Arrays
(SoA) memory layout exclusively. Three MX types: MXFP4 (E2M1), MXFP8
(E4M3), MXFP6 (E2M3), implementing the OCP Microscaling v1.0 spec.

SoA layout stores [qs contiguous][e8m0 contiguous] per row, enabling
aligned memory access patterns for GPU backends. All functions in the
flash attention pipeline — set_rows quantization, Q preprocessing, K/V
dequantization — use SoA end-to-end. The existing AoS block layout
remains for MUL_MAT weight quantization (untouched).

Q preprocessing applies Walsh-Hadamard rotation (block-32) before
quantize/dequant round-trip, distributing outlier energy across the
shared exponent group. This is essential for perplexity:
  MXFP8: +0.22 PPL without rotation
  MXFP6: +3.34 PPL without rotation
Hadamard is skipped for MLA models (DK != DV) where V is a view of K.

Shared infrastructure in ggml-common.h:
- Block structures (block_mxfp8: 33B, block_mxfp6: 25B per 32 elements)
- E8M0 MSE-optimal scale search with ±1 range
- Canonical element converters (FP8 E4M3/E5M2, FP6 E2M3/E3M2)
- FP6 tight packing (4 six-bit values in 3 bytes, 25% savings)
- IEEE-754 bit reconstruction constants for SIMD backends
- SoA layout macros, portable bit cast, type property queries

CPU implementation:
- Scalar reference + ARM NEON + x86 AVX2 optimized paths
- Both FA paths supported: one_chunk (scalar) and tiled (SIMD GEMM)
- Split-KV path extended for single-query decode
- Generic vec_dot via dequant-to-float for MUL_MAT compatibility
- Arch fallbacks for loongarch, powerpc, riscv, s390, wasm

KV cache integration:
- set_rows writes SoA with optional Hadamard (op_params[0] flag)
- K cache block-aligned to 16 for CUDA cp.async compatibility
- CLI: --cache-type-k/v with short aliases (mxfp4, mxfp6, mxfp8)

Tests:
- Flash attention: all 3 types at D=64/128, mixed K/V (mxfp8+mxfp4)
- SET_ROWS: Hadamard rotation for all types
- SoA-aware test initialization and comparison for MXFP tensors
- Quantize functions coverage for all types

Rename GGML_TYPE_MXFP4 → GGML_TYPE_MXFP4_E2M1 across all backends
(CPU, OpenCL, SYCL) for consistency with the MX type family naming.
This commit is contained in:
Tim Burke 2026-03-15 15:29:13 -04:00
parent b91d7dfe5b
commit d8c9f9c7f6
28 changed files with 3002 additions and 151 deletions

View File

@ -398,9 +398,23 @@ const std::vector<ggml_type> kv_cache_types = {
GGML_TYPE_IQ4_NL,
GGML_TYPE_Q5_0,
GGML_TYPE_Q5_1,
GGML_TYPE_MXFP4_E2M1,
GGML_TYPE_MXFP8_E4M3,
GGML_TYPE_MXFP6_E2M3,
};
static ggml_type kv_cache_type_from_str(const std::string & s) {
// Short aliases: "mxfp4" → E2M1, "mxfp6" → E2M3, "mxfp8" → E4M3.
// Full names (mxfp4_e2m1, mxfp8_e4m3, mxfp6_e2m3, etc.) match via ggml_type_name() below.
if (s == "mxfp4") {
return GGML_TYPE_MXFP4_E2M1;
}
if (s == "mxfp6") {
return GGML_TYPE_MXFP6_E2M3;
}
if (s == "mxfp8") {
return GGML_TYPE_MXFP8_E4M3;
}
for (const auto & type : kv_cache_types) {
if (ggml_type_name(type) == s) {
return type;

View File

@ -115,6 +115,7 @@ extern "C" {
struct ggml_type_traits_cpu {
ggml_from_float_t from_float;
ggml_to_float_t to_float; // SIMD-optimized dequant (NULL = use global to_float)
ggml_vec_dot_t vec_dot;
enum ggml_type vec_dot_type;
int64_t nrows; // number of rows to process simultaneously

View File

@ -426,9 +426,11 @@ extern "C" {
// GGML_TYPE_IQ4_NL_4_4 = 36,
// GGML_TYPE_IQ4_NL_4_8 = 37,
// GGML_TYPE_IQ4_NL_8_8 = 38,
GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale)
GGML_TYPE_COUNT = 41,
GGML_TYPE_MXFP4_E2M1 = 39, // MX FP4 E2M1
GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale)
GGML_TYPE_MXFP8_E4M3 = 41, // MX FP8 E4M3
GGML_TYPE_MXFP6_E2M3 = 42, // MX FP6 E2M3
GGML_TYPE_COUNT = 43,
};
// precision
@ -463,7 +465,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors
GGML_FTYPE_MOSTLY_MXFP4_E2M1 = 25, // except 1d tensors
GGML_FTYPE_MOSTLY_NVFP4 = 26, // except 1d tensors
};
@ -744,6 +746,9 @@ extern "C" {
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
GGML_API bool ggml_is_quantized(enum ggml_type type);
GGML_API bool ggml_is_type_mxfp(enum ggml_type type);
GGML_API bool ggml_mxfp_use_hadamard(enum ggml_type type);
GGML_API int ggml_mxfp_qs_per_block(enum ggml_type type); // quantized bytes per 32-element block (SoA qs region)
// TODO: temporary until model loading of ggml examples is refactored
GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);

View File

@ -71,6 +71,9 @@ typedef sycl::half2 ggml_half2;
#define GGML_COMMON_DECL
#endif
// Pure numeric constants needed by both DECL and IMPL sections.
#define MXFP_HADAMARD_32_NORM 0.17677669529663689f // 1/sqrt(32)
#if defined(GGML_COMMON_DECL)
#ifndef __cplusplus
@ -105,6 +108,12 @@ typedef sycl::half2 ggml_half2;
#define QI_NVFP4 (QK_NVFP4 / (4 * QR_NVFP4))
#define QR_NVFP4 2
#define QI_MXFP8 (QK_MXFP8 / (4 * QR_MXFP8))
#define QR_MXFP8 1
#define QI_MXFP6 (QK_MXFP6 / (4 * QR_MXFP6))
#define QR_MXFP6 1
#define QI5_0 (QK5_0 / (4 * QR5_0))
#define QR5_0 2
@ -190,6 +199,74 @@ typedef struct {
} block_q4_1;
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
// MXFP E8M0 shared exponent constants (OCP MX v1.0 §5.3).
// EMAX_OFFSET: ceil(log2(max_finite)) for each element type — used to center the E8M0 scale.
// MSE_RANGE: search radius around round(log2(amax)). Tests 2*range+1 candidate exponents,
// picking the one that minimizes total round-trip quantization error per block.
// Inspired by "Four Over Six" (arXiv:2512.02010); generalized to all MX types.
#define MXFP_E8M0_MSE_RANGE 2
#define MXFP4_E2M1_EMAX_OFFSET 2 // ceil(log2(6.0))
#define MXFP6_E2M3_EMAX_OFFSET 3 // ceil(log2(7.5))
#define MXFP6_E3M2_EMAX_OFFSET 5 // ceil(log2(28.0))
#define MXFP8_E4M3_EMAX_OFFSET 8 // ceil(log2(448))
#define MXFP8_E5M2_EMAX_OFFSET 16 // ceil(log2(57344))
// MXFP type properties — single source of truth for all backends.
// Bits per element, quantized bytes per block, and Hadamard rotation flag.
// USE_HADAMARD: 1 for types with >= 3-bit mantissa (E2M1, E4M3, E2M3).
// 0 for 2-bit mantissa types (E5M2, E3M2) where Hadamard provides
// no quality benefit and hurts models with D_head ≤ 64.
#define MXFP_BITS_PER_ELEM_E2M1 4
#define MXFP_BITS_PER_ELEM_E4M3 8
#define MXFP_BITS_PER_ELEM_E5M2 8
#define MXFP_BITS_PER_ELEM_E2M3 6
#define MXFP_BITS_PER_ELEM_E3M2 6
#define MXFP_QS_PER_BLOCK_E2M1 16 // 32 * 4 / 8
#define MXFP_QS_PER_BLOCK_E4M3 32 // 32 * 8 / 8
#define MXFP_QS_PER_BLOCK_E5M2 32
#define MXFP_QS_PER_BLOCK_E2M3 24 // 32 * 6 / 8
#define MXFP_QS_PER_BLOCK_E3M2 24
#define MXFP_USE_HADAMARD_E2M1 1
#define MXFP_USE_HADAMARD_E4M3 1
#define MXFP_USE_HADAMARD_E5M2 0
#define MXFP_USE_HADAMARD_E2M3 1
#define MXFP_USE_HADAMARD_E3M2 0
// SIMD dequant constants for IEEE-754 bit reconstruction of FP8/FP6 elements.
// For a format with sign(1), exp(E), mant(M), bias(B):
// EXP_MASK = (1<<E)-1 MANT_MASK = (1<<M)-1 EXP_SHIFT = M
// IEEE_EXP_OFF = 127-B MANT_SHIFT = 23-M SUB_SCALE = 2^(1-B-M)
// Used by x86 AVX2 and ARM NEON vectorized dequant in dot product, AoS dequant, SoA dequant.
#define MXFP8_E4M3_EXP_MASK 0xF // (1<<4)-1
#define MXFP8_E4M3_MANT_MASK 0x7 // (1<<3)-1
#define MXFP8_E4M3_EXP_SHIFT 3
#define MXFP8_E4M3_IEEE_EXP_OFF 120 // 127-7
#define MXFP8_E4M3_MANT_SHIFT 20 // 23-3
#define MXFP8_E4M3_SUB_SCALE (1.0f/512.0f) // 2^(-9) = 2^(1-7-3)
#define MXFP8_E5M2_EXP_MASK 0x1F // (1<<5)-1
#define MXFP8_E5M2_MANT_MASK 0x3 // (1<<2)-1
#define MXFP8_E5M2_EXP_SHIFT 2
#define MXFP8_E5M2_IEEE_EXP_OFF 112 // 127-15
#define MXFP8_E5M2_MANT_SHIFT 21 // 23-2
#define MXFP8_E5M2_SUB_SCALE (1.0f/65536.0f) // 2^(-16) = 2^(1-15-2)
#define MXFP6_E2M3_EXP_MASK 0x3 // (1<<2)-1
#define MXFP6_E2M3_MANT_MASK 0x7 // (1<<3)-1
#define MXFP6_E2M3_EXP_SHIFT 3
#define MXFP6_E2M3_IEEE_EXP_OFF 126 // 127-1
#define MXFP6_E2M3_MANT_SHIFT 20 // 23-3
#define MXFP6_E2M3_SUB_SCALE (1.0f/8.0f) // 2^(-3) = 2^(1-1-3)
#define MXFP6_E3M2_EXP_MASK 0x7 // (1<<3)-1
#define MXFP6_E3M2_MANT_MASK 0x3 // (1<<2)-1
#define MXFP6_E3M2_EXP_SHIFT 2
#define MXFP6_E3M2_IEEE_EXP_OFF 124 // 127-3
#define MXFP6_E3M2_MANT_SHIFT 21 // 23-2
#define MXFP6_E3M2_SUB_SCALE (1.0f/16.0f) // 2^(-4) = 2^(1-3-2)
#define QK_MXFP4 32
typedef struct {
uint8_t e; // E8M0
@ -205,6 +282,34 @@ typedef struct {
} block_nvfp4;
static_assert(sizeof(block_nvfp4) == sizeof(uint8_t)*(QK_NVFP4/QK_NVFP4_SUB) + QK_NVFP4/2, "wrong nvfp4 block size/padding");
#define QK_MXFP8 32
typedef struct {
uint8_t e; // E8M0 shared exponent
uint8_t qs[QK_MXFP8]; // 32 FP8 values (1 byte each), used for E4M3 and E5M2
} block_mxfp8;
static_assert(sizeof(block_mxfp8) == sizeof(uint8_t) + QK_MXFP8, "wrong mxfp8 block size/padding");
#define QK_MXFP6 32
typedef struct {
uint8_t e; // E8M0 shared exponent
uint8_t qs[QK_MXFP6 * 6 / 8]; // 24 bytes: 32 six-bit values tightly packed, used for E2M3 and E3M2
} block_mxfp6;
static_assert(sizeof(block_mxfp6) == sizeof(uint8_t) + QK_MXFP6 * 6 / 8, "wrong mxfp6 block size/padding");
// SoA (Struct-of-Arrays) layout constants for MXFP KV cache.
// Per row: [qs_block0|qs_block1|...][e8m0_0|e8m0_1|...]
// Total bytes per row is IDENTICAL to AoS — same tensor strides, just rearranged.
// Aliases for the canonical MXFP_QS_PER_BLOCK_* defines above.
#define MXFP4_SOA_QS_PER_BLOCK MXFP_QS_PER_BLOCK_E2M1 // 16 bytes
#define MXFP8_SOA_QS_PER_BLOCK MXFP_QS_PER_BLOCK_E4M3 // 32 bytes
#define MXFP6_SOA_QS_PER_BLOCK MXFP_QS_PER_BLOCK_E2M3 // 24 bytes
// SoA offset helpers — single source of truth for the SoA memory layout contract.
// qs region: blocks 0..nblocks-1 at contiguous qs_per_block-byte strides.
// e8m0 region: starts immediately after all qs blocks.
#define MXFP_SOA_QS_OFFSET(block_idx, qs_per_block) ((block_idx) * (qs_per_block))
#define MXFP_SOA_E8M0_OFFSET(nblocks, qs_per_block) ((nblocks) * (qs_per_block))
#define QK5_0 32
typedef struct {
ggml_half d; // delta
@ -447,16 +552,36 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_
#if defined(GGML_COMMON_IMPL_C)
#include <stdint.h>
#include <string.h>
#include <math.h>
#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
#define GGML_TABLE_END() };
#define GGML_MXFP_FUNC static inline
static inline uint32_t ggml_mxfp_f32_as_u32_(float f) { uint32_t u; memcpy(&u, &f, sizeof(u)); return u; }
static inline float ggml_mxfp_u32_as_f32_(uint32_t u) { float f; memcpy(&f, &u, sizeof(f)); return f; }
#define GGML_MXFP_F32_AS_U32(f) ggml_mxfp_f32_as_u32_(f)
#define GGML_MXFP_U32_AS_F32(u) ggml_mxfp_u32_as_f32_(u)
#define GGML_MXFP_LDEXPF(x, n) ldexpf(x, n)
#define GGML_MXFP_THREAD
#define GGML_MXFP_UNROLL
#define GGML_COMMON_IMPL
#elif defined(GGML_COMMON_IMPL_CPP)
#include <cstdint>
#include <cstring>
#include <cmath>
#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
#define GGML_TABLE_END() };
#define GGML_MXFP_FUNC static inline
static inline uint32_t ggml_mxfp_f32_as_u32_(float f) { uint32_t u; memcpy(&u, &f, sizeof(u)); return u; }
static inline float ggml_mxfp_u32_as_f32_(uint32_t u) { float f; memcpy(&f, &u, sizeof(f)); return f; }
#define GGML_MXFP_F32_AS_U32(f) ggml_mxfp_f32_as_u32_(f)
#define GGML_MXFP_U32_AS_F32(u) ggml_mxfp_u32_as_f32_(u)
#define GGML_MXFP_LDEXPF(x, n) ldexpf(x, n)
#define GGML_MXFP_THREAD
#define GGML_MXFP_UNROLL
#define GGML_COMMON_IMPL
#elif defined(GGML_COMMON_IMPL_METAL)
@ -464,21 +589,44 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_
#define GGML_TABLE_BEGIN(type, name, size) static const constant type name[size] = {
#define GGML_TABLE_END() };
#define GGML_MXFP_FUNC static inline
#define GGML_MXFP_F32_AS_U32(f) as_type<uint32_t>(f)
#define GGML_MXFP_U32_AS_F32(u) as_type<float>(u)
#define GGML_MXFP_LDEXPF(x, n) metal::ldexp(x, n)
#define GGML_MXFP_THREAD thread
#define GGML_MXFP_UNROLL _Pragma("unroll")
#define GGML_COMMON_IMPL
#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) || defined(GGML_COMMON_IMPL_MUSA)
#include <cstdint>
#include <cstring>
#define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = {
#define GGML_TABLE_END() };
#define GGML_MXFP_FUNC static __device__ __forceinline__
#define GGML_MXFP_F32_AS_U32(f) __float_as_uint(f)
#define GGML_MXFP_U32_AS_F32(u) __uint_as_float(u)
#define GGML_MXFP_LDEXPF(x, n) ldexpf(x, n)
#define GGML_MXFP_THREAD
#define GGML_MXFP_UNROLL _Pragma("unroll")
#define GGML_COMMON_IMPL
#elif defined(GGML_COMMON_IMPL_SYCL)
#include <cstdint>
#include <cstring>
#include <cmath>
#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
#define GGML_TABLE_END() };
#define GGML_MXFP_FUNC static inline
static inline uint32_t ggml_mxfp_f32_as_u32_(float f) { uint32_t u; memcpy(&u, &f, sizeof(u)); return u; }
static inline float ggml_mxfp_u32_as_f32_(uint32_t u) { float f; memcpy(&f, &u, sizeof(f)); return f; }
#define GGML_MXFP_F32_AS_U32(f) ggml_mxfp_f32_as_u32_(f)
#define GGML_MXFP_U32_AS_F32(u) ggml_mxfp_u32_as_f32_(u)
#define GGML_MXFP_LDEXPF(x, n) ldexpf(x, n)
#define GGML_MXFP_THREAD
#define GGML_MXFP_UNROLL
#define GGML_COMMON_IMPL
#endif
@ -1100,12 +1248,415 @@ GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
GGML_TABLE_END()
// e2m1 values (doubled)
// Canonical E2M1 values (true FP4 magnitudes).
// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
GGML_TABLE_BEGIN(float, kvalues_mxfp4_float, 16)
0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f,
-0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f,
GGML_TABLE_END()
// E2M1 values doubled (implementation detail for CPU/CUDA integer arithmetic).
// Used with GGML_E8M0_TO_FP32_HALF(e) = scale/2 so that int8 × half_scale = true value.
// Canonical values are in kvalues_mxfp4_float above.
GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
GGML_TABLE_END()
// FP6 E2M3 dequantization LUT: 6-bit value → float (64 entries).
// Generated from ggml_mxfp_fp6_e2m3_to_float(). Indices 0-31 positive, 32-63 negative.
GGML_TABLE_BEGIN(float, kvalues_mxfp6_e2m3, 64)
0.0f, 0.125f, 0.25f, 0.375f, 0.5f, 0.625f, 0.75f, 0.875f,
1.0f, 1.125f, 1.25f, 1.375f, 1.5f, 1.625f, 1.75f, 1.875f,
2.0f, 2.25f, 2.5f, 2.75f, 3.0f, 3.25f, 3.5f, 3.75f,
4.0f, 4.5f, 5.0f, 5.5f, 6.0f, 6.5f, 7.0f, 7.5f,
-0.0f, -0.125f, -0.25f, -0.375f, -0.5f, -0.625f, -0.75f, -0.875f,
-1.0f, -1.125f, -1.25f, -1.375f, -1.5f, -1.625f, -1.75f, -1.875f,
-2.0f, -2.25f, -2.5f, -2.75f, -3.0f, -3.25f, -3.5f, -3.75f,
-4.0f, -4.5f, -5.0f, -5.5f, -6.0f, -6.5f, -7.0f, -7.5f,
GGML_TABLE_END()
// FP6 E3M2 dequantization LUT: 6-bit value → float (64 entries).
// Generated from ggml_mxfp_fp6_e3m2_to_float(). No NaN/Inf — all bit patterns are valid.
GGML_TABLE_BEGIN(float, kvalues_mxfp6_e3m2, 64)
0.0f, 0.0625f, 0.125f, 0.1875f, 0.25f, 0.3125f, 0.375f, 0.4375f,
0.5f, 0.625f, 0.75f, 0.875f, 1.0f, 1.25f, 1.5f, 1.75f,
2.0f, 2.5f, 3.0f, 3.5f, 4.0f, 5.0f, 6.0f, 7.0f,
8.0f, 10.0f, 12.0f, 14.0f, 16.0f, 20.0f, 24.0f, 28.0f,
-0.0f, -0.0625f, -0.125f,-0.1875f, -0.25f,-0.3125f, -0.375f,-0.4375f,
-0.5f, -0.625f, -0.75f, -0.875f, -1.0f, -1.25f, -1.5f, -1.75f,
-2.0f, -2.5f, -3.0f, -3.5f, -4.0f, -5.0f, -6.0f, -7.0f,
-8.0f, -10.0f, -12.0f, -14.0f, -16.0f, -20.0f, -24.0f, -28.0f,
GGML_TABLE_END()
// FP8 E4M3 dequantization LUT: byte → float (256 entries).
// Generated from ggml_mxfp_fp8_e4m3_to_float(). Entry 127 = 448 (max finite), 255 = NaN.
GGML_TABLE_BEGIN(float, kvalues_mxfp8_e4m3, 256)
0.0f, 0.001953125f, 0.00390625f, 0.005859375f, 0.0078125f, 0.009765625f, 0.01171875f, 0.013671875f,
0.015625f, 0.017578125f, 0.01953125f, 0.021484375f, 0.0234375f, 0.025390625f, 0.02734375f, 0.029296875f,
0.03125f, 0.03515625f, 0.0390625f, 0.04296875f, 0.046875f, 0.05078125f, 0.0546875f, 0.05859375f,
0.0625f, 0.0703125f, 0.078125f, 0.0859375f, 0.09375f, 0.1015625f, 0.109375f, 0.1171875f,
0.125f, 0.140625f, 0.15625f, 0.171875f, 0.1875f, 0.203125f, 0.21875f, 0.234375f,
0.25f, 0.28125f, 0.3125f, 0.34375f, 0.375f, 0.40625f, 0.4375f, 0.46875f,
0.5f, 0.5625f, 0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f, 0.9375f,
1.0f, 1.125f, 1.25f, 1.375f, 1.5f, 1.625f, 1.75f, 1.875f,
2.0f, 2.25f, 2.5f, 2.75f, 3.0f, 3.25f, 3.5f, 3.75f,
4.0f, 4.5f, 5.0f, 5.5f, 6.0f, 6.5f, 7.0f, 7.5f,
8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
16.0f, 18.0f, 20.0f, 22.0f, 24.0f, 26.0f, 28.0f, 30.0f,
32.0f, 36.0f, 40.0f, 44.0f, 48.0f, 52.0f, 56.0f, 60.0f,
64.0f, 72.0f, 80.0f, 88.0f, 96.0f, 104.0f, 112.0f, 120.0f,
128.0f, 144.0f, 160.0f, 176.0f, 192.0f, 208.0f, 224.0f, 240.0f,
256.0f, 288.0f, 320.0f, 352.0f, 384.0f, 416.0f, 448.0f, NAN,
-0.0f,-0.001953125f, -0.00390625f,-0.005859375f, -0.0078125f,-0.009765625f, -0.01171875f,-0.013671875f,
-0.015625f,-0.017578125f, -0.01953125f,-0.021484375f, -0.0234375f,-0.025390625f, -0.02734375f,-0.029296875f,
-0.03125f, -0.03515625f, -0.0390625f, -0.04296875f, -0.046875f, -0.05078125f, -0.0546875f, -0.05859375f,
-0.0625f, -0.0703125f, -0.078125f, -0.0859375f, -0.09375f, -0.1015625f, -0.109375f, -0.1171875f,
-0.125f, -0.140625f, -0.15625f, -0.171875f, -0.1875f, -0.203125f, -0.21875f, -0.234375f,
-0.25f, -0.28125f, -0.3125f, -0.34375f, -0.375f, -0.40625f, -0.4375f, -0.46875f,
-0.5f, -0.5625f, -0.625f, -0.6875f, -0.75f, -0.8125f, -0.875f, -0.9375f,
-1.0f, -1.125f, -1.25f, -1.375f, -1.5f, -1.625f, -1.75f, -1.875f,
-2.0f, -2.25f, -2.5f, -2.75f, -3.0f, -3.25f, -3.5f, -3.75f,
-4.0f, -4.5f, -5.0f, -5.5f, -6.0f, -6.5f, -7.0f, -7.5f,
-8.0f, -9.0f, -10.0f, -11.0f, -12.0f, -13.0f, -14.0f, -15.0f,
-16.0f, -18.0f, -20.0f, -22.0f, -24.0f, -26.0f, -28.0f, -30.0f,
-32.0f, -36.0f, -40.0f, -44.0f, -48.0f, -52.0f, -56.0f, -60.0f,
-64.0f, -72.0f, -80.0f, -88.0f, -96.0f, -104.0f, -112.0f, -120.0f,
-128.0f, -144.0f, -160.0f, -176.0f, -192.0f, -208.0f, -224.0f, -240.0f,
-256.0f, -288.0f, -320.0f, -352.0f, -384.0f, -416.0f, -448.0f, NAN,
GGML_TABLE_END()
// FP8 E5M2 dequantization LUT: byte → float (256 entries).
// Generated from ggml_mxfp_fp8_e5m2_to_float(). Entries 124-127 = {Inf, NaN, NaN, NaN}.
GGML_TABLE_BEGIN(float, kvalues_mxfp8_e5m2, 256)
0.0f, 1.525879e-05f, 3.051758e-05f, 4.577637e-05f, 6.103516e-05f, 7.629395e-05f, 9.155273e-05f, 1.068115e-04f,
1.220703e-04f, 1.525879e-04f, 1.831055e-04f, 2.136230e-04f, 2.441406e-04f, 3.051758e-04f, 3.662109e-04f, 4.272461e-04f,
4.882812e-04f, 6.103516e-04f, 7.324219e-04f, 8.544922e-04f, 9.765625e-04f, 1.220703e-03f, 1.464844e-03f, 1.708984e-03f,
1.953125e-03f, 2.441406e-03f, 2.929688e-03f, 3.417969e-03f, 3.906250e-03f, 4.882812e-03f, 5.859375e-03f, 6.835938e-03f,
7.812500e-03f, 9.765625e-03f, 1.171875e-02f, 1.367188e-02f, 1.562500e-02f, 1.953125e-02f, 2.343750e-02f, 2.734375e-02f,
3.125000e-02f, 3.906250e-02f, 4.687500e-02f, 5.468750e-02f, 6.250000e-02f, 7.812500e-02f, 9.375000e-02f, 1.093750e-01f,
0.125f, 0.15625f, 0.1875f, 0.21875f, 0.25f, 0.3125f, 0.375f, 0.4375f,
0.5f, 0.625f, 0.75f, 0.875f, 1.0f, 1.25f, 1.5f, 1.75f,
2.0f, 2.5f, 3.0f, 3.5f, 4.0f, 5.0f, 6.0f, 7.0f,
8.0f, 10.0f, 12.0f, 14.0f, 16.0f, 20.0f, 24.0f, 28.0f,
32.0f, 40.0f, 48.0f, 56.0f, 64.0f, 80.0f, 96.0f, 112.0f,
128.0f, 160.0f, 192.0f, 224.0f, 256.0f, 320.0f, 384.0f, 448.0f,
512.0f, 640.0f, 768.0f, 896.0f, 1024.0f, 1280.0f, 1536.0f, 1792.0f,
2048.0f, 2560.0f, 3072.0f, 3584.0f, 4096.0f, 5120.0f, 6144.0f, 7168.0f,
8192.0f, 10240.0f, 12288.0f, 14336.0f, 16384.0f, 20480.0f, 24576.0f, 28672.0f,
32768.0f, 40960.0f, 49152.0f, 57344.0f, INFINITY, NAN, NAN, NAN,
-0.0f,-1.525879e-05f,-3.051758e-05f,-4.577637e-05f,-6.103516e-05f,-7.629395e-05f,-9.155273e-05f,-1.068115e-04f,
-1.220703e-04f,-1.525879e-04f,-1.831055e-04f,-2.136230e-04f,-2.441406e-04f,-3.051758e-04f,-3.662109e-04f,-4.272461e-04f,
-4.882812e-04f,-6.103516e-04f,-7.324219e-04f,-8.544922e-04f,-9.765625e-04f,-1.220703e-03f,-1.464844e-03f,-1.708984e-03f,
-1.953125e-03f,-2.441406e-03f,-2.929688e-03f,-3.417969e-03f,-3.906250e-03f,-4.882812e-03f,-5.859375e-03f,-6.835938e-03f,
-7.812500e-03f,-9.765625e-03f,-1.171875e-02f,-1.367188e-02f,-1.562500e-02f,-1.953125e-02f,-2.343750e-02f,-2.734375e-02f,
-3.125000e-02f,-3.906250e-02f,-4.687500e-02f,-5.468750e-02f,-6.250000e-02f,-7.812500e-02f,-9.375000e-02f,-1.093750e-01f,
-0.125f, -0.15625f, -0.1875f, -0.21875f, -0.25f, -0.3125f, -0.375f, -0.4375f,
-0.5f, -0.625f, -0.75f, -0.875f, -1.0f, -1.25f, -1.5f, -1.75f,
-2.0f, -2.5f, -3.0f, -3.5f, -4.0f, -5.0f, -6.0f, -7.0f,
-8.0f, -10.0f, -12.0f, -14.0f, -16.0f, -20.0f, -24.0f, -28.0f,
-32.0f, -40.0f, -48.0f, -56.0f, -64.0f, -80.0f, -96.0f, -112.0f,
-128.0f, -160.0f, -192.0f, -224.0f, -256.0f, -320.0f, -384.0f, -448.0f,
-512.0f, -640.0f, -768.0f, -896.0f, -1024.0f, -1280.0f, -1536.0f, -1792.0f,
-2048.0f, -2560.0f, -3072.0f, -3584.0f, -4096.0f, -5120.0f, -6144.0f, -7168.0f,
-8192.0f, -10240.0f, -12288.0f, -14336.0f, -16384.0f, -20480.0f, -24576.0f, -28672.0f,
-32768.0f, -40960.0f, -49152.0f, -57344.0f, -INFINITY, NAN, NAN, NAN,
GGML_TABLE_END()
// ------------------------------------------------------------------------------------------------------------------
// Canonical MXFP element converters — portable IEEE-754 bit manipulation.
// Single source of truth for CPU, CUDA, HIP, MUSA, SYCL. Metal/Vulkan keep MSL/GLSL copies.
// ------------------------------------------------------------------------------------------------------------------
#if defined(GGML_MXFP_FUNC)
// --- FP4 E2M1: [S(1) | E(2) | M(1)] — max normal = 6.0 ---
// Canonical converters using true E2M1 values {0, 0.5, 1, 1.5, 2, 3, 4, 6}.
// The int8 kvalues_mxfp4 LUT stores doubled values {0,1,2,3,4,6,8,12} for
// CPU/CUDA nibble-indexed integer arithmetic — that doubling is an implementation detail.
GGML_MXFP_FUNC float ggml_mxfp_fp4_e2m1_to_float(uint8_t v) {
const float sign = (v & 0x8) ? -1.0f : 1.0f;
const int exp = (v >> 1) & 0x3;
const int mant = v & 0x1;
if (exp == 0) return sign * (float)mant * 0.5f;
return sign * (1.0f + mant * 0.5f) * (float)(1 << (exp - 1));
}
GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp4_e2m1(float x) {
uint8_t sign = 0;
if (x < 0) { sign = 0x8; x = -x; }
if (x == 0) return sign;
if (x >= 6.0f) return sign | 0x7; // max finite
// Decision boundaries (midpoints of adjacent canonical values):
// {0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0}
if (x < 0.25f) return sign | 0x0; // 0
else if (x < 0.75f) return sign | 0x1; // 0.5
else if (x < 1.25f) return sign | 0x2; // 1.0
else if (x < 1.75f) return sign | 0x3; // 1.5
else if (x < 2.5f) return sign | 0x4; // 2.0
else if (x < 3.5f) return sign | 0x5; // 3.0
else if (x < 5.0f) return sign | 0x6; // 4.0
else return sign | 0x7; // 6.0
}
// --- FP6 E2M3: [S(1) | E(2) | M(3)] — max normal = 7.5 ---
GGML_MXFP_FUNC float ggml_mxfp_fp6_e2m3_to_float(uint8_t v) {
const float sign = (v & 0x20) ? -1.0f : 1.0f;
const int exp = (v >> 3) & 0x3;
const int mant = v & 0x7;
if (exp == 0) return sign * (float)mant * 0.125f;
return sign * (1.0f + mant * 0.125f) * (float)(1 << (exp - 1));
}
GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp6_e2m3(float x) {
uint8_t sign = 0;
if (x < 0) { sign = 0x20; x = -x; }
if (x == 0) return sign;
if (x >= 7.5f) return sign | 0x1F; // max finite
uint32_t bits = GGML_MXFP_F32_AS_U32(x);
int f32_exp = (int)((bits >> 23) & 0xFF) - 127;
if (f32_exp < 0) {
// Subnormal in E2M3: mant * 2^(-3)
float scaled = x * 8.0f;
int mant = (int)(scaled + 0.5f);
if (mant > 7) return sign | 0x08; // smallest normal
return sign | (uint8_t)mant;
}
if (f32_exp > 2) f32_exp = 2;
float mantf = (x / (float)(1 << f32_exp)) - 1.0f;
int mant = (int)(mantf * 8.0f + 0.5f);
if (mant > 7) { mant = 0; f32_exp++; }
if (f32_exp > 2) return sign | 0x1F;
return sign | (uint8_t)(((f32_exp + 1) << 3) | mant);
}
// --- FP6 E3M2: [S(1) | E(3) | M(2)] — max normal = 28.0, no NaN/Inf ---
GGML_MXFP_FUNC float ggml_mxfp_fp6_e3m2_to_float(uint8_t v) {
const float sign = (v & 0x20) ? -1.0f : 1.0f;
const int exp = (v >> 2) & 0x7;
const int mant = v & 0x3;
if (exp == 0) return sign * (float)mant * 0.0625f; // 2^(-4)
// MX E3M2 has no NaN/Inf — exp=7 is a valid normal value (max finite = 28.0).
return sign * GGML_MXFP_LDEXPF(1.0f + mant * 0.25f, exp - 3);
}
GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp6_e3m2(float x) {
uint8_t sign = 0;
if (x < 0) { sign = 0x20; x = -x; }
if (x == 0) return sign;
if (x >= 28.0f) return sign | 0x1F; // max finite
uint32_t bits = GGML_MXFP_F32_AS_U32(x);
int f32_exp = (int)((bits >> 23) & 0xFF) - 127;
int biased_exp = f32_exp + 3;
if (biased_exp <= 0) {
// Subnormal in E3M2: mant * 2^(-4)
float scaled = x * 16.0f;
int mant = (int)(scaled + 0.5f);
if (mant > 3) return sign | 0x04; // smallest normal
return sign | (uint8_t)mant;
}
if (biased_exp > 7) return sign | 0x1F;
float pow2 = (f32_exp >= 0) ? (float)(1 << f32_exp) : 1.0f / (float)(1 << (-f32_exp));
float mantf = (x / pow2) - 1.0f;
int mant = (int)(mantf * 4.0f + 0.5f);
if (mant > 3) { mant = 0; biased_exp++; }
if (biased_exp > 7) return sign | 0x1F;
return sign | (uint8_t)((biased_exp << 2) | mant);
}
// --- FP8 E4M3: [S(1) | E(4) | M(3)] — bias=7, max finite=448 ---
GGML_MXFP_FUNC float ggml_mxfp_fp8_e4m3_to_float(uint8_t v) {
uint32_t sign = ((uint32_t)(v & 0x80)) << 24;
uint32_t exp = (v >> 3) & 0xF;
uint32_t mant = v & 0x7;
if (exp == 0) {
if (mant == 0) return GGML_MXFP_U32_AS_F32(sign);
// Subnormal: mant * 2^(1-7) * 2^(-3) = mant * 2^(-9)
float val = (float)mant * (1.0f / 512.0f);
uint32_t vb = GGML_MXFP_F32_AS_U32(val);
vb = (vb & 0x7FFFFFFFu) | sign;
return GGML_MXFP_U32_AS_F32(vb);
}
if (exp == 15 && mant == 7) {
return GGML_MXFP_U32_AS_F32(sign | 0x7FC00000u);
}
// Normal: (-1)^S * 2^(E-7) * (1 + M/8) → F32 exp = E-7+127 = E+120
return GGML_MXFP_U32_AS_F32(sign | ((exp + 120) << 23) | (mant << 20));
}
GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp8_e4m3(float x) {
uint32_t bits = GGML_MXFP_F32_AS_U32(x);
uint8_t sign = (bits >> 24) & 0x80;
bits &= 0x7FFFFFFFu;
if (bits == 0) return sign;
uint32_t f32_exp = (bits >> 23) & 0xFF;
uint32_t f32_mant = bits & 0x7FFFFF;
int e4m3_exp = (int)f32_exp - 120;
if (e4m3_exp <= 0) {
// Subnormal in E4M3
int shift = 1 - e4m3_exp;
uint32_t full_mant = (1u << 23) | f32_mant;
int total_shift = 20 + shift;
if (total_shift >= 32) return sign;
uint32_t mant3 = full_mant >> total_shift;
if (total_shift > 0 && total_shift < 32) {
uint32_t round_bit = (full_mant >> (total_shift - 1)) & 1;
uint32_t sticky = (total_shift > 1) ? (full_mant & ((1u << (total_shift - 1)) - 1)) : 0;
if (round_bit && (sticky || (mant3 & 1))) mant3++;
}
if (mant3 > 7) return sign | 0x08;
return sign | (uint8_t)mant3;
}
uint32_t round_bit = (f32_mant >> 19) & 1;
uint32_t sticky = f32_mant & ((1u << 19) - 1);
uint32_t mant3 = f32_mant >> 20;
if (round_bit && (sticky || (mant3 & 1))) {
mant3++;
if (mant3 > 7) { mant3 = 0; e4m3_exp++; }
}
if (e4m3_exp > 15 || (e4m3_exp == 15 && mant3 >= 7)) return sign | 0x7E; // max finite
return sign | (uint8_t)((e4m3_exp << 3) | mant3);
}
// --- FP8 E5M2: [S(1) | E(5) | M(2)] — bias=15, max finite=57344 ---
GGML_MXFP_FUNC float ggml_mxfp_fp8_e5m2_to_float(uint8_t v) {
uint32_t sign = ((uint32_t)(v & 0x80)) << 24;
uint32_t exp = (v >> 2) & 0x1F;
uint32_t mant = v & 0x3;
if (exp == 0) {
if (mant == 0) return GGML_MXFP_U32_AS_F32(sign);
// Subnormal: mant * 2^(1-15) * 2^(-2) = mant/4 * 2^(-14)
float val = (float)mant * 0.25f * (1.0f / 16384.0f);
uint32_t vb = GGML_MXFP_F32_AS_U32(val);
vb = (vb & 0x7FFFFFFFu) | sign;
return GGML_MXFP_U32_AS_F32(vb);
}
if (exp == 31) {
return GGML_MXFP_U32_AS_F32(sign | 0x7F800000u | (mant ? 0x400000u : 0));
}
// Normal: F32 exp = E-15+127 = E+112
return GGML_MXFP_U32_AS_F32(sign | ((exp + 112) << 23) | (mant << 21));
}
GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp8_e5m2(float x) {
uint32_t bits = GGML_MXFP_F32_AS_U32(x);
uint8_t sign = (bits >> 24) & 0x80;
bits &= 0x7FFFFFFFu;
if (bits == 0) return sign;
uint32_t f32_exp = (bits >> 23) & 0xFF;
uint32_t f32_mant = bits & 0x7FFFFF;
int e5m2_exp = (int)f32_exp - 112;
if (e5m2_exp <= 0) {
int shift = 1 - e5m2_exp;
uint32_t full_mant = (1u << 23) | f32_mant;
int total_shift = 21 + shift;
if (total_shift >= 32) return sign;
uint32_t mant2 = full_mant >> total_shift;
if (total_shift > 0 && total_shift < 32) {
uint32_t round_bit = (full_mant >> (total_shift - 1)) & 1;
uint32_t sticky = (total_shift > 1) ? (full_mant & ((1u << (total_shift - 1)) - 1)) : 0;
if (round_bit && (sticky || (mant2 & 1))) mant2++;
}
if (mant2 > 3) return sign | 0x04;
return sign | (uint8_t)mant2;
}
uint32_t round_bit = (f32_mant >> 20) & 1;
uint32_t sticky = f32_mant & ((1u << 20) - 1);
uint32_t mant2 = f32_mant >> 21;
if (round_bit && (sticky || (mant2 & 1))) {
mant2++;
if (mant2 > 3) { mant2 = 0; e5m2_exp++; }
}
if (e5m2_exp >= 31) return sign | 0x7B; // max finite
return sign | (uint8_t)((e5m2_exp << 2) | mant2);
}
// --- FP6 packing/unpacking ---
// Pack 4 six-bit values into 3 bytes
GGML_MXFP_FUNC void ggml_mxfp_pack_fp6x4(const uint8_t v[4], uint8_t out[3]) {
uint32_t packed = (v[0] & 0x3F) | ((v[1] & 0x3F) << 6) |
((v[2] & 0x3F) << 12) | ((v[3] & 0x3F) << 18);
out[0] = (uint8_t)(packed);
out[1] = (uint8_t)(packed >> 8);
out[2] = (uint8_t)(packed >> 16);
}
// Unpack 3 bytes into 4 six-bit values
GGML_MXFP_FUNC void ggml_mxfp_unpack_fp6x4(const uint8_t in[3], uint8_t v[4]) {
uint32_t packed = (uint32_t)in[0] | ((uint32_t)in[1] << 8) | ((uint32_t)in[2] << 16);
v[0] = packed & 0x3F;
v[1] = (packed >> 6) & 0x3F;
v[2] = (packed >> 12) & 0x3F;
v[3] = (packed >> 18) & 0x3F;
}
// E8M0 shared exponent → float conversion.
// E8M0 encoding: value = 2^(x - 127) for x > 0, 2^(-127) for x == 0.
GGML_MXFP_FUNC float ggml_mxfp_e8m0_to_fp32(uint8_t x) {
uint32_t bits = (x == 0) ? 0x00400000u : ((uint32_t)x << 23);
return GGML_MXFP_U32_AS_F32(bits);
}
// E8M0 → float/2. Used with MXFP4 since E2M1 values are doubled in kvalues_mxfp4.
GGML_MXFP_FUNC float ggml_mxfp_e8m0_to_fp32_half(uint8_t x) {
uint32_t bits = (x < 2) ? (0x00200000u << x) : ((uint32_t)(x - 1) << 23);
return GGML_MXFP_U32_AS_F32(bits);
}
// E8M0 base exponent estimate: round(log2(amax)) - emax_offset + 127.
// Uses integer bit extraction — no log2f() SFU dependency.
// Caller must ensure amax > 0 and finite. Returns unclamped e_base.
GGML_MXFP_FUNC int ggml_mxfp_e8m0_base_estimate(float amax, int emax_offset) {
uint32_t amax_bits = GGML_MXFP_F32_AS_U32(amax);
const int floor_log2 = (int)((amax_bits >> 23) & 0xFF) - 127;
// Round: add 1 if mantissa >= sqrt(2)-1 (0x3504F3 in 23-bit IEEE mantissa).
const int round_log2 = floor_log2 + ((amax_bits & 0x7FFFFF) >= 0x3504F3 ? 1 : 0);
return round_log2 - emax_offset + 127;
}
// Block-32 Walsh-Hadamard Transform, normalized by 1/sqrt(32).
// Spreads outlier energy across all elements sharing an E8M0 exponent,
// improving quantization quality (see QuaRot arXiv:2404.00456).
GGML_MXFP_FUNC void ggml_mxfp_hadamard_32_inplace(GGML_MXFP_THREAD float * vals) {
GGML_MXFP_UNROLL
for (int stride = 1; stride < 32; stride *= 2) {
GGML_MXFP_UNROLL
for (int i = 0; i < 32; i += 2 * stride) {
GGML_MXFP_UNROLL
for (int j = 0; j < stride; ++j) {
const float a = vals[i + j];
const float b = vals[i + j + stride];
vals[i + j] = a + b;
vals[i + j + stride] = a - b;
}
}
}
GGML_MXFP_UNROLL
for (int i = 0; i < 32; ++i) {
vals[i] *= MXFP_HADAMARD_32_NORM;
}
}
#endif // GGML_MXFP_FUNC
#define NGRID_IQ1S 2048
#define IQ1S_DELTA 0.125f
#define IQ1M_DELTA 0.125f

View File

@ -16,6 +16,8 @@
#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0
#define ggml_vec_dot_mxfp6_e2m3_q8_0_generic ggml_vec_dot_mxfp6_e2m3_q8_0
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
@ -341,3 +343,14 @@
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
#endif
// MXFP dequantize has no arch-specific (SIMD) implementations except on arm and x86.
// All other targets use the scalar generic as the public cpu function.
#if !defined(__aarch64__) && !defined(__arm__) && !defined(_M_ARM) && !defined(_M_ARM64) && \
!defined(__x86_64__) && !defined(__i386__) && !defined(_M_IX86) && !defined(_M_X64)
#define dequantize_row_mxfp8_cpu_generic dequantize_row_mxfp8_cpu
#define dequantize_row_mxfp6_e2m3_cpu_generic dequantize_row_mxfp6_e2m3_cpu
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
#endif

View File

@ -4134,3 +4134,541 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
#endif
}
// NEON-optimized MXFP8 × Q8_0 dot product.
// Dequants FP8 elements to float via IEEE 754 bit construction, then dots against Q8_0.
// Parameters encode the FP8 format: sign_shift, exp_mask, mant_mask, ieee_exp_bias, mant_shift, sub_scale.
#if defined(__ARM_NEON)
static inline void ggml_vec_dot_mxfp8_q8_0_neon(
int n, float * GGML_RESTRICT s,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
// FP8 format parameters:
const uint32_t exp_mask, // 0xF for E4M3, 0x1F for E5M2
const uint32_t mant_mask, // 0x7 for E4M3, 0x3 for E5M2
const int exp_shift, // 3 for E4M3, 2 for E5M2
const uint32_t ieee_exp_off, // 120 for E4M3, 112 for E5M2
const int mant_shift, // 20 for E4M3, 21 for E5M2
const float sub_scale) { // 1/512 for E4M3, 1/65536 for E5M2
assert(n % QK_MXFP8 == 0);
const int nb = n / QK_MXFP8;
const block_mxfp8 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;
float32x4_t acc0 = vdupq_n_f32(0.0f);
float32x4_t acc1 = vdupq_n_f32(0.0f);
const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale);
// Use variable shifts (vshlq_u32) instead of constant shifts (vshlq_n_u32)
// because exp_shift/mant_shift are function parameters, not compile-time constants.
// Clang requires _n_ intrinsics to have literal constant arguments.
const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift);
const int32x4_t v_mant_shift = vdupq_n_s32(mant_shift);
for (int ib = 0; ib < nb; ++ib) {
const float scale = GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d);
const float32x4_t v_scale = vdupq_n_f32(scale);
// Process 32 FP8 elements in 8 groups of 4
for (int j = 0; j < 32; j += 8) {
// Load 8 FP8 bytes, extend to two uint32x4_t
const uint8x8_t raw8 = vld1_u8(x[ib].qs + j);
const uint16x8_t raw16 = vmovl_u8(raw8);
const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16));
const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16));
// Load 8 Q8_0 int8 values, extend to two int32x4_t → float32x4_t
const int8x8_t q8 = vld1_s8(y[ib].qs + j);
const int16x8_t q16 = vmovl_s8(q8);
const float32x4_t qf_lo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(q16)));
const float32x4_t qf_hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16)));
// Dequant FP8 → float for both groups of 4
#define DEQUANT_FP8_NEON(v_raw, qf, acc) do { \
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x80)); \
const uint32x4_t exp = vandq_u32( \
vshlq_u32(v_raw, v_neg_exp_shift), \
v_exp_mask); \
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \
/* Normal: IEEE bits = (exp + offset) << 23 | mant << mant_shift */ \
const uint32x4_t ieee = vorrq_u32( \
vorrq_u32(vshlq_n_u32(sign, 24), \
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \
vshlq_u32(mant, v_mant_shift)); \
const float32x4_t normal = vreinterpretq_f32_u32(ieee); \
/* Subnormal: sign * mant * sub_scale */ \
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \
const uint32x4_t sub_bits = vorrq_u32( \
vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 24)); \
const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \
/* Select: subnormal when exp == 0, else normal */ \
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \
const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \
/* Multiply by scale and Q8 value, accumulate */ \
(acc) = vfmaq_f32((acc), vmulq_f32(val, v_scale), qf); \
} while (0)
DEQUANT_FP8_NEON(v_lo, qf_lo, acc0);
DEQUANT_FP8_NEON(v_hi, qf_hi, acc1);
#undef DEQUANT_FP8_NEON
}
}
*s = vaddvq_f32(vaddq_f32(acc0, acc1));
}
#endif
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
#if defined(__ARM_NEON)
// E4M3: sign(1) exp(4) mant(3), bias=7
ggml_vec_dot_mxfp8_q8_0_neon(n, s, vx, vy,
MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT,
MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE);
#else
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
// NEON-optimized MXFP6 × Q8_0 dot product.
// Unpacks tight 6-bit packing (4 values per 3 bytes), then dequants to float.
#if defined(__ARM_NEON)
static inline void ggml_vec_dot_mxfp6_q8_0_neon(
int n, float * GGML_RESTRICT s,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
size_t block_size,
// FP6 format parameters:
const uint32_t exp_mask, // 0x3 for E2M3, 0x7 for E3M2
const uint32_t mant_mask, // 0x7 for E2M3, 0x3 for E3M2
const int exp_shift, // 3 for E2M3, 2 for E3M2
const uint32_t ieee_exp_off, // 126 for E2M3, 124 for E3M2
const int mant_shift, // 20 for E2M3, 21 for E3M2
const float sub_scale) { // 1/8 for E2M3, 1/16 for E3M2
assert(n % QK_MXFP6 == 0);
const int nb = n / QK_MXFP6;
const block_q8_0 * GGML_RESTRICT y = vy;
float32x4_t acc0 = vdupq_n_f32(0.0f);
float32x4_t acc1 = vdupq_n_f32(0.0f);
const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale);
const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift);
const int32x4_t v_mant_shift = vdupq_n_s32(mant_shift);
for (int ib = 0; ib < nb; ++ib) {
const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size);
const float scale = GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d);
const float32x4_t v_scale = vdupq_n_f32(scale);
// Process 32 FP6 elements: 8 groups of 4, each packed in 3 bytes
for (int j = 0; j < 32; j += 8) {
// Unpack two groups of 4 FP6 values (6 bytes → 8 values)
uint8_t unpacked[8];
// Group 1: 3 bytes → 4 values
{
const uint8_t * p = xb->qs + (j * 3 / 4);
const uint32_t packed = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
unpacked[0] = (packed >> 0) & 0x3F;
unpacked[1] = (packed >> 6) & 0x3F;
unpacked[2] = (packed >> 12) & 0x3F;
unpacked[3] = (packed >> 18) & 0x3F;
}
// Group 2: next 3 bytes → 4 values
{
const uint8_t * p = xb->qs + ((j + 4) * 3 / 4);
const uint32_t packed = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
unpacked[4] = (packed >> 0) & 0x3F;
unpacked[5] = (packed >> 6) & 0x3F;
unpacked[6] = (packed >> 12) & 0x3F;
unpacked[7] = (packed >> 18) & 0x3F;
}
// Extend to uint32x4_t
const uint8x8_t raw8 = vld1_u8(unpacked);
const uint16x8_t raw16 = vmovl_u8(raw8);
const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16));
const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16));
// Load Q8_0 int8 values
const int8x8_t q8 = vld1_s8(y[ib].qs + j);
const int16x8_t q16 = vmovl_s8(q8);
const float32x4_t qf_lo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(q16)));
const float32x4_t qf_hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16)));
// Dequant FP6 → float (same IEEE construction as FP8, sign bit at position 5)
#define DEQUANT_FP6_NEON(v_raw, qf, acc) do { \
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x20)); \
const uint32x4_t exp = vandq_u32( \
vshlq_u32(v_raw, v_neg_exp_shift), \
v_exp_mask); \
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \
const uint32x4_t ieee = vorrq_u32( \
vorrq_u32(vshlq_n_u32(sign, 26), \
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \
vshlq_u32(mant, v_mant_shift)); \
const float32x4_t normal = vreinterpretq_f32_u32(ieee); \
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \
const uint32x4_t sub_bits = vorrq_u32( \
vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 26)); \
const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \
const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \
(acc) = vfmaq_f32((acc), vmulq_f32(val, v_scale), qf); \
} while (0)
DEQUANT_FP6_NEON(v_lo, qf_lo, acc0);
DEQUANT_FP6_NEON(v_hi, qf_hi, acc1);
#undef DEQUANT_FP6_NEON
}
}
*s = vaddvq_f32(vaddq_f32(acc0, acc1));
}
#endif
void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
#if defined(__ARM_NEON)
// E2M3: sign(1) exp(2) mant(3), bias=1
ggml_vec_dot_mxfp6_q8_0_neon(n, s, vx, vy, sizeof(block_mxfp6),
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE);
#else
ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
// ---- MXFP dequantize_row (to_float) — NEON-optimized ----
#if defined(__ARM_NEON)
static inline void dequantize_row_mxfp8_neon(
const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k,
const uint32_t exp_mask, const uint32_t mant_mask,
const int exp_shift, const uint32_t ieee_exp_off,
const int mant_shift, const float sub_scale) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
const block_mxfp8 * GGML_RESTRICT x = vx;
const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale);
const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift);
const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift);
for (int ib = 0; ib < nb; ++ib) {
const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32(x[ib].e));
for (int j = 0; j < 32; j += 8) {
const uint8x8_t raw8 = vld1_u8(x[ib].qs + j);
const uint16x8_t raw16 = vmovl_u8(raw8);
const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16));
const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16));
#define DEQUANT_FP8_STORE(v_raw, dst) do { \
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x80)); \
const uint32x4_t exp = vandq_u32( \
vshlq_u32(v_raw, v_neg_exp_shift), \
v_exp_mask); \
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \
const uint32x4_t ieee = vorrq_u32( \
vorrq_u32(vshlq_n_u32(sign, 24), \
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \
vshlq_u32(mant, v_mant_shift_v)); \
const float32x4_t normal = vreinterpretq_f32_u32(ieee); \
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \
const uint32x4_t sub_bits = vorrq_u32( \
vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 24)); \
const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \
const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \
vst1q_f32(dst, vmulq_f32(val, v_scale)); \
} while (0)
DEQUANT_FP8_STORE(v_lo, y + ib * QK_MXFP8 + j);
DEQUANT_FP8_STORE(v_hi, y + ib * QK_MXFP8 + j + 4);
#undef DEQUANT_FP8_STORE
}
}
}
static inline void dequantize_row_mxfp6_neon(
const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k,
size_t block_size,
const uint32_t exp_mask, const uint32_t mant_mask,
const int exp_shift, const uint32_t ieee_exp_off,
const int mant_shift, const float sub_scale) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale);
const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift);
const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift);
for (int ib = 0; ib < nb; ++ib) {
const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size);
const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32(xb->e));
for (int j = 0; j < 32; j += 4) {
const uint8_t * p = xb->qs + (j * 3 / 4);
const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
uint8_t unpacked[4];
unpacked[0] = (pk >> 0) & 0x3F;
unpacked[1] = (pk >> 6) & 0x3F;
unpacked[2] = (pk >> 12) & 0x3F;
unpacked[3] = (pk >> 18) & 0x3F;
const uint8x8_t raw8 = vcreate_u8(
(uint64_t)unpacked[0] | ((uint64_t)unpacked[1] << 8) |
((uint64_t)unpacked[2] << 16) | ((uint64_t)unpacked[3] << 24));
const uint32x4_t v_raw = vmovl_u16(vget_low_u16(vmovl_u8(raw8)));
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x20));
const uint32x4_t exp = vandq_u32(
vshlq_u32(v_raw, v_neg_exp_shift),
v_exp_mask);
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask);
const uint32x4_t ieee = vorrq_u32(
vorrq_u32(vshlq_n_u32(sign, 26),
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)),
vshlq_u32(mant, v_mant_shift_v));
const float32x4_t normal = vreinterpretq_f32_u32(ieee);
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc);
const uint32x4_t sub_bits = vorrq_u32(
vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 26));
const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits);
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0));
const float32x4_t val = vbslq_f32(is_sub, sub_val, normal);
vst1q_f32(y + ib * QK_MXFP6 + j, vmulq_f32(val, v_scale));
}
}
}
#endif
void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__ARM_NEON)
dequantize_row_mxfp8_neon(x, y, k,
MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT,
MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE);
#else
dequantize_row_mxfp8_cpu_generic(x, y, k);
#endif
}
void dequantize_row_mxfp6_e2m3_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__ARM_NEON)
dequantize_row_mxfp6_neon(x, y, k, sizeof(block_mxfp6),
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE);
#else
dequantize_row_mxfp6_e2m3_cpu_generic(x, y, k);
#endif
}
// ---- MXFP SoA dequantize_row (to_float) — NEON-optimized ----
#if defined(__ARM_NEON)
static inline void dequantize_row_mxfp4_soa_neon(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) {
assert(k % QK_MXFP4 == 0);
const int nb = k / QK_MXFP4;
const char * row = (const char *)src;
const char * qs_base = row;
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK);
const int8x16_t values = vld1q_s8(kvalues_mxfp4);
const uint8x16_t m4b = vdupq_n_u8(0x0f);
for (int i = 0; i < nb; i++) {
const float d = GGML_E8M0_TO_FP32_HALF((uint8_t)e8m0_base[i]);
const float32x4_t v_scale = vdupq_n_f32(d);
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP4_SOA_QS_PER_BLOCK));
const uint8x16_t q4bits = vld1q_u8(qs);
const int8x16_t lo = ggml_vqtbl1q_s8(values, vandq_u8(q4bits, m4b));
const int8x16_t hi = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits, 4));
float * out_lo = y + i * QK_MXFP4;
float * out_hi = y + i * QK_MXFP4 + QK_MXFP4/2;
{
const int16x8_t lo16_0 = vmovl_s8(vget_low_s8(lo));
const int16x8_t lo16_1 = vmovl_s8(vget_high_s8(lo));
vst1q_f32(out_lo + 0, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(lo16_0))), v_scale));
vst1q_f32(out_lo + 4, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(lo16_0))), v_scale));
vst1q_f32(out_lo + 8, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(lo16_1))), v_scale));
vst1q_f32(out_lo + 12, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(lo16_1))), v_scale));
}
{
const int16x8_t hi16_0 = vmovl_s8(vget_low_s8(hi));
const int16x8_t hi16_1 = vmovl_s8(vget_high_s8(hi));
vst1q_f32(out_hi + 0, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(hi16_0))), v_scale));
vst1q_f32(out_hi + 4, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(hi16_0))), v_scale));
vst1q_f32(out_hi + 8, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(hi16_1))), v_scale));
vst1q_f32(out_hi + 12, vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(hi16_1))), v_scale));
}
}
}
static inline void dequantize_row_mxfp8_soa_neon(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
const uint32_t exp_mask, const uint32_t mant_mask,
const int exp_shift, const uint32_t ieee_exp_off,
const int mant_shift, const float sub_scale) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
const char * row = (const char *)src;
const char * qs_base = row;
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK);
const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale);
const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift);
const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift);
for (int ib = 0; ib < nb; ++ib) {
const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP8_SOA_QS_PER_BLOCK));
for (int j = 0; j < 32; j += 8) {
const uint8x8_t raw8 = vld1_u8(qs + j);
const uint16x8_t raw16 = vmovl_u8(raw8);
const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16));
const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16));
#define DEQUANT_FP8_STORE_SOA(v_raw, dst) do { \
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x80)); \
const uint32x4_t exp = vandq_u32( \
vshlq_u32(v_raw, v_neg_exp_shift), \
v_exp_mask); \
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \
const uint32x4_t ieee = vorrq_u32( \
vorrq_u32(vshlq_n_u32(sign, 24), \
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \
vshlq_u32(mant, v_mant_shift_v)); \
const float32x4_t normal = vreinterpretq_f32_u32(ieee); \
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \
const uint32x4_t sub_bits = vorrq_u32( \
vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 24)); \
const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \
const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \
vst1q_f32(dst, vmulq_f32(val, v_scale)); \
} while (0)
DEQUANT_FP8_STORE_SOA(v_lo, y + ib * QK_MXFP8 + j);
DEQUANT_FP8_STORE_SOA(v_hi, y + ib * QK_MXFP8 + j + 4);
#undef DEQUANT_FP8_STORE_SOA
}
}
}
static inline void dequantize_row_mxfp6_soa_neon(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
const uint32_t exp_mask, const uint32_t mant_mask,
const int exp_shift, const uint32_t ieee_exp_off,
const int mant_shift, const float sub_scale) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
const char * row = (const char *)src;
const char * qs_base = row;
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK);
const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale);
const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift);
const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift);
for (int ib = 0; ib < nb; ++ib) {
const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP6_SOA_QS_PER_BLOCK));
for (int j = 0; j < 32; j += 4) {
const uint8_t * p = qs + (j * 3 / 4);
const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
uint8_t unpacked[4];
unpacked[0] = (pk >> 0) & 0x3F;
unpacked[1] = (pk >> 6) & 0x3F;
unpacked[2] = (pk >> 12) & 0x3F;
unpacked[3] = (pk >> 18) & 0x3F;
const uint8x8_t raw8 = vcreate_u8(
(uint64_t)unpacked[0] | ((uint64_t)unpacked[1] << 8) |
((uint64_t)unpacked[2] << 16) | ((uint64_t)unpacked[3] << 24));
const uint32x4_t v_raw = vmovl_u16(vget_low_u16(vmovl_u8(raw8)));
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x20));
const uint32x4_t exp = vandq_u32(
vshlq_u32(v_raw, v_neg_exp_shift),
v_exp_mask);
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask);
const uint32x4_t ieee = vorrq_u32(
vorrq_u32(vshlq_n_u32(sign, 26),
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)),
vshlq_u32(mant, v_mant_shift_v));
const float32x4_t normal = vreinterpretq_f32_u32(ieee);
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc);
const uint32x4_t sub_bits = vorrq_u32(
vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 26));
const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits);
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0));
const float32x4_t val = vbslq_f32(is_sub, sub_val, normal);
vst1q_f32(y + ib * QK_MXFP6 + j, vmulq_f32(val, v_scale));
}
}
}
#endif
void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__ARM_NEON)
dequantize_row_mxfp4_soa_neon(x, y, k);
#else
dequantize_row_mxfp4_soa_cpu_generic(x, y, k);
#endif
}
void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__ARM_NEON)
dequantize_row_mxfp8_soa_neon(x, y, k,
MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT,
MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE);
#else
dequantize_row_mxfp8_soa_cpu_generic(x, y, k);
#endif
}
void dequantize_row_mxfp6_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__ARM_NEON)
dequantize_row_mxfp6_soa_neon(x, y, k,
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE);
#else
dequantize_row_mxfp6_soa_cpu_generic(x, y, k);
#endif
}

View File

@ -2157,3 +2157,14 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
#endif
}
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}
void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}

View File

@ -2303,3 +2303,10 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
#endif
}
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}
void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}

View File

@ -3607,3 +3607,11 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
return ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}
void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}

View File

@ -1464,3 +1464,10 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
#endif
}
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}
void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}

View File

@ -1219,3 +1219,14 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
#endif
}
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}
void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}

View File

@ -3818,3 +3818,501 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
// AVX2-optimized MXFP8 × Q8_0 dot product.
// Dequants FP8 elements to float via IEEE 754 bit construction, then dots against Q8_0.
// Parameters encode the FP8 format: exp_mask, mant_mask, exp_shift, ieee_exp_offset, mant_shift, sub_scale.
#if defined(__AVX2__)
static inline void ggml_vec_dot_mxfp8_q8_0_avx2(
int n, float * GGML_RESTRICT s,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
// FP8 format parameters:
const int exp_mask, // 0xF for E4M3, 0x1F for E5M2
const int mant_mask, // 0x7 for E4M3, 0x3 for E5M2
const int exp_shift, // 3 for E4M3, 2 for E5M2
const int ieee_exp_off, // 120 for E4M3, 112 for E5M2
const int mant_shift, // 20 for E4M3, 21 for E5M2
const float sub_scale) { // 1/512 for E4M3, 1/65536 for E5M2
assert(n % QK_MXFP8 == 0);
const int nb = n / QK_MXFP8;
const block_mxfp8 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;
const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(sub_scale);
const __m256i v_zero = _mm256_setzero_si256();
__m256 acc = _mm256_setzero_ps();
for (int ib = 0; ib < nb; ++ib) {
const __m256 v_scale = _mm256_set1_ps(
GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d));
// Process 32 FP8 elements in 4 groups of 8
// AVX2 _mm256_cvtepu8_epi32 widens 8 bytes → 8 int32s directly
for (int j = 0; j < 32; j += 8) {
// Load 8 FP8 bytes → 8 int32s
const __m128i raw8 = _mm_loadl_epi64((const __m128i *)(x[ib].qs + j));
const __m256i v_raw = _mm256_cvtepu8_epi32(raw8);
// Load 8 Q8_0 int8 values → float
const __m128i q8 = _mm_loadl_epi64((const __m128i *)(y[ib].qs + j));
const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(q8));
// Extract sign (bit 7), exponent, mantissa
const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x80));
const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask);
const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask);
// Normal path: IEEE bits = (sign << 24) | ((exp + offset) << 23) | (mant << mant_shift)
const __m256i ieee = _mm256_or_si256(
_mm256_or_si256(_mm256_slli_epi32(sign, 24),
_mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)),
_mm256_slli_epi32(mant, mant_shift));
const __m256 normal = _mm256_castsi256_ps(ieee);
// Subnormal path: |val| = mant * sub_scale, then apply sign
const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc);
const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256(
_mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 24)));
// Select: subnormal when exp == 0, else normal
const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero));
const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub);
// Accumulate: val * scale * q8_float
acc = _mm256_fmadd_ps(_mm256_mul_ps(val, v_scale), qf, acc);
}
}
*s = hsum_float_8(acc);
}
#endif
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
#if defined(__AVX2__)
// E4M3: sign(1) exp(4) mant(3), bias=7
ggml_vec_dot_mxfp8_q8_0_avx2(n, s, vx, vy,
MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT,
MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE);
#else
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
// AVX2-optimized MXFP6 × Q8_0 dot product.
// Unpacks tight 6-bit packing (4 values per 3 bytes), then dequants to float.
#if defined(__AVX2__)
static inline void ggml_vec_dot_mxfp6_q8_0_avx2(
int n, float * GGML_RESTRICT s,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
size_t block_size,
// FP6 format parameters:
const int exp_mask, // 0x3 for E2M3, 0x7 for E3M2
const int mant_mask, // 0x7 for E2M3, 0x3 for E3M2
const int exp_shift, // 3 for E2M3, 2 for E3M2
const int ieee_exp_off, // 126 for E2M3, 124 for E3M2
const int mant_shift, // 20 for E2M3, 21 for E3M2
const float sub_scale) { // 1/8 for E2M3, 1/16 for E3M2
assert(n % QK_MXFP6 == 0);
const int nb = n / QK_MXFP6;
const block_q8_0 * GGML_RESTRICT y = vy;
const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(sub_scale);
const __m256i v_zero = _mm256_setzero_si256();
__m256 acc = _mm256_setzero_ps();
for (int ib = 0; ib < nb; ++ib) {
const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size);
const __m256 v_scale = _mm256_set1_ps(
GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d));
// Process 32 FP6 elements in 4 groups of 8 (each group = 2 × 3-byte packs)
for (int j = 0; j < 32; j += 8) {
// Unpack 8 FP6 values from 6 bytes (two groups of 3 bytes → 4 values each)
uint8_t unpacked[8];
{
const uint8_t * p = xb->qs + (j * 3 / 4);
const uint32_t pk0 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
unpacked[0] = (pk0 >> 0) & 0x3F;
unpacked[1] = (pk0 >> 6) & 0x3F;
unpacked[2] = (pk0 >> 12) & 0x3F;
unpacked[3] = (pk0 >> 18) & 0x3F;
}
{
const uint8_t * p = xb->qs + ((j + 4) * 3 / 4);
const uint32_t pk1 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
unpacked[4] = (pk1 >> 0) & 0x3F;
unpacked[5] = (pk1 >> 6) & 0x3F;
unpacked[6] = (pk1 >> 12) & 0x3F;
unpacked[7] = (pk1 >> 18) & 0x3F;
}
// Widen 8 bytes → 8 int32s
const __m128i raw8 = _mm_loadl_epi64((const __m128i *)unpacked);
const __m256i v_raw = _mm256_cvtepu8_epi32(raw8);
// Load 8 Q8_0 int8 values → float
const __m128i q8 = _mm_loadl_epi64((const __m128i *)(y[ib].qs + j));
const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(q8));
// Extract sign (bit 5 for FP6), exponent, mantissa
const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x20));
const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask);
const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask);
// Normal: IEEE bits = (sign << 26) | ((exp + offset) << 23) | (mant << mant_shift)
const __m256i ieee = _mm256_or_si256(
_mm256_or_si256(_mm256_slli_epi32(sign, 26),
_mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)),
_mm256_slli_epi32(mant, mant_shift));
const __m256 normal = _mm256_castsi256_ps(ieee);
// Subnormal: |val| = mant * sub_scale, apply sign
const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc);
const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256(
_mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 26)));
// Select: subnormal when exp == 0
const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero));
const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub);
acc = _mm256_fmadd_ps(_mm256_mul_ps(val, v_scale), qf, acc);
}
}
*s = hsum_float_8(acc);
}
#endif
void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
#if defined(__AVX2__)
// E2M3: sign(1) exp(2) mant(3), bias=1
ggml_vec_dot_mxfp6_q8_0_avx2(n, s, vx, vy, sizeof(block_mxfp6),
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE);
#else
ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
// ---- MXFP dequantize_row (to_float) — AVX2-optimized ----
// Extracts the SIMD dequant logic from vec_dot above, writing floats to output buffer
// instead of accumulating a dot product.
#if defined(__AVX2__)
static inline void dequantize_row_mxfp8_avx2(
const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k,
const int exp_mask, const int mant_mask, const int exp_shift,
const int ieee_exp_off, const int mant_shift, const float sub_scale) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
const block_mxfp8 * GGML_RESTRICT x = vx;
const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(sub_scale);
const __m256i v_zero = _mm256_setzero_si256();
for (int ib = 0; ib < nb; ++ib) {
const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32(x[ib].e));
for (int j = 0; j < 32; j += 8) {
const __m128i raw8 = _mm_loadl_epi64((const __m128i *)(x[ib].qs + j));
const __m256i v_raw = _mm256_cvtepu8_epi32(raw8);
const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x80));
const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask);
const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask);
const __m256i ieee = _mm256_or_si256(
_mm256_or_si256(_mm256_slli_epi32(sign, 24),
_mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)),
_mm256_slli_epi32(mant, mant_shift));
const __m256 normal = _mm256_castsi256_ps(ieee);
const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc);
const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256(
_mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 24)));
const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero));
const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub);
_mm256_storeu_ps(y + ib * QK_MXFP8 + j, _mm256_mul_ps(val, v_scale));
}
}
}
static inline void dequantize_row_mxfp6_avx2(
const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k,
size_t block_size,
const int exp_mask, const int mant_mask, const int exp_shift,
const int ieee_exp_off, const int mant_shift, const float sub_scale) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(sub_scale);
const __m256i v_zero = _mm256_setzero_si256();
for (int ib = 0; ib < nb; ++ib) {
const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size);
const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32(xb->e));
for (int j = 0; j < 32; j += 8) {
// Unpack 8 FP6 values from 6 bytes (two groups of 3 bytes → 4 values each)
uint8_t unpacked[8];
{
const uint8_t * p = xb->qs + (j * 3 / 4);
const uint32_t pk0 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
unpacked[0] = (pk0 >> 0) & 0x3F;
unpacked[1] = (pk0 >> 6) & 0x3F;
unpacked[2] = (pk0 >> 12) & 0x3F;
unpacked[3] = (pk0 >> 18) & 0x3F;
}
{
const uint8_t * p = xb->qs + ((j + 4) * 3 / 4);
const uint32_t pk1 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
unpacked[4] = (pk1 >> 0) & 0x3F;
unpacked[5] = (pk1 >> 6) & 0x3F;
unpacked[6] = (pk1 >> 12) & 0x3F;
unpacked[7] = (pk1 >> 18) & 0x3F;
}
const __m128i raw8 = _mm_loadl_epi64((const __m128i *)unpacked);
const __m256i v_raw = _mm256_cvtepu8_epi32(raw8);
const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x20));
const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask);
const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask);
const __m256i ieee = _mm256_or_si256(
_mm256_or_si256(_mm256_slli_epi32(sign, 26),
_mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)),
_mm256_slli_epi32(mant, mant_shift));
const __m256 normal = _mm256_castsi256_ps(ieee);
const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc);
const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256(
_mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 26)));
const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero));
const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub);
_mm256_storeu_ps(y + ib * QK_MXFP6 + j, _mm256_mul_ps(val, v_scale));
}
}
}
#endif
void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__AVX2__)
dequantize_row_mxfp8_avx2(x, y, k,
MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT,
MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE);
#else
dequantize_row_mxfp8_cpu_generic(x, y, k);
#endif
}
void dequantize_row_mxfp6_e2m3_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__AVX2__)
dequantize_row_mxfp6_avx2(x, y, k, sizeof(block_mxfp6),
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE);
#else
dequantize_row_mxfp6_e2m3_cpu_generic(x, y, k);
#endif
}
// SoA dequant for flash attention — contiguous qs region + separate e8m0 region
#if defined(__AVX2__)
static inline void dequantize_row_mxfp4_soa_avx2(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) {
assert(k % QK_MXFP4 == 0);
const int nb = k / QK_MXFP4;
const char * row = (const char *)src;
const char * qs_base = row;
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK);
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
for (int i = 0; i < nb; i++) {
const float d = GGML_E8M0_TO_FP32_HALF((uint8_t)e8m0_base[i]);
const __m256 v_scale = _mm256_set1_ps(d);
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP4_SOA_QS_PER_BLOCK));
const __m128i q4bits = _mm_loadu_si128((const __m128i *)qs);
const __m128i lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits, m4b));
const __m128i hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4b));
// lo nibbles → first 16 floats
const __m256i lo32_0 = _mm256_cvtepi8_epi32(lo);
const __m256i lo32_1 = _mm256_cvtepi8_epi32(_mm_srli_si128(lo, 8));
_mm256_storeu_ps(y + i * QK_MXFP4 + 0, _mm256_mul_ps(_mm256_cvtepi32_ps(lo32_0), v_scale));
_mm256_storeu_ps(y + i * QK_MXFP4 + 8, _mm256_mul_ps(_mm256_cvtepi32_ps(lo32_1), v_scale));
// hi nibbles → second 16 floats
const __m256i hi32_0 = _mm256_cvtepi8_epi32(hi);
const __m256i hi32_1 = _mm256_cvtepi8_epi32(_mm_srli_si128(hi, 8));
_mm256_storeu_ps(y + i * QK_MXFP4 + 16, _mm256_mul_ps(_mm256_cvtepi32_ps(hi32_0), v_scale));
_mm256_storeu_ps(y + i * QK_MXFP4 + 24, _mm256_mul_ps(_mm256_cvtepi32_ps(hi32_1), v_scale));
}
}
static inline void dequantize_row_mxfp8_soa_avx2(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
const int exp_mask, const int mant_mask, const int exp_shift,
const int ieee_exp_off, const int mant_shift, const float sub_scale) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
const char * row = (const char *)src;
const char * qs_base = row;
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK);
const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(sub_scale);
const __m256i v_zero = _mm256_setzero_si256();
for (int ib = 0; ib < nb; ++ib) {
const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP8_SOA_QS_PER_BLOCK));
for (int j = 0; j < 32; j += 8) {
const __m128i raw8 = _mm_loadl_epi64((const __m128i *)(qs + j));
const __m256i v_raw = _mm256_cvtepu8_epi32(raw8);
const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x80));
const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask);
const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask);
const __m256i ieee = _mm256_or_si256(
_mm256_or_si256(_mm256_slli_epi32(sign, 24),
_mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)),
_mm256_slli_epi32(mant, mant_shift));
const __m256 normal = _mm256_castsi256_ps(ieee);
const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc);
const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256(
_mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 24)));
const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero));
const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub);
_mm256_storeu_ps(y + ib * QK_MXFP8 + j, _mm256_mul_ps(val, v_scale));
}
}
}
static inline void dequantize_row_mxfp6_soa_avx2(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
const int exp_mask, const int mant_mask, const int exp_shift,
const int ieee_exp_off, const int mant_shift, const float sub_scale) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
const char * row = (const char *)src;
const char * qs_base = row;
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK);
const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(sub_scale);
const __m256i v_zero = _mm256_setzero_si256();
for (int ib = 0; ib < nb; ++ib) {
const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP6_SOA_QS_PER_BLOCK));
for (int j = 0; j < 32; j += 8) {
uint8_t unpacked[8];
{
const uint8_t * p = qs + (j * 3 / 4);
const uint32_t pk0 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
unpacked[0] = (pk0 >> 0) & 0x3F;
unpacked[1] = (pk0 >> 6) & 0x3F;
unpacked[2] = (pk0 >> 12) & 0x3F;
unpacked[3] = (pk0 >> 18) & 0x3F;
}
{
const uint8_t * p = qs + ((j + 4) * 3 / 4);
const uint32_t pk1 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
unpacked[4] = (pk1 >> 0) & 0x3F;
unpacked[5] = (pk1 >> 6) & 0x3F;
unpacked[6] = (pk1 >> 12) & 0x3F;
unpacked[7] = (pk1 >> 18) & 0x3F;
}
const __m128i raw8 = _mm_loadl_epi64((const __m128i *)unpacked);
const __m256i v_raw = _mm256_cvtepu8_epi32(raw8);
const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x20));
const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask);
const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask);
const __m256i ieee = _mm256_or_si256(
_mm256_or_si256(_mm256_slli_epi32(sign, 26),
_mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)),
_mm256_slli_epi32(mant, mant_shift));
const __m256 normal = _mm256_castsi256_ps(ieee);
const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc);
const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256(
_mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 26)));
const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero));
const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub);
_mm256_storeu_ps(y + ib * QK_MXFP6 + j, _mm256_mul_ps(val, v_scale));
}
}
}
#endif
void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__AVX2__)
dequantize_row_mxfp4_soa_avx2(x, y, k);
#else
dequantize_row_mxfp4_soa_cpu_generic(x, y, k);
#endif
}
void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__AVX2__)
dequantize_row_mxfp8_soa_avx2(x, y, k,
MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT,
MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE);
#else
dequantize_row_mxfp8_soa_cpu_generic(x, y, k);
#endif
}
void dequantize_row_mxfp6_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__AVX2__)
dequantize_row_mxfp6_soa_avx2(x, y, k,
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE);
#else
dequantize_row_mxfp6_soa_cpu_generic(x, y, k);
#endif
}

View File

@ -264,7 +264,7 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_1,
.nrows = 1,
},
[GGML_TYPE_MXFP4] = {
[GGML_TYPE_MXFP4_E2M1] = {
.from_float = quantize_row_mxfp4,
.vec_dot = ggml_vec_dot_mxfp4_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
@ -276,6 +276,20 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
[GGML_TYPE_MXFP8_E4M3] = {
.from_float = quantize_row_mxfp8,
.to_float = dequantize_row_mxfp8_cpu,
.vec_dot = ggml_vec_dot_mxfp8_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
[GGML_TYPE_MXFP6_E2M3] = {
.from_float = quantize_row_mxfp6_e2m3,
.to_float = dequantize_row_mxfp6_e2m3_cpu,
.vec_dot = ggml_vec_dot_mxfp6_e2m3_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
[GGML_TYPE_Q2_K] = {
.from_float = quantize_row_q2_K,
.vec_dot = ggml_vec_dot_q2_K_q8_K,

View File

@ -2,6 +2,8 @@
#include "ggml-cpu.h"
#include "ggml-impl.h"
#include "ggml-quants.h"
#include "quants.h"
#include "binary-ops.h"
#include "simd-gemm.h"
#include "ggml.h"
@ -11,6 +13,7 @@
#include <algorithm>
#include <cfloat>
#include <cmath>
#include <cstring>
// ggml_compute_forward_dup
@ -669,8 +672,10 @@ void ggml_compute_forward_add(
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_MXFP4_E2M1:
case GGML_TYPE_NVFP4:
case GGML_TYPE_MXFP8_E4M3:
case GGML_TYPE_MXFP6_E2M3:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@ -1119,8 +1124,10 @@ void ggml_compute_forward_add1(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_MXFP4_E2M1:
case GGML_TYPE_NVFP4:
case GGML_TYPE_MXFP8_E4M3:
case GGML_TYPE_MXFP6_E2M3:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@ -1248,8 +1255,10 @@ void ggml_compute_forward_acc(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_MXFP4_E2M1:
case GGML_TYPE_NVFP4:
case GGML_TYPE_MXFP8_E4M3:
case GGML_TYPE_MXFP6_E2M3:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@ -4336,8 +4345,10 @@ void ggml_compute_forward_out_prod(
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_MXFP4_E2M1:
case GGML_TYPE_NVFP4:
case GGML_TYPE_MXFP8_E4M3:
case GGML_TYPE_MXFP6_E2M3:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@ -4612,8 +4623,10 @@ void ggml_compute_forward_set(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_MXFP4_E2M1:
case GGML_TYPE_NVFP4:
case GGML_TYPE_MXFP8_E4M3:
case GGML_TYPE_MXFP6_E2M3:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@ -4835,8 +4848,10 @@ void ggml_compute_forward_get_rows(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_MXFP4_E2M1:
case GGML_TYPE_NVFP4:
case GGML_TYPE_MXFP8_E4M3:
case GGML_TYPE_MXFP6_E2M3:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@ -4894,6 +4909,96 @@ void ggml_compute_forward_get_rows(
//}
}
// NEON-optimized Hadamard for ARM platforms; scalar fallback uses ggml_hadamard_32_inplace
// from ggml-quants.c (the reference implementation).
#if defined(__ARM_NEON)
static void hadamard_32_inplace(float vals[32]) {
float32x4_t v0 = vld1q_f32(vals + 0);
float32x4_t v1 = vld1q_f32(vals + 4);
float32x4_t v2 = vld1q_f32(vals + 8);
float32x4_t v3 = vld1q_f32(vals + 12);
float32x4_t v4 = vld1q_f32(vals + 16);
float32x4_t v5 = vld1q_f32(vals + 20);
float32x4_t v6 = vld1q_f32(vals + 24);
float32x4_t v7 = vld1q_f32(vals + 28);
#define HADAMARD_S1(v) do { \
float32x2_t lo = vget_low_f32(v); \
float32x2_t hi = vget_high_f32(v); \
float32x2x2_t t = vtrn_f32(lo, hi); \
float32x2_t sum = vadd_f32(t.val[0], t.val[1]); \
float32x2_t dif = vsub_f32(t.val[0], t.val[1]); \
float32x2x2_t r = vtrn_f32(sum, dif); \
(v) = vcombine_f32(r.val[0], r.val[1]); \
} while (0)
HADAMARD_S1(v0); HADAMARD_S1(v1); HADAMARD_S1(v2); HADAMARD_S1(v3);
HADAMARD_S1(v4); HADAMARD_S1(v5); HADAMARD_S1(v6); HADAMARD_S1(v7);
#undef HADAMARD_S1
#define HADAMARD_S2(v) do { \
float32x2_t lo = vget_low_f32(v); \
float32x2_t hi = vget_high_f32(v); \
(v) = vcombine_f32(vadd_f32(lo, hi), vsub_f32(lo, hi)); \
} while (0)
HADAMARD_S2(v0); HADAMARD_S2(v1); HADAMARD_S2(v2); HADAMARD_S2(v3);
HADAMARD_S2(v4); HADAMARD_S2(v5); HADAMARD_S2(v6); HADAMARD_S2(v7);
#undef HADAMARD_S2
#define HADAMARD_S4(a, b) do { \
float32x4_t s = vaddq_f32(a, b); \
float32x4_t d = vsubq_f32(a, b); \
(a) = s; (b) = d; \
} while (0)
HADAMARD_S4(v0, v1); HADAMARD_S4(v2, v3);
HADAMARD_S4(v4, v5); HADAMARD_S4(v6, v7);
#undef HADAMARD_S4
{ float32x4_t s, d;
s = vaddq_f32(v0, v2); d = vsubq_f32(v0, v2); v0 = s; v2 = d;
s = vaddq_f32(v1, v3); d = vsubq_f32(v1, v3); v1 = s; v3 = d;
s = vaddq_f32(v4, v6); d = vsubq_f32(v4, v6); v4 = s; v6 = d;
s = vaddq_f32(v5, v7); d = vsubq_f32(v5, v7); v5 = s; v7 = d;
}
{ float32x4_t s, d;
s = vaddq_f32(v0, v4); d = vsubq_f32(v0, v4); v0 = s; v4 = d;
s = vaddq_f32(v1, v5); d = vsubq_f32(v1, v5); v1 = s; v5 = d;
s = vaddq_f32(v2, v6); d = vsubq_f32(v2, v6); v2 = s; v6 = d;
s = vaddq_f32(v3, v7); d = vsubq_f32(v3, v7); v3 = s; v7 = d;
}
const float32x4_t norm = vdupq_n_f32(MXFP_HADAMARD_32_NORM);
vst1q_f32(vals + 0, vmulq_f32(v0, norm));
vst1q_f32(vals + 4, vmulq_f32(v1, norm));
vst1q_f32(vals + 8, vmulq_f32(v2, norm));
vst1q_f32(vals + 12, vmulq_f32(v3, norm));
vst1q_f32(vals + 16, vmulq_f32(v4, norm));
vst1q_f32(vals + 20, vmulq_f32(v5, norm));
vst1q_f32(vals + 24, vmulq_f32(v6, norm));
vst1q_f32(vals + 28, vmulq_f32(v7, norm));
}
#else
// Scalar fallback: delegate to reference implementation in ggml-quants.c
static void hadamard_32_inplace(float vals[32]) {
ggml_hadamard_32_inplace(vals);
}
#endif
// Apply Hadamard rotation to each 32-element block in a float buffer.
static void ggml_apply_hadamard_blocks(float * data, int64_t n) {
GGML_ASSERT(n % 32 == 0);
for (int64_t i = 0; i < n; i += 32) {
hadamard_32_inplace(data + i);
}
}
// Prefer SIMD-optimized CPU dequant, fall back to scalar reference.
static inline ggml_to_float_t ggml_get_to_float_fn(ggml_type type) {
ggml_to_float_t fn = ggml_get_type_traits_cpu(type)->to_float;
if (!fn) { fn = ggml_get_type_traits(type)->to_float; }
return fn;
}
template<typename idx_t>
static void ggml_compute_forward_set_rows_f32(
const ggml_compute_params * params,
@ -4924,7 +5029,22 @@ static void ggml_compute_forward_set_rows_f32(
const int64_t ir0 = dr*ith;
const int64_t ir1 = std::min(ir0 + dr, nr);
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
const int32_t apply_hadamard = ((const int32_t *)dst->op_params)[0];
// For MXFP types, use SoA quantize (canonical FA layout).
// For non-MXFP types, use the standard AoS from_float.
typedef void (*quantize_soa_fn)(const float *, void *, int64_t);
quantize_soa_fn mxfp_soa_quantize = nullptr;
ggml_from_float_t from_float = nullptr;
switch (dst->type) {
case GGML_TYPE_MXFP4_E2M1: mxfp_soa_quantize = quantize_row_mxfp4_soa; break;
case GGML_TYPE_MXFP8_E4M3: mxfp_soa_quantize = quantize_row_mxfp8_soa; break;
case GGML_TYPE_MXFP6_E2M3: mxfp_soa_quantize = quantize_row_mxfp6_soa; break;
default:
from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
break;
}
for (int64_t i03 = 0; i03 < ne03; ++i03) {
for (int64_t i02 = 0; i02 < ne02; ++i02) {
@ -4937,9 +5057,26 @@ static void ggml_compute_forward_set_rows_f32(
GGML_ASSERT(i1 >= 0 && i1 < ne1);
from_float(
(const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc);
const float * src_row = (const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03);
char * dst_row = ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3);
if (apply_hadamard) {
GGML_ASSERT(nc <= 1024);
float tmp[1024];
memcpy(tmp, src_row, nc * sizeof(float));
ggml_apply_hadamard_blocks(tmp, nc);
if (mxfp_soa_quantize) {
mxfp_soa_quantize(tmp, dst_row, nc);
} else {
from_float(tmp, dst_row, nc);
}
} else {
if (mxfp_soa_quantize) {
mxfp_soa_quantize(src_row, dst_row, nc);
} else {
from_float(src_row, dst_row, nc);
}
}
}
}
}
@ -5560,8 +5697,10 @@ void ggml_compute_forward_clamp(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_MXFP4_E2M1:
case GGML_TYPE_NVFP4:
case GGML_TYPE_MXFP8_E4M3:
case GGML_TYPE_MXFP6_E2M3:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@ -8118,6 +8257,67 @@ void ggml_compute_forward_top_k(
}
}
// SoA function pointer types for MXFP flash attention paths.
typedef void (*mxfp_soa_quantize_fn)(const float *, void *, int64_t);
typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t);
// Shared MXFP dispatch parameters for flash attention.
// Populated once and used by both the one_chunk and tiled paths.
struct mxfp_fa_params {
mxfp_soa_quantize_fn q_quantize;
mxfp_soa_dequantize_fn k_dequantize;
mxfp_soa_dequantize_fn v_dequantize;
bool k_multihead;
bool v_multihead;
int64_t k_soa_elems;
int64_t v_soa_elems;
bool apply_hadamard;
};
static mxfp_fa_params mxfp_fa_params_init(
const ggml_tensor * k, const ggml_tensor * v,
int64_t DK, int64_t DV,
size_t nbk2, size_t nbv2,
int64_t nek2, int64_t nev2) {
mxfp_fa_params p = {};
const bool is_mxfp_k = ggml_is_type_mxfp(k->type);
const bool is_mxfp_v = ggml_is_type_mxfp(v->type);
if (is_mxfp_k) {
switch (k->type) {
case GGML_TYPE_MXFP4_E2M1: p.q_quantize = quantize_row_mxfp4_soa; p.k_dequantize = dequantize_row_mxfp4_soa_cpu; break;
case GGML_TYPE_MXFP8_E4M3: p.q_quantize = quantize_row_mxfp8_soa; p.k_dequantize = dequantize_row_mxfp8_soa_cpu; break;
case GGML_TYPE_MXFP6_E2M3: p.q_quantize = quantize_row_mxfp6_soa; p.k_dequantize = dequantize_row_mxfp6_soa_cpu; break;
default: GGML_ABORT("unsupported MXFP K type");
}
}
if (is_mxfp_v) {
switch (v->type) {
case GGML_TYPE_MXFP4_E2M1: p.v_dequantize = dequantize_row_mxfp4_soa_cpu; break;
case GGML_TYPE_MXFP8_E4M3: p.v_dequantize = dequantize_row_mxfp8_soa_cpu; break;
case GGML_TYPE_MXFP6_E2M3: p.v_dequantize = dequantize_row_mxfp6_soa_cpu; break;
default: GGML_ABORT("unsupported MXFP V type");
}
}
// Hadamard rotation must match K rotation.
// Skipped for: MLA (DK != DV, V is a view of K).
p.apply_hadamard = is_mxfp_k && (DK == DV) && ggml_mxfp_use_hadamard(k->type);
// SoA layout detection: in the real KV cache, heads are contiguous within
// one KV-position stride (nb[2] == row_size(DK)), so SoA spans all heads.
// In test tensors, heads may be at distant offsets (nb[2] >> row_size(DK)),
// so SoA is per-head. Detect which case and set dequant parameters accordingly.
p.k_multihead = is_mxfp_k && (nbk2 == (size_t)ggml_row_size(k->type, DK));
p.k_soa_elems = is_mxfp_k ? (p.k_multihead ? nek2 * DK : DK) : 0;
p.v_multihead = is_mxfp_v && (nbv2 == (size_t)ggml_row_size(v->type, DV));
p.v_soa_elems = is_mxfp_v ? (p.v_multihead ? nev2 * DV : DV) : 0;
return p;
}
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
const ggml_compute_params * params,
ggml_tensor * dst,
@ -8192,13 +8392,29 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
const bool is_mxfp_k = ggml_is_type_mxfp(k->type);
const bool is_mxfp_v = ggml_is_type_mxfp(v->type);
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
const mxfp_fa_params mxfp = mxfp_fa_params_init(k, v, DK, DV, nbk2, nbv2, nek2, nev2);
ggml_from_float_t q_to_vec_dot = nullptr;
ggml_vec_dot_t kq_vec_dot = nullptr;
ggml_to_float_t v_to_float = nullptr;
if (is_mxfp_k) {
kq_vec_dot = nullptr;
} else {
ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
}
if (!is_mxfp_v) {
v_to_float = ggml_get_to_float_fn(v->type);
}
GGML_ASSERT((is_mxfp_k || q_to_vec_dot) && "fattn: unsupported K-type");
GGML_ASSERT((v->type == GGML_TYPE_F32 || is_mxfp_v || v_to_float) && "fattn: unsupported V-type");
int ith = params->ith;
@ -8236,7 +8452,31 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
const int iv2 = iq2 / rv2;
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
q_to_vec_dot(pq, Q_q, DK);
float Q_f32[1024];
if (is_mxfp_k) {
// Q preprocessing: Hadamard → SoA quantize → SoA dequant (round-trip).
// Captures the same quantization loss as K, matching GPU MMA semantics.
GGML_ASSERT(DK <= 1024);
if (mxfp.apply_hadamard) {
float q_tmp[1024];
memcpy(q_tmp, pq, DK * sizeof(float));
ggml_apply_hadamard_blocks(q_tmp, DK);
mxfp.q_quantize(q_tmp, Q_q, DK);
} else {
mxfp.q_quantize(pq, Q_q, DK);
}
mxfp.k_dequantize(Q_q, Q_f32, DK);
} else {
if (mxfp.apply_hadamard) {
GGML_ASSERT(DK <= 1024);
float q_tmp[1024];
memcpy(q_tmp, pq, DK * sizeof(float));
ggml_apply_hadamard_blocks(q_tmp, DK);
q_to_vec_dot(q_tmp, Q_q, DK);
} else {
q_to_vec_dot(pq, Q_q, DK);
}
}
// online softmax / attention
// loop over n_kv and n_head_kv
@ -8251,7 +8491,20 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
float s; // KQ value
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
if (is_mxfp_k) {
// Dequant SoA data. Multi-head: full row base, extract head portion.
// Per-head: use k_data directly.
const char * k_soa_base = mxfp.k_multihead
? ((const char *) k->data + ic*nbk1 + ik3*nbk3)
: k_data;
float k_soa_f32[4096];
GGML_ASSERT(mxfp.k_soa_elems <= 4096);
mxfp.k_dequantize(k_soa_base, k_soa_f32, mxfp.k_soa_elems);
const float * k_head = k_soa_f32 + (mxfp.k_multihead ? ik2 * DK : 0);
ggml_vec_dot_f32(DK, &s, 0, k_head, 0, Q_f32, 0, 1);
} else {
kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
}
s = s*scale; // scale KQ value
@ -8297,7 +8550,15 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
}
// V += v*expf(s - M)
if (v_to_float) {
if (mxfp.v_dequantize) {
const char * v_soa_base = mxfp.v_multihead
? ((const char *) v->data + ic*nbv1 + iv3*nbv3)
: v_data;
float v_soa_f32[4096];
GGML_ASSERT(mxfp.v_soa_elems <= 4096);
mxfp.v_dequantize(v_soa_base, v_soa_f32, mxfp.v_soa_elems);
ggml_vec_mad_f32(DV, VKQ32, v_soa_f32 + (mxfp.v_multihead ? iv2 * DV : 0), vs);
} else if (v_to_float) {
v_to_float(v_data, V32, DV);
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
} else {
@ -8399,9 +8660,17 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(k->type == v->type);
const ggml_type kv_type = k->type;
const ggml_type k_type = k->type;
const ggml_type v_type = v->type;
const bool is_mxfp_k = ggml_is_type_mxfp(k_type);
const bool is_mxfp_v = ggml_is_type_mxfp(v_type);
const mxfp_fa_params mxfp = mxfp_fa_params_init(k, v, DK, DV, nbk2, nbv2, nek2, nev2);
// Non-MXFP dequant functions
ggml_to_float_t k_to_float = is_mxfp_k ? nullptr : ggml_get_to_float_fn(k_type);
ggml_to_float_t v_to_float = is_mxfp_v ? nullptr : ggml_get_to_float_fn(v_type);
// broadcast factors
const int64_t rk2 = neq2/nek2;
@ -8490,6 +8759,16 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
for (int tq = 0; tq < tile_rows; tq++) {
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));
if (is_mxfp_k) {
if (mxfp.apply_hadamard) {
ggml_apply_hadamard_blocks(Q_f32 + tq * DK, DK);
}
// SoA round-trip: quantize Q to SoA, then dequant back to float.
uint8_t q_mxfp_buf[1024];
mxfp.q_quantize(Q_f32 + tq * DK, q_mxfp_buf, DK);
mxfp.k_dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK);
}
}
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
memset(Q_f32 + tq * DK, 0, DK * sizeof(float));
@ -8528,16 +8807,33 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
// Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
for (int tk = 0; tk < kv_tile; tk++) {
const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
if (kv_type == GGML_TYPE_F16) {
if (k_type == GGML_TYPE_F16) {
const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
for (int64_t dk = 0; dk < DK; dk++) {
K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]);
}
} else {
} else if (k_type == GGML_TYPE_F32) {
const float * k_f32_src = (const float *)k_data;
for (int64_t dk = 0; dk < DK; dk++) {
K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
}
} else if (mxfp.k_dequantize) {
const char * k_soa_base = mxfp.k_multihead
? ((const char *)k->data + (ic + tk)*nbk1 + ik3*nbk3)
: k_data;
float k_soa[4096];
GGML_ASSERT(mxfp.k_soa_elems <= 4096);
mxfp.k_dequantize(k_soa_base, k_soa, mxfp.k_soa_elems);
const float * k_head = k_soa + (mxfp.k_multihead ? ik2 * DK : 0);
for (int64_t dk = 0; dk < DK; dk++) {
K_f32[dk * KV_TILE_SZ + tk] = k_head[dk];
}
} else {
float k_tmp[1024];
k_to_float(k_data, k_tmp, DK);
for (int64_t dk = 0; dk < DK; dk++) {
K_f32[dk * KV_TILE_SZ + tk] = k_tmp[dk];
}
}
}
memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
@ -8593,10 +8889,20 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
// Pack V tile to contiguous F32, zero-padded
for (int tk = 0; tk < kv_tile; tk++) {
const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
if (kv_type == GGML_TYPE_F16) {
if (v_type == GGML_TYPE_F16) {
ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
} else {
} else if (v_type == GGML_TYPE_F32) {
memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
} else if (mxfp.v_dequantize) {
const char * v_soa_base = mxfp.v_multihead
? ((const char *)v->data + (ic + tk)*nbv1 + iv3*nbv3)
: v_data;
float v_soa[4096];
GGML_ASSERT(mxfp.v_soa_elems <= 4096);
mxfp.v_dequantize(v_soa_base, v_soa, mxfp.v_soa_elems);
memcpy(V32 + tk * DV, v_soa + (mxfp.v_multihead ? iv2 * DV : 0), DV * sizeof(float));
} else {
v_to_float(v_data, V32 + tk * DV, DV);
}
}
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
@ -8764,8 +9070,10 @@ static void ggml_compute_forward_flash_attn_ext_f16(
// When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking)
const bool use_ref = params->use_ref;
const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512;
// Split-KV: parallelize across KV chunks for single-query decode (token generation).
// Delegates to one_chunk which handles all supported types (F16, Q8_0, Q4_0, MXFP, etc).
const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1)
&& q->type == GGML_TYPE_F32 && nek1 >= 512;
if (use_split_kv_path) {
const int64_t chunk_size = (nek1 + nth - 1) / nth;
@ -8824,8 +9132,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
bool use_tiled = !use_ref &&
(q->type == GGML_TYPE_F32 &&
kv_is_f32_or_f16 &&
k->type == v->type &&
neq1 >= Q_TILE_SZ);
#ifdef GGML_SIMD
use_tiled &= (DV % GGML_F32_EPR == 0);

View File

@ -54,6 +54,14 @@ void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, i
quantize_row_nvfp4_ref(x, y, k);
}
void quantize_row_mxfp8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
quantize_row_mxfp8_ref(x, y, k);
}
void quantize_row_mxfp6_e2m3(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
quantize_row_mxfp6_e2m3_ref(x, y, k);
}
//
// 2-6 bit quantization in super-blocks
//
@ -256,6 +264,70 @@ void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
*s = sumf;
}
// Generic MXFP-to-Q8_0 dot product. Dequants one MX block (32 elements)
// to float via the existing public dequantize_row functions, then dots
// against Q8_0 int8 values. Reference implementation — not SIMD-optimized.
static void ggml_vec_dot_mxfp_q8_0_impl(
int n, float * GGML_RESTRICT s,
const void * GGML_RESTRICT vx, size_t block_size,
const void * GGML_RESTRICT vy,
ggml_to_float_t dequant) {
assert(n % QK8_0 == 0);
const int nb = n / QK8_0;
const block_q8_0 * GGML_RESTRICT y = vy;
float sumf = 0;
for (int ib = 0; ib < nb; ib++) {
float tmp[QK8_0];
dequant((const char *)vx + ib * block_size, tmp, QK8_0);
const float y_d = GGML_CPU_FP16_TO_FP32(y[ib].d);
float block_sum = 0;
for (int j = 0; j < QK8_0; j++) {
block_sum += tmp[j] * (float)y[ib].qs[j];
}
sumf += block_sum * y_d;
}
*s = sumf;
}
void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
ggml_vec_dot_mxfp_q8_0_impl(n, s, vx, sizeof(block_mxfp8), vy,
(ggml_to_float_t)dequantize_row_mxfp8);
}
void ggml_vec_dot_mxfp6_e2m3_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
ggml_vec_dot_mxfp_q8_0_impl(n, s, vx, sizeof(block_mxfp6), vy,
(ggml_to_float_t)dequantize_row_mxfp6_e2m3);
}
// Generic (scalar) dequant wrappers — delegates to ggml-quants.c reference implementations.
// On x86/ARM, arch-specific SIMD versions override these via the fallback.h mapping.
void dequantize_row_mxfp8_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp8(x, y, k);
}
void dequantize_row_mxfp6_e2m3_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp6_e2m3(x, y, k);
}
void dequantize_row_mxfp4_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp4_soa(x, y, k);
}
void dequantize_row_mxfp8_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp8_soa(x, y, k);
}
void dequantize_row_mxfp6_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp6_soa(x, y, k);
}
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;

View File

@ -21,6 +21,12 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_mxfp8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_mxfp6_e2m3(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
// Dequantization (SIMD-optimized, arch-dispatched)
void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp6_e2m3_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
@ -44,6 +50,8 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@ -76,6 +84,20 @@ void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c
void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_mxfp6_e2m3_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void dequantize_row_mxfp8_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp6_e2m3_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
// SoA dequant (SIMD-optimized for FA)
void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp6_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp4_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp8_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp6_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);

View File

@ -3767,7 +3767,7 @@ static block_mxfp4x4 make_block_mxfp4x4(block_mxfp4 * in, unsigned int blck_size
}
static int repack_mxfp4_to_mxfp4_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_MXFP4);
GGML_ASSERT(t->type == GGML_TYPE_MXFP4_E2M1);
GGML_ASSERT(interleave_block == 4);
const block_mxfp4 * src = (const block_mxfp4 *)data;
@ -3824,7 +3824,7 @@ static block_mxfp4x8 make_block_mxfp4x8(block_mxfp4 * in, unsigned int blck_size
}
static int repack_mxfp4_to_mxfp4_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_MXFP4);
GGML_ASSERT(t->type == GGML_TYPE_MXFP4_E2M1);
GGML_ASSERT(interleave_block == 8);
const block_mxfp4 * src = (const block_mxfp4 *)data;
@ -4682,7 +4682,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
}
#endif
}
} else if (cur->type == GGML_TYPE_MXFP4) {
} else if (cur->type == GGML_TYPE_MXFP4_E2M1) {
if (ggml_cpu_has_avx2()) {
if (cur->ne[1] % 8 == 0) {
return &mxfp4_8x8_q8_0;

View File

@ -430,59 +430,28 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
// E8M0 shared exponent to float. Canonical source: ggml_mxfp_e8m0_to_fp32() in ggml-common.h.
// Kept here because ggml-impl.h cannot depend on ggml-common.h IMPL section.
static inline float ggml_e8m0_to_fp32(uint8_t x) {
uint32_t bits; // Stores the raw bit representation of the float
// Handle special case for minimum exponent (denormalized float)
uint32_t bits;
if (x == 0) {
// Bit pattern for 2^(-127):
// - Sign bit: 0 (positive)
// - Exponent: 0 (denormalized number)
// - Mantissa: 0x400000 (0.5 in fractional form)
// Value = 0.5 * 2^(-126) = 2^(-127)
bits = 0x00400000;
}
// note: disabled as we don't need to handle NaNs
//// Handle special case for NaN (all bits set)
//else if (x == 0xFF) {
// // Standard quiet NaN pattern:
// // - Sign bit: 0
// // - Exponent: all 1s (0xFF)
// // - Mantissa: 0x400000 (quiet NaN flag)
// bits = 0x7FC00000;
//}
// Normalized values (most common case)
else {
// Construct normalized float by shifting exponent into position:
// - Exponent field: 8 bits (positions 30-23)
// - Mantissa: 0 (implicit leading 1)
// Value = 2^(x - 127)
bits = 0x00400000; // 2^(-127)
} else {
bits = (uint32_t) x << 23;
}
float result; // Final float value
// Safely reinterpret bit pattern as float without type-punning issues
float result;
memcpy(&result, &bits, sizeof(float));
return result;
}
// Equal to ggml_e8m0_to_fp32/2
// Useful with MXFP4 quantization since the E0M2 values are doubled
// E8M0 to float/2. Canonical source: ggml_mxfp_e8m0_to_fp32_half() in ggml-common.h.
static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
uint32_t bits;
// For x < 2: use precomputed denormal patterns
if (x < 2) {
// 0x00200000 = 2^(-128), 0x00400000 = 2^(-127)
bits = 0x00200000 << x;
}
// For x >= 2: normalized exponent adjustment
else {
// 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1)
} else {
bits = (uint32_t)(x - 1) << 23;
}
// Note: NaNs are not handled here
float result;
memcpy(&result, &bits, sizeof(float));
return result;

View File

@ -3760,7 +3760,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
} else if (op->src[0]->type == GGML_TYPE_F32) {
return op->src[1]->type == GGML_TYPE_F32;
} else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 ||
op->src[0]->type == GGML_TYPE_MXFP4 ||
op->src[0]->type == GGML_TYPE_MXFP4_E2M1 ||
op->src[0]->type == GGML_TYPE_Q4_K ||
op->src[0]->type == GGML_TYPE_Q6_K) {
return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
@ -3771,7 +3771,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
case GGML_OP_MUL_MAT_ID:
if (op->src[0]->type == GGML_TYPE_Q4_0 ||
op->src[0]->type == GGML_TYPE_Q8_0 ||
op->src[0]->type == GGML_TYPE_MXFP4) {
op->src[0]->type == GGML_TYPE_MXFP4_E2M1) {
if (op->src[1]->type == GGML_TYPE_F32) {
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
}
@ -4559,7 +4559,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
return;
}
if (tensor->type == GGML_TYPE_MXFP4) {
if (tensor->type == GGML_TYPE_MXFP4_E2M1) {
ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
@ -5136,7 +5136,7 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
CL_CHECK(clReleaseMemObject(data_device));
return;
}
if (tensor->type == GGML_TYPE_MXFP4) {
if (tensor->type == GGML_TYPE_MXFP4_E2M1) {
ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra;
cl_int err;
@ -5585,7 +5585,7 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso
CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL));
CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_d, buf_d, 0, NULL, NULL));
CL_CHECK(clFinish(queue));
} else if (tensor->type == GGML_TYPE_MXFP4) {
} else if (tensor->type == GGML_TYPE_MXFP4_E2M1) {
ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *) tensor->extra;
GGML_ASSERT(extra);
@ -10550,7 +10550,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3));
#endif // GGML_OPENCL_SOA_Q
break;
case GGML_TYPE_MXFP4: {
case GGML_TYPE_MXFP4_E2M1: {
#ifdef GGML_OPENCL_SOA_Q
kernel = backend_ctx->kernel_mul_mv_mxfp4_f32_flat;
@ -10630,7 +10630,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
GGML_ASSERT(false && "not implemented");
}
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4 ||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4_E2M1 ||
src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q8_0 ||
src0t == GGML_TYPE_Q2_K) {
@ -10864,7 +10864,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
#endif // GGML_OPENCL_SOA_Q
break;
}
case GGML_TYPE_MXFP4: {
case GGML_TYPE_MXFP4_E2M1: {
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
if (use_adreno_moe_kernels(backend_ctx, src0)) {
cl_int status;

View File

@ -257,19 +257,188 @@ void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_REST
}
}
static inline int best_index_mxfp4(float x, float e) {
int best_index = 0;
float best_err = fabsf(kvalues_mxfp4[0]*e - x);
for (int i = 1; i < 16; i++) {
float err = fabsf(kvalues_mxfp4[i]*e - x);
if (err < best_err) {
best_index = i;
best_err = err;
}
}
return best_index;
// ============================================================================
// MXFP Element Conversion Functions
// ============================================================================
//
// Reference implementations for OCP Microscaling (MX) format element types.
// Spec: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
//
// All converters use IEEE-754 bit manipulation via memcpy (C99 safe, no strict
// aliasing issues). Quantization uses round-to-nearest-even (RNE) per MX spec.
//
// These functions are exposed in ggml-quants.h for use by CPU backends and tests.
// GPU backends (CUDA, Vulkan, Metal) provide their own optimized versions using
// hardware intrinsics (e.g., __nv_cvt_float_to_fp8, SIMD groups, LUT lookups).
//
// Key design decisions validated empirically on CUDA (Qwen3-Coder-30B-A3B):
//
// 1. SATURATION, NOT NaN PROPAGATION: FP8 E4M3 saturates to max (0x7E = 448)
// rather than producing NaN. The single NaN encoding (0x7F) is avoided.
// This matches the MX spec behavior and prevents NaN corruption in KV caches.
//
// 2. MX FP6 HAS NO NaN/Inf: Unlike IEEE-754, the MX spec defines exp=max as a
// valid normal value for FP6 types. Dequantizers must NOT special-case it.
//
// 3. RNE ROUNDING IN SUBNORMALS: Both normal and subnormal paths use proper
// round-to-nearest-even with sticky bit tracking. This was a P0 bug fix —
// truncation caused measurable PPL regression.
//
// 4. E3M2 SUBNORMAL SCALE: mant * 2^(1-bias-m) = mant * 2^(-4) = mant/16.
// NOT mant/4. This was a critical bug — the exponent bias and mantissa width
// both affect the subnormal multiplier.
//
// FP8 E4M3: 1 sign, 4 exponent (bias 7), 3 mantissa
// Max finite: 448 (exp=15, mant=6), NaN: exp=15, mant=7
// Thin wrappers around canonical implementations in ggml-common.h.
// Verified bit-for-bit identical by test-mxfp-converters.
float fp8_e4m3_to_float(uint8_t v) { return ggml_mxfp_fp8_e4m3_to_float(v); }
uint8_t float_to_fp8_e4m3_rn(float x) { return ggml_mxfp_float_to_fp8_e4m3(x); }
// ============================================================================
// MXFP Quantization Infrastructure
// ============================================================================
//
// The MX format uses a shared E8M0 exponent per block of 32 elements. Choosing
// the optimal exponent is critical for quantization quality.
//
// The OCP MX v1.0 spec (§5.3) specifies floor(log2(amax)) for the shared exponent.
// We improve on this with an MSE-optimal ±1 search that tests 3 candidate exponents
// {e-1, e, e+1} around round(log2(amax)) and picks whichever minimizes the total
// round-trip quantization error for the block. This consistently improves perplexity
// by 0.05-0.2 across all MX types versus floor-only or round-only approaches.
//
// The round(log2(amax)) base is computed via IEEE-754 integer bit extraction rather
// than log2f(), avoiding GPU Special Function Unit (SFU) bottlenecks. The rounding
// threshold 0x3504F3 is the fractional part of sqrt(2) in IEEE-754 mantissa bits:
// if mantissa >= (sqrt(2)-1)*2^23 ≈ 0x3504F3, then log2(x) >= n+0.5, so round up.
//
// Each MX element type provides an mse_error function that computes the round-trip
// quantization error for a single value at a given scale. The traits structure
// encapsulates this per-type behavior.
//
// Per-type traits for MSE-optimal E8M0 scale computation.
// emax_offset: type-specific offset from E8M0 bias to type's max representable exponent
// to_elem/to_float: element conversion function pointers (NULL for MXFP4 which uses LUT)
// mse_error: round-trip error function for a single value at a given scale
typedef struct {
int emax_offset;
uint8_t (*to_elem)(float);
float (*to_float)(uint8_t);
float (*mse_error)(float val, float inv_scale, float scale);
} mxfp_elem_traits_t;
// Forward declaration — defined after kvalues_mxfp4 lookup table section.
static inline int best_index_mxfp4(float x, float e);
// MXFP4 E2M1 MSE error: decision boundary quantization with HALF scale factor.
//
// This CPU implementation uses the doubled int8 kvalues_mxfp4 LUT {0,1,2,3,4,6,8,12}
// with GGML_E8M0_TO_FP32_HALF(e) = scale/2 for efficient nibble-indexed integer arithmetic.
// The MSE interface passes GGML_E8M0_TO_FP32(e) as scale, so we halve it.
//
// Canonical E2M1 values are {0, 0.5, 1, 1.5, 2, 3, 4, 6} (kvalues_mxfp4_float in ggml-common.h).
// Doubled boundaries {0.5, 1.5, 2.5, 3.5, 5, 7, 10} ÷ 2 = canonical {0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5}.
// Mathematically identical — the doubling is an implementation detail.
// This is the Lloyd-Max quantizer for uniform input density.
static float mse_error_mxfp4(float val, float inv_scale, float scale) {
// Decision boundary quantization with direct reconstruction.
// kvalues_mxfp4 positive sorted: {0, 1, 2, 3, 4, 6, 8, 12}
// Use inv_scale * 2 since MXFP4 scale includes 0.5x factor.
const float d = scale * 0.5f;
const float inv_d = (d > 0.0f) ? 1.0f / d : 0.0f;
const float normalized = fabsf(val) * inv_d;
(void)inv_scale;
float qval;
if (normalized < 0.5f) qval = 0.0f;
else if (normalized < 1.5f) qval = 1.0f;
else if (normalized < 2.5f) qval = 2.0f;
else if (normalized < 3.5f) qval = 3.0f;
else if (normalized < 5.0f) qval = 4.0f;
else if (normalized < 7.0f) qval = 6.0f;
else if (normalized < 10.0f) qval = 8.0f;
else qval = 12.0f;
const float err = fabsf(val) - qval * d;
return err * err;
}
static const mxfp_elem_traits_t mxfp4_traits = { MXFP4_E2M1_EMAX_OFFSET, NULL, NULL, mse_error_mxfp4 };
// MSE-optimal E8M0 shared exponent computation.
//
// Algorithm:
// 1. Find amax = max(|x[0..qk-1]|)
// 2. Compute e_base = round(log2(amax)) - emax_offset + 127 via integer bit ops
// 3. Test {e_base-R .. e_base+R}, pick the one minimizing total round-trip MSE
// where R = MXFP_E8M0_MSE_RANGE (defined in ggml-common.h)
//
// The ±R search improves on the OCP spec's floor(log2(amax)). Wider search finds
// better scales for blocks with non-uniform value distributions (especially FP4).
// Cost is (2R+1) × qk roundtrip evaluations per block — negligible vs attention compute.
//
// Integer log2 avoids log2f() (SFU-dependent on GPU). The sqrt(2) rounding threshold
// ensures we start from round() not floor().
//
// Ref: OCP MX v1.0 §5.3; Four Over Six (arXiv:2512.02010)
static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_elem_traits_t * traits) {
float amax = 0.0f;
for (int j = 0; j < qk; j++) {
const float a = fabsf(x[j]);
if (a > amax) amax = a;
}
if (amax == 0.0f) return 0;
const int e_base = ggml_mxfp_e8m0_base_estimate(amax, traits->emax_offset);
// ±R MSE search: test 2R+1 candidates around e_base, pick lowest total MSE.
int e_lo = e_base - MXFP_E8M0_MSE_RANGE;
int e_hi = e_base + MXFP_E8M0_MSE_RANGE;
if (e_lo < 1) e_lo = 1;
if (e_hi < 1) e_hi = 1;
if (e_hi > 254) e_hi = 254;
int best_e = e_base < 0 ? 0 : (e_base > 254 ? 254 : e_base);
float best_mse = 1e30f;
for (int test_e = e_lo; test_e <= e_hi; ++test_e) {
const float test_scale = GGML_E8M0_TO_FP32((uint8_t)test_e);
const float test_inv = 1.0f / test_scale;
float mse = 0.0f;
for (int j = 0; j < qk; ++j) {
mse += traits->mse_error(x[j], test_inv, test_scale);
}
if (mse < best_mse) {
best_mse = mse;
best_e = test_e;
}
}
return (uint8_t)best_e;
}
static inline int best_index_mxfp4(float x, float e) {
// Decision boundary quantization: 7 comparisons instead of 16-element scan.
// kvalues_mxfp4 positive sorted: {0, 1, 2, 3, 4, 6, 8, 12}
// Decision boundaries (midpoints): {0.5, 1.5, 2.5, 3.5, 5, 7, 10}
const float inv_e = (e > 0.0f) ? 1.0f / e : 0.0f;
const float normalized = fabsf(x) * inv_e;
int idx;
if (normalized < 0.5f) idx = 0;
else if (normalized < 1.5f) idx = 1;
else if (normalized < 2.5f) idx = 2;
else if (normalized < 3.5f) idx = 3;
else if (normalized < 5.0f) idx = 4;
else if (normalized < 7.0f) idx = 5;
else if (normalized < 10.0f) idx = 6;
else idx = 7;
return (x < 0.0f) ? (idx + 8) : idx;
}
// FP4 E2M1: search-based quantization using best_index_mxfp4 lookup table.
// Unlike FP6/FP8 which use direct float->element conversion, FP4 finds the
// closest 4-bit value by minimizing reconstruction error against the lookup table.
// Scale uses GGML_E8M0_TO_FP32_HALF (includes 0.5x factor for E2M1 mantissa range).
void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {
static const int qk = QK_MXFP4;
@ -278,18 +447,7 @@ void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RE
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max
for (int j = 0; j < qk; j++) {
const float v = x[i*qk + j];
if (amax < fabsf(v)) {
amax = fabsf(v);
}
}
const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0;
const uint8_t e = mxfp_compute_e8m0_mse(&x[i*qk], qk, &mxfp4_traits);
const float d = GGML_E8M0_TO_FP32_HALF(e);
y[i].e = e;
@ -494,6 +652,305 @@ void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_REST
}
}
// ============================================================================
// Hadamard Rotation (reference scalar implementation)
// ============================================================================
//
// 32-element Walsh-Hadamard transform, applied to MX blocks before quantization
// to spread outlier energy uniformly across the shared-exponent group.
//
// Without rotation, a single outlier in a block of 32 forces the shared E8M0
// exponent high, wasting precision for all 31 other elements. The Hadamard
// transform is orthogonal (H^T·H = I), so H(K)·H(Q) = K·Q — attention scores
// are preserved exactly when both K and Q undergo the same rotation.
//
// Implementation: 5 butterfly stages (log2(32) = 5) of the fast Walsh-Hadamard
// transform, followed by normalization by 1/sqrt(32). Total: 160 FP add/sub +
// 32 FP mul. This is the standard "in-place" FWHT with O(n·log(n)) operations.
//
// The 1/sqrt(32) normalization factor makes the transform orthonormal:
// H_normalized = H_unnormalized / sqrt(N)
// This ensures the transform preserves vector norms (energy), which is critical
// for maintaining attention score magnitudes after rotation.
//
// Prior art: QuIP# (Tseng et al. 2024), BRQ (Huang et al. 2024) apply Hadamard
// for weight quantization. Our novel contribution: applying it to KV cache
// quantization at the MX block boundary (block-32), where it matches the shared
// exponent group size. Tested alternatives (block-8, block-16, sign flips,
// permutations) all degraded quality — block-32 Hadamard is uniquely optimal
// because it spreads energy across exactly the elements sharing an exponent.
//
// Empirical PPL impact WITHOUT Hadamard rotation (Qwen3-Coder-30B-A3B):
// MXFP8 E4M3: +0.22, MXFP8 E5M2: +1.38, MXFP6 E2M3: +3.34, MXFP6 E3M2: +4.60
//
void ggml_hadamard_32_inplace(float vals[32]) {
ggml_mxfp_hadamard_32_inplace(vals);
}
float fp6_e2m3_to_float(uint8_t v) { return ggml_mxfp_fp6_e2m3_to_float(v); }
uint8_t float_to_fp6_e2m3_rn(float x) { return ggml_mxfp_float_to_fp6_e2m3(x); }
float fp6_e3m2_to_float(uint8_t v) { return ggml_mxfp_fp6_e3m2_to_float(v); }
uint8_t float_to_fp6_e3m2_rn(float x) { return ggml_mxfp_float_to_fp6_e3m2(x); }
float fp8_e5m2_to_float(uint8_t v) { return ggml_mxfp_fp8_e5m2_to_float(v); }
uint8_t float_to_fp8_e5m2_rn(float x) { return ggml_mxfp_float_to_fp8_e5m2(x); }
void pack_fp6x4(const uint8_t v[4], uint8_t out[3]) { ggml_mxfp_pack_fp6x4(v, out); }
void unpack_fp6x4(const uint8_t in[3], uint8_t v[4]) { ggml_mxfp_unpack_fp6x4(in, v); }
// MSE error functions for FP8/FP6: quantize at given scale → dequantize → squared error.
// Used by mxfp_compute_e8m0_mse() to evaluate candidate E8M0 exponents.
// These call the public API wrappers which delegate to canonical ggml_mxfp_* in ggml-common.h.
static float mse_error_fp8_e4m3(float val, float inv_scale, float scale) {
const float recon = fp8_e4m3_to_float(float_to_fp8_e4m3_rn(val * inv_scale)) * scale;
const float err = val - recon;
return err * err;
}
static float mse_error_fp6_e2m3(float val, float inv_scale, float scale) {
const float recon = fp6_e2m3_to_float(float_to_fp6_e2m3_rn(val * inv_scale)) * scale;
const float err = val - recon;
return err * err;
}
// emax_offset = ceil(log2(max_finite_value)) for each element type.
// This centers the E8M0 exponent search around the optimal scale for the type's range.
// E4M3: max=448, ceil(log2(448)) = 9, but offset=8 matches CUDA (empirically better)
// E2M3: max=7.5, ceil(log2(7.5)) = 3
static const mxfp_elem_traits_t mxfp8_e4m3_traits = { MXFP8_E4M3_EMAX_OFFSET, float_to_fp8_e4m3_rn, fp8_e4m3_to_float, mse_error_fp8_e4m3 };
static const mxfp_elem_traits_t mxfp6_e2m3_traits = { MXFP6_E2M3_EMAX_OFFSET, float_to_fp6_e2m3_rn, fp6_e2m3_to_float, mse_error_fp6_e2m3 };
// FP8 quantize/dequantize: byte-per-element, shared by E4M3 and E5M2
static void quantize_row_mxfp8_impl(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y,
int64_t k, const mxfp_elem_traits_t * traits) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
for (int i = 0; i < nb; i++) {
const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP8], QK_MXFP8, traits);
const float d = GGML_E8M0_TO_FP32(e);
const float inv_d = d > 0.0f ? 1.0f / d : 0.0f;
y[i].e = e;
for (int j = 0; j < QK_MXFP8; ++j) {
y[i].qs[j] = traits->to_elem(x[i*QK_MXFP8 + j] * inv_d);
}
}
}
static void dequantize_row_mxfp8_impl(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y,
int64_t k, const mxfp_elem_traits_t * traits) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
for (int i = 0; i < nb; i++) {
const float d = GGML_E8M0_TO_FP32(x[i].e);
for (int j = 0; j < QK_MXFP8; ++j) {
y[i*QK_MXFP8 + j] = traits->to_float(x[i].qs[j]) * d;
}
}
}
// FP6 quantize/dequantize: tight 6-bit packing (4 values per 3 bytes), shared by E2M3 and E3M2
static void quantize_row_mxfp6_impl(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y,
int64_t k, const mxfp_elem_traits_t * traits) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
for (int i = 0; i < nb; i++) {
const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP6], QK_MXFP6, traits);
const float d = GGML_E8M0_TO_FP32(e);
const float inv_d = d > 0.0f ? 1.0f / d : 0.0f;
y[i].e = e;
for (int j = 0; j < QK_MXFP6; j += 4) {
uint8_t vals[4];
for (int jj = 0; jj < 4; jj++) {
vals[jj] = traits->to_elem(x[i*QK_MXFP6 + j + jj] * inv_d);
}
pack_fp6x4(vals, &y[i].qs[j * 3 / 4]);
}
}
}
static void dequantize_row_mxfp6_impl(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y,
int64_t k, const mxfp_elem_traits_t * traits) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
for (int i = 0; i < nb; i++) {
const float d = GGML_E8M0_TO_FP32(x[i].e);
for (int j = 0; j < QK_MXFP6; j += 4) {
uint8_t vals[4];
unpack_fp6x4(&x[i].qs[j * 3 / 4], vals);
for (int jj = 0; jj < 4; jj++) {
y[i*QK_MXFP6 + j + jj] = traits->to_float(vals[jj]) * d;
}
}
}
}
// Public API wrappers — one-line delegates to the traits-parameterized impl
void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k) {
quantize_row_mxfp8_impl(x, y, k, &mxfp8_e4m3_traits);
}
void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp8_impl(x, y, k, &mxfp8_e4m3_traits);
}
void quantize_row_mxfp6_e2m3_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k) {
quantize_row_mxfp6_impl(x, y, k, &mxfp6_e2m3_traits);
}
void dequantize_row_mxfp6_e2m3(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp6_impl(x, y, k, &mxfp6_e2m3_traits);
}
// ============================================================================
// SoA (Struct-of-Arrays) quantize/dequantize — canonical reference for FA
// ============================================================================
//
// SoA layout per row: [qs_block0|qs_block1|...|qs_blockN][e8m0_0|e8m0_1|...|e8m0_N]
// Total bytes per row = nblocks * (QS_PER_BLOCK + 1) = identical to AoS.
// This is the ONLY layout used by flash attention across all backends.
void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) {
assert(k % QK_MXFP4 == 0);
const int nb = k / QK_MXFP4;
char * row = (char *)dst;
char * qs_base = row;
char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK);
for (int i = 0; i < nb; i++) {
const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP4], QK_MXFP4, &mxfp4_traits);
const float d = GGML_E8M0_TO_FP32_HALF(e);
e8m0_base[i] = (char)e;
uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP4_SOA_QS_PER_BLOCK));
for (int j = 0; j < QK_MXFP4/2; ++j) {
const uint8_t x0 = best_index_mxfp4(x[i*QK_MXFP4 + 0 + j], d);
const uint8_t x1 = best_index_mxfp4(x[i*QK_MXFP4 + QK_MXFP4/2 + j], d);
qs[j] = x0 | (x1 << 4);
}
}
}
void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) {
assert(k % QK_MXFP4 == 0);
const int nb = k / QK_MXFP4;
const char * row = (const char *)src;
const char * qs_base = row;
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK);
for (int i = 0; i < nb; i++) {
const float d = GGML_E8M0_TO_FP32_HALF((uint8_t)e8m0_base[i]);
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP4_SOA_QS_PER_BLOCK));
for (int j = 0; j < QK_MXFP4/2; ++j) {
const int8_t x0 = kvalues_mxfp4[qs[j] & 0x0F];
const int8_t x1 = kvalues_mxfp4[qs[j] >> 4];
y[i*QK_MXFP4 + j + 0 ] = x0*d;
y[i*QK_MXFP4 + j + QK_MXFP4/2] = x1*d;
}
}
}
static void quantize_row_mxfp8_soa_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst,
int64_t k, const mxfp_elem_traits_t * traits) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
char * row = (char *)dst;
char * qs_base = row;
char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK);
for (int i = 0; i < nb; i++) {
const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP8], QK_MXFP8, traits);
const float d = GGML_E8M0_TO_FP32(e);
const float inv_d = d > 0.0f ? 1.0f / d : 0.0f;
e8m0_base[i] = (char)e;
uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP8_SOA_QS_PER_BLOCK));
for (int j = 0; j < QK_MXFP8; ++j) {
qs[j] = traits->to_elem(x[i*QK_MXFP8 + j] * inv_d);
}
}
}
static void dequantize_row_mxfp8_soa_impl(const void * GGML_RESTRICT src, float * GGML_RESTRICT y,
int64_t k, const mxfp_elem_traits_t * traits) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
const char * row = (const char *)src;
const char * qs_base = row;
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK);
for (int i = 0; i < nb; i++) {
const float d = GGML_E8M0_TO_FP32((uint8_t)e8m0_base[i]);
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP8_SOA_QS_PER_BLOCK));
for (int j = 0; j < QK_MXFP8; ++j) {
y[i*QK_MXFP8 + j] = traits->to_float(qs[j]) * d;
}
}
}
static void quantize_row_mxfp6_soa_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst,
int64_t k, const mxfp_elem_traits_t * traits) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
char * row = (char *)dst;
char * qs_base = row;
char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK);
for (int i = 0; i < nb; i++) {
const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP6], QK_MXFP6, traits);
const float d = GGML_E8M0_TO_FP32(e);
const float inv_d = d > 0.0f ? 1.0f / d : 0.0f;
e8m0_base[i] = (char)e;
uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP6_SOA_QS_PER_BLOCK));
for (int j = 0; j < QK_MXFP6; j += 4) {
uint8_t vals[4];
for (int jj = 0; jj < 4; jj++) {
vals[jj] = traits->to_elem(x[i*QK_MXFP6 + j + jj] * inv_d);
}
pack_fp6x4(vals, &qs[j * 3 / 4]);
}
}
}
static void dequantize_row_mxfp6_soa_impl(const void * GGML_RESTRICT src, float * GGML_RESTRICT y,
int64_t k, const mxfp_elem_traits_t * traits) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
const char * row = (const char *)src;
const char * qs_base = row;
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK);
for (int i = 0; i < nb; i++) {
const float d = GGML_E8M0_TO_FP32((uint8_t)e8m0_base[i]);
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP6_SOA_QS_PER_BLOCK));
for (int j = 0; j < QK_MXFP6; j += 4) {
uint8_t vals[4];
unpack_fp6x4(&qs[j * 3 / 4], vals);
for (int jj = 0; jj < 4; jj++) {
y[i*QK_MXFP6 + j + jj] = traits->to_float(vals[jj]) * d;
}
}
}
}
void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) {
quantize_row_mxfp8_soa_impl(x, dst, k, &mxfp8_e4m3_traits);
}
void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp8_soa_impl(src, y, k, &mxfp8_e4m3_traits);
}
void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) {
quantize_row_mxfp6_soa_impl(x, dst, k, &mxfp6_e2m3_traits);
}
void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp6_soa_impl(src, y, k, &mxfp6_e2m3_traits);
}
//
// 2-6 bit quantization in super-blocks
//
@ -2155,7 +2612,7 @@ size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
GGML_UNUSED(quant_weights);
quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
return nrow * ggml_row_size(GGML_TYPE_MXFP4_E2M1, n_per_row);
}
size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
@ -2164,6 +2621,18 @@ size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
return nrow * ggml_row_size(GGML_TYPE_NVFP4, n_per_row);
}
size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
GGML_UNUSED(quant_weights);
quantize_row_mxfp8_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_MXFP8_E4M3, n_per_row);
}
size_t quantize_mxfp6_e2m3(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
GGML_UNUSED(quant_weights);
quantize_row_mxfp6_e2m3_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_MXFP6_E2M3, n_per_row);
}
// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {
@ -5306,7 +5775,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
{
VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
} break;
case GGML_TYPE_MXFP4:
case GGML_TYPE_MXFP4_E2M1:
{
VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb);
} break;

View File

@ -23,6 +23,8 @@ GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 *
GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp6_e2m3_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
@ -50,6 +52,17 @@ GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GG
GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_mxfp6_e2m3(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
// SoA (Struct-of-Arrays) quantize/dequantize — canonical reference for flash attention.
// Layout: [qs contiguous][e8m0 contiguous] per row. Same total bytes as AoS.
GGML_API void quantize_row_mxfp4_soa (const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
GGML_API void dequantize_row_mxfp4_soa (const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp8_soa (const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
GGML_API void dequantize_row_mxfp8_soa (const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
GGML_API void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
@ -98,6 +111,87 @@ GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTR
GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_mxfp6_e2m3(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
//
// MXFP element-level conversion functions (reference implementations)
//
// These implement the OCP Microscaling (MX) format element types as defined in:
// OCP Microscaling Formats (MX) Specification v1.0, Sep 2023
// https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
//
// Each MX block contains 32 elements sharing a single E8M0 exponent. The element
// types define the per-element mantissa format within each block.
//
// All converters use IEEE-754 bit manipulation for exact results (no floating-point
// rounding in the conversion itself). Quantization functions use round-to-nearest-even
// (RNE) per the MX specification.
//
// GPU backends (CUDA, Metal, Vulkan) provide their own optimized versions — these
// functions serve as the canonical reference and are used by the CPU backend.
//
// FP8 E4M3: 1 sign, 4 exponent (bias 7), 3 mantissa bits
// Range: ±[2^-9, 448], NaN: exp=15 mant=7 (only NaN encoding)
// Ref: OCP MX v1.0 §4.2
GGML_API float fp8_e4m3_to_float(uint8_t v);
GGML_API uint8_t float_to_fp8_e4m3_rn(float x);
// FP8 E5M2: 1 sign, 5 exponent (bias 15), 2 mantissa bits
// Range: ±[2^-16, 57344], NaN/Inf: exp=31 (standard IEEE-like)
// Ref: OCP MX v1.0 §4.2
GGML_API float fp8_e5m2_to_float(uint8_t v);
GGML_API uint8_t float_to_fp8_e5m2_rn(float x);
// FP6 E2M3: 1 sign, 2 exponent (bias 1), 3 mantissa bits
// Range: ±[2^-3, 7.5], stored as low 6 bits of a byte (00xxxxxx)
// MX format: NO NaN/Inf — all bit patterns are valid numbers
// Ref: OCP MX v1.0 §4.2
GGML_API float fp6_e2m3_to_float(uint8_t v);
GGML_API uint8_t float_to_fp6_e2m3_rn(float x);
// FP6 E3M2: 1 sign, 3 exponent (bias 3), 2 mantissa bits
// Range: ±[2^-4, 28.0], stored as low 6 bits of a byte (00xxxxxx)
// MX format: NO NaN/Inf — exp=7 is a valid normal value (unlike IEEE-754)
// CRITICAL: subnormal scale is 2^(1-bias-m) = 2^(-4) = 1/16, NOT 1/4
// Ref: OCP MX v1.0 §4.2
GGML_API float fp6_e3m2_to_float(uint8_t v);
GGML_API uint8_t float_to_fp6_e3m2_rn(float x);
// FP6 tight packing: pack/unpack 4 six-bit values into/from 3 bytes
// Layout: v[0]=bits[5:0], v[1]=bits[11:6], v[2]=bits[17:12], v[3]=bits[23:18]
// Saves 25% memory vs byte-padded storage (24B vs 32B per MX block)
GGML_API void pack_fp6x4(const uint8_t v[4], uint8_t out[3]);
GGML_API void unpack_fp6x4(const uint8_t in[3], uint8_t v[4]);
//
// Hadamard rotation (reference scalar implementation)
//
// 32-element Walsh-Hadamard transform applied to MX blocks before quantization.
// Distributes outlier energy uniformly across the block, dramatically improving
// quantization quality for types with shared exponents.
//
// Mathematical property: H^T·H = I (orthogonal), so H(K)·H(Q) = K·Q.
// Flash attention applies matching rotation to Q, preserving attention scores exactly.
//
// Implementation: 5 butterfly stages (log2(32) = 5) + normalization by 1/sqrt(32).
// This is the standard "fast Walsh-Hadamard transform" with O(n log n) operations.
//
// Applied in set_rows (K cache quantization) and flash_attn (Q quantization).
// Skipped for MLA models (DK != DV) where V is a view of K — rotation would corrupt V.
//
// Empirical impact (PPL degradation WITHOUT rotation, Qwen3-Coder-30B):
// MXFP8 E4M3: +0.22, MXFP8 E5M2: +1.38, MXFP6 E2M3: +3.34, MXFP6 E3M2: +4.60
//
// Prior art: QuIP# (Tseng et al. 2024), BRQ (Huang et al. 2024) use Hadamard for
// weight quantization. Our contribution applies it to KV cache quantization at the
// MX block boundary, where block-32 is optimal because it matches the shared exponent
// group size exactly.
//
// GPU backends provide optimized versions (CUDA warp shuffles, Metal SIMD groups).
//
GGML_API void ggml_hadamard_32_inplace(float vals[32]);
GGML_API void iq2xs_init_impl(enum ggml_type type);
GGML_API void iq2xs_free_impl(enum ggml_type type);

View File

@ -639,7 +639,7 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
return dequantize_row_iq4_xs_sycl;
case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_sycl;
case GGML_TYPE_MXFP4:
case GGML_TYPE_MXFP4_E2M1:
return dequantize_row_mxfp4_sycl;
case GGML_TYPE_F32:
return convert_unary_sycl<float>;
@ -706,7 +706,7 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
return dequantize_row_iq4_xs_sycl;
case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_sycl;
case GGML_TYPE_MXFP4:
case GGML_TYPE_MXFP4_E2M1:
return dequantize_row_mxfp4_sycl;
case GGML_TYPE_F16:
return convert_unary_sycl<sycl::half>;

View File

@ -1142,7 +1142,7 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
case GGML_TYPE_IQ4_XS:
mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_MXFP4:
case GGML_TYPE_MXFP4_E2M1:
mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
default:

View File

@ -710,8 +710,8 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true,
.from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref,
},
[GGML_TYPE_MXFP4] = {
.type_name = "mxfp4",
[GGML_TYPE_MXFP4_E2M1] = {
.type_name = "mxfp4_e2m1",
.blck_size = QK_MXFP4,
.type_size = sizeof(block_mxfp4),
.is_quantized = true,
@ -726,6 +726,22 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
.to_float = (ggml_to_float_t) dequantize_row_nvfp4,
.from_float_ref = (ggml_from_float_t)quantize_row_nvfp4_ref,
},
[GGML_TYPE_MXFP8_E4M3] = {
.type_name = "mxfp8_e4m3",
.blck_size = QK_MXFP8,
.type_size = sizeof(block_mxfp8),
.is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_mxfp8,
.from_float_ref = (ggml_from_float_t)quantize_row_mxfp8_ref,
},
[GGML_TYPE_MXFP6_E2M3] = {
.type_name = "mxfp6_e2m3",
.blck_size = QK_MXFP6,
.type_size = sizeof(block_mxfp6),
.is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_mxfp6_e2m3,
.from_float_ref = (ggml_from_float_t)quantize_row_mxfp6_e2m3_ref,
},
[GGML_TYPE_Q2_K] = {
.type_name = "q2_K",
.blck_size = QK_K,
@ -1306,6 +1322,30 @@ bool ggml_is_quantized(enum ggml_type type) {
return type_traits[type].is_quantized;
}
bool ggml_is_type_mxfp(enum ggml_type type) {
return type == GGML_TYPE_MXFP4_E2M1 ||
type == GGML_TYPE_MXFP8_E4M3 ||
type == GGML_TYPE_MXFP6_E2M3;
}
bool ggml_mxfp_use_hadamard(enum ggml_type type) {
switch (type) {
case GGML_TYPE_MXFP4_E2M1: return MXFP_USE_HADAMARD_E2M1;
case GGML_TYPE_MXFP8_E4M3: return MXFP_USE_HADAMARD_E4M3;
case GGML_TYPE_MXFP6_E2M3: return MXFP_USE_HADAMARD_E2M3;
default: return false;
}
}
int ggml_mxfp_qs_per_block(enum ggml_type type) {
switch (type) {
case GGML_TYPE_MXFP4_E2M1: return MXFP_QS_PER_BLOCK_E2M1;
case GGML_TYPE_MXFP8_E4M3: return MXFP_QS_PER_BLOCK_E4M3;
case GGML_TYPE_MXFP6_E2M3: return MXFP_QS_PER_BLOCK_E2M3;
default: return 0;
}
}
const char * ggml_op_name(enum ggml_op op) {
return GGML_OP_NAME[op];
}
@ -1381,7 +1421,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
case GGML_FTYPE_MOSTLY_MXFP4: wtype = GGML_TYPE_MXFP4; break;
case GGML_FTYPE_MOSTLY_MXFP4_E2M1: wtype = GGML_TYPE_MXFP4_E2M1; break;
case GGML_FTYPE_MOSTLY_NVFP4: wtype = GGML_TYPE_NVFP4; break;
case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
@ -7649,8 +7689,10 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_MXFP4_E2M1: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_MXFP8_E4M3: result = quantize_mxfp8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_MXFP6_E2M3: result = quantize_mxfp6_e2m3(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;

View File

@ -51,6 +51,7 @@ llama_kv_cache::llama_kv_cache(
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
ggml_init_params params = {
// 2 base tensors (K+V) + 2*n_stream view tensors
/*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
@ -135,7 +136,17 @@ llama_kv_cache::llama_kv_cache(
const bool has_k = true;
const bool has_v = !is_mla;
ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream) : nullptr;
// MXFP K cache: align block count to 16 for cp.async.
uint32_t n_embd_k_alloc = n_embd_k_gqa;
const bool is_mxfp_k = ggml_is_type_mxfp(type_k);
if (is_mxfp_k) {
const int qk = (int)ggml_blck_size(type_k); // 32 for all MXFP types
const int blocks = (int)n_embd_k_gqa / qk;
const int blocks_aligned = (blocks + 15) & ~15; // align to 16
n_embd_k_alloc = (uint32_t)(blocks_aligned * qk);
}
ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_alloc, kv_size, n_stream) : nullptr;
ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream) : nullptr;
has_k && ggml_format_name(k, "cache_k_l%d", il);
@ -1025,19 +1036,16 @@ ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_k
auto * k = layers[ikv].k;
const uint64_t kv_size = get_size();
const uint64_t n_embd_k_gqa = k->ne[0];
assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il));
// For MXFP types: k->ne[0] may include alignment padding (blocks aligned to 16).
// The row stride (k->nb[1]) reflects the padded allocation.
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
return ggml_view_4d(ctx, k,
hparams.n_embd_head_k(il), hparams.n_head_kv(il), n_kv, ns,
ggml_row_size(k->type, hparams.n_embd_head_k(il)),
ggml_row_size(k->type, n_embd_k_gqa),
ggml_row_size(k->type, n_embd_k_gqa*kv_size),
ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
k->nb[1],
k->nb[2],
k->nb[2]*sinfo.s0);
}
ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
@ -1092,19 +1100,38 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm
k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_tokens, k_cur->nb[2], 0);
const int64_t n_stream = k->ne[2];
const int64_t kv_size = get_size();
if (n_stream > 1) {
const int64_t kv_size = get_size();
assert(n_embd_gqa == k->ne[0]);
assert(kv_size == k->ne[1]);
assert(kv_size == k->ne[1]);
// merge the buffer across all streams because the idxs are global
k = ggml_reshape_2d(ctx, k, n_embd_gqa, kv_size*n_stream);
// Use view_2d to preserve nb[1] (which includes alignment padding for MXFP types)
k = ggml_view_2d(ctx, k, k->ne[0], kv_size*n_stream, k->nb[1], 0);
}
const bool is_mxfp = ggml_is_type_mxfp(k->type);
// For MXFP: ne[0] may be padded for block alignment, but k_cur has n_embd_gqa.
// Create view with ne[0]=n_embd_gqa, preserving the larger row stride nb[1].
ggml_tensor * k_dst = k;
if (is_mxfp) {
k_dst = ggml_view_2d(ctx, k, n_embd_gqa, k->ne[1], k->nb[1], 0);
}
// store the current K values into the cache
return ggml_set_rows(ctx, k, k_cur, k_idxs);
ggml_tensor * result = ggml_set_rows(ctx, k_dst, k_cur, k_idxs);
// Flag K cache writes for Walsh-Hadamard rotation (QuaRot, arXiv:2404.00456; BRQ, arXiv:2511.04214).
// The flash attention kernel applies matching rotation to Q so H(Q)·H(K)^T = Q·K^T.
// V cache writes are NOT rotated (op_params[0] defaults to 0).
// Skipped for: MLA (V is a view of K — rotation would corrupt V),
// E5M2/E3M2 (2-bit mantissa — Hadamard provides no quality benefit).
if (is_mxfp && !hparams.is_mla() && ggml_mxfp_use_hadamard(k->type)) {
((int32_t *)result->op_params)[0] = 1;
}
return result;
}
ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {

View File

@ -457,7 +457,7 @@ static ggml_type llama_tensor_get_type_impl(quantize_state_impl & qs, ggml_type
// MoE tensors -> MXFP4
// other tensors -> Q8_0
if (tensor->ne[2] > 1) {
new_type = GGML_TYPE_MXFP4;
new_type = GGML_TYPE_MXFP4_E2M1;
} else {
new_type = GGML_TYPE_Q8_0;
}
@ -795,7 +795,7 @@ static ggml_type llama_ftype_get_default_type(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_BF16: return GGML_TYPE_BF16;
case LLAMA_FTYPE_ALL_F32: return GGML_TYPE_F32;
case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return GGML_TYPE_MXFP4;
case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return GGML_TYPE_MXFP4_E2M1;
// K-quants
case LLAMA_FTYPE_MOSTLY_Q2_K_S:

View File

@ -150,6 +150,91 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
}
}
// SoA quantize/dequantize functions — declared here because ggml-quants.h is not in the test include path.
typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t);
extern "C" {
void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
}
// Initialize an MXFP tensor with SoA (Struct-of-Arrays) layout.
// soa_bytes: byte width of one SoA region. Default 0 = ne[0] elements (one ggml row).
// For FA K/V tensors, pass nb[1] so that when heads are physically contiguous
// within one KV-position stride, the SoA region spans all heads (matching FA's read pattern).
static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f, size_t soa_bytes = 0) {
GGML_ASSERT(ggml_is_type_mxfp(tensor->type));
typedef void (*soa_quantize_fn)(const float *, void *, int64_t);
soa_quantize_fn quantize_soa = nullptr;
switch (tensor->type) {
case GGML_TYPE_MXFP4_E2M1: quantize_soa = quantize_row_mxfp4_soa; break;
case GGML_TYPE_MXFP8_E4M3: quantize_soa = quantize_row_mxfp8_soa; break;
case GGML_TYPE_MXFP6_E2M3: quantize_soa = quantize_row_mxfp6_soa; break;
default: GGML_ABORT("unsupported MXFP type for SoA init");
}
const int qk = (int)ggml_blck_size(tensor->type);
const size_t block_size = ggml_type_size(tensor->type);
const size_t head_row_sz = ggml_row_size(tensor->type, tensor->ne[0]);
if (soa_bytes == 0) { soa_bytes = head_row_sz; }
const int64_t soa_elems = (int64_t)(soa_bytes / block_size) * qk;
std::default_random_engine gen(42);
std::uniform_real_distribution<float> dist(min, max);
std::vector<float> region_f32(soa_elems);
// Iterate over logical SoA regions using tensor strides.
// Each SoA region is soa_bytes wide at the innermost stride level.
// Outer dimensions (those with stride > soa_bytes) are iterated explicitly.
const size_t nb1 = tensor->nb[1];
const size_t nb2 = tensor->nb[2];
const size_t nb3 = tensor->nb[3];
const int64_t ne1 = tensor->ne[1];
const int64_t ne2 = tensor->ne[2];
const int64_t ne3 = tensor->ne[3];
// Determine iteration: if soa_bytes == nb1, iterate over (ne1 * ne2 * ne3) regions.
// If soa_bytes < nb1 (per-head), iterate over (ne1 * ne2 * ne3) regions with stride nb1.
// We use strides to compute offsets, handling views and permutations correctly.
const int64_t heads_per_region = (int64_t)(soa_bytes / head_row_sz);
// For multi-head regions, we step by nb1 (KV-position stride) between regions.
// For per-head, we step through all dimensions.
std::vector<uint8_t> buf(ggml_nbytes(tensor), 0);
if (heads_per_region > 1) {
// Multi-head SoA: iterate over (kv_positions * batches), each region = nb1 bytes
for (int64_t i3 = 0; i3 < ne3; i3++) {
// ne2/heads_per_region = number of head groups (for GQA broadcast, usually 1)
const int64_t n_groups = ne2 / heads_per_region;
for (int64_t ig = 0; ig < n_groups; ig++) {
for (int64_t i1 = 0; i1 < ne1; i1++) {
size_t offset = i3*nb3 + ig*heads_per_region*nb2 + i1*nb1;
for (int64_t j = 0; j < soa_elems; j++) { region_f32[j] = dist(gen); }
quantize_soa(region_f32.data(), buf.data() + offset, soa_elems);
}
}
}
} else {
// Per-head SoA: one SoA region per ggml row
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
for (int64_t i1 = 0; i1 < ne1; i1++) {
size_t offset = i3*nb3 + i2*nb2 + i1*nb1;
for (int64_t j = 0; j < soa_elems; j++) { region_f32[j] = dist(gen); }
quantize_soa(region_f32.data(), buf.data() + offset, soa_elems);
}
}
}
}
ggml_backend_tensor_set(tensor, buf.data(), 0, buf.size());
}
// generate an F16 mask where certain blocks are randomly masked with -INF value
static void init_tensor_kq_mask(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
GGML_ASSERT(tensor->type == GGML_TYPE_F16);
@ -239,11 +324,30 @@ static std::vector<float> tensor_to_float(const ggml_tensor * t) {
size_t bs = ggml_blck_size(t->type);
std::vector<float> vq(ggml_blck_size(t->type));
bool quantized = ggml_is_quantized(t->type);
const bool is_mxfp = ggml_is_type_mxfp(t->type);
// SoA dequant for MXFP readback
mxfp_soa_dequantize_fn mxfp_dequant_soa = nullptr;
if (is_mxfp) {
switch (t->type) {
case GGML_TYPE_MXFP4_E2M1: mxfp_dequant_soa = dequantize_row_mxfp4_soa; break;
case GGML_TYPE_MXFP8_E4M3: mxfp_dequant_soa = dequantize_row_mxfp8_soa; break;
case GGML_TYPE_MXFP6_E2M3: mxfp_dequant_soa = dequantize_row_mxfp6_soa; break;
default: GGML_ABORT("unsupported MXFP type in tensor_to_float");
}
}
// access elements by index to avoid gaps in views
for (int64_t i3 = 0; i3 < t->ne[3]; i3++) {
for (int64_t i2 = 0; i2 < t->ne[2]; i2++) {
for (int64_t i1 = 0; i1 < t->ne[1]; i1++) {
if (is_mxfp) {
size_t row_off = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1];
std::vector<float> row_f32(t->ne[0]);
mxfp_dequant_soa(&buf[row_off], row_f32.data(), t->ne[0]);
tv.insert(tv.end(), row_f32.begin(), row_f32.end());
continue;
}
for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) {
size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0];
if (t->type == GGML_TYPE_F16) {
@ -2309,8 +2413,12 @@ struct test_set_rows : public test_case {
const std::array<int, 2> nr23; // broadcast only dims 2 and 3
const int r; // rows to set
const bool v; // view (non-contiguous src1)
const bool hadamard; // apply Walsh-Hadamard rotation before quantization
std::string vars() override {
if (hadamard) {
return VARS_TO_STR6(type, type_idx, ne, nr23, r, v) + ",hadamard=1";
}
return VARS_TO_STR6(type, type_idx, ne, nr23, r, v);
}
@ -2318,8 +2426,8 @@ struct test_set_rows : public test_case {
ggml_type type_idx,
std::array<int64_t, 4> ne,
std::array<int, 2> nr23,
int r, bool v = false)
: type(type), type_idx(type_idx), ne(ne), nr23(nr23), r(r), v(v) {}
int r, bool v = false, bool hadamard = false)
: type(type), type_idx(type_idx), ne(ne), nr23(nr23), r(r), v(v), hadamard(hadamard) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
@ -2338,6 +2446,11 @@ struct test_set_rows : public test_case {
}
ggml_tensor * out = ggml_set_rows(ctx, dst, src, row_idxs);
if (hadamard) {
((int32_t *)out->op_params)[0] = 1;
}
ggml_set_name(out, "out");
return out;
@ -2351,6 +2464,10 @@ struct test_set_rows : public test_case {
}
init_set_rows_row_ids(t, ne[1]);
} else if (ggml_is_type_mxfp(t->type)) {
// MXFP dst tensors must use SoA layout — set_rows writes SoA,
// and tensor_to_float reads back assuming SoA for MXFP types.
init_tensor_mxfp_soa(t);
} else {
init_tensor_uniform(t);
}
@ -3798,7 +3915,7 @@ struct test_mul_mat : public test_case {
double max_nmse_err(ggml_backend_t backend) override {
// for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance
if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
if (type_a == GGML_TYPE_MXFP4_E2M1 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
return 2e-2;
}
return max_nmse_err();
@ -3932,9 +4049,10 @@ struct test_mul_mat_id : public test_case {
return 5e-4;
}
// Same Blackwell FP4 tolerance as test_mul_mat above.
double max_nmse_err(ggml_backend_t backend) override {
// for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance
if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
if (type_a == GGML_TYPE_MXFP4_E2M1 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
return 2e-2;
}
return max_nmse_err();
@ -6180,9 +6298,14 @@ struct test_flash_attn_ext : public test_case {
const ggml_prec prec;
const ggml_type type_KV;
const ggml_type type_V; // V type, defaults to type_KV for same-type K/V
std::array<int32_t, 4> permute;
std::string vars() override {
if (type_V != type_KV) {
return VARS_TO_STR13(hsk, hsv, nh, nr23, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, permute)
+ ",type_V=" + ggml_type_name(type_V);
}
return VARS_TO_STR13(hsk, hsv, nh, nr23, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, permute);
}
@ -6199,12 +6322,14 @@ struct test_flash_attn_ext : public test_case {
test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, std::array<int64_t, 2> nr23 = {1, 1}, int64_t kv = 96, int64_t nb = 8,
bool mask = true, bool sinks = false, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
: hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), sinks(sinks), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3},
ggml_type type_V_override = GGML_TYPE_COUNT)
: hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), sinks(sinks), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV),
type_V(type_V_override == GGML_TYPE_COUNT ? type_KV : type_V_override), permute(permute) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_KV));
const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_V));
auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, bool is_view) -> ggml_tensor * {
int64_t ne[4] = {ne0, ne1, ne2, ne3};
@ -6242,7 +6367,7 @@ struct test_flash_attn_ext : public test_case {
// - https://github.com/ggml-org/llama.cpp/pull/18986
v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], 0);
} else {
v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache
v = create_permuted(type_V, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache
}
ggml_set_name(v, "v");
@ -6273,6 +6398,11 @@ struct test_flash_attn_ext : public test_case {
init_tensor_uniform(t, -10.0f, 10.0f);
} else if (strcmp(t->name, "m") == 0) {
init_tensor_kq_mask(t);
} else if ((strcmp(t->name, "k") == 0 || strcmp(t->name, "v") == 0) && ggml_is_type_mxfp(t->type)) {
// MXFP K/V tensors use SoA layout. Pass nb[1] (KV-position stride) as the
// SoA region width — when heads are physically contiguous within that stride,
// the FA kernel dequants the full multi-head region as one SoA block.
init_tensor_mxfp_soa(t, -1.0f, 1.0f, t->nb[1]);
} else {
init_tensor_uniform(t);
}
@ -7279,7 +7409,8 @@ static const ggml_type all_types[] = {
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
GGML_TYPE_Q8_0,
GGML_TYPE_MXFP4,
GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP8_E4M3,
GGML_TYPE_MXFP6_E2M3,
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
GGML_TYPE_Q6_K,
@ -7295,7 +7426,7 @@ static const ggml_type base_types[] = {
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1, // for I8MM tests
GGML_TYPE_Q4_K,
GGML_TYPE_MXFP4, // TODO: or "other"
GGML_TYPE_MXFP4_E2M1,
GGML_TYPE_IQ2_XXS
};
@ -7413,6 +7544,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
// SET_ROWS with Hadamard rotation (exercises the op_params[0] flag used by MXFP KV cache)
for (ggml_type type : {GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP8_E4M3,
GGML_TYPE_MXFP6_E2M3}) {
// ne[0] must be divisible by 32 (Hadamard block size)
test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 128, 5, 1, 1 }, { 1, 1 }, 1, false, true));
test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 5, 1, 3 }, { 1, 1 }, 1, false, true));
}
for (int mode : { GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_VISION }) {
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
for (int ne2 : {1, 8, 512}) {
@ -8143,7 +8282,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat_id_fusion(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
// gpt-oss issue with Vulkan mmq_id
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4_E2M1, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
@ -8603,8 +8742,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (int nb : { 1, 3, 32, 75, }) {
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72) continue;
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0,
GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3,
}) {
// Non-F16 types: test at D=64, D=72, and D=128.
if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72 && hsk != 128) continue;
// MXFP types require D % 32 == 0, skip D=72.
if (ggml_is_type_mxfp(type_KV) && hsk == 72) continue;
test_cases.emplace_back(new test_flash_attn_ext(
hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));
// run fewer test cases permuted
@ -8626,6 +8770,26 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
// MXFP-specific K/V type combinations (mixed and same-type)
// Mixed: mxfp8 K + mxfp4 V, mxfp6 K + mxfp4 V (our recommended configs)
for (ggml_type type_K : {GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3}) {
for (ggml_type type_V : {GGML_TYPE_MXFP4_E2M1}) {
if (type_K == type_V) continue;
for (int nb : {1, 3, 32}) {
// hsk hsv nh nr23 kv nb mask sinks bias softcap prec type_K permute type_V
test_cases.emplace_back(new test_flash_attn_ext(
128, 128, 4, {1, 1}, 512, nb, true, false, 0.0f, 0.0f, GGML_PREC_F32, type_K, {0, 1, 2, 3}, type_V));
}
}
}
// Same-type: mxfp8/mxfp8, mxfp6/mxfp6
for (ggml_type type_KV : {GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3}) {
for (int nb : {1, 3, 32}) {
test_cases.emplace_back(new test_flash_attn_ext(
128, 128, 4, {1, 1}, 512, nb, true, false, 0.0f, 0.0f, GGML_PREC_F32, type_KV, {0, 1, 2, 3}, type_KV));
}
}
test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, { 10, 5, 4, 3}));
test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, {30000, 1, 1, 1}));
test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, { 10, 5, 4, 3}));
@ -8849,7 +9013,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
// gpt-oss-20b
for (int bs : {1, 4, 8, 512}) {
for (ggml_type type_a : {GGML_TYPE_MXFP4}) {
for (ggml_type type_a : {GGML_TYPE_MXFP4_E2M1}) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 2880, bs, 2880));
test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1));

View File

@ -483,7 +483,15 @@ static ggml_type ggml_type_from_name(const std::string & s) {
if (s == "iq4_nl") {
return GGML_TYPE_IQ4_NL;
}
if (s == "mxfp4" || s == "mxfp4_e2m1") {
return GGML_TYPE_MXFP4_E2M1;
}
if (s == "mxfp8" || s == "mxfp8_e4m3") {
return GGML_TYPE_MXFP8_E4M3;
}
if (s == "mxfp6" || s == "mxfp6_e2m3") {
return GGML_TYPE_MXFP6_E2M3;
}
return GGML_TYPE_COUNT;
}