Comment consistency pass and cleanup.

This commit is contained in:
Tim Burke 2026-03-21 13:37:09 -04:00
parent 23e88631c4
commit 5bb05ed21c
12 changed files with 113 additions and 294 deletions

View File

@ -404,8 +404,6 @@ const std::vector<ggml_type> kv_cache_types = {
};
static ggml_type kv_cache_type_from_str(const std::string & s) {
// Short aliases: "mxfp4" → E2M1, "mxfp6" → E2M3, "mxfp8" → E4M3.
// Full names (mxfp4_e2m1, mxfp8_e4m3, mxfp6_e2m3, etc.) match via ggml_type_name() below.
if (s == "mxfp4") {
return GGML_TYPE_MXFP4_E2M1;
}

View File

@ -71,7 +71,6 @@ typedef sycl::half2 ggml_half2;
#define GGML_COMMON_DECL
#endif
// Pure numeric constants needed by both DECL and IMPL sections.
#define MXFP_HADAMARD_32_NORM 0.17677669529663689f // 1/sqrt(32)
#if defined(GGML_COMMON_DECL)
@ -199,11 +198,8 @@ typedef struct {
} block_q4_1;
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
// MXFP E8M0 shared exponent constants (OCP MX v1.0 §5.3).
// EMAX_OFFSET: ceil(log2(max_finite)) for each element type — used to center the E8M0 scale.
// MSE_RANGE: search radius around round(log2(amax)). Tests 2*range+1 candidate exponents,
// picking the one that minimizes total round-trip quantization error per block.
// Inspired by "Four Over Six" (arXiv:2512.02010); generalized to all MX types.
// E8M0 shared exponent constants (OCP MX v1.0 SS5.3).
// EMAX_OFFSET = ceil(log2(max_finite)), MSE_RANGE = search radius for optimal scale.
#define MXFP_E8M0_MSE_RANGE 2
#define MXFP4_E2M1_EMAX_OFFSET 2 // ceil(log2(6.0))
#define MXFP6_E2M3_EMAX_OFFSET 3 // ceil(log2(7.5))
@ -211,11 +207,7 @@ static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4
#define MXFP8_E4M3_EMAX_OFFSET 8 // ceil(log2(448))
#define MXFP8_E5M2_EMAX_OFFSET 16 // ceil(log2(57344))
// MXFP type properties — single source of truth for all backends.
// Bits per element, quantized bytes per block, and Hadamard rotation flag.
// USE_HADAMARD: 1 for types with >= 3-bit mantissa (E2M1, E4M3, E2M3).
// 0 for 2-bit mantissa types (E5M2, E3M2) where Hadamard provides
// no quality benefit and hurts models with D_head ≤ 64.
// MXFP type properties -- shared across all backends.
#define MXFP_BITS_PER_ELEM_E2M1 4
#define MXFP_BITS_PER_ELEM_E4M3 8
#define MXFP_BITS_PER_ELEM_E5M2 8
@ -239,52 +231,48 @@ static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4
// EXP_MASK = (1<<E)-1 MANT_MASK = (1<<M)-1 EXP_SHIFT = M
// IEEE_EXP_OFF = 127-B MANT_SHIFT = 23-M SUB_SCALE = 2^(1-B-M)
// Used by x86 AVX2 and ARM NEON vectorized dequant in dot product, AoS dequant, SoA dequant.
#define MXFP8_E4M3_EXP_MASK 0xF // (1<<4)-1
#define MXFP8_E4M3_MANT_MASK 0x7 // (1<<3)-1
#define MXFP8_E4M3_EXP_MASK 0xF
#define MXFP8_E4M3_MANT_MASK 0x7
#define MXFP8_E4M3_EXP_SHIFT 3
#define MXFP8_E4M3_IEEE_EXP_OFF 120 // 127-7
#define MXFP8_E4M3_MANT_SHIFT 20 // 23-3
#define MXFP8_E4M3_SUB_SCALE (1.0f/512.0f) // 2^(-9) = 2^(1-7-3)
#define MXFP8_E4M3_IEEE_EXP_OFF 120
#define MXFP8_E4M3_MANT_SHIFT 20
#define MXFP8_E4M3_SUB_SCALE (1.0f/512.0f)
#define MXFP8_E5M2_EXP_MASK 0x1F // (1<<5)-1
#define MXFP8_E5M2_MANT_MASK 0x3 // (1<<2)-1
#define MXFP8_E5M2_EXP_MASK 0x1F
#define MXFP8_E5M2_MANT_MASK 0x3
#define MXFP8_E5M2_EXP_SHIFT 2
#define MXFP8_E5M2_IEEE_EXP_OFF 112 // 127-15
#define MXFP8_E5M2_MANT_SHIFT 21 // 23-2
#define MXFP8_E5M2_SUB_SCALE (1.0f/65536.0f) // 2^(-16) = 2^(1-15-2)
#define MXFP8_E5M2_IEEE_EXP_OFF 112
#define MXFP8_E5M2_MANT_SHIFT 21
#define MXFP8_E5M2_SUB_SCALE (1.0f/65536.0f)
#define MXFP6_E2M3_EXP_MASK 0x3 // (1<<2)-1
#define MXFP6_E2M3_MANT_MASK 0x7 // (1<<3)-1
#define MXFP6_E2M3_EXP_MASK 0x3
#define MXFP6_E2M3_MANT_MASK 0x7
#define MXFP6_E2M3_EXP_SHIFT 3
#define MXFP6_E2M3_IEEE_EXP_OFF 126 // 127-1
#define MXFP6_E2M3_MANT_SHIFT 20 // 23-3
#define MXFP6_E2M3_SUB_SCALE (1.0f/8.0f) // 2^(-3) = 2^(1-1-3)
#define MXFP6_E2M3_IEEE_EXP_OFF 126
#define MXFP6_E2M3_MANT_SHIFT 20
#define MXFP6_E2M3_SUB_SCALE (1.0f/8.0f)
#define MXFP6_E3M2_EXP_MASK 0x7 // (1<<3)-1
#define MXFP6_E3M2_MANT_MASK 0x3 // (1<<2)-1
#define MXFP6_E3M2_EXP_MASK 0x7
#define MXFP6_E3M2_MANT_MASK 0x3
#define MXFP6_E3M2_EXP_SHIFT 2
#define MXFP6_E3M2_IEEE_EXP_OFF 124 // 127-3
#define MXFP6_E3M2_MANT_SHIFT 21 // 23-2
#define MXFP6_E3M2_SUB_SCALE (1.0f/16.0f) // 2^(-4) = 2^(1-3-2)
#define MXFP6_E3M2_IEEE_EXP_OFF 124
#define MXFP6_E3M2_MANT_SHIFT 21
#define MXFP6_E3M2_SUB_SCALE (1.0f/16.0f)
// Unified MXFP dequantization traits for SIMD backends (CPU x86/ARM, CUDA, Metal, Vulkan).
// Contains all parameters needed for IEEE-754 bit reconstruction of FP8/FP6 elements.
// FP4 uses LUT-based dequant and does not need this struct.
// MXFP dequant traits for IEEE-754 bit reconstruction (FP8/FP6).
typedef struct {
int exp_mask; // (1<<E)-1: exponent field mask
int mant_mask; // (1<<M)-1: mantissa field mask
int exp_shift; // M: right-shift to extract exponent
int ieee_exp_off; // 127-bias: offset to convert to IEEE exponent
int mant_shift; // 23-M: left-shift to align mantissa in IEEE float
float sub_scale; // 2^(1-bias-M): subnormal scale factor
int sign_mask; // 0x80 for 8-bit, 0x20 for 6-bit formats
int sign_shift; // 24 for 8-bit, 26 for 6-bit formats
int qs_per_block; // bytes of quantized data per 32-element block
int emax_offset; // type-specific offset for E8M0 MSE search
int exp_mask;
int mant_mask;
int exp_shift;
int ieee_exp_off;
int mant_shift;
float sub_scale;
int sign_mask; // 0x80 for 8-bit, 0x20 for 6-bit
int sign_shift; // 24 for 8-bit, 26 for 6-bit
int qs_per_block;
int emax_offset;
} mxfp_dequant_traits_t;
// Static const trait instances for each MXFP format.
// Gated by GGML_COMMON_IMPL to ensure single definition per translation unit.
#if defined(GGML_COMMON_IMPL)
static const mxfp_dequant_traits_t MXFP_TRAITS_E4M3 = {
MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT,
@ -337,17 +325,12 @@ typedef struct {
} block_mxfp6;
static_assert(sizeof(block_mxfp6) == sizeof(uint8_t) + QK_MXFP6 * 6 / 8, "wrong mxfp6 block size/padding");
// SoA (Struct-of-Arrays) layout constants for MXFP KV cache.
// Per row: [qs_block0|qs_block1|...][e8m0_0|e8m0_1|...]
// Total bytes per row is IDENTICAL to AoS — same tensor strides, just rearranged.
// Aliases for the canonical MXFP_QS_PER_BLOCK_* defines above.
#define MXFP4_SOA_QS_PER_BLOCK MXFP_QS_PER_BLOCK_E2M1 // 16 bytes
#define MXFP8_SOA_QS_PER_BLOCK MXFP_QS_PER_BLOCK_E4M3 // 32 bytes
#define MXFP6_SOA_QS_PER_BLOCK MXFP_QS_PER_BLOCK_E2M3 // 24 bytes
// SoA layout for MXFP KV cache: [qs blocks][e8m0 scales]
#define MXFP4_SOA_QS_PER_BLOCK MXFP_QS_PER_BLOCK_E2M1
#define MXFP8_SOA_QS_PER_BLOCK MXFP_QS_PER_BLOCK_E4M3
#define MXFP6_SOA_QS_PER_BLOCK MXFP_QS_PER_BLOCK_E2M3
// SoA offset helpers — single source of truth for the SoA memory layout contract.
// qs region: blocks 0..nblocks-1 at contiguous qs_per_block-byte strides.
// e8m0 region: starts immediately after all qs blocks.
// SoA offset helpers
#define MXFP_SOA_QS_OFFSET(block_idx, qs_per_block) ((block_idx) * (qs_per_block))
#define MXFP_SOA_E8M0_OFFSET(nblocks, qs_per_block) ((nblocks) * (qs_per_block))
@ -1296,15 +1279,12 @@ GGML_TABLE_BEGIN(float, kvalues_mxfp4_float, 16)
-0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f,
GGML_TABLE_END()
// E2M1 values doubled (implementation detail for CPU/CUDA integer arithmetic).
// Used with GGML_E8M0_TO_FP32_HALF(e) = scale/2 so that int8 × half_scale = true value.
// Canonical values are in kvalues_mxfp4_float above.
// E2M1 values doubled (for integer arithmetic with half-scale).
GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
GGML_TABLE_END()
// FP6 E2M3 dequantization LUT: 6-bit value → float (64 entries).
// Generated from ggml_mxfp_fp6_e2m3_to_float(). Indices 0-31 positive, 32-63 negative.
// FP6 E2M3 dequantization LUT: 6-bit value -> float.
GGML_TABLE_BEGIN(float, kvalues_mxfp6_e2m3, 64)
0.0f, 0.125f, 0.25f, 0.375f, 0.5f, 0.625f, 0.75f, 0.875f,
1.0f, 1.125f, 1.25f, 1.375f, 1.5f, 1.625f, 1.75f, 1.875f,
@ -1316,8 +1296,7 @@ GGML_TABLE_BEGIN(float, kvalues_mxfp6_e2m3, 64)
-4.0f, -4.5f, -5.0f, -5.5f, -6.0f, -6.5f, -7.0f, -7.5f,
GGML_TABLE_END()
// FP6 E3M2 dequantization LUT: 6-bit value → float (64 entries).
// Generated from ggml_mxfp_fp6_e3m2_to_float(). No NaN/Inf — all bit patterns are valid.
// FP6 E3M2 dequantization LUT: 6-bit value -> float. No NaN/Inf.
GGML_TABLE_BEGIN(float, kvalues_mxfp6_e3m2, 64)
0.0f, 0.0625f, 0.125f, 0.1875f, 0.25f, 0.3125f, 0.375f, 0.4375f,
0.5f, 0.625f, 0.75f, 0.875f, 1.0f, 1.25f, 1.5f, 1.75f,
@ -1329,8 +1308,7 @@ GGML_TABLE_BEGIN(float, kvalues_mxfp6_e3m2, 64)
-8.0f, -10.0f, -12.0f, -14.0f, -16.0f, -20.0f, -24.0f, -28.0f,
GGML_TABLE_END()
// FP8 E4M3 dequantization LUT: byte → float (256 entries).
// Generated from ggml_mxfp_fp8_e4m3_to_float(). Entry 127 = 448 (max finite), 255 = NaN.
// FP8 E4M3 dequantization LUT: byte -> float. Entry 127 = 448 (max finite), 255 = NaN.
GGML_TABLE_BEGIN(float, kvalues_mxfp8_e4m3, 256)
0.0f, 0.001953125f, 0.00390625f, 0.005859375f, 0.0078125f, 0.009765625f, 0.01171875f, 0.013671875f,
0.015625f, 0.017578125f, 0.01953125f, 0.021484375f, 0.0234375f, 0.025390625f, 0.02734375f, 0.029296875f,
@ -1366,8 +1344,7 @@ GGML_TABLE_BEGIN(float, kvalues_mxfp8_e4m3, 256)
-256.0f, -288.0f, -320.0f, -352.0f, -384.0f, -416.0f, -448.0f, NAN,
GGML_TABLE_END()
// FP8 E5M2 dequantization LUT: byte → float (256 entries).
// Generated from ggml_mxfp_fp8_e5m2_to_float(). Entries 124-127 = {Inf, NaN, NaN, NaN}.
// FP8 E5M2 dequantization LUT: byte -> float. Entries 124-127 = {Inf, NaN, NaN, NaN}.
GGML_TABLE_BEGIN(float, kvalues_mxfp8_e5m2, 256)
0.0f, 1.525879e-05f, 3.051758e-05f, 4.577637e-05f, 6.103516e-05f, 7.629395e-05f, 9.155273e-05f, 1.068115e-04f,
1.220703e-04f, 1.525879e-04f, 1.831055e-04f, 2.136230e-04f, 2.441406e-04f, 3.051758e-04f, 3.662109e-04f, 4.272461e-04f,
@ -1403,16 +1380,10 @@ GGML_TABLE_BEGIN(float, kvalues_mxfp8_e5m2, 256)
-32768.0f, -40960.0f, -49152.0f, -57344.0f, -INFINITY, NAN, NAN, NAN,
GGML_TABLE_END()
// ------------------------------------------------------------------------------------------------------------------
// Canonical MXFP element converters — portable IEEE-754 bit manipulation.
// Single source of truth for CPU, CUDA, HIP, MUSA, SYCL. Metal/Vulkan keep MSL/GLSL copies.
// ------------------------------------------------------------------------------------------------------------------
// MXFP element converters -- portable IEEE-754 bit manipulation.
#if defined(GGML_MXFP_FUNC)
// --- FP4 E2M1: [S(1) | E(2) | M(1)] — max normal = 6.0 ---
// Canonical converters using true E2M1 values {0, 0.5, 1, 1.5, 2, 3, 4, 6}.
// The int8 kvalues_mxfp4 LUT stores doubled values {0,1,2,3,4,6,8,12} for
// CPU/CUDA nibble-indexed integer arithmetic — that doubling is an implementation detail.
// FP4 E2M1: [S(1) | E(2) | M(1)], max normal = 6.0
GGML_MXFP_FUNC float ggml_mxfp_fp4_e2m1_to_float(uint8_t v) {
const float sign = (v & 0x8) ? -1.0f : 1.0f;
@ -1427,8 +1398,6 @@ GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp4_e2m1(float x) {
if (x < 0) { sign = 0x8; x = -x; }
if (x == 0) return sign;
if (x >= 6.0f) return sign | 0x7; // max finite
// Decision boundaries (midpoints of adjacent canonical values):
// {0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0}
if (x < 0.25f) return sign | 0x0; // 0
else if (x < 0.75f) return sign | 0x1; // 0.5
else if (x < 1.25f) return sign | 0x2; // 1.0
@ -1439,7 +1408,7 @@ GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp4_e2m1(float x) {
else return sign | 0x7; // 6.0
}
// --- FP6 E2M3: [S(1) | E(2) | M(3)] — max normal = 7.5 ---
// FP6 E2M3: [S(1) | E(2) | M(3)], max normal = 7.5
GGML_MXFP_FUNC float ggml_mxfp_fp6_e2m3_to_float(uint8_t v) {
const float sign = (v & 0x20) ? -1.0f : 1.0f;
@ -1474,7 +1443,7 @@ GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp6_e2m3(float x) {
return sign | (uint8_t)(((f32_exp + 1) << 3) | mant);
}
// --- FP6 E3M2: [S(1) | E(3) | M(2)] — max normal = 28.0, no NaN/Inf ---
// FP6 E3M2: [S(1) | E(3) | M(2)], max normal = 28.0, no NaN/Inf
GGML_MXFP_FUNC float ggml_mxfp_fp6_e3m2_to_float(uint8_t v) {
const float sign = (v & 0x20) ? -1.0f : 1.0f;
@ -1512,7 +1481,7 @@ GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp6_e3m2(float x) {
return sign | (uint8_t)((biased_exp << 2) | mant);
}
// --- FP8 E4M3: [S(1) | E(4) | M(3)] — bias=7, max finite=448 ---
// FP8 E4M3: [S(1) | E(4) | M(3)], bias=7, max finite=448
GGML_MXFP_FUNC float ggml_mxfp_fp8_e4m3_to_float(uint8_t v) {
uint32_t sign = ((uint32_t)(v & 0x80)) << 24;
@ -1571,7 +1540,7 @@ GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp8_e4m3(float x) {
return sign | (uint8_t)((e4m3_exp << 3) | mant3);
}
// --- FP8 E5M2: [S(1) | E(5) | M(2)] — bias=15, max finite=57344 ---
// FP8 E5M2: [S(1) | E(5) | M(2)], bias=15, max finite=57344
GGML_MXFP_FUNC float ggml_mxfp_fp8_e5m2_to_float(uint8_t v) {
uint32_t sign = ((uint32_t)(v & 0x80)) << 24;
@ -1629,7 +1598,7 @@ GGML_MXFP_FUNC uint8_t ggml_mxfp_float_to_fp8_e5m2(float x) {
return sign | (uint8_t)((e5m2_exp << 2) | mant2);
}
// --- FP6 packing/unpacking ---
// FP6 packing/unpacking
// Pack 4 six-bit values into 3 bytes
GGML_MXFP_FUNC void ggml_mxfp_pack_fp6x4(const uint8_t v[4], uint8_t out[3]) {
@ -1674,8 +1643,6 @@ GGML_MXFP_FUNC int ggml_mxfp_e8m0_base_estimate(float amax, int emax_offset) {
}
// Block-32 Walsh-Hadamard Transform, normalized by 1/sqrt(32).
// Spreads outlier energy across all elements sharing an E8M0 exponent,
// improving quantization quality (see QuaRot arXiv:2404.00456).
GGML_MXFP_FUNC void ggml_mxfp_hadamard_32_inplace(GGML_MXFP_THREAD float * vals) {
GGML_MXFP_UNROLL
for (int stride = 1; stride < 32; stride *= 2) {

View File

@ -2,7 +2,6 @@
#pragma once
// Rename `_generic` functions if no native implementation is available.
// This effectively selects the generic implementation.
#if defined(GGML_CPU_GENERIC)
// quants.c

View File

@ -4134,21 +4134,14 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
#endif
}
// ── MXFP FP8/FP6 NEON helpers ──────────────────────────────────────────────
// Shared IEEE-754 bit reconstruction and FP6 unpacking used by vec_dot,
// dequantize_row, and SoA dequant functions.
//
// NEON requires vshlq_n_u32 to have a compile-time literal constant, so we use
// two separate helpers for FP8 (sign at bit 7, shift 24) and FP6 (sign at bit 5,
// shift 26) rather than a single parameterized function.
// MXFP FP8/FP6 NEON helpers
// Separate FP8/FP6 functions because NEON vshlq_n_u32 requires compile-time constants.
#if defined(__ARM_NEON)
// Use shared mxfp_dequant_traits_t from ggml-common.h.
#define mxfp_neon_traits_t mxfp_dequant_traits_t
// Dequantize 4 raw FP8 values (uint32x4_t) → 4 IEEE-754 floats.
// Sign bit at position 7, sign shift = 24.
// Dequantize 4 FP8 values to floats.
static inline float32x4_t mxfp8_dequant_neon(
const uint32x4_t v_raw,
const uint32x4_t v_exp_mask, const uint32x4_t v_mant_mask,
@ -4172,8 +4165,7 @@ static inline float32x4_t mxfp8_dequant_neon(
return vbslq_f32(is_sub, sub_val, normal);
}
// Dequantize 4 raw FP6 values (uint32x4_t) → 4 IEEE-754 floats.
// Sign bit at position 5, sign shift = 26.
// Dequantize 4 FP6 values to floats.
static inline float32x4_t mxfp6_dequant_neon(
const uint32x4_t v_raw,
const uint32x4_t v_exp_mask, const uint32x4_t v_mant_mask,
@ -4229,7 +4221,7 @@ static inline void widen_s8x8_to_f32x4x2(const int8_t * src,
*hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16)));
}
// ── MXFP FP8/FP6 vec_dot ──────────────────────────────────────────────────
// MXFP FP8/FP6 vec_dot
static void ggml_vec_dot_mxfp8_q8_0_neon(
int n, float * GGML_RESTRICT s,
@ -4319,7 +4311,7 @@ static void ggml_vec_dot_mxfp6_q8_0_neon(
*s = vaddvq_f32(vaddq_f32(acc0, acc1));
}
// ── MXFP FP8/FP6 dequantize_row (AoS) ─────────────────────────────────────
// MXFP FP8/FP6 dequantize_row (AoS)
static void dequantize_row_mxfp8_neon(
const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k,
@ -4381,7 +4373,7 @@ static void dequantize_row_mxfp6_neon(
}
}
// ── MXFP SoA dequant (flash attention) ─────────────────────────────────────
// MXFP SoA dequant (flash attention)
static void dequantize_row_mxfp8_soa_neon(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
@ -4492,7 +4484,7 @@ static void dequantize_row_mxfp4_soa_neon(
#endif // __ARM_NEON
// ── Public dispatch functions ──────────────────────────────────────────────
// Public dispatch functions
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);

View File

@ -3819,18 +3819,13 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
#endif
}
// ── MXFP FP8/FP6 AVX2 helpers ──────────────────────────────────────────────
// Shared IEEE-754 bit reconstruction and FP6 unpacking used by vec_dot,
// dequantize_row, and SoA dequant functions.
// MXFP FP8/FP6 AVX2 helpers
#if defined(__AVX2__)
// Use shared mxfp_dequant_traits_t from ggml-common.h.
// Aliases for readability within this file.
#define mxfp_avx2_traits_t mxfp_dequant_traits_t
// Dequantize 8 raw MXFP values (widened to int32) → 8 IEEE-754 floats.
// Handles both normal and subnormal paths. Works for any FP6/FP8 format.
// Dequantize 8 FP8/FP6 values to floats.
static inline __m256 mxfp_dequant_avx2(
const __m256i v_raw,
const __m256i v_exp_mask, const __m256i v_mant_mask,
@ -3872,9 +3867,9 @@ static inline __m256i unpack_fp6x8_avx2(const uint8_t * qs, int j) {
return _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)unpacked));
}
// ── MXFP FP8/FP6 vec_dot ──────────────────────────────────────────────────
// MXFP FP8/FP6 vec_dot
// Unified FP8 × Q8_0 dot product (works for E4M3 and E5M2).
// FP8 x Q8_0 dot product (E4M3/E5M2).
static void ggml_vec_dot_mxfp8_q8_0_avx2(
int n, float * GGML_RESTRICT s,
const void * GGML_RESTRICT vx,
@ -3915,7 +3910,7 @@ static void ggml_vec_dot_mxfp8_q8_0_avx2(
*s = hsum_float_8(acc);
}
// Unified FP6 × Q8_0 dot product (works for E2M3 and E3M2).
// FP6 x Q8_0 dot product (E2M3/E3M2).
static void ggml_vec_dot_mxfp6_q8_0_avx2(
int n, float * GGML_RESTRICT s,
const void * GGML_RESTRICT vx,
@ -3955,7 +3950,7 @@ static void ggml_vec_dot_mxfp6_q8_0_avx2(
*s = hsum_float_8(acc);
}
// ── MXFP FP8/FP6 dequantize_row (AoS) ─────────────────────────────────────
// MXFP FP8/FP6 dequantize_row (AoS)
static void dequantize_row_mxfp8_avx2(
const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k,
@ -4016,7 +4011,7 @@ static void dequantize_row_mxfp6_avx2(
}
}
// ── MXFP SoA dequant (flash attention) ─────────────────────────────────────
// MXFP SoA dequant (flash attention)
static void dequantize_row_mxfp8_soa_avx2(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
@ -4116,7 +4111,7 @@ static void dequantize_row_mxfp4_soa_avx2(
#endif // __AVX2__
// ── Public dispatch functions ──────────────────────────────────────────────
// Public dispatch functions
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);

View File

@ -4909,8 +4909,7 @@ void ggml_compute_forward_get_rows(
//}
}
// NEON-optimized Hadamard for ARM platforms; scalar fallback uses ggml_hadamard_32_inplace
// from ggml-quants.c (the reference implementation).
// NEON-optimized Hadamard; scalar fallback below
#if defined(__ARM_NEON)
static void hadamard_32_inplace(float vals[32]) {
float32x4_t v0 = vld1q_f32(vals + 0);
@ -4978,13 +4977,11 @@ static void hadamard_32_inplace(float vals[32]) {
vst1q_f32(vals + 28, vmulq_f32(v7, norm));
}
#else
// Scalar fallback: delegate to reference implementation in ggml-quants.c
static void hadamard_32_inplace(float vals[32]) {
ggml_hadamard_32_inplace(vals);
}
#endif
// Apply Hadamard rotation to each 32-element block in a float buffer.
static void ggml_apply_hadamard_blocks(float * data, int64_t n) {
GGML_ASSERT(n % 32 == 0);
for (int64_t i = 0; i < n; i += 32) {
@ -5031,8 +5028,6 @@ static void ggml_compute_forward_set_rows_f32(
const int32_t apply_hadamard = ((const int32_t *)dst->op_params)[0];
// For MXFP types, use SoA quantize (canonical FA layout).
// For non-MXFP types, use the standard AoS from_float.
typedef void (*quantize_soa_fn)(const float *, void *, int64_t);
quantize_soa_fn mxfp_soa_quantize = nullptr;
ggml_from_float_t from_float = nullptr;
@ -5046,8 +5041,6 @@ static void ggml_compute_forward_set_rows_f32(
break;
}
// Pre-allocate Hadamard temp buffer once outside the hot loop (nc is constant).
// nc == n_embd_k_gqa which is bounded by model architecture (typically <= 8192).
std::vector<float> had_tmp;
if (apply_hadamard) {
had_tmp.resize(nc);
@ -8275,8 +8268,7 @@ void ggml_compute_forward_top_k(
typedef void (*mxfp_soa_quantize_fn)(const float *, void *, int64_t);
typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t);
// Shared MXFP dispatch parameters for flash attention.
// Populated once and used by both the one_chunk and tiled paths.
// MXFP dispatch parameters for flash attention.
struct mxfp_fa_params {
mxfp_soa_quantize_fn q_quantize;
mxfp_soa_dequantize_fn k_dequantize;
@ -8287,9 +8279,6 @@ struct mxfp_fa_params {
int64_t v_soa_elems;
bool apply_hadamard;
// Per-head SoA addressing (avoids dequanting all heads in multihead mode).
// qs_per_block: bytes of quantized data per 32-element block.
// head_qs_bytes: total qs bytes for one head (blocks_per_head * qs_per_block).
// head_e8m0_offset: byte offset from row start to e8m0 region.
int k_qs_per_block;
int v_qs_per_block;
int k_head_qs_bytes;
@ -8455,9 +8444,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
ggml_vec_dot_t kq_vec_dot = nullptr;
ggml_to_float_t v_to_float = nullptr;
if (is_mxfp_k) {
kq_vec_dot = nullptr;
} else {
if (!is_mxfp_k) {
ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
@ -8472,20 +8459,15 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
int ith = params->ith;
// Dequant buffers for MXFP SoA — stack-allocated, no heap allocation in the hot path.
// In multihead mode, only dequant one head (DK or DV elements) instead of all heads.
// DK/DV are bounded by 1024 (asserted below for MXFP).
if (is_mxfp_k) { GGML_ASSERT(DK <= 1024); }
if (is_mxfp_v) { GGML_ASSERT(DV <= 1024); }
float k_dequant_buf[1024];
float v_dequant_buf[1024];
// Per-head SoA temp buffer: holds [qs | e8m0] for one head in multihead mode.
// For DK=1024 with MXFP8: 32 blocks * 32 qs + 32 e8m0 = 1056 bytes.
// Stack-allocated since sizes are bounded by DK/DV <= 1024.
// Max: 1024/32 * 32(qs) + 1024/32 = 1056 bytes (MXFP8).
char k_head_soa[1088]; // 1056 rounded up for alignment
char k_head_soa[1088]; // max: DK=1024 MXFP8 -> 1056 bytes, rounded up
char v_head_soa[1088];
// Thread-local work buffers (constant across ir loop)
float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32);
float * V32 = (VKQ32 + 1*DV);
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV);
@ -8521,31 +8503,24 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
const int iv3 = iq3 / rv3;
const int iv2 = iq2 / rv2;
// Precompute loop-invariant base pointer offsets for K and V.
// Only ic varies in the inner loop; head/batch offsets are constant.
const size_t k_base_offset = ik2*nbk2 + ik3*nbk3;
const size_t v_base_offset = iv2*nbv2 + iv3*nbv3;
const char * k_base = (const char *) k->data + k_base_offset;
const char * v_base = (const char *) v->data + v_base_offset;
// For multihead MXFP: precompute per-head SoA byte offsets (constant per query row).
// head_qs_start: byte offset to this head's qs blocks within the SoA row.
// head_e8m0_start: byte offset to this head's e8m0 scales within the SoA row.
// Per-head SoA byte offsets
const int k_head_qs_start = mxfp.k_multihead ? ik2 * mxfp.k_head_qs_bytes : 0;
const int k_head_e8m0_start = mxfp.k_multihead ? (int)mxfp.k_head_e8m0_offset + ik2 * mxfp.k_blocks_per_head : 0;
const int v_head_qs_start = mxfp.v_multihead ? iv2 * mxfp.v_head_qs_bytes : 0;
const int v_head_e8m0_start = mxfp.v_multihead ? (int)mxfp.v_head_e8m0_offset + iv2 * mxfp.v_blocks_per_head : 0;
// Multihead MXFP row base (without head offset) — only ic varies.
const char * k_row_base = mxfp.k_multihead ? ((const char *) k->data + ik3*nbk3) : nullptr;
const char * v_row_base = mxfp.v_multihead ? ((const char *) v->data + iv3*nbv3) : nullptr;
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
float Q_f32[1024];
if (is_mxfp_k) {
// Q preprocessing: Hadamard → SoA quantize → SoA dequant (round-trip).
// Captures the same quantization loss as K, matching GPU MMA semantics.
GGML_ASSERT(DK <= 1024);
// Q preprocessing: Hadamard + SoA round-trip captures same quantization loss as K.
if (mxfp.apply_hadamard) {
float q_tmp[1024];
memcpy(q_tmp, pq, DK * sizeof(float));
@ -8557,7 +8532,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
mxfp.k_dequantize(Q_q, Q_f32, DK);
} else {
if (mxfp.apply_hadamard) {
GGML_ASSERT(DK <= 1024);
float q_tmp[1024];
memcpy(q_tmp, pq, DK * sizeof(float));
ggml_apply_hadamard_blocks(q_tmp, DK);
@ -8581,8 +8555,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
if (is_mxfp_k) {
if (mxfp.k_multihead) {
// Multihead: extract this head's SoA blocks into temp buffer, dequant only DK elements.
// Copy qs blocks then e8m0 scales for this head into contiguous [qs|e8m0] layout.
// Extract this head's SoA blocks
const char * row = k_row_base + ic*nbk1;
memcpy(k_head_soa, row + k_head_qs_start, mxfp.k_head_qs_bytes);
memcpy(k_head_soa + mxfp.k_head_qs_bytes, row + k_head_e8m0_start, mxfp.k_blocks_per_head);
@ -8610,10 +8583,12 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
if (v_is_f16) {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
M = s;
ms = expf(Mold - M);
ggml_vec_scale_f16(DV, VKQ16, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}
@ -8621,10 +8596,12 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) (v_base + ic*nbv1), vs);
} else {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
M = s;
ms = expf(Mold - M);
ggml_vec_scale_f32(DV, VKQ32, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}
@ -8643,6 +8620,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
v_to_float(v_base + ic*nbv1, V32, DV);
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
} else {
// V is F32
ggml_vec_mad_f32(DV, VKQ32, (const float *) (v_base + ic*nbv1), vs);
}
}
@ -8782,12 +8760,12 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
// MXFP dequant scratch buffers — allocated once per thread, reused across all tiles.
// DK/DV bounded by 1024, so per-head dequant fits in stack buffers.
if (is_mxfp_k) { GGML_ASSERT(DK <= 1024); }
if (is_mxfp_v) { GGML_ASSERT(DV <= 1024); }
float k_dequant_buf[1024];
float v_dequant_buf[1024];
// Per-head SoA temp buffers for multihead extraction (same as one_chunk path).
char k_head_soa[1088];
char v_head_soa[1088];
@ -8853,10 +8831,7 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
if (mxfp.apply_hadamard) {
ggml_apply_hadamard_blocks(Q_f32 + tq * DK, DK);
}
// SoA round-trip: quantize Q to SoA, then dequant back to float
const size_t q_soa_bytes = ggml_row_size(k->type, DK);
GGML_ASSERT(q_soa_bytes <= 2048);
uint8_t q_mxfp_buf[2048]; // max: DK=1024 * 33/32 = 1056 bytes (MXFP8)
uint8_t q_mxfp_buf[1088]; // max: DK=1024 MXFP8 -> 1056 bytes
mxfp.q_quantize(Q_f32 + tq * DK, q_mxfp_buf, DK);
mxfp.k_dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK);
}
@ -9234,10 +9209,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int64_t dr = (nr + nchunk - 1) / nchunk;
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
// Tiled GEMM path: dequant K/V to float, then simd_gemm.
// Only for types that natively dequant to float (f32, f16, MXFP).
// Standard quant types (q8_0, q4_0) must use the scalar one_chunk path
// to preserve vec_dot semantics and produce identical results to master.
// Tiled path: f32, f16, and MXFP only (quant types use one_chunk)
bool use_tiled = !use_ref &&
(q->type == GGML_TYPE_F32 &&
kv_is_f32_f16_or_mxfp &&

View File

@ -264,9 +264,7 @@ void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
*s = sumf;
}
// Generic MXFP-to-Q8_0 dot product. Dequants one MX block (32 elements)
// to float via the existing public dequantize_row functions, then dots
// against Q8_0 int8 values. Reference implementation — not SIMD-optimized.
// Generic MXFP x Q8_0 dot product (scalar, not SIMD-optimized)
static void ggml_vec_dot_mxfp_q8_0_impl(
int n, float * GGML_RESTRICT s,
const void * GGML_RESTRICT vx, size_t block_size,
@ -311,8 +309,7 @@ void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
(ggml_to_float_t)dequantize_row_mxfp6);
}
// Generic (scalar) dequant wrappers — delegates to ggml-quants.c reference implementations.
// On x86/ARM, arch-specific SIMD versions override these via the fallback.h mapping.
// Generic dequant wrappers — arch-specific SIMD versions override via fallback.h.
void dequantize_row_mxfp8_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp8(x, y, k);
}

View File

@ -431,15 +431,11 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
// E8M0 shared exponent to float: returns 2^(x - 127).
// Canonical source: ggml_mxfp_e8m0_to_fp32() in ggml-common.h.
// Kept here because ggml-impl.h cannot depend on ggml-common.h IMPL section.
// NaN (x == 0xFF) is not handled — callers guarantee valid exponents.
static inline float ggml_e8m0_to_fp32(uint8_t x) {
uint32_t bits;
if (x == 0) {
bits = 0x00400000; // denorm: 0.5 * 2^(-126) = 2^(-127)
} else {
// normalized: exponent x placed in bits [30:23], mantissa = 0 → 2^(x-127)
bits = (uint32_t) x << 23;
}
float result;
@ -447,9 +443,7 @@ static inline float ggml_e8m0_to_fp32(uint8_t x) {
return result;
}
// E8M0 to float/2: returns 2^(x - 128). Equal to ggml_e8m0_to_fp32(x) / 2.
// Useful with MXFP4 because the E2M1 kvalues table stores 2 * float_value.
// Canonical source: ggml_mxfp_e8m0_to_fp32_half() in ggml-common.h.
// E8M0 to float/2: returns 2^(x - 128).
static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
uint32_t bits;
if (x < 2) {
@ -467,8 +461,7 @@ static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
#define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)
#define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)
// UE4M3 (unsigned E4M3): 4 exponent bits (bias 7), 3 mantissa bits, no sign.
// Range: [0, 448], with 0x7F = NaN treated as zero.
// UE4M3 (unsigned E4M3): 4 exponent bits (bias 7), 3 mantissa bits.
// Returns value * 0.5 to match kvalues_mxfp4 convention (kvalues = 2 * E2M1_float).
static inline float ggml_ue4m3_to_fp32(uint8_t x) {
if (x == 0 || x == 0x7F) {
@ -487,20 +480,19 @@ static inline float ggml_ue4m3_to_fp32(uint8_t x) {
return raw * 0.5f;
}
// Float32 to UE4M3 with round-to-nearest-even.
// Clamps to [0, 448]. Max representable: 0x7E = (1 + 7/8) * 2^8 = 448.
// Float32 to UE4M3 with round-to-nearest.
static inline uint8_t ggml_fp32_to_ue4m3(float x) {
if (!(x > 0.0f)) {
return 0; // negative, zero, NaN → 0
return 0;
}
if (x > 448.0f) {
x = 448.0f; // clamp to max representable
x = 448.0f;
}
uint32_t bits;
memcpy(&bits, &x, 4);
int fp32_exp = ((bits >> 23) & 0xFF) - 127;
int fp32_man = (bits >> 20) & 0x7; // top 3 mantissa bits
int ue4m3_exp = fp32_exp + 7; // rebias: FP32 bias 127 → UE4M3 bias 7
int fp32_man = (bits >> 20) & 0x7;
int ue4m3_exp = fp32_exp + 7;
if (ue4m3_exp <= 0) {
// subnormal: value = man * 2^(-9), so man = round(x * 512)
int man = (int) (x * 512.0f + 0.5f);
@ -513,17 +505,15 @@ static inline uint8_t ggml_fp32_to_ue4m3(float x) {
return (uint8_t) man;
}
if (ue4m3_exp >= 15) {
return 0x7E; // max normal
return 0x7E;
}
// round-to-nearest using bit 19 (first bit below UE4M3 mantissa)
int round_bit = (bits >> 19) & 1;
int ue4m3_man = fp32_man + round_bit;
if (ue4m3_man > 7) {
// mantissa overflow → carry into exponent
ue4m3_man = 0;
ue4m3_exp++;
if (ue4m3_exp >= 15) {
return 0x7E; // max normal
return 0x7E;
}
}
return (uint8_t) ((ue4m3_exp << 3) | ue4m3_man);

View File

@ -259,16 +259,12 @@ void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_REST
// ====================== MXFP element conversions (wrappers around ggml-common.h)
// FP8 E4M3: 1 sign, 4 exp (bias 7), 3 mantissa. Max finite: 448, NaN: 0x7F (saturated to max).
float fp8_e4m3_to_float(uint8_t v) { return ggml_mxfp_fp8_e4m3_to_float(v); }
uint8_t float_to_fp8_e4m3_rn(float x) { return ggml_mxfp_float_to_fp8_e4m3(x); }
// ====================== MXFP quantization infrastructure
//
// MSE-optimal E8M0 shared exponent: tests ±R candidates around round(log2(amax))
// and picks whichever minimizes total round-trip quantization error per block.
// Improves on OCP MX v1.0 §5.3 floor(log2(amax)) by 0.05-0.2 PPL.
// Ref: OCP MX v1.0 spec, Four Over Six (arXiv:2512.02010)
// MSE-optimal E8M0: tests candidates around round(log2(amax)), picks lowest quantization error.
typedef struct {
int emax_offset; // type-specific offset to max representable exponent
@ -279,8 +275,7 @@ typedef struct {
static inline int best_index_mxfp4(float x, float e);
// MXFP4 MSE error using decision boundary quantization with half-scale
// (kvalues_mxfp4 are doubled E2M1 values, so scale is halved to compensate)
// MSE error for MXFP4 (kvalues are doubled, so scale is halved)
static float mse_error_mxfp4(float val, float inv_scale, float scale) {
const float d = scale * 0.5f;
const float inv_d = (d > 0.0f) ? 1.0f / d : 0.0f;
@ -301,7 +296,6 @@ static float mse_error_mxfp4(float val, float inv_scale, float scale) {
static const mxfp_elem_traits_t mxfp4_traits = { MXFP4_E2M1_EMAX_OFFSET, NULL, NULL, mse_error_mxfp4 };
// Find MSE-optimal E8M0 exponent by testing ±R candidates around round(log2(amax))
static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_elem_traits_t * traits) {
float amax = 0.0f;
for (int j = 0; j < qk; j++) {
@ -336,7 +330,6 @@ static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_
return (uint8_t)best_e;
}
// Decision boundary quantization for kvalues_mxfp4 {0,1,2,3,4,6,8,12}
static inline int best_index_mxfp4(float x, float e) {
const float inv_e = (e > 0.0f) ? 1.0f / e : 0.0f;
const float normalized = fabsf(x) * inv_e;
@ -566,11 +559,6 @@ void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_REST
}
// ====================== Hadamard rotation
//
// 32-element Walsh-Hadamard transform applied before MX quantization to spread
// outlier energy across the shared-exponent group. Orthogonal (H^T·H = I), so
// H(K)·H(Q) = K·Q — attention scores are preserved when both K and Q are rotated.
// Prior art: QuIP# (Tseng et al. 2024), BRQ (Huang et al. 2024)
void ggml_hadamard_32_inplace(float vals[32]) {
ggml_mxfp_hadamard_32_inplace(vals);
@ -586,7 +574,6 @@ uint8_t float_to_fp8_e5m2_rn(float x) { return ggml_mxfp_float_to_fp8_e5m2(x); }
void pack_fp6x4(const uint8_t v[4], uint8_t out[3]) { ggml_mxfp_pack_fp6x4(v, out); }
void unpack_fp6x4(const uint8_t in[3], uint8_t v[4]) { ggml_mxfp_unpack_fp6x4(in, v); }
// round-trip quantization error for MSE-optimal exponent search
static float mse_error_fp8_e4m3(float val, float inv_scale, float scale) {
const float recon = fp8_e4m3_to_float(float_to_fp8_e4m3_rn(val * inv_scale)) * scale;
const float err = val - recon;
@ -685,8 +672,6 @@ void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_REST
}
// ====================== SoA (Struct-of-Arrays) quantize/dequantize for flash attention
//
// Layout per row: [qs_block0|qs_block1|...|qs_blockN][e8m0_0|e8m0_1|...|e8m0_N]
void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) {
assert(k % QK_MXFP4 == 0);

View File

@ -55,8 +55,7 @@ GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float *
GGML_API void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
// SoA (Struct-of-Arrays) quantize/dequantize — canonical reference for flash attention.
// Layout: [qs contiguous][e8m0 contiguous] per row. Same total bytes as AoS.
// SoA quantize/dequantize for flash attention
GGML_API void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
GGML_API void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
@ -114,83 +113,26 @@ GGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_REST
GGML_API size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
//
// MXFP element-level conversion functions (reference implementations)
//
// These implement the OCP Microscaling (MX) format element types as defined in:
// OCP Microscaling Formats (MX) Specification v1.0, Sep 2023
// https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
//
// Each MX block contains 32 elements sharing a single E8M0 exponent. The element
// types define the per-element mantissa format within each block.
//
// All converters use IEEE-754 bit manipulation for exact results (no floating-point
// rounding in the conversion itself). Quantization functions use round-to-nearest-even
// (RNE) per the MX specification.
//
// GPU backends (CUDA, Metal, Vulkan) provide their own optimized versions — these
// functions serve as the canonical reference and are used by the CPU backend.
//
// FP8 E4M3: 1 sign, 4 exponent (bias 7), 3 mantissa bits
// Range: ±[2^-9, 448], NaN: exp=15 mant=7 (only NaN encoding)
// Ref: OCP MX v1.0 §4.2
// MXFP element converters
GGML_API float fp8_e4m3_to_float(uint8_t v);
GGML_API uint8_t float_to_fp8_e4m3_rn(float x);
// FP8 E5M2: 1 sign, 5 exponent (bias 15), 2 mantissa bits
// Range: ±[2^-16, 57344], NaN/Inf: exp=31 (standard IEEE-like)
// Ref: OCP MX v1.0 §4.2
GGML_API float fp8_e5m2_to_float(uint8_t v);
GGML_API uint8_t float_to_fp8_e5m2_rn(float x);
// FP6 E2M3: 1 sign, 2 exponent (bias 1), 3 mantissa bits
// Range: ±[2^-3, 7.5], stored as low 6 bits of a byte (00xxxxxx)
// MX format: NO NaN/Inf — all bit patterns are valid numbers
// Ref: OCP MX v1.0 §4.2
// no NaN/Inf in FP6 — all bit patterns are valid numbers
GGML_API float fp6_e2m3_to_float(uint8_t v);
GGML_API uint8_t float_to_fp6_e2m3_rn(float x);
// FP6 E3M2: 1 sign, 3 exponent (bias 3), 2 mantissa bits
// Range: ±[2^-4, 28.0], stored as low 6 bits of a byte (00xxxxxx)
// MX format: NO NaN/Inf — exp=7 is a valid normal value (unlike IEEE-754)
// CRITICAL: subnormal scale is 2^(1-bias-m) = 2^(-4) = 1/16, NOT 1/4
// Ref: OCP MX v1.0 §4.2
// no NaN/Inf — exp=7 is a valid normal value (unlike IEEE-754)
GGML_API float fp6_e3m2_to_float(uint8_t v);
GGML_API uint8_t float_to_fp6_e3m2_rn(float x);
// FP6 tight packing: pack/unpack 4 six-bit values into/from 3 bytes
// Layout: v[0]=bits[5:0], v[1]=bits[11:6], v[2]=bits[17:12], v[3]=bits[23:18]
// Saves 25% memory vs byte-padded storage (24B vs 32B per MX block)
// Pack/unpack 4 six-bit values into 3 bytes
GGML_API void pack_fp6x4(const uint8_t v[4], uint8_t out[3]);
GGML_API void unpack_fp6x4(const uint8_t in[3], uint8_t v[4]);
//
// Hadamard rotation (reference scalar implementation)
//
// 32-element Walsh-Hadamard transform applied to MX blocks before quantization.
// Distributes outlier energy uniformly across the block, dramatically improving
// quantization quality for types with shared exponents.
//
// Mathematical property: H^T·H = I (orthogonal), so H(K)·H(Q) = K·Q.
// Flash attention applies matching rotation to Q, preserving attention scores exactly.
//
// Implementation: 5 butterfly stages (log2(32) = 5) + normalization by 1/sqrt(32).
// This is the standard "fast Walsh-Hadamard transform" with O(n log n) operations.
//
// Applied in set_rows (K cache quantization) and flash_attn (Q quantization).
// Skipped for MLA models (DK != DV) where V is a view of K — rotation would corrupt V.
//
// Empirical impact (PPL degradation WITHOUT rotation, Qwen3-Coder-30B):
// MXFP8 E4M3: +0.22, MXFP8 E5M2: +1.38, MXFP6 E2M3: +3.34, MXFP6 E3M2: +4.60
//
// Prior art: QuIP# (Tseng et al. 2024), BRQ (Huang et al. 2024) use Hadamard for
// weight quantization. Our contribution applies it to KV cache quantization at the
// MX block boundary, where block-32 is optimal because it matches the shared exponent
// group size exactly.
//
// GPU backends provide optimized versions (CUDA warp shuffles, Metal SIMD groups).
//
// Block-32 Walsh-Hadamard transform, normalized by 1/sqrt(32)
GGML_API void ggml_hadamard_32_inplace(float vals[32]);
GGML_API void iq2xs_init_impl(enum ggml_type type);

View File

@ -51,7 +51,6 @@ llama_kv_cache::llama_kv_cache(
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
ggml_init_params params = {
// 2 base tensors (K+V) + 2*n_stream view tensors
/*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
@ -140,10 +139,10 @@ llama_kv_cache::llama_kv_cache(
uint32_t n_embd_k_alloc = n_embd_k_gqa;
const bool is_mxfp_k = ggml_is_type_mxfp(type_k);
if (is_mxfp_k) {
const int qk = (int)ggml_blck_size(type_k); // 32 for all MXFP types
const int qk = (int)ggml_blck_size(type_k);
GGML_ASSERT(n_embd_k_gqa % qk == 0 && "MXFP K cache requires n_embd_k_gqa divisible by block size");
const int blocks = (int)n_embd_k_gqa / qk;
const int blocks_aligned = (blocks + 15) & ~15; // align to 16
const int blocks_aligned = (blocks + 15) & ~15;
n_embd_k_alloc = (uint32_t)(blocks_aligned * qk);
}

View File

@ -150,8 +150,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
}
}
// MXFP SoA quantize/dequantize (from ggml-quants.h, which is internal to ggml
// and not in the test include path). Signatures must match ggml-quants.h exactly.
// MXFP SoA functions (internal to ggml, not in test include path)
typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t);
extern "C" {
void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
@ -162,10 +161,7 @@ extern "C" {
void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
}
// Initialize an MXFP tensor with SoA (Struct-of-Arrays) layout.
// soa_bytes: byte width of one SoA region. Default 0 = ne[0] elements (one ggml row).
// For FA K/V tensors, pass nb[1] so that when heads are physically contiguous
// within one KV-position stride, the SoA region spans all heads (matching FA's read pattern).
// Initialize an MXFP tensor with SoA layout (soa_bytes = region width, 0 = one row).
static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f, size_t soa_bytes = 0) {
GGML_ASSERT(ggml_is_type_mxfp(tensor->type));
@ -189,9 +185,6 @@ static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float
std::uniform_real_distribution<float> dist(min, max);
std::vector<float> region_f32(soa_elems);
// Iterate over logical SoA regions using tensor strides.
// Each SoA region is soa_bytes wide at the innermost stride level.
// Outer dimensions (those with stride > soa_bytes) are iterated explicitly.
const size_t nb1 = tensor->nb[1];
const size_t nb2 = tensor->nb[2];
const size_t nb3 = tensor->nb[3];
@ -199,18 +192,13 @@ static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float
const int64_t ne2 = tensor->ne[2];
const int64_t ne3 = tensor->ne[3];
// Determine iteration: if soa_bytes == nb1, iterate over (ne1 * ne2 * ne3) regions.
// If soa_bytes < nb1 (per-head), iterate over (ne1 * ne2 * ne3) regions with stride nb1.
// We use strides to compute offsets, handling views and permutations correctly.
const int64_t heads_per_region = (int64_t)(soa_bytes / head_row_sz);
GGML_ASSERT(soa_bytes % head_row_sz == 0 && "soa_bytes must be a multiple of head_row_sz");
// For multi-head regions, we step by nb1 (KV-position stride) between regions.
// For per-head, we step through all dimensions.
std::vector<uint8_t> buf(ggml_nbytes(tensor), 0);
if (heads_per_region > 1) {
// Multi-head SoA: iterate over (kv_positions * batches), each region = nb1 bytes
// Multi-head SoA:
for (int64_t i3 = 0; i3 < ne3; i3++) {
const int64_t n_groups = ne2 / heads_per_region;
for (int64_t ig = 0; ig < n_groups; ig++) {
@ -222,7 +210,7 @@ static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float
}
}
} else {
// Per-head SoA: one SoA region per ggml row
// Per-head SoA:
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
for (int64_t i1 = 0; i1 < ne1; i1++) {
@ -328,7 +316,6 @@ static std::vector<float> tensor_to_float(const ggml_tensor * t) {
bool quantized = ggml_is_quantized(t->type);
const bool is_mxfp = ggml_is_type_mxfp(t->type);
// SoA dequant for MXFP readback
mxfp_soa_dequantize_fn mxfp_dequant_soa = nullptr;
if (is_mxfp) {
switch (t->type) {
@ -4051,7 +4038,6 @@ struct test_mul_mat_id : public test_case {
return 5e-4;
}
// Same Blackwell FP4 tolerance as test_mul_mat above.
double max_nmse_err(ggml_backend_t backend) override {
// for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance
if (type_a == GGML_TYPE_MXFP4_E2M1 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
@ -6401,9 +6387,7 @@ struct test_flash_attn_ext : public test_case {
} else if (strcmp(t->name, "m") == 0) {
init_tensor_kq_mask(t);
} else if ((strcmp(t->name, "k") == 0 || strcmp(t->name, "v") == 0) && ggml_is_type_mxfp(t->type)) {
// MXFP K/V tensors use SoA layout. Pass nb[1] (KV-position stride) as the
// SoA region width — when heads are physically contiguous within that stride,
// the FA kernel dequants the full multi-head region as one SoA block.
// MXFP K/V use SoA layout; nb[1] spans all heads in one KV-position stride
init_tensor_mxfp_soa(t, -1.0f, 1.0f, t->nb[1]);
} else {
init_tensor_uniform(t);
@ -8778,7 +8762,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (ggml_type type_V : {GGML_TYPE_MXFP4_E2M1}) {
if (type_K == type_V) continue;
for (int nb : {1, 3, 32}) {
// hsk hsv nh nr23 kv nb mask sinks bias softcap prec type_K permute type_V
test_cases.emplace_back(new test_flash_attn_ext(
128, 128, 4, {1, 1}, 512, nb, true, false, 0.0f, 0.0f, GGML_PREC_F32, type_K, {0, 1, 2, 3}, type_V));
}