perf : multiple fixes and enhancements, remove MSE search, expand test coverage

* fix: correct tiled flash attention SoA pointer math for multihead MXFP

The cleanup refactoring (c919bc471) extracted mxfp_dequant_head as a
shared helper but failed to update the tiled path's data pointers.
The helper expects the full SoA row base (no per-head offset), but the
tiled path was passing a pointer that already included ik2*nbk2, causing
a double head offset that produced NaN during prefill.

Add mxfp_row_ptr helper to centralize the multihead-aware pointer
calculation across both one_chunk and tiled paths. Verified with 16-chunk
perplexity on gpt-oss-20b: all four configs (f16, mxfp4, mxfp6, mxfp8)
produce exact matches with the known-good commit (23e88631c).

* perf: reduce E8M0 MSE search range from ±2 to ±1

The base estimate round(log2(amax)) is always within 1 step of optimal.
Empirically verified across 30K blocks and 6 distributions: ±1 and ±2
never disagree. This reduces the scale search from 5 to 3 candidates
(40% fewer inner loop iterations) with zero quality impact.

* perf: eliminate redundant work in MXFP quantize and flash attention

- mse_error_mxfp4: use passed inv_scale instead of recomputing 1/d
- mxfp_compute_e8m0_mse: hoist loop-invariant traits branch out of inner loop
- tiled V path: dequant directly to V32 tile, remove intermediate memcpy and dead buffer

* cleanup: fix comments, unify Hadamard condition, simplify E8M0 helpers

- EMAX_OFFSET comments: fix ceil/floor labels to match actual values
- Hadamard flag: unify write path (llama-kv-cache.cpp) and read path
  (ops.cpp) to both use DK==DV condition instead of is_mla()
- E8M0 helpers in ggml-impl.h: simplify to match ggml-common.h style,
  add cross-reference comment

* fix: MXFP8/6 flash attention tests crash on init

The view base tensors for K/V don't get named "k"/"v" but inherit the
MXFP type. The name-based filter in initialize_tensors missed them,
falling through to init_tensor_uniform which calls quantize_chunk and
aborts for KV-cache-only types. Fix by checking ggml_is_type_mxfp() for
all tensors, matching the pattern set_rows tests already use.

* test: expand MXFP set_rows coverage

- Add MXFP8/MXFP6 to all_types for non-Hadamard set_rows coverage
- Expand Hadamard set_rows tests: add views, broadcast, and multi-head configs
- Coverage: 18 → 51 MXFP set_rows tests

* perf: add AVX2 Hadamard for x86 (matches existing ARM NEON path)

* cleanup: DRY MXFP4 quantize/dequant with shared per-block helpers

Extract quantize_block_mxfp4 and dequantize_block_mxfp4 as shared
helpers used by both AoS (quantize_row_mxfp4_ref, dequantize_row_mxfp4)
and SoA (quantize_row_mxfp4_soa, dequantize_row_mxfp4_soa) paths.
Eliminates duplicated per-block logic while keeping layout-specific
pointer arithmetic in the callers.

* feat: add MXFP8/MXFP6 AoS quantize/dequant (full type support)

Extract quantize_block_mxfp / dequantize_block_mxfp per-block helpers
from the SoA generic impl and use them to build AoS row functions for
MXFP8 (E4M3) and MXFP6 (E2M3). Register to_float and from_float_ref
in type traits, add quantize_chunk dispatch, replacing the GGML_ABORT.

MXFP8 and MXFP6 are no longer KV-cache-only — they can now be used
as general quantization types. The SoA impl is also DRY'd to delegate
to the same per-block helpers.

* cleanup: remove dead soa_elems field from mxfp_kv_params

Computed but never read — leftover from an earlier design.

* feat: add MXFP8/MXFP6 vec_dot and full CPU type support

Add scalar vec_dot_mxfp8_q8_0 and vec_dot_mxfp6_q8_0 implementations,
register from_float + vec_dot + vec_dot_type in CPU traits, and add
fallback remaps for all architectures. MXFP8/6 are now fully tested:
AoS quantization error, reference match, and dot product accuracy all
pass in test-quantize-fns.

* perf: remove E8M0 MSE search — base estimate is perplexity-optimal

The MSE search over ±1 candidates around round(log2(amax)) was found to
HURT perplexity by 4-37 PPL points across all MXFP configs on gpt-oss-20b.
The base estimate alone (no search) produces better attention patterns
because minimizing per-block reconstruction error is not the same as
minimizing attention score distortion through softmax.

Removes mse_error_mxfp4, mse_error field from traits, MSE_RANGE constant,
and the entire search loop. E8M0 computation is now a single amax scan +
integer bit extraction — no inner loop, no function pointers. This also
simplifies future GPU/Metal implementations.

* perf: fuse Hadamard rotation into SoA quantize (one pass, no temp buffer)

Add quantize_row_mxfp{4,8,6}_soa_hadamard that apply Hadamard and
quantize block-by-block with a 32-float stack buffer. Eliminates the
std::vector heap allocation and 2 extra memory passes over the full row.

set_rows now dispatches to the fused path when Hadamard is enabled,
falling through to the unfused quantize for non-Hadamard types.

This pattern maps directly to a CUDA kernel: global memory read →
register Hadamard → register quantize → global memory write.

* cleanup: consistent MXFP type names and variable naming

- Rename type_name "mxfp8_e4m3" → "mxfp8", "mxfp6_e2m3" → "mxfp6"
  to match "mxfp4". Only one variant of each exists — the suffix was
  unnecessary disambiguation that implied alternatives.
- Remove redundant MXFP shortcuts from arg.cpp (fallback loop handles
  all types via ggml_type_name matching).
- Rename kv_is_f32_f16_or_mxfp → k_is_f32_f16_or_mxfp (only checks K).

* perf: fuse Q preprocessing round-trip (no SoA buffer needed)

Add mxfp{4,8,6}_hadamard_roundtrip and mxfp{4,8,6}_roundtrip functions
that apply quantization error to float values without materializing SoA
bytes. Replaces the 3-step Q preprocessing (Hadamard → quantize to SoA
buffer → dequant from SoA buffer) with a single pass through per-block
round-trip helpers.

Eliminates the Q_q intermediate buffer and two function pointer calls
from the flash attention hot path. Maps directly to CUDA: Q stays in
registers, Hadamard butterfly + quantize error applied in-place.

* fix: clamp E8M0 = 255 to 254 in decode (fixes CI NaN failures)

E8M0 = 255 means NaN per MX spec, but our encode path already clamps
to 254. When test data contains random E8M0 = 255 bytes, the decode
produces Inf, and Inf * 0.0 = NaN, causing GET_ROWS and CPY tests to
fail on MXFP6 (and potentially MXFP4/8).

Fix: clamp 255 → 254 in both E8M0 decode functions:
  - ggml_e8m0_to_fp32 / ggml_e8m0_to_fp32_half (ggml-impl.h)
  - ggml_mxfp_e8m0_to_fp32 / ggml_mxfp_e8m0_to_fp32_half (ggml-common.h)

These are unfortunately duplicated across two headers because
ggml-common.h compiles for CUDA (__device__) while ggml-impl.h serves
CPU-only callers that don't include ggml-common.h.
This commit is contained in:
Tim Burke 2026-03-22 20:12:09 -04:00 committed by GitHub
parent c919bc471b
commit ccea34ba41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 531 additions and 239 deletions

View File

@ -404,15 +404,6 @@ const std::vector<ggml_type> kv_cache_types = {
};
static ggml_type kv_cache_type_from_str(const std::string & s) {
if (s == "mxfp4") {
return GGML_TYPE_MXFP4;
}
if (s == "mxfp6") {
return GGML_TYPE_MXFP6;
}
if (s == "mxfp8") {
return GGML_TYPE_MXFP8;
}
for (const auto & type : kv_cache_types) {
if (ggml_type_name(type) == s) {
return type;

View File

@ -199,13 +199,12 @@ typedef struct {
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
// E8M0 shared exponent constants (OCP MX v1.0 SS5.3).
// EMAX_OFFSET = ceil(log2(max_finite)), MSE_RANGE = search radius for optimal scale.
#define MXFP_E8M0_MSE_RANGE 2
#define MXFP4_E2M1_EMAX_OFFSET 2 // ceil(log2(6.0))
#define MXFP6_E2M3_EMAX_OFFSET 3 // ceil(log2(7.5))
#define MXFP6_E3M2_EMAX_OFFSET 5 // ceil(log2(28.0))
#define MXFP8_E4M3_EMAX_OFFSET 8 // ceil(log2(448))
#define MXFP8_E5M2_EMAX_OFFSET 16 // ceil(log2(57344))
// EMAX_OFFSET ≈ log2(max_finite), used by round(log2(amax)) base estimate.
#define MXFP4_E2M1_EMAX_OFFSET 2 // floor(log2(6.0)) = 2
#define MXFP6_E2M3_EMAX_OFFSET 3 // ceil(log2(7.5)) = 3
#define MXFP6_E3M2_EMAX_OFFSET 5 // ceil(log2(28.0)) = 5
#define MXFP8_E4M3_EMAX_OFFSET 8 // floor(log2(448)) = 8
#define MXFP8_E5M2_EMAX_OFFSET 16 // ceil(log2(57344)) = 16
// MXFP type properties -- shared across all backends.
#define MXFP_BITS_PER_ELEM_E2M1 4
@ -1635,13 +1634,17 @@ GGML_MXFP_FUNC void ggml_mxfp_unpack_fp6x4(const uint8_t in[3], uint8_t v[4]) {
// E8M0 shared exponent → float conversion.
// E8M0 encoding: value = 2^(x - 127) for x > 0, 2^(-127) for x == 0.
// E8M0 = 255 is NaN per MX spec, but we clamp to 254 (max finite) to match
// the encode path which also clamps to 254, preventing Inf * 0 = NaN in dequant.
GGML_MXFP_FUNC float ggml_mxfp_e8m0_to_fp32(uint8_t x) {
if (x == 255) { x = 254; }
uint32_t bits = (x == 0) ? 0x00400000u : ((uint32_t)x << 23);
return GGML_MXFP_U32_AS_F32(bits);
}
// E8M0 → float/2. Used with MXFP4 since E2M1 values are doubled in kvalues_mxfp4.
GGML_MXFP_FUNC float ggml_mxfp_e8m0_to_fp32_half(uint8_t x) {
if (x == 255) { x = 254; }
uint32_t bits = (x < 2) ? (0x00200000u << x) : ((uint32_t)(x - 1) << 23);
return GGML_MXFP_U32_AS_F32(bits);
}

View File

@ -14,6 +14,8 @@
#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0
#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
@ -72,6 +74,9 @@
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
// quants.c
#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0
#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
@ -83,6 +88,8 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
// quants.c
#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0
#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
@ -164,6 +171,8 @@
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0
#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
@ -319,6 +328,8 @@
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0
#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4

View File

@ -280,13 +280,19 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.nrows = 1,
},
[GGML_TYPE_MXFP8] = {
.from_float = (ggml_from_float_t)quantize_row_mxfp8_ref,
.from_float_soa = quantize_row_mxfp8_soa,
.to_float_soa = dequantize_row_mxfp8_soa_cpu,
.vec_dot = ggml_vec_dot_mxfp8_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
[GGML_TYPE_MXFP6] = {
.from_float = (ggml_from_float_t)quantize_row_mxfp6_ref,
.from_float_soa = quantize_row_mxfp6_soa,
.to_float_soa = dequantize_row_mxfp6_soa_cpu,
.vec_dot = ggml_vec_dot_mxfp6_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
[GGML_TYPE_Q2_K] = {

View File

@ -4909,8 +4909,106 @@ void ggml_compute_forward_get_rows(
//}
}
// NEON-optimized Hadamard; scalar fallback below
#if defined(__ARM_NEON)
// SIMD-optimized Hadamard; scalar fallback below
#if defined(__AVX2__) || defined(__AVX__)
static void hadamard_32_inplace(float vals[32]) {
// 32 floats = 4 × __m256
__m256 v0 = _mm256_loadu_ps(vals + 0);
__m256 v1 = _mm256_loadu_ps(vals + 8);
__m256 v2 = _mm256_loadu_ps(vals + 16);
__m256 v3 = _mm256_loadu_ps(vals + 24);
// Stride 1: butterfly on adjacent pairs within each 256-bit register
{
// Interleave even/odd elements, add/sub
__m256 a, b, s, d;
a = _mm256_shuffle_ps(v0, v0, _MM_SHUFFLE(2, 2, 0, 0));
b = _mm256_shuffle_ps(v0, v0, _MM_SHUFFLE(3, 3, 1, 1));
s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b);
v0 = _mm256_shuffle_ps(s, d, _MM_SHUFFLE(2, 0, 2, 0));
v0 = _mm256_shuffle_ps(v0, v0, _MM_SHUFFLE(3, 1, 2, 0));
a = _mm256_shuffle_ps(v1, v1, _MM_SHUFFLE(2, 2, 0, 0));
b = _mm256_shuffle_ps(v1, v1, _MM_SHUFFLE(3, 3, 1, 1));
s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b);
v1 = _mm256_shuffle_ps(s, d, _MM_SHUFFLE(2, 0, 2, 0));
v1 = _mm256_shuffle_ps(v1, v1, _MM_SHUFFLE(3, 1, 2, 0));
a = _mm256_shuffle_ps(v2, v2, _MM_SHUFFLE(2, 2, 0, 0));
b = _mm256_shuffle_ps(v2, v2, _MM_SHUFFLE(3, 3, 1, 1));
s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b);
v2 = _mm256_shuffle_ps(s, d, _MM_SHUFFLE(2, 0, 2, 0));
v2 = _mm256_shuffle_ps(v2, v2, _MM_SHUFFLE(3, 1, 2, 0));
a = _mm256_shuffle_ps(v3, v3, _MM_SHUFFLE(2, 2, 0, 0));
b = _mm256_shuffle_ps(v3, v3, _MM_SHUFFLE(3, 3, 1, 1));
s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b);
v3 = _mm256_shuffle_ps(s, d, _MM_SHUFFLE(2, 0, 2, 0));
v3 = _mm256_shuffle_ps(v3, v3, _MM_SHUFFLE(3, 1, 2, 0));
}
// Stride 2: butterfly on pairs separated by 2 within 128-bit lanes
{
__m256 a, b, s, d;
a = _mm256_permute_ps(v0, _MM_SHUFFLE(1, 0, 1, 0));
b = _mm256_permute_ps(v0, _MM_SHUFFLE(3, 2, 3, 2));
s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b);
v0 = _mm256_blend_ps(s, d, 0xCC); // 0b11001100
a = _mm256_permute_ps(v1, _MM_SHUFFLE(1, 0, 1, 0));
b = _mm256_permute_ps(v1, _MM_SHUFFLE(3, 2, 3, 2));
s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b);
v1 = _mm256_blend_ps(s, d, 0xCC);
a = _mm256_permute_ps(v2, _MM_SHUFFLE(1, 0, 1, 0));
b = _mm256_permute_ps(v2, _MM_SHUFFLE(3, 2, 3, 2));
s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b);
v2 = _mm256_blend_ps(s, d, 0xCC);
a = _mm256_permute_ps(v3, _MM_SHUFFLE(1, 0, 1, 0));
b = _mm256_permute_ps(v3, _MM_SHUFFLE(3, 2, 3, 2));
s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b);
v3 = _mm256_blend_ps(s, d, 0xCC);
}
// Stride 4: butterfly between 128-bit lanes within each 256-bit register
{
__m128 lo, hi;
lo = _mm256_castps256_ps128(v0); hi = _mm256_extractf128_ps(v0, 1);
v0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_add_ps(lo, hi)), _mm_sub_ps(lo, hi), 1);
lo = _mm256_castps256_ps128(v1); hi = _mm256_extractf128_ps(v1, 1);
v1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_add_ps(lo, hi)), _mm_sub_ps(lo, hi), 1);
lo = _mm256_castps256_ps128(v2); hi = _mm256_extractf128_ps(v2, 1);
v2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_add_ps(lo, hi)), _mm_sub_ps(lo, hi), 1);
lo = _mm256_castps256_ps128(v3); hi = _mm256_extractf128_ps(v3, 1);
v3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_add_ps(lo, hi)), _mm_sub_ps(lo, hi), 1);
}
// Stride 8: butterfly between registers
{
__m256 s, d;
s = _mm256_add_ps(v0, v1); d = _mm256_sub_ps(v0, v1); v0 = s; v1 = d;
s = _mm256_add_ps(v2, v3); d = _mm256_sub_ps(v2, v3); v2 = s; v3 = d;
}
// Stride 16: butterfly between register pairs
{
__m256 s, d;
s = _mm256_add_ps(v0, v2); d = _mm256_sub_ps(v0, v2); v0 = s; v2 = d;
s = _mm256_add_ps(v1, v3); d = _mm256_sub_ps(v1, v3); v1 = s; v3 = d;
}
// Normalize by 1/sqrt(32)
const __m256 norm = _mm256_set1_ps(MXFP_HADAMARD_32_NORM);
_mm256_storeu_ps(vals + 0, _mm256_mul_ps(v0, norm));
_mm256_storeu_ps(vals + 8, _mm256_mul_ps(v1, norm));
_mm256_storeu_ps(vals + 16, _mm256_mul_ps(v2, norm));
_mm256_storeu_ps(vals + 24, _mm256_mul_ps(v3, norm));
}
#elif defined(__ARM_NEON)
static void hadamard_32_inplace(float vals[32]) {
float32x4_t v0 = vld1q_f32(vals + 0);
float32x4_t v1 = vld1q_f32(vals + 4);
@ -5032,9 +5130,15 @@ static void ggml_compute_forward_set_rows_f32(
ggml_from_float_t mxfp_soa_quantize = dst_traits->from_float_soa;
ggml_from_float_t from_float = mxfp_soa_quantize ? nullptr : dst_traits->from_float;
std::vector<float> had_tmp;
if (apply_hadamard) {
had_tmp.resize(nc);
// Fused Hadamard+quantize: one pass per block, 32-float stack buffer, no heap allocation.
ggml_from_float_t mxfp_soa_hadamard_quantize = nullptr;
if (apply_hadamard && mxfp_soa_quantize) {
switch (dst->type) {
case GGML_TYPE_MXFP4: mxfp_soa_hadamard_quantize = (ggml_from_float_t)quantize_row_mxfp4_soa_hadamard; break;
case GGML_TYPE_MXFP8: mxfp_soa_hadamard_quantize = (ggml_from_float_t)quantize_row_mxfp8_soa_hadamard; break;
case GGML_TYPE_MXFP6: mxfp_soa_hadamard_quantize = (ggml_from_float_t)quantize_row_mxfp6_soa_hadamard; break;
default: break;
}
}
for (int64_t i03 = 0; i03 < ne03; ++i03) {
@ -5051,20 +5155,12 @@ static void ggml_compute_forward_set_rows_f32(
const float * src_row = (const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03);
char * dst_row = ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3);
if (apply_hadamard) {
memcpy(had_tmp.data(), src_row, nc * sizeof(float));
ggml_apply_hadamard_blocks(had_tmp.data(), nc);
if (mxfp_soa_quantize) {
mxfp_soa_quantize(had_tmp.data(), dst_row, nc);
} else {
from_float(had_tmp.data(), dst_row, nc);
}
if (mxfp_soa_hadamard_quantize) {
mxfp_soa_hadamard_quantize(src_row, dst_row, nc);
} else if (mxfp_soa_quantize) {
mxfp_soa_quantize(src_row, dst_row, nc);
} else {
if (mxfp_soa_quantize) {
mxfp_soa_quantize(src_row, dst_row, nc);
} else {
from_float(src_row, dst_row, nc);
}
from_float(src_row, dst_row, nc);
}
}
}
@ -8268,7 +8364,6 @@ typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t);
struct mxfp_kv_params {
mxfp_soa_dequantize_fn dequantize;
bool multihead;
int64_t soa_elems;
int qs_per_block;
int head_qs_bytes;
int64_t head_e8m0_offset;
@ -8277,12 +8372,28 @@ struct mxfp_kv_params {
// MXFP dispatch parameters for flash attention.
struct mxfp_fa_params {
mxfp_soa_quantize_fn q_quantize;
mxfp_soa_quantize_fn q_quantize; // SoA quantize for Q (used only when Hadamard is off AND non-MXFP K path)
// Fused Q round-trip: Hadamard + quantize + dequant in one pass, no SoA buffer.
void (*q_roundtrip)(const float *, float *, int64_t);
mxfp_kv_params k;
mxfp_kv_params v;
bool apply_hadamard;
};
// Compute the SoA row base pointer for a given KV position and head.
// In multihead mode, the SoA region spans all heads at one KV position,
// so the row base must NOT include the per-head offset (head_idx * nb2).
// mxfp_dequant_head handles per-head indexing within the SoA region.
// In per-head mode, each head has its own SoA region, so the base includes nb2.
static inline const char * mxfp_row_ptr(
const mxfp_kv_params & kv, const char * data,
int64_t kv_pos, size_t nb1, int head_idx, size_t nb2, int batch_idx, size_t nb3) {
if (kv.multihead) {
return data + kv_pos*nb1 + batch_idx*nb3;
}
return data + kv_pos*nb1 + head_idx*nb2 + batch_idx*nb3;
}
// Extract one head's SoA data from a multihead row and dequantize.
static inline void mxfp_dequant_head(
const mxfp_kv_params & kv, const char * row, int head_idx,
@ -8305,7 +8416,6 @@ static mxfp_kv_params mxfp_kv_params_init(ggml_type type, int64_t D, size_t nb2,
mxfp_kv_params kv = {};
kv.dequantize = ggml_get_type_traits_cpu(type)->to_float_soa;
kv.multihead = (nb2 == (size_t)ggml_row_size(type, D));
kv.soa_elems = kv.multihead ? ne2 * D : D;
kv.qs_per_block = ggml_mxfp_qs_per_block(type);
kv.blocks_per_head = (int)(D / 32);
kv.head_qs_bytes = kv.blocks_per_head * kv.qs_per_block;
@ -8328,6 +8438,17 @@ static mxfp_fa_params mxfp_fa_params_init(
p.q_quantize = ggml_get_type_traits_cpu(k->type)->from_float_soa;
p.k = mxfp_kv_params_init(k->type, DK, nbk2, nek2);
}
// Select fused Q round-trip (Hadamard + quantize error, no SoA buffer).
if (is_mxfp_k) {
const bool had = is_mxfp_k && (DK == DV) && ggml_mxfp_use_hadamard(k->type);
switch (k->type) {
case GGML_TYPE_MXFP4: p.q_roundtrip = had ? mxfp4_hadamard_roundtrip : mxfp4_roundtrip; break;
case GGML_TYPE_MXFP8: p.q_roundtrip = had ? mxfp8_hadamard_roundtrip : mxfp8_roundtrip; break;
case GGML_TYPE_MXFP6: p.q_roundtrip = had ? mxfp6_hadamard_roundtrip : mxfp6_roundtrip; break;
default: break;
}
}
if (is_mxfp_v) {
p.v = mxfp_kv_params_init(v->type, DV, nbv2, nev2);
}
@ -8486,22 +8607,14 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
const char * k_base = (const char *) k->data + k_base_offset;
const char * v_base = (const char *) v->data + v_base_offset;
const char * k_row_base = mxfp.k.multihead ? ((const char *) k->data + ik3*nbk3) : nullptr;
const char * v_row_base = mxfp.v.multihead ? ((const char *) v->data + iv3*nbv3) : nullptr;
const char * k_data_base = (const char *) k->data;
const char * v_data_base = (const char *) v->data;
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
float Q_f32[MXFP_FA_MAX_D];
if (is_mxfp_k) {
// Q preprocessing: Hadamard + SoA round-trip captures same quantization loss as K.
if (mxfp.apply_hadamard) {
float q_tmp[MXFP_FA_MAX_D];
memcpy(q_tmp, pq, DK * sizeof(float));
ggml_apply_hadamard_blocks(q_tmp, DK);
mxfp.q_quantize(q_tmp, Q_q, DK);
} else {
mxfp.q_quantize(pq, Q_q, DK);
}
mxfp.k.dequantize(Q_q, Q_f32, DK);
if (mxfp.q_roundtrip) {
// Q preprocessing: fused Hadamard + quantize round-trip, no SoA buffer.
mxfp.q_roundtrip(pq, Q_f32, DK);
} else {
if (mxfp.apply_hadamard) {
float q_tmp[MXFP_FA_MAX_D];
@ -8526,7 +8639,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
float s; // KQ value
if (is_mxfp_k) {
const char * k_row = mxfp.k.multihead ? k_row_base + ic*nbk1 : k_base + ic*nbk1;
const char * k_row = mxfp_row_ptr(mxfp.k, k_data_base,
ic, nbk1, ik2, nbk2, ik3, nbk3);
mxfp_dequant_head(mxfp.k, k_row, ik2, k_head_soa, k_dequant_buf, DK);
ggml_vec_dot_f32(DK, &s, 0, k_dequant_buf, 0, Q_f32, 0, 1);
} else {
@ -8572,7 +8686,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
// V += v*expf(s - M)
if (mxfp.v.dequantize) {
const char * v_row = mxfp.v.multihead ? v_row_base + ic*nbv1 : v_base + ic*nbv1;
const char * v_row = mxfp_row_ptr(mxfp.v, v_data_base,
ic, nbv1, iv2, nbv2, iv3, nbv3);
mxfp_dequant_head(mxfp.v, v_row, iv2, v_head_soa, v_dequant_buf, DV);
ggml_vec_mad_f32(DV, VKQ32, v_dequant_buf, vs);
} else if (v_to_float) {
@ -8723,7 +8838,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
if (is_mxfp_v) { GGML_ASSERT(DV <= MXFP_FA_MAX_D); }
float k_dequant_buf[MXFP_FA_MAX_D];
float v_dequant_buf[MXFP_FA_MAX_D];
char k_head_soa[MXFP_FA_SOA_BUF];
char v_head_soa[MXFP_FA_SOA_BUF];
@ -8786,13 +8900,9 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));
if (is_mxfp_k) {
if (mxfp.apply_hadamard) {
ggml_apply_hadamard_blocks(Q_f32 + tq * DK, DK);
}
uint8_t q_mxfp_buf[MXFP_FA_SOA_BUF];
mxfp.q_quantize(Q_f32 + tq * DK, q_mxfp_buf, DK);
mxfp.k.dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK);
if (mxfp.q_roundtrip) {
// In-place: Q_f32 is already populated by memcpy above, roundtrip overwrites.
mxfp.q_roundtrip(Q_f32 + tq * DK, Q_f32 + tq * DK, DK);
}
}
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
@ -8843,7 +8953,9 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
}
} else if (mxfp.k.dequantize) {
mxfp_dequant_head(mxfp.k, k_data, ik2, k_head_soa, k_dequant_buf, DK);
const char * k_row = mxfp_row_ptr(mxfp.k, (const char *)k->data,
ic + tk, nbk1, ik2, nbk2, ik3, nbk3);
mxfp_dequant_head(mxfp.k, k_row, ik2, k_head_soa, k_dequant_buf, DK);
for (int64_t dk = 0; dk < DK; dk++) {
K_f32[dk * KV_TILE_SZ + tk] = k_dequant_buf[dk];
}
@ -8913,8 +9025,9 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
} else if (v_type == GGML_TYPE_F32) {
memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
} else if (mxfp.v.dequantize) {
mxfp_dequant_head(mxfp.v, v_data, iv2, v_head_soa, v_dequant_buf, DV);
memcpy(V32 + tk * DV, v_dequant_buf, DV * sizeof(float));
const char * v_row = mxfp_row_ptr(mxfp.v, (const char *)v->data,
ic + tk, nbv1, iv2, nbv2, iv3, nbv3);
mxfp_dequant_head(mxfp.v, v_row, iv2, v_head_soa, V32 + tk * DV, DV);
} else {
v_to_float(v_data, V32 + tk * DV, DV);
}
@ -9087,10 +9200,10 @@ static void ggml_compute_forward_flash_attn_ext_f16(
// Split-KV: parallelize across KV chunks for single-query decode (token generation).
// Only for types whose tiled/one_chunk paths produce identical results (f32, f16, MXFP).
// Standard quant types (q8_0, q4_0) must use the scalar path to preserve vec_dot semantics.
const bool kv_is_f32_f16_or_mxfp = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16
const bool k_is_f32_f16_or_mxfp = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16
|| ggml_is_type_mxfp(k->type));
const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1)
&& kv_is_f32_f16_or_mxfp
&& k_is_f32_f16_or_mxfp
&& q->type == GGML_TYPE_F32 && nek1 >= 512;
if (use_split_kv_path) {
@ -9151,7 +9264,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
// Tiled path: f32, f16, and MXFP only (quant types use one_chunk)
bool use_tiled = !use_ref &&
(q->type == GGML_TYPE_F32 &&
kv_is_f32_f16_or_mxfp &&
k_is_f32_f16_or_mxfp &&
(k->type == v->type || ggml_is_type_mxfp(k->type)) &&
neq1 >= Q_TILE_SZ);
#ifdef GGML_SIMD

View File

@ -189,6 +189,54 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
*s = sumf;
}
void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs);
assert(n % QK_MXFP8 == 0);
static_assert(QK_MXFP8 == QK8_0, "QK_MXFP8 and QK8_0 must be the same");
const block_mxfp8 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;
const int nb = n / QK_MXFP8;
float sumf = 0;
for (int ib = 0; ib < nb; ++ib) {
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d) * GGML_E8M0_TO_FP32(x[ib].e);
float sumi = 0;
for (int j = 0; j < QK_MXFP8; ++j) {
sumi += y[ib].qs[j] * ggml_mxfp_fp8_e4m3_to_float(x[ib].qs[j]);
}
sumf += d * sumi;
}
*s = sumf;
}
void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs);
assert(n % QK_MXFP6 == 0);
static_assert(QK_MXFP6 == QK8_0, "QK_MXFP6 and QK8_0 must be the same");
const block_mxfp6 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;
const int nb = n / QK_MXFP6;
float sumf = 0;
for (int ib = 0; ib < nb; ++ib) {
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d) * GGML_E8M0_TO_FP32(x[ib].e);
float sumi = 0;
for (int j = 0; j < QK_MXFP6; j += 4) {
uint8_t vals[4];
ggml_mxfp_unpack_fp6x4(&x[ib].qs[j * 3 / 4], vals);
for (int jj = 0; jj < 4; jj++) {
sumi += y[ib].qs[j + jj] * ggml_mxfp_fp6_e2m3_to_float(vals[jj]);
}
}
sumf += d * sumi;
}
*s = sumf;
}
void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);

View File

@ -42,6 +42,8 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);

View File

@ -431,13 +431,15 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
// E8M0 shared exponent to float: returns 2^(x - 127).
// Canonical implementation is ggml_mxfp_e8m0_to_fp32 in ggml-common.h.
// This thin wrapper exists because not all callers include ggml-common.h.
// MUST stay in sync — if you change the logic, change ggml-common.h too.
//
// E8M0 = 255 is NaN per MX spec; clamped to 254 (max finite) to match
// the encode path which also clamps to 254, preventing Inf * 0 = NaN.
static inline float ggml_e8m0_to_fp32(uint8_t x) {
uint32_t bits;
if (x == 0) {
bits = 0x00400000; // denorm: 0.5 * 2^(-126) = 2^(-127)
} else {
bits = (uint32_t) x << 23;
}
if (x == 255) { x = 254; }
uint32_t bits = (x == 0) ? 0x00400000u : ((uint32_t)x << 23);
float result;
memcpy(&result, &bits, sizeof(float));
return result;
@ -445,14 +447,8 @@ static inline float ggml_e8m0_to_fp32(uint8_t x) {
// E8M0 to float/2: returns 2^(x - 128).
static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
uint32_t bits;
if (x < 2) {
// x=0 → 2^(-128), x=1 → 2^(-127): denormal bit patterns
bits = 0x00200000 << x;
} else {
// 0.5 * 2^(x-127) = 2^(x-128): normalized with exponent (x-1)
bits = (uint32_t)(x - 1) << 23;
}
if (x == 255) { x = 254; }
uint32_t bits = (x < 2) ? (0x00200000u << x) : ((uint32_t)(x - 1) << 23);
float result;
memcpy(&result, &bits, sizeof(float));
return result;

View File

@ -263,8 +263,6 @@ float fp8_e4m3_to_float(uint8_t v) { return ggml_mxfp_fp8_e4m3_to_float(v); }
uint8_t float_to_fp8_e4m3_rn(float x) { return ggml_mxfp_float_to_fp8_e4m3(x); }
// ====================== MXFP quantization infrastructure
//
// MSE-optimal E8M0: tests candidates around round(log2(amax)), picks lowest quantization error.
typedef struct {
int emax_offset; // type-specific offset to max representable exponent
@ -272,33 +270,13 @@ typedef struct {
int bits_per_elem; // 8 = byte-aligned, 6 = packed via fp6x4
uint8_t (*to_elem)(float);
float (*to_float)(uint8_t);
float (*mse_error)(float val, float inv_scale, float scale); // NULL = use generic round-trip via to_elem/to_float
} mxfp_elem_traits_t;
static inline int best_index_mxfp4(float x, float e);
// MSE error for MXFP4 (kvalues are doubled, so scale is halved)
static float mse_error_mxfp4(float val, float inv_scale, float scale) {
const float d = scale * 0.5f;
const float inv_d = (d > 0.0f) ? 1.0f / d : 0.0f;
const float normalized = fabsf(val) * inv_d;
(void)inv_scale;
float qval;
if (normalized < 0.5f) qval = 0.0f;
else if (normalized < 1.5f) qval = 1.0f;
else if (normalized < 2.5f) qval = 2.0f;
else if (normalized < 3.5f) qval = 3.0f;
else if (normalized < 5.0f) qval = 4.0f;
else if (normalized < 7.0f) qval = 6.0f;
else if (normalized < 10.0f) qval = 8.0f;
else qval = 12.0f;
const float err = fabsf(val) - qval * d;
return err * err;
}
static const mxfp_elem_traits_t mxfp4_traits = { MXFP4_E2M1_EMAX_OFFSET, MXFP4_SOA_QS_PER_BLOCK, 4, NULL, NULL, mse_error_mxfp4 };
static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_elem_traits_t * traits) {
// E8M0 shared exponent: round(log2(amax)) — no MSE search needed.
static inline uint8_t mxfp_compute_e8m0(const float * x, int qk, int emax_offset) {
float amax = 0.0f;
for (int j = 0; j < qk; j++) {
const float a = fabsf(x[j]);
@ -306,36 +284,8 @@ static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_
}
if (amax == 0.0f) return 0;
const int e_base = ggml_mxfp_e8m0_base_estimate(amax, traits->emax_offset);
int e_lo = e_base - MXFP_E8M0_MSE_RANGE;
int e_hi = e_base + MXFP_E8M0_MSE_RANGE;
if (e_lo < 1) e_lo = 1;
if (e_hi < 1) e_hi = 1;
if (e_hi > 254) e_hi = 254;
int best_e = e_base < 0 ? 0 : (e_base > 254 ? 254 : e_base);
float best_mse = 1e30f;
for (int test_e = e_lo; test_e <= e_hi; ++test_e) {
const float test_scale = GGML_E8M0_TO_FP32((uint8_t)test_e);
const float test_inv = 1.0f / test_scale;
float mse = 0.0f;
for (int j = 0; j < qk; ++j) {
if (traits->mse_error) {
mse += traits->mse_error(x[j], test_inv, test_scale);
} else {
const float recon = traits->to_float(traits->to_elem(x[j] * test_inv)) * test_scale;
const float err = x[j] - recon;
mse += err * err;
}
}
if (mse < best_mse) {
best_mse = mse;
best_e = test_e;
}
}
return (uint8_t)best_e;
const int e = ggml_mxfp_e8m0_base_estimate(amax, emax_offset);
return (uint8_t)(e < 0 ? 0 : (e > 254 ? 254 : e));
}
static inline int best_index_mxfp4(float x, float e) {
@ -353,26 +303,112 @@ static inline int best_index_mxfp4(float x, float e) {
return (x < 0.0f) ? (idx + 8) : idx;
}
void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {
static const int qk = QK_MXFP4;
// Per-block MXFP4 quantize: shared between AoS and SoA paths.
static inline void quantize_block_mxfp4(const float * GGML_RESTRICT src, uint8_t * GGML_RESTRICT qs, uint8_t * e_out) {
const uint8_t e = mxfp_compute_e8m0(src, QK_MXFP4, MXFP4_E2M1_EMAX_OFFSET);
const float d = GGML_E8M0_TO_FP32_HALF(e);
*e_out = e;
for (int j = 0; j < QK_MXFP4/2; ++j) {
const uint8_t x0 = best_index_mxfp4(src[0 + j], d);
const uint8_t x1 = best_index_mxfp4(src[QK_MXFP4/2 + j], d);
qs[j] = x0 | (x1 << 4);
}
}
assert(k % qk == 0);
// Per-block MXFP4 quantize round-trip: apply quantization error without materializing bytes.
// Used for Q preprocessing in flash attention — matches K's error pattern.
static inline void roundtrip_block_mxfp4(float * GGML_RESTRICT vals) {
const uint8_t e = mxfp_compute_e8m0(vals, QK_MXFP4, MXFP4_E2M1_EMAX_OFFSET);
const float d = GGML_E8M0_TO_FP32_HALF(e);
for (int j = 0; j < QK_MXFP4; ++j) {
const int idx = best_index_mxfp4(vals[j], d);
vals[j] = kvalues_mxfp4[idx] * d; // kvalues are doubled, d is halved — matches dequant
}
}
const int nb = k / qk;
// Per-block generic MXFP quantize round-trip (MXFP8/MXFP6).
static inline void roundtrip_block_mxfp(float * GGML_RESTRICT vals, const mxfp_elem_traits_t * traits) {
const uint8_t e = mxfp_compute_e8m0(vals, 32, traits->emax_offset);
const float d = GGML_E8M0_TO_FP32(e);
const float inv_d = d > 0.0f ? 1.0f / d : 0.0f;
for (int j = 0; j < 32; ++j) {
vals[j] = traits->to_float(traits->to_elem(vals[j] * inv_d)) * d;
}
}
for (int i = 0; i < nb; i++) {
const uint8_t e = mxfp_compute_e8m0_mse(&x[i*qk], qk, &mxfp4_traits);
const float d = GGML_E8M0_TO_FP32_HALF(e);
// Fused Hadamard + quantize round-trip: one pass, output is float with quantization error.
void mxfp4_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) {
assert(k % 32 == 0);
for (int64_t i = 0; i < k; i += 32) {
memcpy(dst + i, src + i, 32 * sizeof(float));
ggml_mxfp_hadamard_32_inplace(dst + i);
roundtrip_block_mxfp4(dst + i);
}
}
y[i].e = e;
// Non-Hadamard round-trip for MXFP4 (Hadamard disabled or V cache).
void mxfp4_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) {
assert(k % 32 == 0);
for (int64_t i = 0; i < k; i += 32) {
memcpy(dst + i, src + i, 32 * sizeof(float));
roundtrip_block_mxfp4(dst + i);
}
}
for (int j = 0; j < qk/2; ++j) {
const uint8_t x0 = best_index_mxfp4(x[i*qk + 0 + j], d);
const uint8_t x1 = best_index_mxfp4(x[i*qk + qk/2 + j], d);
// Per-block MXFP4 dequant: shared between AoS and SoA paths.
static inline void dequantize_block_mxfp4(const uint8_t * GGML_RESTRICT qs, uint8_t e, float * GGML_RESTRICT dst) {
const float d = GGML_E8M0_TO_FP32_HALF(e);
for (int j = 0; j < QK_MXFP4/2; ++j) {
dst[0 + j] = kvalues_mxfp4[qs[j] & 0x0F] * d;
dst[QK_MXFP4/2 + j] = kvalues_mxfp4[qs[j] >> 4] * d;
}
}
y[i].qs[j] = x0;
y[i].qs[j] |= x1 << 4;
// Per-block generic MXFP quantize/dequant: shared between AoS and SoA for MXFP8/MXFP6.
static inline void quantize_block_mxfp(const float * GGML_RESTRICT src, uint8_t * GGML_RESTRICT qs,
uint8_t * e_out, const mxfp_elem_traits_t * traits) {
const uint8_t e = mxfp_compute_e8m0(src, 32, traits->emax_offset);
const float d = GGML_E8M0_TO_FP32(e);
const float inv_d = d > 0.0f ? 1.0f / d : 0.0f;
*e_out = e;
if (traits->bits_per_elem == 8) {
for (int j = 0; j < 32; ++j) {
qs[j] = traits->to_elem(src[j] * inv_d);
}
} else {
for (int j = 0; j < 32; j += 4) {
uint8_t vals[4];
for (int jj = 0; jj < 4; jj++) {
vals[jj] = traits->to_elem(src[j + jj] * inv_d);
}
pack_fp6x4(vals, &qs[j * 3 / 4]);
}
}
}
static inline void dequantize_block_mxfp(const uint8_t * GGML_RESTRICT qs, uint8_t e,
float * GGML_RESTRICT dst, const mxfp_elem_traits_t * traits) {
const float d = GGML_E8M0_TO_FP32(e);
if (traits->bits_per_elem == 8) {
for (int j = 0; j < 32; ++j) {
dst[j] = traits->to_float(qs[j]) * d;
}
} else {
for (int j = 0; j < 32; j += 4) {
uint8_t vals[4];
unpack_fp6x4(&qs[j * 3 / 4], vals);
for (int jj = 0; jj < 4; jj++) {
dst[j + jj] = traits->to_float(vals[jj]) * d;
}
}
}
}
void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {
assert(k % QK_MXFP4 == 0);
const int nb = k / QK_MXFP4;
for (int i = 0; i < nb; i++) {
quantize_block_mxfp4(&x[i*QK_MXFP4], y[i].qs, &y[i].e);
}
}
@ -522,22 +558,10 @@ void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRI
}
void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
static const int qk = QK_MXFP4;
assert(k % qk == 0);
const int nb = k / qk;
assert(k % QK_MXFP4 == 0);
const int nb = k / QK_MXFP4;
for (int i = 0; i < nb; i++) {
const float d = GGML_E8M0_TO_FP32_HALF(x[i].e);
for (int j = 0; j < qk/2; ++j) {
const int8_t x0 = kvalues_mxfp4[x[i].qs[j] & 0x0F];
const int8_t x1 = kvalues_mxfp4[x[i].qs[j] >> 4];
y[i*qk + j + 0 ] = x0*d;
y[i*qk + j + qk/2] = x1*d;
}
dequantize_block_mxfp4(x[i].qs, x[i].e, &y[i*QK_MXFP4]);
}
}
@ -582,112 +606,95 @@ uint8_t float_to_fp8_e5m2_rn(float x) { return ggml_mxfp_float_to_fp8_e5m2(x); }
void pack_fp6x4(const uint8_t v[4], uint8_t out[3]) { ggml_mxfp_pack_fp6x4(v, out); }
void unpack_fp6x4(const uint8_t in[3], uint8_t v[4]) { ggml_mxfp_unpack_fp6x4(in, v); }
static const mxfp_elem_traits_t mxfp8_e4m3_traits = { MXFP8_E4M3_EMAX_OFFSET, MXFP8_SOA_QS_PER_BLOCK, 8, float_to_fp8_e4m3_rn, fp8_e4m3_to_float, NULL };
static const mxfp_elem_traits_t mxfp6_e2m3_traits = { MXFP6_E2M3_EMAX_OFFSET, MXFP6_SOA_QS_PER_BLOCK, 6, float_to_fp6_e2m3_rn, fp6_e2m3_to_float, NULL };
static const mxfp_elem_traits_t mxfp8_e4m3_traits = { MXFP8_E4M3_EMAX_OFFSET, MXFP8_SOA_QS_PER_BLOCK, 8, float_to_fp8_e4m3_rn, fp8_e4m3_to_float };
static const mxfp_elem_traits_t mxfp6_e2m3_traits = { MXFP6_E2M3_EMAX_OFFSET, MXFP6_SOA_QS_PER_BLOCK, 6, float_to_fp6_e2m3_rn, fp6_e2m3_to_float };
// MXFP8 AoS quantize/dequant — uses shared per-block helpers.
void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
for (int i = 0; i < nb; i++) {
quantize_block_mxfp(&x[i*QK_MXFP8], y[i].qs, &y[i].e, &mxfp8_e4m3_traits);
}
}
void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
for (int i = 0; i < nb; i++) {
dequantize_block_mxfp(x[i].qs, x[i].e, &y[i*QK_MXFP8], &mxfp8_e4m3_traits);
}
}
// MXFP6 AoS quantize/dequant — uses shared per-block helpers.
void quantize_row_mxfp6_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
for (int i = 0; i < nb; i++) {
quantize_block_mxfp(&x[i*QK_MXFP6], y[i].qs, &y[i].e, &mxfp6_e2m3_traits);
}
}
void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
for (int i = 0; i < nb; i++) {
dequantize_block_mxfp(x[i].qs, x[i].e, &y[i*QK_MXFP6], &mxfp6_e2m3_traits);
}
}
// ====================== SoA (Struct-of-Arrays) quantize/dequantize for flash attention
void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) {
assert(k % QK_MXFP4 == 0);
const int nb = k / QK_MXFP4;
char * row = (char *)dst;
char * qs_base = row;
char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK);
char * qs_base = (char *)dst;
char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK);
for (int i = 0; i < nb; i++) {
const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP4], QK_MXFP4, &mxfp4_traits);
const float d = GGML_E8M0_TO_FP32_HALF(e);
e8m0_base[i] = (char)e;
uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP4_SOA_QS_PER_BLOCK));
for (int j = 0; j < QK_MXFP4/2; ++j) {
const uint8_t x0 = best_index_mxfp4(x[i*QK_MXFP4 + 0 + j], d);
const uint8_t x1 = best_index_mxfp4(x[i*QK_MXFP4 + QK_MXFP4/2 + j], d);
qs[j] = x0 | (x1 << 4);
}
quantize_block_mxfp4(&x[i*QK_MXFP4], qs, (uint8_t *)&e8m0_base[i]);
}
}
void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) {
assert(k % QK_MXFP4 == 0);
const int nb = k / QK_MXFP4;
const char * row = (const char *)src;
const char * qs_base = row;
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK);
const char * qs_base = (const char *)src;
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK);
for (int i = 0; i < nb; i++) {
const float d = GGML_E8M0_TO_FP32_HALF((uint8_t)e8m0_base[i]);
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP4_SOA_QS_PER_BLOCK));
for (int j = 0; j < QK_MXFP4/2; ++j) {
const int8_t x0 = kvalues_mxfp4[qs[j] & 0x0F];
const int8_t x1 = kvalues_mxfp4[qs[j] >> 4];
y[i*QK_MXFP4 + j + 0 ] = x0*d;
y[i*QK_MXFP4 + j + QK_MXFP4/2] = x1*d;
}
dequantize_block_mxfp4(qs, (uint8_t)e8m0_base[i], &y[i*QK_MXFP4]);
}
}
// Unified SoA quantize for byte-aligned (FP8) and 6-bit packed (FP6) MXFP formats.
// Unified SoA quantize/dequantize — delegates to shared per-block helpers.
static void quantize_row_mxfp_soa_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst,
int64_t k, const mxfp_elem_traits_t * traits) {
const int qk = 32;
assert(k % qk == 0);
const int nb = k / qk;
assert(k % 32 == 0);
const int nb = k / 32;
const int qpb = traits->qs_per_block;
char * qs_base = (char *)dst;
char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, qpb);
for (int i = 0; i < nb; i++) {
const uint8_t e = mxfp_compute_e8m0_mse(&x[i*qk], qk, traits);
const float d = GGML_E8M0_TO_FP32(e);
const float inv_d = d > 0.0f ? 1.0f / d : 0.0f;
e8m0_base[i] = (char)e;
uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, qpb));
if (traits->bits_per_elem == 8) {
for (int j = 0; j < qk; ++j) {
qs[j] = traits->to_elem(x[i*qk + j] * inv_d);
}
} else {
for (int j = 0; j < qk; j += 4) {
uint8_t vals[4];
for (int jj = 0; jj < 4; jj++) {
vals[jj] = traits->to_elem(x[i*qk + j + jj] * inv_d);
}
pack_fp6x4(vals, &qs[j * 3 / 4]);
}
}
quantize_block_mxfp(&x[i*32], qs, (uint8_t *)&e8m0_base[i], traits);
}
}
// Unified SoA dequantize for byte-aligned (FP8) and 6-bit packed (FP6) MXFP formats.
static void dequantize_row_mxfp_soa_impl(const void * GGML_RESTRICT src, float * GGML_RESTRICT y,
int64_t k, const mxfp_elem_traits_t * traits) {
const int qk = 32;
assert(k % qk == 0);
const int nb = k / qk;
assert(k % 32 == 0);
const int nb = k / 32;
const int qpb = traits->qs_per_block;
const char * qs_base = (const char *)src;
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, qpb);
for (int i = 0; i < nb; i++) {
const float d = GGML_E8M0_TO_FP32((uint8_t)e8m0_base[i]);
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, qpb));
if (traits->bits_per_elem == 8) {
for (int j = 0; j < qk; ++j) {
y[i*qk + j] = traits->to_float(qs[j]) * d;
}
} else {
for (int j = 0; j < qk; j += 4) {
uint8_t vals[4];
unpack_fp6x4(&qs[j * 3 / 4], vals);
for (int jj = 0; jj < 4; jj++) {
y[i*qk + j + jj] = traits->to_float(vals[jj]) * d;
}
}
}
dequantize_block_mxfp(qs, (uint8_t)e8m0_base[i], &y[i*32], traits);
}
}
@ -703,6 +710,83 @@ void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT
void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp_soa_impl(src, y, k, &mxfp6_e2m3_traits);
}
// Fused Hadamard + SoA quantize: one read, one write, 32-float stack buffer per block.
// Eliminates the full-row temp buffer and extra memory pass.
void quantize_row_mxfp4_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) {
assert(k % QK_MXFP4 == 0);
const int nb = k / QK_MXFP4;
char * qs_base = (char *)dst;
char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK);
for (int i = 0; i < nb; i++) {
float tmp[32];
memcpy(tmp, &x[i*QK_MXFP4], QK_MXFP4 * sizeof(float));
ggml_mxfp_hadamard_32_inplace(tmp);
uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP4_SOA_QS_PER_BLOCK));
quantize_block_mxfp4(tmp, qs, (uint8_t *)&e8m0_base[i]);
}
}
static void quantize_row_mxfp_soa_hadamard_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst,
int64_t k, const mxfp_elem_traits_t * traits) {
assert(k % 32 == 0);
const int nb = k / 32;
const int qpb = traits->qs_per_block;
char * qs_base = (char *)dst;
char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, qpb);
for (int i = 0; i < nb; i++) {
float tmp[32];
memcpy(tmp, &x[i*32], 32 * sizeof(float));
ggml_mxfp_hadamard_32_inplace(tmp);
uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, qpb));
quantize_block_mxfp(tmp, qs, (uint8_t *)&e8m0_base[i], traits);
}
}
void quantize_row_mxfp8_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) {
quantize_row_mxfp_soa_hadamard_impl(x, dst, k, &mxfp8_e4m3_traits);
}
void quantize_row_mxfp6_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) {
quantize_row_mxfp_soa_hadamard_impl(x, dst, k, &mxfp6_e2m3_traits);
}
// MXFP8/6 quantize round-trips (with and without Hadamard).
void mxfp8_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) {
assert(k % 32 == 0);
for (int64_t i = 0; i < k; i += 32) {
memcpy(dst + i, src + i, 32 * sizeof(float));
ggml_mxfp_hadamard_32_inplace(dst + i);
roundtrip_block_mxfp(dst + i, &mxfp8_e4m3_traits);
}
}
void mxfp6_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) {
assert(k % 32 == 0);
for (int64_t i = 0; i < k; i += 32) {
memcpy(dst + i, src + i, 32 * sizeof(float));
ggml_mxfp_hadamard_32_inplace(dst + i);
roundtrip_block_mxfp(dst + i, &mxfp6_e2m3_traits);
}
}
void mxfp8_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) {
assert(k % 32 == 0);
for (int64_t i = 0; i < k; i += 32) {
memcpy(dst + i, src + i, 32 * sizeof(float));
roundtrip_block_mxfp(dst + i, &mxfp8_e4m3_traits);
}
}
void mxfp6_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) {
assert(k % 32 == 0);
for (int64_t i = 0; i < k; i += 32) {
memcpy(dst + i, src + i, 32 * sizeof(float));
roundtrip_block_mxfp(dst + i, &mxfp6_e2m3_traits);
}
}
//
// 2-6 bit quantization in super-blocks
//
@ -2373,6 +2457,18 @@ size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
return nrow * ggml_row_size(GGML_TYPE_NVFP4, n_per_row);
}
size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
GGML_UNUSED(quant_weights);
quantize_row_mxfp8_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_MXFP8, n_per_row);
}
size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
GGML_UNUSED(quant_weights);
quantize_row_mxfp6_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_MXFP6, n_per_row);
}
// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {

View File

@ -22,6 +22,8 @@ GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 *
GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp6_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
@ -48,6 +50,8 @@ GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GG
//GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
// SoA quantize/dequantize for flash attention
GGML_API void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
@ -56,6 +60,18 @@ GGML_API void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_
GGML_API void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
GGML_API void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
// Fused Hadamard + SoA quantize (one pass, no temp buffer)
GGML_API void quantize_row_mxfp4_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
GGML_API void quantize_row_mxfp8_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
GGML_API void quantize_row_mxfp6_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
// Quantize round-trip: apply quantization error to floats without materializing bytes.
// Hadamard variants include the rotation. Used for Q preprocessing in flash attention.
GGML_API void mxfp4_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k);
GGML_API void mxfp8_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k);
GGML_API void mxfp6_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k);
GGML_API void mxfp4_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k);
GGML_API void mxfp8_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k);
GGML_API void mxfp6_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k);
GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
@ -103,6 +119,8 @@ GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTR
GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
// MXFP element converters
GGML_API float fp8_e4m3_to_float(uint8_t v);

View File

@ -727,16 +727,20 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
.from_float_ref = (ggml_from_float_t)quantize_row_nvfp4_ref,
},
[GGML_TYPE_MXFP8] = {
.type_name = "mxfp8_e4m3",
.type_name = "mxfp8",
.blck_size = QK_MXFP8,
.type_size = sizeof(block_mxfp8),
.is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_mxfp8,
.from_float_ref = (ggml_from_float_t)quantize_row_mxfp8_ref,
},
[GGML_TYPE_MXFP6] = {
.type_name = "mxfp6_e2m3",
.type_name = "mxfp6",
.blck_size = QK_MXFP6,
.type_size = sizeof(block_mxfp6),
.is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_mxfp6,
.from_float_ref = (ggml_from_float_t)quantize_row_mxfp6_ref,
},
[GGML_TYPE_Q2_K] = {
.type_name = "q2_K",
@ -7693,8 +7697,8 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_MXFP8: GGML_ABORT("MXFP8 is KV-cache-only (SoA layout) — use from_float_soa");
case GGML_TYPE_MXFP6: GGML_ABORT("MXFP6 is KV-cache-only (SoA layout) — use from_float_soa");
case GGML_TYPE_MXFP8: result = quantize_mxfp8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_MXFP6: result = quantize_mxfp6(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;

View File

@ -1121,8 +1121,9 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm
ggml_tensor * result = ggml_set_rows(ctx, k_dst, k_cur, k_idxs);
// enable Hadamard rotation for MXFP K cache (QuaRot arXiv:2404.00456, BRQ arXiv:2511.04214)
// skipped for MLA (V is a view of K) and E5M2/E3M2 (2-bit mantissa, no benefit)
if (is_mxfp && !hparams.is_mla() && ggml_mxfp_use_hadamard(k->type)) {
// skipped when DK != DV (MLA) and for E5M2/E3M2 (2-bit mantissa, no benefit).
// condition must match flash attention read path (ops.cpp: DK == DV).
if (is_mxfp && hparams.n_embd_head_k(il) == hparams.n_embd_head_v(il) && ggml_mxfp_use_hadamard(k->type)) {
((int32_t *)result->op_params)[0] = 1;
}

View File

@ -6390,7 +6390,7 @@ struct test_flash_attn_ext : public test_case {
init_tensor_uniform(t, -10.0f, 10.0f);
} else if (strcmp(t->name, "m") == 0) {
init_tensor_kq_mask(t);
} else if ((strcmp(t->name, "k") == 0 || strcmp(t->name, "v") == 0) && ggml_is_type_mxfp(t->type)) {
} else if (ggml_is_type_mxfp(t->type)) {
init_tensor_mxfp_soa(t);
} else {
init_tensor_uniform(t);
@ -7398,7 +7398,7 @@ static const ggml_type all_types[] = {
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
GGML_TYPE_Q8_0,
GGML_TYPE_MXFP4,
GGML_TYPE_MXFP4, GGML_TYPE_MXFP8, GGML_TYPE_MXFP6,
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
GGML_TYPE_Q6_K,
@ -7533,11 +7533,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
// SET_ROWS with Hadamard rotation (exercises the op_params[0] flag used by MXFP KV cache)
for (ggml_type type : {GGML_TYPE_MXFP4, GGML_TYPE_MXFP8,
GGML_TYPE_MXFP6}) {
for (ggml_type type : {GGML_TYPE_MXFP4, GGML_TYPE_MXFP8, GGML_TYPE_MXFP6}) {
// ne[0] must be divisible by 32 (Hadamard block size)
test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 128, 5, 1, 1 }, { 1, 1 }, 1, false, true));
test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 5, 1, 3 }, { 1, 1 }, 1, false, true));
// multi-row, broadcast, views
test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 128, 5, 1, 1 }, { 1, 1 }, 1, true, true));
test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 11, 1, 1 }, { 2, 3 }, 7, false, true));
test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 512, 5, 3, 1 }, { 1, 1 }, 1, false, true));
}
for (int mode : { GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_VISION }) {

View File

@ -33,7 +33,7 @@ constexpr float MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 = 0.0020f;
// These represent actual RMSE through the full KV cache write/read path.
constexpr float MAX_MXFP_PIPELINE_ERROR_MXFP4 = 0.40f;
constexpr float MAX_MXFP_PIPELINE_ERROR_MXFP8 = 0.08f;
constexpr float MAX_MXFP_PIPELINE_ERROR_MXFP6 = 0.10f;
constexpr float MAX_MXFP_PIPELINE_ERROR_MXFP6 = 0.30f;
constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f;
constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f;