cleanup : remove unused untested code and improve consistency

* cleanup: consolidate MXFP type aliases, fix SoA linker bug on 5 platforms

- Add GGML_TYPE_MXFP8 and GGML_TYPE_MXFP6 short aliases (matching
  existing GGML_TYPE_MXFP4 pattern) and use short names consistently
  throughout the codebase instead of mixing long/short forms.

- Fix missing SoA dequant symbols (dequantize_row_mxfp{4,8,6}_soa_cpu)
  on loongarch, powerpc, riscv, s390, and wasm by adding proper aliases
  to each arch section in arch-fallback.h. Previously these were only
  defined under GGML_CPU_GENERIC, causing linker failures on those
  platforms when using MXFP flash attention.

- Remove 10 files from the PR diff:
  - 5 arch stub files replaced by arch-fallback.h aliases
  - 5 rename-only files (sycl, opencl, repack, llama-quant) reverted
    since the GGML_TYPE_MXFP4 compat alias handles them

* cleanup: DRY FP6 unpack, extract mxfp_kv_params + mxfp_dequant_head helper

- FP6 unpack: x86 and ARM SIMD versions now call ggml_mxfp_unpack_fp6x4()
  from ggml-common.h instead of duplicating the scalar bit manipulation.

- Extract mxfp_kv_params sub-struct from mxfp_fa_params: the 7 symmetric
  K/V fields (dequantize, multihead, soa_elems, qs_per_block,
  head_qs_bytes, head_e8m0_offset, blocks_per_head) are now in a reusable
  struct accessed as mxfp.k and mxfp.v.

- Add mxfp_dequant_head() helper: replaces 4 instances of the multihead
  SoA extraction pattern (2x memcpy + dequant, with multihead/single-head
  branching) with a single function call. Future backends get the pattern
  for free.

* cleanup: extract mxfp_kv_params_init to DRY the K/V init blocks

The K and V initialization in mxfp_fa_params_init were structurally
identical 10-line blocks differing only by tensor/dimension. Extract
into mxfp_kv_params_init(type, D, nb2, ne2) so future MXFP formats
get the multihead SoA addressing logic automatically.

* cleanup: generic MSE round-trip, replace magic buffer sizes with constants

- Remove mse_error_fp8_e4m3 and mse_error_fp6_e2m3: these were identical
  round-trip functions differing only by converter. mxfp_compute_e8m0_mse
  now uses to_elem/to_float directly when mse_error is NULL (FP8/FP6).
  MXFP4 keeps its custom decision-tree MSE. New formats get MSE for free
  by just setting to_elem/to_float in their traits.

- Replace magic 1024/1088 buffer sizes in flash attention with named
  constants MXFP_FA_MAX_D and MXFP_FA_SOA_BUF. One place to change if
  max head dimension grows.

* cleanup: remove dead AoS vec_dot for MXFP8/MXFP6, unify SoA impls

MXFP8 and MXFP6 are KV-cache-only types that use SoA layout for flash
attention. The AoS vec_dot functions (scalar generic, AVX2, NEON) were
dead code — no matmul path uses them.

Removed:
- ggml_vec_dot_mxfp{8,6}_q8_0 from scalar, x86, ARM, quants.h
- ggml_vec_dot_mxfp_q8_0_impl shared helper
- arch-fallback.h aliases for vec_dot mxfp8/mxfp6 (12 lines)
- vec_dot/vec_dot_type registration in ggml-cpu.c

Also unified SoA quantize/dequant: the separate mxfp8_soa_impl and
mxfp6_soa_impl functions (4 functions, ~80 lines) are replaced by two
generic functions (quantize_row_mxfp_soa_impl, dequantize_row_mxfp_soa_impl)
that use traits->bits_per_elem and traits->qs_per_block to handle both
byte-aligned (FP8) and 6-bit packed (FP6) formats. New MXFP formats
get SoA for free by setting these trait fields.

* cleanup: remove all AoS MXFP8/MXFP6 quantize/dequant — SoA only

MXFP8 and MXFP6 are KV-cache-only types. All quantization and
dequantization goes through the SoA (Struct-of-Arrays) path for flash
attention. The AoS (block_mxfp8/block_mxfp6 struct) implementations
were dead code that should never have been added.

Removed:
- quantize_row_mxfp{8,6}_impl, dequantize_row_mxfp{8,6}_impl
- quantize_row_mxfp{8,6}_ref, dequantize_row_mxfp{8,6}
- quantize_mxfp{8,6} (ggml_quantize_chunk wrappers)
- All declarations from ggml-quants.h and quants.h
- to_float/from_float_ref registrations from ggml.c type traits
- from_float registration from ggml-cpu.c CPU traits

Block struct definitions (block_mxfp8, block_mxfp6) are retained for
sizeof() in type traits and validate_row_data.

* cleanup: fail fast in ggml_quantize_chunk for KV-cache-only types

Add explicit GGML_ABORT for MXFP8/MXFP6 in ggml_quantize_chunk —
these are KV-cache-only types that use SoA layout via from_float_soa.
Attempting AoS quantization through this entry point is a bug.
This commit is contained in:
Tim Burke 2026-03-22 02:44:56 -04:00 committed by GitHub
parent 0e3304fbca
commit c919bc471b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 277 additions and 769 deletions

View File

@ -398,20 +398,20 @@ const std::vector<ggml_type> kv_cache_types = {
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL,
GGML_TYPE_Q5_0, GGML_TYPE_Q5_0,
GGML_TYPE_Q5_1, GGML_TYPE_Q5_1,
GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP4,
GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP8,
GGML_TYPE_MXFP6_E2M3, GGML_TYPE_MXFP6,
}; };
static ggml_type kv_cache_type_from_str(const std::string & s) { static ggml_type kv_cache_type_from_str(const std::string & s) {
if (s == "mxfp4") { if (s == "mxfp4") {
return GGML_TYPE_MXFP4_E2M1; return GGML_TYPE_MXFP4;
} }
if (s == "mxfp6") { if (s == "mxfp6") {
return GGML_TYPE_MXFP6_E2M3; return GGML_TYPE_MXFP6;
} }
if (s == "mxfp8") { if (s == "mxfp8") {
return GGML_TYPE_MXFP8_E4M3; return GGML_TYPE_MXFP8;
} }
for (const auto & type : kv_cache_types) { for (const auto & type : kv_cache_types) {
if (ggml_type_name(type) == s) { if (ggml_type_name(type) == s) {

View File

@ -430,7 +430,9 @@ extern "C" {
GGML_TYPE_MXFP4 = GGML_TYPE_MXFP4_E2M1, // compat alias GGML_TYPE_MXFP4 = GGML_TYPE_MXFP4_E2M1, // compat alias
GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale) GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale)
GGML_TYPE_MXFP8_E4M3 = 41, // MX FP8 E4M3 GGML_TYPE_MXFP8_E4M3 = 41, // MX FP8 E4M3
GGML_TYPE_MXFP8 = GGML_TYPE_MXFP8_E4M3, // compat alias
GGML_TYPE_MXFP6_E2M3 = 42, // MX FP6 E2M3 GGML_TYPE_MXFP6_E2M3 = 42, // MX FP6 E2M3
GGML_TYPE_MXFP6 = GGML_TYPE_MXFP6_E2M3, // compat alias
GGML_TYPE_COUNT = 43, GGML_TYPE_COUNT = 43,
}; };

View File

@ -15,8 +15,9 @@
#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0 #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0 #define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0 #define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
@ -113,6 +114,9 @@
// quants.c // quants.c
#define quantize_row_q8_K_generic quantize_row_q8_K #define quantize_row_q8_K_generic quantize_row_q8_K
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
@ -161,6 +165,9 @@
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
// repack.cpp // repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@ -201,6 +208,9 @@
#elif defined(__riscv) #elif defined(__riscv)
// quants.c // quants.c
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
// repack.cpp // repack.cpp
#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 #define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
@ -241,6 +251,9 @@
// quants.c // quants.c
#define quantize_row_q8_K_generic quantize_row_q8_K #define quantize_row_q8_K_generic quantize_row_q8_K
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
@ -291,6 +304,9 @@
#elif defined(__wasm__) #elif defined(__wasm__)
// quants.c // quants.c
#define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1 #define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K #define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K
@ -342,10 +358,3 @@
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
#endif #endif
// MXFP dequantize fallbacks (same GGML_CPU_GENERIC guard as above)
#if defined(GGML_CPU_GENERIC)
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
#endif

View File

@ -4191,12 +4191,8 @@ static inline float32x4_t mxfp6_dequant_neon(
// Unpack 4 tightly-packed 6-bit values from 3 bytes, widen to uint32x4_t. // Unpack 4 tightly-packed 6-bit values from 3 bytes, widen to uint32x4_t.
static inline uint32x4_t unpack_fp6x4_neon(const uint8_t * p) { static inline uint32x4_t unpack_fp6x4_neon(const uint8_t * p) {
const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
uint8_t u[4]; uint8_t u[4];
u[0] = (pk >> 0) & 0x3F; ggml_mxfp_unpack_fp6x4(p, u);
u[1] = (pk >> 6) & 0x3F;
u[2] = (pk >> 12) & 0x3F;
u[3] = (pk >> 18) & 0x3F;
const uint8x8_t raw8 = vcreate_u8( const uint8x8_t raw8 = vcreate_u8(
(uint64_t)u[0] | ((uint64_t)u[1] << 8) | (uint64_t)u[0] | ((uint64_t)u[1] << 8) |
((uint64_t)u[2] << 16) | ((uint64_t)u[3] << 24)); ((uint64_t)u[2] << 16) | ((uint64_t)u[3] << 24));
@ -4221,96 +4217,6 @@ static inline void widen_s8x8_to_f32x4x2(const int8_t * src,
*hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16))); *hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16)));
} }
// MXFP FP8/FP6 vec_dot
static void ggml_vec_dot_mxfp8_q8_0_neon(
int n, float * GGML_RESTRICT s,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
const mxfp_neon_traits_t * t) {
assert(n % QK_MXFP8 == 0);
const int nb = n / QK_MXFP8;
const block_mxfp8 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;
const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale);
const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift);
const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift);
float32x4_t acc0 = vdupq_n_f32(0.0f);
float32x4_t acc1 = vdupq_n_f32(0.0f);
for (int ib = 0; ib < nb; ++ib) {
const float32x4_t v_scale = vdupq_n_f32(
GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d));
for (int j = 0; j < 32; j += 8) {
uint32x4_t v_lo, v_hi;
widen_u8x8_to_u32x4x2(x[ib].qs + j, &v_lo, &v_hi);
float32x4_t qf_lo, qf_hi;
widen_s8x8_to_f32x4x2(y[ib].qs + j, &qf_lo, &qf_hi);
const float32x4_t val_lo = mxfp8_dequant_neon(v_lo,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
const float32x4_t val_hi = mxfp8_dequant_neon(v_hi,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
acc0 = vfmaq_f32(acc0, vmulq_f32(val_lo, v_scale), qf_lo);
acc1 = vfmaq_f32(acc1, vmulq_f32(val_hi, v_scale), qf_hi);
}
}
*s = vaddvq_f32(vaddq_f32(acc0, acc1));
}
static void ggml_vec_dot_mxfp6_q8_0_neon(
int n, float * GGML_RESTRICT s,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
const mxfp_neon_traits_t * t) {
assert(n % QK_MXFP6 == 0);
const int nb = n / QK_MXFP6;
const block_q8_0 * GGML_RESTRICT y = vy;
const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale);
const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift);
const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift);
float32x4_t acc0 = vdupq_n_f32(0.0f);
float32x4_t acc1 = vdupq_n_f32(0.0f);
for (int ib = 0; ib < nb; ++ib) {
const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib;
const float32x4_t v_scale = vdupq_n_f32(
GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d));
for (int j = 0; j < 32; j += 8) {
const uint32x4_t v_lo = unpack_fp6x4_neon(xb->qs + (j * 3 / 4));
const uint32x4_t v_hi = unpack_fp6x4_neon(xb->qs + ((j + 4) * 3 / 4));
float32x4_t qf_lo, qf_hi;
widen_s8x8_to_f32x4x2(y[ib].qs + j, &qf_lo, &qf_hi);
const float32x4_t val_lo = mxfp6_dequant_neon(v_lo,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
const float32x4_t val_hi = mxfp6_dequant_neon(v_hi,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
acc0 = vfmaq_f32(acc0, vmulq_f32(val_lo, v_scale), qf_lo);
acc1 = vfmaq_f32(acc1, vmulq_f32(val_hi, v_scale), qf_hi);
}
}
*s = vaddvq_f32(vaddq_f32(acc0, acc1));
}
// MXFP SoA dequant (flash attention) // MXFP SoA dequant (flash attention)
static void dequantize_row_mxfp8_soa_neon( static void dequantize_row_mxfp8_soa_neon(
@ -4424,26 +4330,6 @@ static void dequantize_row_mxfp4_soa_neon(
// Public dispatch functions // Public dispatch functions
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
#if defined(__ARM_NEON)
ggml_vec_dot_mxfp8_q8_0_neon(n, s, vx, vy, &MXFP_TRAITS_E4M3);
#else
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
#if defined(__ARM_NEON)
ggml_vec_dot_mxfp6_q8_0_neon(n, s, vx, vy, &MXFP_TRAITS_E2M3);
#else
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__ARM_NEON) #if defined(__ARM_NEON)
dequantize_row_mxfp4_soa_neon(x, y, k); dequantize_row_mxfp4_soa_neon(x, y, k);

View File

@ -2157,14 +2157,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
#endif #endif
} }
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}

View File

@ -2303,10 +2303,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
#endif #endif
} }
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}

View File

@ -3621,11 +3621,3 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif #endif
} }
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}

View File

@ -1464,10 +1464,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
#endif #endif
} }
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}

View File

@ -1219,14 +1219,3 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
#endif #endif
} }
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
}

View File

@ -3850,106 +3850,14 @@ static inline __m256 mxfp_dequant_avx2(
return _mm256_blendv_ps(normal, sub_val, is_sub); return _mm256_blendv_ps(normal, sub_val, is_sub);
} }
// Unpack 4 tightly-packed 6-bit values from 3 bytes into separate bytes.
static inline void unpack_fp6x4_avx2(const uint8_t * p, uint8_t out[4]) {
const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
out[0] = (pk >> 0) & 0x3F;
out[1] = (pk >> 6) & 0x3F;
out[2] = (pk >> 12) & 0x3F;
out[3] = (pk >> 18) & 0x3F;
}
// Unpack 8 FP6 values (two groups of 4) from packed qs data at offset j. // Unpack 8 FP6 values (two groups of 4) from packed qs data at offset j.
static inline __m256i unpack_fp6x8_avx2(const uint8_t * qs, int j) { static inline __m256i unpack_fp6x8_avx2(const uint8_t * qs, int j) {
uint8_t unpacked[8]; uint8_t unpacked[8];
unpack_fp6x4_avx2(qs + (j * 3 / 4), unpacked); ggml_mxfp_unpack_fp6x4(qs + (j * 3 / 4), unpacked);
unpack_fp6x4_avx2(qs + ((j + 4) * 3 / 4), unpacked + 4); ggml_mxfp_unpack_fp6x4(qs + ((j + 4) * 3 / 4), unpacked + 4);
return _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)unpacked)); return _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)unpacked));
} }
// MXFP FP8/FP6 vec_dot
// FP8 x Q8_0 dot product (E4M3/E5M2).
static void ggml_vec_dot_mxfp8_q8_0_avx2(
int n, float * GGML_RESTRICT s,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
const mxfp_avx2_traits_t * t) {
assert(n % QK_MXFP8 == 0);
const int nb = n / QK_MXFP8;
const block_mxfp8 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;
const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale);
const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask);
const __m256i v_zero = _mm256_setzero_si256();
__m256 acc = _mm256_setzero_ps();
for (int ib = 0; ib < nb; ++ib) {
const __m256 v_scale = _mm256_set1_ps(
GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d));
for (int j = 0; j < 32; j += 8) {
const __m256i v_raw = _mm256_cvtepu8_epi32(
_mm_loadl_epi64((const __m128i *)(x[ib].qs + j)));
const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
_mm_loadl_epi64((const __m128i *)(y[ib].qs + j))));
const __m256 val = mxfp_dequant_avx2(v_raw,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc,
v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift);
acc = _mm256_fmadd_ps(_mm256_mul_ps(val, v_scale), qf, acc);
}
}
*s = hsum_float_8(acc);
}
// FP6 x Q8_0 dot product (E2M3/E3M2).
static void ggml_vec_dot_mxfp6_q8_0_avx2(
int n, float * GGML_RESTRICT s,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
const mxfp_avx2_traits_t * t) {
assert(n % QK_MXFP6 == 0);
const int nb = n / QK_MXFP6;
const block_q8_0 * GGML_RESTRICT y = vy;
const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale);
const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask);
const __m256i v_zero = _mm256_setzero_si256();
__m256 acc = _mm256_setzero_ps();
for (int ib = 0; ib < nb; ++ib) {
const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib;
const __m256 v_scale = _mm256_set1_ps(
GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d));
for (int j = 0; j < 32; j += 8) {
const __m256i v_raw = unpack_fp6x8_avx2(xb->qs, j);
const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
_mm_loadl_epi64((const __m128i *)(y[ib].qs + j))));
const __m256 val = mxfp_dequant_avx2(v_raw,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc,
v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift);
acc = _mm256_fmadd_ps(_mm256_mul_ps(val, v_scale), qf, acc);
}
}
*s = hsum_float_8(acc);
}
// MXFP SoA dequant (flash attention) // MXFP SoA dequant (flash attention)
static void dequantize_row_mxfp8_soa_avx2( static void dequantize_row_mxfp8_soa_avx2(
@ -4052,26 +3960,6 @@ static void dequantize_row_mxfp4_soa_avx2(
// Public dispatch functions // Public dispatch functions
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
#if defined(__AVX2__)
ggml_vec_dot_mxfp8_q8_0_avx2(n, s, vx, vy, &MXFP_TRAITS_E4M3);
#else
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
#if defined(__AVX2__)
ggml_vec_dot_mxfp6_q8_0_avx2(n, s, vx, vy, &MXFP_TRAITS_E2M3);
#else
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__AVX2__) #if defined(__AVX2__)
dequantize_row_mxfp4_soa_avx2(x, y, k); dequantize_row_mxfp4_soa_avx2(x, y, k);

View File

@ -265,7 +265,7 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_1, .vec_dot_type = GGML_TYPE_Q8_1,
.nrows = 1, .nrows = 1,
}, },
[GGML_TYPE_MXFP4_E2M1] = { [GGML_TYPE_MXFP4] = {
.from_float = quantize_row_mxfp4, .from_float = quantize_row_mxfp4,
.from_float_soa = quantize_row_mxfp4_soa, .from_float_soa = quantize_row_mxfp4_soa,
.to_float_soa = dequantize_row_mxfp4_soa_cpu, .to_float_soa = dequantize_row_mxfp4_soa_cpu,
@ -279,20 +279,14 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1, .nrows = 1,
}, },
[GGML_TYPE_MXFP8_E4M3] = { [GGML_TYPE_MXFP8] = {
.from_float = quantize_row_mxfp8,
.from_float_soa = quantize_row_mxfp8_soa, .from_float_soa = quantize_row_mxfp8_soa,
.to_float_soa = dequantize_row_mxfp8_soa_cpu, .to_float_soa = dequantize_row_mxfp8_soa_cpu,
.vec_dot = ggml_vec_dot_mxfp8_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1, .nrows = 1,
}, },
[GGML_TYPE_MXFP6_E2M3] = { [GGML_TYPE_MXFP6] = {
.from_float = quantize_row_mxfp6,
.from_float_soa = quantize_row_mxfp6_soa, .from_float_soa = quantize_row_mxfp6_soa,
.to_float_soa = dequantize_row_mxfp6_soa_cpu, .to_float_soa = dequantize_row_mxfp6_soa_cpu,
.vec_dot = ggml_vec_dot_mxfp6_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1, .nrows = 1,
}, },
[GGML_TYPE_Q2_K] = { [GGML_TYPE_Q2_K] = {

View File

@ -672,10 +672,10 @@ void ggml_compute_forward_add(
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4_E2M1: case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4: case GGML_TYPE_NVFP4:
case GGML_TYPE_MXFP8_E4M3: case GGML_TYPE_MXFP8:
case GGML_TYPE_MXFP6_E2M3: case GGML_TYPE_MXFP6:
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K: case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K: case GGML_TYPE_Q4_K:
@ -1124,10 +1124,10 @@ void ggml_compute_forward_add1(
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1: case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4_E2M1: case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4: case GGML_TYPE_NVFP4:
case GGML_TYPE_MXFP8_E4M3: case GGML_TYPE_MXFP8:
case GGML_TYPE_MXFP6_E2M3: case GGML_TYPE_MXFP6:
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K: case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K: case GGML_TYPE_Q4_K:
@ -1255,10 +1255,10 @@ void ggml_compute_forward_acc(
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1: case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4_E2M1: case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4: case GGML_TYPE_NVFP4:
case GGML_TYPE_MXFP8_E4M3: case GGML_TYPE_MXFP8:
case GGML_TYPE_MXFP6_E2M3: case GGML_TYPE_MXFP6:
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K: case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K: case GGML_TYPE_Q4_K:
@ -4345,10 +4345,10 @@ void ggml_compute_forward_out_prod(
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4_E2M1: case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4: case GGML_TYPE_NVFP4:
case GGML_TYPE_MXFP8_E4M3: case GGML_TYPE_MXFP8:
case GGML_TYPE_MXFP6_E2M3: case GGML_TYPE_MXFP6:
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K: case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K: case GGML_TYPE_Q4_K:
@ -4623,10 +4623,10 @@ void ggml_compute_forward_set(
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1: case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4_E2M1: case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4: case GGML_TYPE_NVFP4:
case GGML_TYPE_MXFP8_E4M3: case GGML_TYPE_MXFP8:
case GGML_TYPE_MXFP6_E2M3: case GGML_TYPE_MXFP6:
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K: case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K: case GGML_TYPE_Q4_K:
@ -4848,10 +4848,10 @@ void ggml_compute_forward_get_rows(
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1: case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4_E2M1: case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4: case GGML_TYPE_NVFP4:
case GGML_TYPE_MXFP8_E4M3: case GGML_TYPE_MXFP8:
case GGML_TYPE_MXFP6_E2M3: case GGML_TYPE_MXFP6:
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K: case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K: case GGML_TYPE_Q4_K:
@ -5686,10 +5686,10 @@ void ggml_compute_forward_clamp(
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1: case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4_E2M1: case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4: case GGML_TYPE_NVFP4:
case GGML_TYPE_MXFP8_E4M3: case GGML_TYPE_MXFP8:
case GGML_TYPE_MXFP6_E2M3: case GGML_TYPE_MXFP6:
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K: case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K: case GGML_TYPE_Q4_K:
@ -8255,31 +8255,65 @@ void ggml_compute_forward_top_k(
} }
} }
// Max head dimension for stack-allocated MXFP buffers.
static constexpr int64_t MXFP_FA_MAX_D = 1024;
// SoA buffer size for MXFP_FA_MAX_D with MXFP8 (worst case: 1024 + 32 e8m0 = 1056, rounded up).
static constexpr int MXFP_FA_SOA_BUF = 1088;
// SoA function pointer types for MXFP flash attention paths. // SoA function pointer types for MXFP flash attention paths.
typedef void (*mxfp_soa_quantize_fn)(const float *, void *, int64_t); typedef void (*mxfp_soa_quantize_fn)(const float *, void *, int64_t);
typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t); typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t);
// Per-KV-type MXFP parameters (shared between K and V).
struct mxfp_kv_params {
mxfp_soa_dequantize_fn dequantize;
bool multihead;
int64_t soa_elems;
int qs_per_block;
int head_qs_bytes;
int64_t head_e8m0_offset;
int blocks_per_head;
};
// MXFP dispatch parameters for flash attention. // MXFP dispatch parameters for flash attention.
struct mxfp_fa_params { struct mxfp_fa_params {
mxfp_soa_quantize_fn q_quantize; mxfp_soa_quantize_fn q_quantize;
mxfp_soa_dequantize_fn k_dequantize; mxfp_kv_params k;
mxfp_soa_dequantize_fn v_dequantize; mxfp_kv_params v;
bool k_multihead; bool apply_hadamard;
bool v_multihead;
int64_t k_soa_elems;
int64_t v_soa_elems;
bool apply_hadamard;
// Per-head SoA addressing (avoids dequanting all heads in multihead mode).
int k_qs_per_block;
int v_qs_per_block;
int k_head_qs_bytes;
int v_head_qs_bytes;
int64_t k_head_e8m0_offset;
int64_t v_head_e8m0_offset;
int k_blocks_per_head;
int v_blocks_per_head;
}; };
// Extract one head's SoA data from a multihead row and dequantize.
static inline void mxfp_dequant_head(
const mxfp_kv_params & kv, const char * row, int head_idx,
char * soa_buf, float * out, int64_t D) {
if (kv.multihead) {
const int qs_off = head_idx * kv.head_qs_bytes;
const int e8m0_off = (int)kv.head_e8m0_offset + head_idx * kv.blocks_per_head;
memcpy(soa_buf, row + qs_off, kv.head_qs_bytes);
memcpy(soa_buf + kv.head_qs_bytes, row + e8m0_off, kv.blocks_per_head);
kv.dequantize(soa_buf, out, D);
} else {
kv.dequantize(row, out, D);
}
}
// Initialize per-KV-type params from tensor metadata.
// Multihead detection: nb2 == row_size(D) means heads are contiguous within
// one KV-position stride, so SoA spans all heads. Otherwise SoA is per-head.
static mxfp_kv_params mxfp_kv_params_init(ggml_type type, int64_t D, size_t nb2, int64_t ne2) {
mxfp_kv_params kv = {};
kv.dequantize = ggml_get_type_traits_cpu(type)->to_float_soa;
kv.multihead = (nb2 == (size_t)ggml_row_size(type, D));
kv.soa_elems = kv.multihead ? ne2 * D : D;
kv.qs_per_block = ggml_mxfp_qs_per_block(type);
kv.blocks_per_head = (int)(D / 32);
kv.head_qs_bytes = kv.blocks_per_head * kv.qs_per_block;
const int64_t total_blocks = kv.multihead ? ne2 * kv.blocks_per_head : kv.blocks_per_head;
kv.head_e8m0_offset = total_blocks * kv.qs_per_block;
return kv;
}
static mxfp_fa_params mxfp_fa_params_init( static mxfp_fa_params mxfp_fa_params_init(
const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * k, const ggml_tensor * v,
int64_t DK, int64_t DV, int64_t DK, int64_t DV,
@ -8291,44 +8325,17 @@ static mxfp_fa_params mxfp_fa_params_init(
const bool is_mxfp_v = ggml_is_type_mxfp(v->type); const bool is_mxfp_v = ggml_is_type_mxfp(v->type);
if (is_mxfp_k) { if (is_mxfp_k) {
const struct ggml_type_traits_cpu * k_traits = ggml_get_type_traits_cpu(k->type); p.q_quantize = ggml_get_type_traits_cpu(k->type)->from_float_soa;
p.q_quantize = k_traits->from_float_soa; p.k = mxfp_kv_params_init(k->type, DK, nbk2, nek2);
p.k_dequantize = k_traits->to_float_soa;
} }
if (is_mxfp_v) { if (is_mxfp_v) {
p.v_dequantize = ggml_get_type_traits_cpu(v->type)->to_float_soa; p.v = mxfp_kv_params_init(v->type, DV, nbv2, nev2);
} }
// Hadamard rotation must match K rotation. // Hadamard rotation must match K rotation.
// Skipped for: MLA (DK != DV, V is a view of K). // Skipped for MLA (DK != DV, V is a view of K).
p.apply_hadamard = is_mxfp_k && (DK == DV) && ggml_mxfp_use_hadamard(k->type); p.apply_hadamard = is_mxfp_k && (DK == DV) && ggml_mxfp_use_hadamard(k->type);
// SoA layout detection: in the real KV cache, heads are contiguous within
// one KV-position stride (nb[2] == row_size(DK)), so SoA spans all heads.
// In test tensors, heads may be at distant offsets (nb[2] >> row_size(DK)),
// so SoA is per-head. Detect which case and set dequant parameters accordingly.
p.k_multihead = is_mxfp_k && (nbk2 == (size_t)ggml_row_size(k->type, DK));
p.k_soa_elems = is_mxfp_k ? (p.k_multihead ? nek2 * DK : DK) : 0;
p.v_multihead = is_mxfp_v && (nbv2 == (size_t)ggml_row_size(v->type, DV));
p.v_soa_elems = is_mxfp_v ? (p.v_multihead ? nev2 * DV : DV) : 0;
if (is_mxfp_k) {
p.k_qs_per_block = ggml_mxfp_qs_per_block(k->type);
p.k_blocks_per_head = (int)(DK / 32);
p.k_head_qs_bytes = p.k_blocks_per_head * p.k_qs_per_block;
const int64_t k_total_blocks = p.k_multihead ? nek2 * p.k_blocks_per_head : p.k_blocks_per_head;
p.k_head_e8m0_offset = k_total_blocks * p.k_qs_per_block;
}
if (is_mxfp_v) {
p.v_qs_per_block = ggml_mxfp_qs_per_block(v->type);
p.v_blocks_per_head = (int)(DV / 32);
p.v_head_qs_bytes = p.v_blocks_per_head * p.v_qs_per_block;
const int64_t v_total_blocks = p.v_multihead ? nev2 * p.v_blocks_per_head : p.v_blocks_per_head;
p.v_head_e8m0_offset = v_total_blocks * p.v_qs_per_block;
}
return p; return p;
} }
@ -8430,14 +8437,14 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
int ith = params->ith; int ith = params->ith;
if (is_mxfp_k) { GGML_ASSERT(DK <= 1024); } if (is_mxfp_k) { GGML_ASSERT(DK <= MXFP_FA_MAX_D); }
if (is_mxfp_v) { GGML_ASSERT(DV <= 1024); } if (is_mxfp_v) { GGML_ASSERT(DV <= MXFP_FA_MAX_D); }
float k_dequant_buf[1024]; float k_dequant_buf[MXFP_FA_MAX_D];
float v_dequant_buf[1024]; float v_dequant_buf[MXFP_FA_MAX_D];
char k_head_soa[1088]; // max: DK=1024 MXFP8 -> 1056 bytes, rounded up char k_head_soa[MXFP_FA_SOA_BUF]; // max: DK=1024 MXFP8 -> 1056 bytes, rounded up
char v_head_soa[1088]; char v_head_soa[MXFP_FA_SOA_BUF];
float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32);
float * V32 = (VKQ32 + 1*DV); float * V32 = (VKQ32 + 1*DV);
@ -8479,31 +8486,25 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
const char * k_base = (const char *) k->data + k_base_offset; const char * k_base = (const char *) k->data + k_base_offset;
const char * v_base = (const char *) v->data + v_base_offset; const char * v_base = (const char *) v->data + v_base_offset;
// Per-head SoA byte offsets const char * k_row_base = mxfp.k.multihead ? ((const char *) k->data + ik3*nbk3) : nullptr;
const int k_head_qs_start = mxfp.k_multihead ? ik2 * mxfp.k_head_qs_bytes : 0; const char * v_row_base = mxfp.v.multihead ? ((const char *) v->data + iv3*nbv3) : nullptr;
const int k_head_e8m0_start = mxfp.k_multihead ? (int)mxfp.k_head_e8m0_offset + ik2 * mxfp.k_blocks_per_head : 0;
const int v_head_qs_start = mxfp.v_multihead ? iv2 * mxfp.v_head_qs_bytes : 0;
const int v_head_e8m0_start = mxfp.v_multihead ? (int)mxfp.v_head_e8m0_offset + iv2 * mxfp.v_blocks_per_head : 0;
const char * k_row_base = mxfp.k_multihead ? ((const char *) k->data + ik3*nbk3) : nullptr;
const char * v_row_base = mxfp.v_multihead ? ((const char *) v->data + iv3*nbv3) : nullptr;
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
float Q_f32[1024]; float Q_f32[MXFP_FA_MAX_D];
if (is_mxfp_k) { if (is_mxfp_k) {
// Q preprocessing: Hadamard + SoA round-trip captures same quantization loss as K. // Q preprocessing: Hadamard + SoA round-trip captures same quantization loss as K.
if (mxfp.apply_hadamard) { if (mxfp.apply_hadamard) {
float q_tmp[1024]; float q_tmp[MXFP_FA_MAX_D];
memcpy(q_tmp, pq, DK * sizeof(float)); memcpy(q_tmp, pq, DK * sizeof(float));
ggml_apply_hadamard_blocks(q_tmp, DK); ggml_apply_hadamard_blocks(q_tmp, DK);
mxfp.q_quantize(q_tmp, Q_q, DK); mxfp.q_quantize(q_tmp, Q_q, DK);
} else { } else {
mxfp.q_quantize(pq, Q_q, DK); mxfp.q_quantize(pq, Q_q, DK);
} }
mxfp.k_dequantize(Q_q, Q_f32, DK); mxfp.k.dequantize(Q_q, Q_f32, DK);
} else { } else {
if (mxfp.apply_hadamard) { if (mxfp.apply_hadamard) {
float q_tmp[1024]; float q_tmp[MXFP_FA_MAX_D];
memcpy(q_tmp, pq, DK * sizeof(float)); memcpy(q_tmp, pq, DK * sizeof(float));
ggml_apply_hadamard_blocks(q_tmp, DK); ggml_apply_hadamard_blocks(q_tmp, DK);
q_to_vec_dot(q_tmp, Q_q, DK); q_to_vec_dot(q_tmp, Q_q, DK);
@ -8525,15 +8526,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
float s; // KQ value float s; // KQ value
if (is_mxfp_k) { if (is_mxfp_k) {
if (mxfp.k_multihead) { const char * k_row = mxfp.k.multihead ? k_row_base + ic*nbk1 : k_base + ic*nbk1;
// Extract this head's SoA blocks mxfp_dequant_head(mxfp.k, k_row, ik2, k_head_soa, k_dequant_buf, DK);
const char * row = k_row_base + ic*nbk1;
memcpy(k_head_soa, row + k_head_qs_start, mxfp.k_head_qs_bytes);
memcpy(k_head_soa + mxfp.k_head_qs_bytes, row + k_head_e8m0_start, mxfp.k_blocks_per_head);
mxfp.k_dequantize(k_head_soa, k_dequant_buf, DK);
} else {
mxfp.k_dequantize(k_base + ic*nbk1, k_dequant_buf, DK);
}
ggml_vec_dot_f32(DK, &s, 0, k_dequant_buf, 0, Q_f32, 0, 1); ggml_vec_dot_f32(DK, &s, 0, k_dequant_buf, 0, Q_f32, 0, 1);
} else { } else {
kq_vec_dot(DK, &s, 0, k_base + ic*nbk1, 0, Q_q, 0, 1); kq_vec_dot(DK, &s, 0, k_base + ic*nbk1, 0, Q_q, 0, 1);
@ -8577,15 +8571,9 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
} }
// V += v*expf(s - M) // V += v*expf(s - M)
if (mxfp.v_dequantize) { if (mxfp.v.dequantize) {
if (mxfp.v_multihead) { const char * v_row = mxfp.v.multihead ? v_row_base + ic*nbv1 : v_base + ic*nbv1;
const char * row = v_row_base + ic*nbv1; mxfp_dequant_head(mxfp.v, v_row, iv2, v_head_soa, v_dequant_buf, DV);
memcpy(v_head_soa, row + v_head_qs_start, mxfp.v_head_qs_bytes);
memcpy(v_head_soa + mxfp.v_head_qs_bytes, row + v_head_e8m0_start, mxfp.v_blocks_per_head);
mxfp.v_dequantize(v_head_soa, v_dequant_buf, DV);
} else {
mxfp.v_dequantize(v_base + ic*nbv1, v_dequant_buf, DV);
}
ggml_vec_mad_f32(DV, VKQ32, v_dequant_buf, vs); ggml_vec_mad_f32(DV, VKQ32, v_dequant_buf, vs);
} else if (v_to_float) { } else if (v_to_float) {
v_to_float(v_base + ic*nbv1, V32, DV); v_to_float(v_base + ic*nbv1, V32, DV);
@ -8731,14 +8719,14 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q; static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV; static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
if (is_mxfp_k) { GGML_ASSERT(DK <= 1024); } if (is_mxfp_k) { GGML_ASSERT(DK <= MXFP_FA_MAX_D); }
if (is_mxfp_v) { GGML_ASSERT(DV <= 1024); } if (is_mxfp_v) { GGML_ASSERT(DV <= MXFP_FA_MAX_D); }
float k_dequant_buf[1024]; float k_dequant_buf[MXFP_FA_MAX_D];
float v_dequant_buf[1024]; float v_dequant_buf[MXFP_FA_MAX_D];
char k_head_soa[1088]; char k_head_soa[MXFP_FA_SOA_BUF];
char v_head_soa[1088]; char v_head_soa[MXFP_FA_SOA_BUF];
int ir = ir0; int ir = ir0;
while (ir < ir1) { while (ir < ir1) {
@ -8802,9 +8790,9 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
if (mxfp.apply_hadamard) { if (mxfp.apply_hadamard) {
ggml_apply_hadamard_blocks(Q_f32 + tq * DK, DK); ggml_apply_hadamard_blocks(Q_f32 + tq * DK, DK);
} }
uint8_t q_mxfp_buf[1088]; // max: DK=1024 MXFP8 -> 1056 bytes uint8_t q_mxfp_buf[MXFP_FA_SOA_BUF];
mxfp.q_quantize(Q_f32 + tq * DK, q_mxfp_buf, DK); mxfp.q_quantize(Q_f32 + tq * DK, q_mxfp_buf, DK);
mxfp.k_dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK); mxfp.k.dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK);
} }
} }
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) { for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
@ -8854,23 +8842,13 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
for (int64_t dk = 0; dk < DK; dk++) { for (int64_t dk = 0; dk < DK; dk++) {
K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk]; K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
} }
} else if (mxfp.k_dequantize) { } else if (mxfp.k.dequantize) {
if (mxfp.k_multihead) { mxfp_dequant_head(mxfp.k, k_data, ik2, k_head_soa, k_dequant_buf, DK);
// Per-head extraction: copy only this head's SoA blocks + e8m0, dequant DK elements.
const char * row = (const char *)k->data + (ic + tk)*nbk1 + ik3*nbk3;
const int kqs = ik2 * mxfp.k_head_qs_bytes;
const int ke8 = (int)mxfp.k_head_e8m0_offset + ik2 * mxfp.k_blocks_per_head;
memcpy(k_head_soa, row + kqs, mxfp.k_head_qs_bytes);
memcpy(k_head_soa + mxfp.k_head_qs_bytes, row + ke8, mxfp.k_blocks_per_head);
mxfp.k_dequantize(k_head_soa, k_dequant_buf, DK);
} else {
mxfp.k_dequantize(k_data, k_dequant_buf, DK);
}
for (int64_t dk = 0; dk < DK; dk++) { for (int64_t dk = 0; dk < DK; dk++) {
K_f32[dk * KV_TILE_SZ + tk] = k_dequant_buf[dk]; K_f32[dk * KV_TILE_SZ + tk] = k_dequant_buf[dk];
} }
} else { } else {
float k_tmp[1024]; float k_tmp[MXFP_FA_MAX_D];
k_to_float(k_data, k_tmp, DK); k_to_float(k_data, k_tmp, DK);
for (int64_t dk = 0; dk < DK; dk++) { for (int64_t dk = 0; dk < DK; dk++) {
K_f32[dk * KV_TILE_SZ + tk] = k_tmp[dk]; K_f32[dk * KV_TILE_SZ + tk] = k_tmp[dk];
@ -8934,18 +8912,8 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV); ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
} else if (v_type == GGML_TYPE_F32) { } else if (v_type == GGML_TYPE_F32) {
memcpy(V32 + tk * DV, v_data, DV * sizeof(float)); memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
} else if (mxfp.v_dequantize) { } else if (mxfp.v.dequantize) {
if (mxfp.v_multihead) { mxfp_dequant_head(mxfp.v, v_data, iv2, v_head_soa, v_dequant_buf, DV);
// Per-head extraction: copy only this head's SoA blocks + e8m0, dequant DV elements.
const char * row = (const char *)v->data + (ic + tk)*nbv1 + iv3*nbv3;
const int vqs = iv2 * mxfp.v_head_qs_bytes;
const int ve8 = (int)mxfp.v_head_e8m0_offset + iv2 * mxfp.v_blocks_per_head;
memcpy(v_head_soa, row + vqs, mxfp.v_head_qs_bytes);
memcpy(v_head_soa + mxfp.v_head_qs_bytes, row + ve8, mxfp.v_blocks_per_head);
mxfp.v_dequantize(v_head_soa, v_dequant_buf, DV);
} else {
mxfp.v_dequantize(v_data, v_dequant_buf, DV);
}
memcpy(V32 + tk * DV, v_dequant_buf, DV * sizeof(float)); memcpy(V32 + tk * DV, v_dequant_buf, DV * sizeof(float));
} else { } else {
v_to_float(v_data, V32 + tk * DV, DV); v_to_float(v_data, V32 + tk * DV, DV);

View File

@ -54,14 +54,6 @@ void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, i
quantize_row_nvfp4_ref(x, y, k); quantize_row_nvfp4_ref(x, y, k);
} }
void quantize_row_mxfp8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
quantize_row_mxfp8_ref(x, y, k);
}
void quantize_row_mxfp6(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
quantize_row_mxfp6_ref(x, y, k);
}
// //
// 2-6 bit quantization in super-blocks // 2-6 bit quantization in super-blocks
// //
@ -264,51 +256,6 @@ void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
*s = sumf; *s = sumf;
} }
// Generic MXFP x Q8_0 dot product (scalar, not SIMD-optimized)
static void ggml_vec_dot_mxfp_q8_0_impl(
int n, float * GGML_RESTRICT s,
const void * GGML_RESTRICT vx, size_t block_size,
const void * GGML_RESTRICT vy,
ggml_to_float_t dequant) {
assert(n % QK8_0 == 0);
const int nb = n / QK8_0;
const block_q8_0 * GGML_RESTRICT y = vy;
float sumf = 0;
for (int ib = 0; ib < nb; ib++) {
float tmp[QK8_0];
dequant((const char *)vx + ib * block_size, tmp, QK8_0);
const float y_d = GGML_CPU_FP16_TO_FP32(y[ib].d);
float block_sum = 0;
for (int j = 0; j < QK8_0; j++) {
block_sum += tmp[j] * (float)y[ib].qs[j];
}
sumf += block_sum * y_d;
}
*s = sumf;
}
void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
ggml_vec_dot_mxfp_q8_0_impl(n, s, vx, sizeof(block_mxfp8), vy,
(ggml_to_float_t)dequantize_row_mxfp8);
}
void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
ggml_vec_dot_mxfp_q8_0_impl(n, s, vx, sizeof(block_mxfp6), vy,
(ggml_to_float_t)dequantize_row_mxfp6);
}
// Generic SoA dequant wrappers — arch-specific SIMD versions override via fallback.h. // Generic SoA dequant wrappers — arch-specific SIMD versions override via fallback.h.
void dequantize_row_mxfp4_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { void dequantize_row_mxfp4_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp4_soa(x, y, k); dequantize_row_mxfp4_soa(x, y, k);

View File

@ -21,9 +21,6 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_mxfp8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_mxfp6(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
@ -46,9 +43,6 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@ -80,9 +74,6 @@ void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c
void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
// SoA dequant (SIMD-dispatched, CPU backend) // SoA dequant (SIMD-dispatched, CPU backend)
void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);

View File

@ -3770,7 +3770,7 @@ static block_mxfp4x4 make_block_mxfp4x4(block_mxfp4 * in, unsigned int blck_size
} }
static int repack_mxfp4_to_mxfp4_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { static int repack_mxfp4_to_mxfp4_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_MXFP4_E2M1); GGML_ASSERT(t->type == GGML_TYPE_MXFP4);
GGML_ASSERT(interleave_block == 4); GGML_ASSERT(interleave_block == 4);
const block_mxfp4 * src = (const block_mxfp4 *)data; const block_mxfp4 * src = (const block_mxfp4 *)data;
@ -3827,7 +3827,7 @@ static block_mxfp4x8 make_block_mxfp4x8(block_mxfp4 * in, unsigned int blck_size
} }
static int repack_mxfp4_to_mxfp4_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { static int repack_mxfp4_to_mxfp4_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_MXFP4_E2M1); GGML_ASSERT(t->type == GGML_TYPE_MXFP4);
GGML_ASSERT(interleave_block == 8); GGML_ASSERT(interleave_block == 8);
const block_mxfp4 * src = (const block_mxfp4 *)data; const block_mxfp4 * src = (const block_mxfp4 *)data;
@ -4685,7 +4685,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
} }
#endif #endif
} }
} else if (cur->type == GGML_TYPE_MXFP4_E2M1) { } else if (cur->type == GGML_TYPE_MXFP4) {
if (ggml_cpu_has_avx2()) { if (ggml_cpu_has_avx2()) {
if (cur->ne[1] % 8 == 0) { if (cur->ne[1] % 8 == 0) {
return &mxfp4_8x8_q8_0; return &mxfp4_8x8_q8_0;

View File

@ -1014,7 +1014,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
// MXFP4: has AoS shaders (MUL_MAT, GET_ROWS) but no SoA/flash attention support yet. // MXFP4: has AoS shaders (MUL_MAT, GET_ROWS) but no SoA/flash attention support yet.
for (size_t i = 0, n = 3; i < n; ++i) { for (size_t i = 0, n = 3; i < n; ++i) {
if (op->src[i] != NULL && ggml_is_type_mxfp(op->src[i]->type)) { if (op->src[i] != NULL && ggml_is_type_mxfp(op->src[i]->type)) {
if (op->src[i]->type != GGML_TYPE_MXFP4_E2M1) { if (op->src[i]->type != GGML_TYPE_MXFP4) {
return false; return false;
} }
if (op->op == GGML_OP_FLASH_ATTN_EXT || op->op == GGML_OP_SET_ROWS) { if (op->op == GGML_OP_FLASH_ATTN_EXT || op->op == GGML_OP_SET_ROWS) {

View File

@ -3760,7 +3760,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
} else if (op->src[0]->type == GGML_TYPE_F32) { } else if (op->src[0]->type == GGML_TYPE_F32) {
return op->src[1]->type == GGML_TYPE_F32; return op->src[1]->type == GGML_TYPE_F32;
} else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 ||
op->src[0]->type == GGML_TYPE_MXFP4_E2M1 || op->src[0]->type == GGML_TYPE_MXFP4 ||
op->src[0]->type == GGML_TYPE_Q4_K || op->src[0]->type == GGML_TYPE_Q4_K ||
op->src[0]->type == GGML_TYPE_Q6_K) { op->src[0]->type == GGML_TYPE_Q6_K) {
return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
@ -3771,7 +3771,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
if (op->src[0]->type == GGML_TYPE_Q4_0 || if (op->src[0]->type == GGML_TYPE_Q4_0 ||
op->src[0]->type == GGML_TYPE_Q8_0 || op->src[0]->type == GGML_TYPE_Q8_0 ||
op->src[0]->type == GGML_TYPE_MXFP4_E2M1) { op->src[0]->type == GGML_TYPE_MXFP4) {
if (op->src[1]->type == GGML_TYPE_F32) { if (op->src[1]->type == GGML_TYPE_F32) {
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
} }
@ -4559,7 +4559,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
#endif // GGML_OPENCL_USE_ADRENO_KERNELS #endif // GGML_OPENCL_USE_ADRENO_KERNELS
return; return;
} }
if (tensor->type == GGML_TYPE_MXFP4_E2M1) { if (tensor->type == GGML_TYPE_MXFP4) {
ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
@ -5136,7 +5136,7 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
CL_CHECK(clReleaseMemObject(data_device)); CL_CHECK(clReleaseMemObject(data_device));
return; return;
} }
if (tensor->type == GGML_TYPE_MXFP4_E2M1) { if (tensor->type == GGML_TYPE_MXFP4) {
ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra; ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra;
cl_int err; cl_int err;
@ -5585,7 +5585,7 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso
CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL)); CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL));
CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_d, buf_d, 0, NULL, NULL)); CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_d, buf_d, 0, NULL, NULL));
CL_CHECK(clFinish(queue)); CL_CHECK(clFinish(queue));
} else if (tensor->type == GGML_TYPE_MXFP4_E2M1) { } else if (tensor->type == GGML_TYPE_MXFP4) {
ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *) tensor->extra; ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *) tensor->extra;
GGML_ASSERT(extra); GGML_ASSERT(extra);
@ -10550,7 +10550,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3));
#endif // GGML_OPENCL_SOA_Q #endif // GGML_OPENCL_SOA_Q
break; break;
case GGML_TYPE_MXFP4_E2M1: { case GGML_TYPE_MXFP4: {
#ifdef GGML_OPENCL_SOA_Q #ifdef GGML_OPENCL_SOA_Q
kernel = backend_ctx->kernel_mul_mv_mxfp4_f32_flat; kernel = backend_ctx->kernel_mul_mv_mxfp4_f32_flat;
@ -10630,7 +10630,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
GGML_ASSERT(false && "not implemented"); GGML_ASSERT(false && "not implemented");
} }
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4_E2M1 || if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4 ||
src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q8_0 ||
src0t == GGML_TYPE_Q2_K) { src0t == GGML_TYPE_Q2_K) {
@ -10864,7 +10864,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
#endif // GGML_OPENCL_SOA_Q #endif // GGML_OPENCL_SOA_Q
break; break;
} }
case GGML_TYPE_MXFP4_E2M1: { case GGML_TYPE_MXFP4: {
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
if (use_adreno_moe_kernels(backend_ctx, src0)) { if (use_adreno_moe_kernels(backend_ctx, src0)) {
cl_int status; cl_int status;

View File

@ -267,10 +267,12 @@ uint8_t float_to_fp8_e4m3_rn(float x) { return ggml_mxfp_float_to_fp8_e4m3(x); }
// MSE-optimal E8M0: tests candidates around round(log2(amax)), picks lowest quantization error. // MSE-optimal E8M0: tests candidates around round(log2(amax)), picks lowest quantization error.
typedef struct { typedef struct {
int emax_offset; // type-specific offset to max representable exponent int emax_offset; // type-specific offset to max representable exponent
int qs_per_block; // quantized scalar bytes per 32-element block
int bits_per_elem; // 8 = byte-aligned, 6 = packed via fp6x4
uint8_t (*to_elem)(float); uint8_t (*to_elem)(float);
float (*to_float)(uint8_t); float (*to_float)(uint8_t);
float (*mse_error)(float val, float inv_scale, float scale); float (*mse_error)(float val, float inv_scale, float scale); // NULL = use generic round-trip via to_elem/to_float
} mxfp_elem_traits_t; } mxfp_elem_traits_t;
static inline int best_index_mxfp4(float x, float e); static inline int best_index_mxfp4(float x, float e);
@ -294,7 +296,7 @@ static float mse_error_mxfp4(float val, float inv_scale, float scale) {
return err * err; return err * err;
} }
static const mxfp_elem_traits_t mxfp4_traits = { MXFP4_E2M1_EMAX_OFFSET, NULL, NULL, mse_error_mxfp4 }; static const mxfp_elem_traits_t mxfp4_traits = { MXFP4_E2M1_EMAX_OFFSET, MXFP4_SOA_QS_PER_BLOCK, 4, NULL, NULL, mse_error_mxfp4 };
static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_elem_traits_t * traits) { static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_elem_traits_t * traits) {
float amax = 0.0f; float amax = 0.0f;
@ -319,7 +321,13 @@ static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_
const float test_inv = 1.0f / test_scale; const float test_inv = 1.0f / test_scale;
float mse = 0.0f; float mse = 0.0f;
for (int j = 0; j < qk; ++j) { for (int j = 0; j < qk; ++j) {
mse += traits->mse_error(x[j], test_inv, test_scale); if (traits->mse_error) {
mse += traits->mse_error(x[j], test_inv, test_scale);
} else {
const float recon = traits->to_float(traits->to_elem(x[j] * test_inv)) * test_scale;
const float err = x[j] - recon;
mse += err * err;
}
} }
if (mse < best_mse) { if (mse < best_mse) {
best_mse = mse; best_mse = mse;
@ -574,102 +582,8 @@ uint8_t float_to_fp8_e5m2_rn(float x) { return ggml_mxfp_float_to_fp8_e5m2(x); }
void pack_fp6x4(const uint8_t v[4], uint8_t out[3]) { ggml_mxfp_pack_fp6x4(v, out); } void pack_fp6x4(const uint8_t v[4], uint8_t out[3]) { ggml_mxfp_pack_fp6x4(v, out); }
void unpack_fp6x4(const uint8_t in[3], uint8_t v[4]) { ggml_mxfp_unpack_fp6x4(in, v); } void unpack_fp6x4(const uint8_t in[3], uint8_t v[4]) { ggml_mxfp_unpack_fp6x4(in, v); }
static float mse_error_fp8_e4m3(float val, float inv_scale, float scale) { static const mxfp_elem_traits_t mxfp8_e4m3_traits = { MXFP8_E4M3_EMAX_OFFSET, MXFP8_SOA_QS_PER_BLOCK, 8, float_to_fp8_e4m3_rn, fp8_e4m3_to_float, NULL };
const float recon = fp8_e4m3_to_float(float_to_fp8_e4m3_rn(val * inv_scale)) * scale; static const mxfp_elem_traits_t mxfp6_e2m3_traits = { MXFP6_E2M3_EMAX_OFFSET, MXFP6_SOA_QS_PER_BLOCK, 6, float_to_fp6_e2m3_rn, fp6_e2m3_to_float, NULL };
const float err = val - recon;
return err * err;
}
static float mse_error_fp6_e2m3(float val, float inv_scale, float scale) {
const float recon = fp6_e2m3_to_float(float_to_fp6_e2m3_rn(val * inv_scale)) * scale;
const float err = val - recon;
return err * err;
}
static const mxfp_elem_traits_t mxfp8_e4m3_traits = { MXFP8_E4M3_EMAX_OFFSET, float_to_fp8_e4m3_rn, fp8_e4m3_to_float, mse_error_fp8_e4m3 };
static const mxfp_elem_traits_t mxfp6_e2m3_traits = { MXFP6_E2M3_EMAX_OFFSET, float_to_fp6_e2m3_rn, fp6_e2m3_to_float, mse_error_fp6_e2m3 };
static void quantize_row_mxfp8_impl(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y,
int64_t k, const mxfp_elem_traits_t * traits) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
for (int i = 0; i < nb; i++) {
const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP8], QK_MXFP8, traits);
const float d = GGML_E8M0_TO_FP32(e);
const float inv_d = d > 0.0f ? 1.0f / d : 0.0f;
y[i].e = e;
for (int j = 0; j < QK_MXFP8; ++j) {
y[i].qs[j] = traits->to_elem(x[i*QK_MXFP8 + j] * inv_d);
}
}
}
static void dequantize_row_mxfp8_impl(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y,
int64_t k, const mxfp_elem_traits_t * traits) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
for (int i = 0; i < nb; i++) {
const float d = GGML_E8M0_TO_FP32(x[i].e);
for (int j = 0; j < QK_MXFP8; ++j) {
y[i*QK_MXFP8 + j] = traits->to_float(x[i].qs[j]) * d;
}
}
}
static void quantize_row_mxfp6_impl(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y,
int64_t k, const mxfp_elem_traits_t * traits) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
for (int i = 0; i < nb; i++) {
const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP6], QK_MXFP6, traits);
const float d = GGML_E8M0_TO_FP32(e);
const float inv_d = d > 0.0f ? 1.0f / d : 0.0f;
y[i].e = e;
for (int j = 0; j < QK_MXFP6; j += 4) {
uint8_t vals[4];
for (int jj = 0; jj < 4; jj++) {
vals[jj] = traits->to_elem(x[i*QK_MXFP6 + j + jj] * inv_d);
}
pack_fp6x4(vals, &y[i].qs[j * 3 / 4]);
}
}
}
static void dequantize_row_mxfp6_impl(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y,
int64_t k, const mxfp_elem_traits_t * traits) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
for (int i = 0; i < nb; i++) {
const float d = GGML_E8M0_TO_FP32(x[i].e);
for (int j = 0; j < QK_MXFP6; j += 4) {
uint8_t vals[4];
unpack_fp6x4(&x[i].qs[j * 3 / 4], vals);
for (int jj = 0; jj < 4; jj++) {
y[i*QK_MXFP6 + j + jj] = traits->to_float(vals[jj]) * d;
}
}
}
}
void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k) {
quantize_row_mxfp8_impl(x, y, k, &mxfp8_e4m3_traits);
}
void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp8_impl(x, y, k, &mxfp8_e4m3_traits);
}
void quantize_row_mxfp6_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k) {
quantize_row_mxfp6_impl(x, y, k, &mxfp6_e2m3_traits);
}
void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp6_impl(x, y, k, &mxfp6_e2m3_traits);
}
// ====================== SoA (Struct-of-Arrays) quantize/dequantize for flash attention // ====================== SoA (Struct-of-Arrays) quantize/dequantize for flash attention
@ -715,101 +629,79 @@ void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTR
} }
} }
static void quantize_row_mxfp8_soa_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, // Unified SoA quantize for byte-aligned (FP8) and 6-bit packed (FP6) MXFP formats.
int64_t k, const mxfp_elem_traits_t * traits) { static void quantize_row_mxfp_soa_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst,
assert(k % QK_MXFP8 == 0); int64_t k, const mxfp_elem_traits_t * traits) {
const int nb = k / QK_MXFP8; const int qk = 32;
char * row = (char *)dst; assert(k % qk == 0);
char * qs_base = row; const int nb = k / qk;
char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK); const int qpb = traits->qs_per_block;
char * qs_base = (char *)dst;
char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, qpb);
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP8], QK_MXFP8, traits); const uint8_t e = mxfp_compute_e8m0_mse(&x[i*qk], qk, traits);
const float d = GGML_E8M0_TO_FP32(e); const float d = GGML_E8M0_TO_FP32(e);
const float inv_d = d > 0.0f ? 1.0f / d : 0.0f; const float inv_d = d > 0.0f ? 1.0f / d : 0.0f;
e8m0_base[i] = (char)e; e8m0_base[i] = (char)e;
uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP8_SOA_QS_PER_BLOCK)); uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, qpb));
for (int j = 0; j < QK_MXFP8; ++j) { if (traits->bits_per_elem == 8) {
qs[j] = traits->to_elem(x[i*QK_MXFP8 + j] * inv_d); for (int j = 0; j < qk; ++j) {
} qs[j] = traits->to_elem(x[i*qk + j] * inv_d);
} }
} } else {
for (int j = 0; j < qk; j += 4) {
static void dequantize_row_mxfp8_soa_impl(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, uint8_t vals[4];
int64_t k, const mxfp_elem_traits_t * traits) { for (int jj = 0; jj < 4; jj++) {
assert(k % QK_MXFP8 == 0); vals[jj] = traits->to_elem(x[i*qk + j + jj] * inv_d);
const int nb = k / QK_MXFP8; }
const char * row = (const char *)src; pack_fp6x4(vals, &qs[j * 3 / 4]);
const char * qs_base = row;
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK);
for (int i = 0; i < nb; i++) {
const float d = GGML_E8M0_TO_FP32((uint8_t)e8m0_base[i]);
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP8_SOA_QS_PER_BLOCK));
for (int j = 0; j < QK_MXFP8; ++j) {
y[i*QK_MXFP8 + j] = traits->to_float(qs[j]) * d;
}
}
}
static void quantize_row_mxfp6_soa_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst,
int64_t k, const mxfp_elem_traits_t * traits) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
char * row = (char *)dst;
char * qs_base = row;
char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK);
for (int i = 0; i < nb; i++) {
const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP6], QK_MXFP6, traits);
const float d = GGML_E8M0_TO_FP32(e);
const float inv_d = d > 0.0f ? 1.0f / d : 0.0f;
e8m0_base[i] = (char)e;
uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP6_SOA_QS_PER_BLOCK));
for (int j = 0; j < QK_MXFP6; j += 4) {
uint8_t vals[4];
for (int jj = 0; jj < 4; jj++) {
vals[jj] = traits->to_elem(x[i*QK_MXFP6 + j + jj] * inv_d);
} }
pack_fp6x4(vals, &qs[j * 3 / 4]);
} }
} }
} }
static void dequantize_row_mxfp6_soa_impl(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, // Unified SoA dequantize for byte-aligned (FP8) and 6-bit packed (FP6) MXFP formats.
int64_t k, const mxfp_elem_traits_t * traits) { static void dequantize_row_mxfp_soa_impl(const void * GGML_RESTRICT src, float * GGML_RESTRICT y,
assert(k % QK_MXFP6 == 0); int64_t k, const mxfp_elem_traits_t * traits) {
const int nb = k / QK_MXFP6; const int qk = 32;
const char * row = (const char *)src; assert(k % qk == 0);
const char * qs_base = row; const int nb = k / qk;
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK); const int qpb = traits->qs_per_block;
const char * qs_base = (const char *)src;
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, qpb);
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const float d = GGML_E8M0_TO_FP32((uint8_t)e8m0_base[i]); const float d = GGML_E8M0_TO_FP32((uint8_t)e8m0_base[i]);
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP6_SOA_QS_PER_BLOCK)); const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, qpb));
for (int j = 0; j < QK_MXFP6; j += 4) { if (traits->bits_per_elem == 8) {
uint8_t vals[4]; for (int j = 0; j < qk; ++j) {
unpack_fp6x4(&qs[j * 3 / 4], vals); y[i*qk + j] = traits->to_float(qs[j]) * d;
for (int jj = 0; jj < 4; jj++) { }
y[i*QK_MXFP6 + j + jj] = traits->to_float(vals[jj]) * d; } else {
for (int j = 0; j < qk; j += 4) {
uint8_t vals[4];
unpack_fp6x4(&qs[j * 3 / 4], vals);
for (int jj = 0; jj < 4; jj++) {
y[i*qk + j + jj] = traits->to_float(vals[jj]) * d;
}
} }
} }
} }
} }
void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) { void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) {
quantize_row_mxfp8_soa_impl(x, dst, k, &mxfp8_e4m3_traits); quantize_row_mxfp_soa_impl(x, dst, k, &mxfp8_e4m3_traits);
} }
void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) { void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp8_soa_impl(src, y, k, &mxfp8_e4m3_traits); dequantize_row_mxfp_soa_impl(src, y, k, &mxfp8_e4m3_traits);
} }
void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) { void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) {
quantize_row_mxfp6_soa_impl(x, dst, k, &mxfp6_e2m3_traits); quantize_row_mxfp_soa_impl(x, dst, k, &mxfp6_e2m3_traits);
} }
void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) { void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) {
dequantize_row_mxfp6_soa_impl(src, y, k, &mxfp6_e2m3_traits); dequantize_row_mxfp_soa_impl(src, y, k, &mxfp6_e2m3_traits);
} }
// //
// 2-6 bit quantization in super-blocks // 2-6 bit quantization in super-blocks
@ -2472,7 +2364,7 @@ size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
GGML_UNUSED(quant_weights); GGML_UNUSED(quant_weights);
quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row); quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_MXFP4_E2M1, n_per_row); return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
} }
size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
@ -2481,18 +2373,6 @@ size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
return nrow * ggml_row_size(GGML_TYPE_NVFP4, n_per_row); return nrow * ggml_row_size(GGML_TYPE_NVFP4, n_per_row);
} }
size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
GGML_UNUSED(quant_weights);
quantize_row_mxfp8_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_MXFP8_E4M3, n_per_row);
}
size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
GGML_UNUSED(quant_weights);
quantize_row_mxfp6_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_MXFP6_E2M3, n_per_row);
}
// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs) // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) { void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {
@ -5635,15 +5515,15 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
{ {
VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb); VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
} break; } break;
case GGML_TYPE_MXFP4_E2M1: case GGML_TYPE_MXFP4:
{ {
VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb); VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb);
} break; } break;
case GGML_TYPE_MXFP8_E4M3: case GGML_TYPE_MXFP8:
{ {
VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp8, data, nb); VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp8, data, nb);
} break; } break;
case GGML_TYPE_MXFP6_E2M3: case GGML_TYPE_MXFP6:
{ {
VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp6, data, nb); VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp6, data, nb);
} break; } break;

View File

@ -23,9 +23,6 @@ GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 *
GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp6_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
@ -52,9 +49,6 @@ GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GG
GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
// SoA quantize/dequantize for flash attention // SoA quantize/dequantize for flash attention
GGML_API void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); GGML_API void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
GGML_API void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
@ -110,9 +104,6 @@ GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTR
GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
// MXFP element converters // MXFP element converters
GGML_API float fp8_e4m3_to_float(uint8_t v); GGML_API float fp8_e4m3_to_float(uint8_t v);
GGML_API uint8_t float_to_fp8_e4m3_rn(float x); GGML_API uint8_t float_to_fp8_e4m3_rn(float x);

View File

@ -639,7 +639,7 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
return dequantize_row_iq4_xs_sycl; return dequantize_row_iq4_xs_sycl;
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_sycl; return dequantize_row_iq4_nl_sycl;
case GGML_TYPE_MXFP4_E2M1: case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_sycl; return dequantize_row_mxfp4_sycl;
case GGML_TYPE_F32: case GGML_TYPE_F32:
return convert_unary_sycl<float>; return convert_unary_sycl<float>;
@ -706,7 +706,7 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
return dequantize_row_iq4_xs_sycl; return dequantize_row_iq4_xs_sycl;
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_sycl; return dequantize_row_iq4_nl_sycl;
case GGML_TYPE_MXFP4_E2M1: case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_sycl; return dequantize_row_mxfp4_sycl;
case GGML_TYPE_F16: case GGML_TYPE_F16:
return convert_unary_sycl<sycl::half>; return convert_unary_sycl<sycl::half>;

View File

@ -1142,7 +1142,7 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_XS:
mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break; break;
case GGML_TYPE_MXFP4_E2M1: case GGML_TYPE_MXFP4:
mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break; break;
default: default:

View File

@ -710,7 +710,7 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref, .from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref,
}, },
[GGML_TYPE_MXFP4_E2M1] = { [GGML_TYPE_MXFP4] = {
.type_name = "mxfp4", .type_name = "mxfp4",
.blck_size = QK_MXFP4, .blck_size = QK_MXFP4,
.type_size = sizeof(block_mxfp4), .type_size = sizeof(block_mxfp4),
@ -726,21 +726,17 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
.to_float = (ggml_to_float_t) dequantize_row_nvfp4, .to_float = (ggml_to_float_t) dequantize_row_nvfp4,
.from_float_ref = (ggml_from_float_t)quantize_row_nvfp4_ref, .from_float_ref = (ggml_from_float_t)quantize_row_nvfp4_ref,
}, },
[GGML_TYPE_MXFP8_E4M3] = { [GGML_TYPE_MXFP8] = {
.type_name = "mxfp8_e4m3", .type_name = "mxfp8_e4m3",
.blck_size = QK_MXFP8, .blck_size = QK_MXFP8,
.type_size = sizeof(block_mxfp8), .type_size = sizeof(block_mxfp8),
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_mxfp8,
.from_float_ref = (ggml_from_float_t)quantize_row_mxfp8_ref,
}, },
[GGML_TYPE_MXFP6_E2M3] = { [GGML_TYPE_MXFP6] = {
.type_name = "mxfp6_e2m3", .type_name = "mxfp6_e2m3",
.blck_size = QK_MXFP6, .blck_size = QK_MXFP6,
.type_size = sizeof(block_mxfp6), .type_size = sizeof(block_mxfp6),
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_mxfp6,
.from_float_ref = (ggml_from_float_t)quantize_row_mxfp6_ref,
}, },
[GGML_TYPE_Q2_K] = { [GGML_TYPE_Q2_K] = {
.type_name = "q2_K", .type_name = "q2_K",
@ -1329,25 +1325,25 @@ bool ggml_is_quantized(enum ggml_type type) {
} }
bool ggml_is_type_mxfp(enum ggml_type type) { bool ggml_is_type_mxfp(enum ggml_type type) {
return type == GGML_TYPE_MXFP4_E2M1 || return type == GGML_TYPE_MXFP4 ||
type == GGML_TYPE_MXFP8_E4M3 || type == GGML_TYPE_MXFP8 ||
type == GGML_TYPE_MXFP6_E2M3; type == GGML_TYPE_MXFP6;
} }
bool ggml_mxfp_use_hadamard(enum ggml_type type) { bool ggml_mxfp_use_hadamard(enum ggml_type type) {
switch (type) { switch (type) {
case GGML_TYPE_MXFP4_E2M1: return MXFP_USE_HADAMARD_E2M1; case GGML_TYPE_MXFP4: return MXFP_USE_HADAMARD_E2M1;
case GGML_TYPE_MXFP8_E4M3: return MXFP_USE_HADAMARD_E4M3; case GGML_TYPE_MXFP8: return MXFP_USE_HADAMARD_E4M3;
case GGML_TYPE_MXFP6_E2M3: return MXFP_USE_HADAMARD_E2M3; case GGML_TYPE_MXFP6: return MXFP_USE_HADAMARD_E2M3;
default: return false; default: return false;
} }
} }
int ggml_mxfp_qs_per_block(enum ggml_type type) { int ggml_mxfp_qs_per_block(enum ggml_type type) {
switch (type) { switch (type) {
case GGML_TYPE_MXFP4_E2M1: return MXFP_QS_PER_BLOCK_E2M1; case GGML_TYPE_MXFP4: return MXFP_QS_PER_BLOCK_E2M1;
case GGML_TYPE_MXFP8_E4M3: return MXFP_QS_PER_BLOCK_E4M3; case GGML_TYPE_MXFP8: return MXFP_QS_PER_BLOCK_E4M3;
case GGML_TYPE_MXFP6_E2M3: return MXFP_QS_PER_BLOCK_E2M3; case GGML_TYPE_MXFP6: return MXFP_QS_PER_BLOCK_E2M3;
default: return 0; default: return 0;
} }
} }
@ -7695,10 +7691,10 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_MXFP4_E2M1: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_MXFP8_E4M3: result = quantize_mxfp8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_MXFP8: GGML_ABORT("MXFP8 is KV-cache-only (SoA layout) — use from_float_soa");
case GGML_TYPE_MXFP6_E2M3: result = quantize_mxfp6(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_MXFP6: GGML_ABORT("MXFP6 is KV-cache-only (SoA layout) — use from_float_soa");
case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;

View File

@ -457,7 +457,7 @@ static ggml_type llama_tensor_get_type_impl(quantize_state_impl & qs, ggml_type
// MoE tensors -> MXFP4 // MoE tensors -> MXFP4
// other tensors -> Q8_0 // other tensors -> Q8_0
if (tensor->ne[2] > 1) { if (tensor->ne[2] > 1) {
new_type = GGML_TYPE_MXFP4_E2M1; new_type = GGML_TYPE_MXFP4;
} else { } else {
new_type = GGML_TYPE_Q8_0; new_type = GGML_TYPE_Q8_0;
} }
@ -795,7 +795,7 @@ static ggml_type llama_ftype_get_default_type(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_BF16: return GGML_TYPE_BF16; case LLAMA_FTYPE_MOSTLY_BF16: return GGML_TYPE_BF16;
case LLAMA_FTYPE_ALL_F32: return GGML_TYPE_F32; case LLAMA_FTYPE_ALL_F32: return GGML_TYPE_F32;
case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return GGML_TYPE_MXFP4_E2M1; case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return GGML_TYPE_MXFP4;
// K-quants // K-quants
case LLAMA_FTYPE_MOSTLY_Q2_K_S: case LLAMA_FTYPE_MOSTLY_Q2_K_S:

View File

@ -170,9 +170,9 @@ struct mxfp_soa_fns {
}; };
static const mxfp_soa_fns mxfp_soa_table[] = { static const mxfp_soa_fns mxfp_soa_table[] = {
{ GGML_TYPE_MXFP4_E2M1, quantize_row_mxfp4_soa, dequantize_row_mxfp4_soa }, { GGML_TYPE_MXFP4, quantize_row_mxfp4_soa, dequantize_row_mxfp4_soa },
{ GGML_TYPE_MXFP8_E4M3, quantize_row_mxfp8_soa, dequantize_row_mxfp8_soa }, { GGML_TYPE_MXFP8, quantize_row_mxfp8_soa, dequantize_row_mxfp8_soa },
{ GGML_TYPE_MXFP6_E2M3, quantize_row_mxfp6_soa, dequantize_row_mxfp6_soa }, { GGML_TYPE_MXFP6, quantize_row_mxfp6_soa, dequantize_row_mxfp6_soa },
}; };
static const mxfp_soa_fns * get_mxfp_soa(ggml_type type) { static const mxfp_soa_fns * get_mxfp_soa(ggml_type type) {
@ -3908,7 +3908,7 @@ struct test_mul_mat : public test_case {
double max_nmse_err(ggml_backend_t backend) override { double max_nmse_err(ggml_backend_t backend) override {
// for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance // for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance
if (type_a == GGML_TYPE_MXFP4_E2M1 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) { if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
return 2e-2; return 2e-2;
} }
return max_nmse_err(); return max_nmse_err();
@ -4044,7 +4044,7 @@ struct test_mul_mat_id : public test_case {
double max_nmse_err(ggml_backend_t backend) override { double max_nmse_err(ggml_backend_t backend) override {
// for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance // for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance
if (type_a == GGML_TYPE_MXFP4_E2M1 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) { if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
return 2e-2; return 2e-2;
} }
return max_nmse_err(); return max_nmse_err();
@ -7398,7 +7398,7 @@ static const ggml_type all_types[] = {
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
GGML_TYPE_Q8_0, GGML_TYPE_Q8_0,
GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP4,
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
GGML_TYPE_Q6_K, GGML_TYPE_Q6_K,
@ -7414,7 +7414,7 @@ static const ggml_type base_types[] = {
GGML_TYPE_Q4_0, GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1, // for I8MM tests GGML_TYPE_Q4_1, // for I8MM tests
GGML_TYPE_Q4_K, GGML_TYPE_Q4_K,
GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP4,
GGML_TYPE_IQ2_XXS GGML_TYPE_IQ2_XXS
}; };
@ -7533,8 +7533,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
} }
// SET_ROWS with Hadamard rotation (exercises the op_params[0] flag used by MXFP KV cache) // SET_ROWS with Hadamard rotation (exercises the op_params[0] flag used by MXFP KV cache)
for (ggml_type type : {GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP8_E4M3, for (ggml_type type : {GGML_TYPE_MXFP4, GGML_TYPE_MXFP8,
GGML_TYPE_MXFP6_E2M3}) { GGML_TYPE_MXFP6}) {
// ne[0] must be divisible by 32 (Hadamard block size) // ne[0] must be divisible by 32 (Hadamard block size)
test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 128, 5, 1, 1 }, { 1, 1 }, 1, false, true)); test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 128, 5, 1, 1 }, { 1, 1 }, 1, false, true));
test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 5, 1, 3 }, { 1, 1 }, 1, false, true)); test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 5, 1, 3 }, { 1, 1 }, 1, false, true));
@ -8270,7 +8270,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat_id_fusion(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3)); test_cases.emplace_back(new test_mul_mat_id_fusion(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
// gpt-oss issue with Vulkan mmq_id // gpt-oss issue with Vulkan mmq_id
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4_E2M1, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880)); test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
for (ggml_type type_a : base_types) { for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
@ -8731,7 +8731,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) { for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue; if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0,
GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3, GGML_TYPE_MXFP4, GGML_TYPE_MXFP8, GGML_TYPE_MXFP6,
}) { }) {
// Non-F16 types: test at D=64, D=72, and D=128. // Non-F16 types: test at D=64, D=72, and D=128.
if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72 && hsk != 128) continue; if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72 && hsk != 128) continue;
@ -8760,8 +8760,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
// MXFP-specific K/V type combinations (mixed and same-type) // MXFP-specific K/V type combinations (mixed and same-type)
// Mixed: mxfp8 K + mxfp4 V, mxfp6 K + mxfp4 V (our recommended configs) // Mixed: mxfp8 K + mxfp4 V, mxfp6 K + mxfp4 V (our recommended configs)
for (ggml_type type_K : {GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3}) { for (ggml_type type_K : {GGML_TYPE_MXFP8, GGML_TYPE_MXFP6}) {
for (ggml_type type_V : {GGML_TYPE_MXFP4_E2M1}) { for (ggml_type type_V : {GGML_TYPE_MXFP4}) {
if (type_K == type_V) continue; if (type_K == type_V) continue;
for (int nb : {1, 3, 32}) { for (int nb : {1, 3, 32}) {
test_cases.emplace_back(new test_flash_attn_ext( test_cases.emplace_back(new test_flash_attn_ext(
@ -8770,7 +8770,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
} }
} }
// Same-type: mxfp8/mxfp8, mxfp6/mxfp6 // Same-type: mxfp8/mxfp8, mxfp6/mxfp6
for (ggml_type type_KV : {GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3}) { for (ggml_type type_KV : {GGML_TYPE_MXFP8, GGML_TYPE_MXFP6}) {
for (int nb : {1, 3, 32}) { for (int nb : {1, 3, 32}) {
test_cases.emplace_back(new test_flash_attn_ext( test_cases.emplace_back(new test_flash_attn_ext(
128, 128, 4, {1, 1}, 512, nb, true, false, 0.0f, 0.0f, GGML_PREC_F32, type_KV, {0, 1, 2, 3}, type_KV)); 128, 128, 4, {1, 1}, 512, nb, true, false, 0.0f, 0.0f, GGML_PREC_F32, type_KV, {0, 1, 2, 3}, type_KV));
@ -9000,7 +9000,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
// gpt-oss-20b // gpt-oss-20b
for (int bs : {1, 4, 8, 512}) { for (int bs : {1, 4, 8, 512}) {
for (ggml_type type_a : {GGML_TYPE_MXFP4_E2M1}) { for (ggml_type type_a : {GGML_TYPE_MXFP4}) {
for (ggml_type type_b : {GGML_TYPE_F32}) { for (ggml_type type_b : {GGML_TYPE_F32}) {
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 2880, bs, 2880)); test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 2880, bs, 2880));
test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1)); test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1));

View File

@ -178,9 +178,9 @@ int main(int argc, char * argv[]) {
type == GGML_TYPE_IQ3_S ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS : type == GGML_TYPE_IQ3_S ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS : type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS :
type == GGML_TYPE_NVFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_FP4 : type == GGML_TYPE_NVFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_FP4 :
type == GGML_TYPE_MXFP4_E2M1 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 : type == GGML_TYPE_MXFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 :
type == GGML_TYPE_MXFP6_E2M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 : type == GGML_TYPE_MXFP6 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 :
type == GGML_TYPE_MXFP8_E4M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR; type == GGML_TYPE_MXFP8 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR;
failed = !(total_error < max_quantization_error); failed = !(total_error < max_quantization_error);
num_failed += failed; num_failed += failed;
if (failed || verbose) { if (failed || verbose) {
@ -202,7 +202,7 @@ int main(int argc, char * argv[]) {
? MAX_DOT_PRODUCT_ERROR_TERNARY ? MAX_DOT_PRODUCT_ERROR_TERNARY
: type == GGML_TYPE_NVFP4 : type == GGML_TYPE_NVFP4
? MAX_DOT_PRODUCT_ERROR_FP4 ? MAX_DOT_PRODUCT_ERROR_FP4
: type == GGML_TYPE_MXFP4_E2M1 || type == GGML_TYPE_MXFP6_E2M3 || type == GGML_TYPE_MXFP8_E4M3 : type == GGML_TYPE_MXFP4 || type == GGML_TYPE_MXFP6 || type == GGML_TYPE_MXFP8
? MAX_DOT_PRODUCT_ERROR_MXFP ? MAX_DOT_PRODUCT_ERROR_MXFP
: MAX_DOT_PRODUCT_ERROR; : MAX_DOT_PRODUCT_ERROR;
failed = !(vec_dot_error < max_allowed_error); failed = !(vec_dot_error < max_allowed_error);
@ -231,9 +231,9 @@ int main(int argc, char * argv[]) {
const float soa_error = array_rmse(test_data.data(), tmp_out.data(), test_size); const float soa_error = array_rmse(test_data.data(), tmp_out.data(), test_size);
const float max_soa_error = const float max_soa_error =
type == GGML_TYPE_MXFP4_E2M1 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 : type == GGML_TYPE_MXFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 :
type == GGML_TYPE_MXFP6_E2M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 : type == GGML_TYPE_MXFP6 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 :
type == GGML_TYPE_MXFP8_E4M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR; type == GGML_TYPE_MXFP8 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR;
failed = !(soa_error < max_soa_error); failed = !(soa_error < max_soa_error);
num_failed += failed; num_failed += failed;
if (failed || verbose) { if (failed || verbose) {
@ -243,7 +243,7 @@ int main(int argc, char * argv[]) {
// MXFP traits: SoA required, MXFP6/MXFP8 are KV-cache-only (no AoS dequant) // MXFP traits: SoA required, MXFP6/MXFP8 are KV-cache-only (no AoS dequant)
{ {
const ggml_type all_mxfp_types[] = { GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3 }; const ggml_type all_mxfp_types[] = { GGML_TYPE_MXFP4, GGML_TYPE_MXFP8, GGML_TYPE_MXFP6 };
for (ggml_type type : all_mxfp_types) { for (ggml_type type : all_mxfp_types) {
const auto * cpu = ggml_get_type_traits_cpu(type); const auto * cpu = ggml_get_type_traits_cpu(type);
@ -255,7 +255,7 @@ int main(int argc, char * argv[]) {
} }
// KV-cache-only types: no AoS dequant // KV-cache-only types: no AoS dequant
const ggml_type kv_only_types[] = { GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3 }; const ggml_type kv_only_types[] = { GGML_TYPE_MXFP8, GGML_TYPE_MXFP6 };
for (ggml_type type : kv_only_types) { for (ggml_type type : kv_only_types) {
const auto * cpu = ggml_get_type_traits_cpu(type); const auto * cpu = ggml_get_type_traits_cpu(type);
failed = (cpu->to_float != nullptr); failed = (cpu->to_float != nullptr);
@ -297,9 +297,9 @@ int main(int argc, char * argv[]) {
}; };
const soa_cross_check checks[] = { const soa_cross_check checks[] = {
{ GGML_TYPE_MXFP4_E2M1, dequantize_row_mxfp4_soa }, { GGML_TYPE_MXFP4, dequantize_row_mxfp4_soa },
{ GGML_TYPE_MXFP8_E4M3, dequantize_row_mxfp8_soa }, { GGML_TYPE_MXFP8, dequantize_row_mxfp8_soa },
{ GGML_TYPE_MXFP6_E2M3, dequantize_row_mxfp6_soa }, { GGML_TYPE_MXFP6, dequantize_row_mxfp6_soa },
}; };
for (const auto & c : checks) { for (const auto & c : checks) {
@ -774,9 +774,9 @@ int main(int argc, char * argv[]) {
// SoA layout: verify offset macros produce correct byte positions // SoA layout: verify offset macros produce correct byte positions
{ {
const struct { ggml_type type; int qs_per_block; } soa_types[] = { const struct { ggml_type type; int qs_per_block; } soa_types[] = {
{ GGML_TYPE_MXFP4_E2M1, MXFP4_SOA_QS_PER_BLOCK }, { GGML_TYPE_MXFP4, MXFP4_SOA_QS_PER_BLOCK },
{ GGML_TYPE_MXFP8_E4M3, MXFP8_SOA_QS_PER_BLOCK }, { GGML_TYPE_MXFP8, MXFP8_SOA_QS_PER_BLOCK },
{ GGML_TYPE_MXFP6_E2M3, MXFP6_SOA_QS_PER_BLOCK }, { GGML_TYPE_MXFP6, MXFP6_SOA_QS_PER_BLOCK },
}; };
for (const auto & st : soa_types) { for (const auto & st : soa_types) {
@ -864,7 +864,7 @@ int main(int argc, char * argv[]) {
dequantize_row_mxfp4(aos_q.data(), aos_out.data(), nelems); dequantize_row_mxfp4(aos_q.data(), aos_out.data(), nelems);
// Quantize and dequant via SoA // Quantize and dequant via SoA
const size_t soa_buf_size = ggml_row_size(GGML_TYPE_MXFP4_E2M1, nelems); const size_t soa_buf_size = ggml_row_size(GGML_TYPE_MXFP4, nelems);
std::vector<uint8_t> soa_q(soa_buf_size); std::vector<uint8_t> soa_q(soa_buf_size);
std::vector<float> soa_out(nelems); std::vector<float> soa_out(nelems);
quantize_row_mxfp4_soa(input, soa_q.data(), nelems); quantize_row_mxfp4_soa(input, soa_q.data(), nelems);
@ -901,9 +901,9 @@ int main(int argc, char * argv[]) {
}; };
const hadamard_pipeline_check pipeline_checks[] = { const hadamard_pipeline_check pipeline_checks[] = {
{ "mxfp4", GGML_TYPE_MXFP4_E2M1, MAX_MXFP_PIPELINE_ERROR_MXFP4 }, { "mxfp4", GGML_TYPE_MXFP4, MAX_MXFP_PIPELINE_ERROR_MXFP4 },
{ "mxfp8", GGML_TYPE_MXFP8_E4M3, MAX_MXFP_PIPELINE_ERROR_MXFP8 }, { "mxfp8", GGML_TYPE_MXFP8, MAX_MXFP_PIPELINE_ERROR_MXFP8 },
{ "mxfp6", GGML_TYPE_MXFP6_E2M3, MAX_MXFP_PIPELINE_ERROR_MXFP6 }, { "mxfp6", GGML_TYPE_MXFP6, MAX_MXFP_PIPELINE_ERROR_MXFP6 },
}; };
for (const auto & p : pipeline_checks) { for (const auto & p : pipeline_checks) {
@ -963,7 +963,7 @@ int main(int argc, char * argv[]) {
// zero block produces E8M0=0 // zero block produces E8M0=0
{ {
float zeros[32] = {}; float zeros[32] = {};
const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP8_E4M3, 32); const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP8, 32);
std::vector<uint8_t> buf(buf_size, 0xFF); // fill with 0xFF to detect non-writes std::vector<uint8_t> buf(buf_size, 0xFF); // fill with 0xFF to detect non-writes
quantize_row_mxfp8_soa(zeros, buf.data(), 32); quantize_row_mxfp8_soa(zeros, buf.data(), 32);
@ -991,7 +991,7 @@ int main(int argc, char * argv[]) {
// MXFP4 // MXFP4
{ {
const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP4_E2M1, nelems); const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP4, nelems);
std::vector<uint8_t> buf(buf_size); std::vector<uint8_t> buf(buf_size);
std::vector<float> ref_out(nelems); std::vector<float> ref_out(nelems);
std::vector<float> manual_out(nelems); std::vector<float> manual_out(nelems);
@ -1032,7 +1032,7 @@ int main(int argc, char * argv[]) {
// MXFP8 // MXFP8
{ {
const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP8_E4M3, nelems); const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP8, nelems);
std::vector<uint8_t> buf(buf_size); std::vector<uint8_t> buf(buf_size);
std::vector<float> ref_out(nelems); std::vector<float> ref_out(nelems);
std::vector<float> manual_out(nelems); std::vector<float> manual_out(nelems);
@ -1069,7 +1069,7 @@ int main(int argc, char * argv[]) {
// MXFP6 // MXFP6
{ {
const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP6_E2M3, nelems); const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP6, nelems);
std::vector<uint8_t> buf(buf_size); std::vector<uint8_t> buf(buf_size);
std::vector<float> ref_out(nelems); std::vector<float> ref_out(nelems);
std::vector<float> manual_out(nelems); std::vector<float> manual_out(nelems);

View File

@ -484,13 +484,13 @@ static ggml_type ggml_type_from_name(const std::string & s) {
return GGML_TYPE_IQ4_NL; return GGML_TYPE_IQ4_NL;
} }
if (s == "mxfp4" || s == "mxfp4_e2m1") { if (s == "mxfp4" || s == "mxfp4_e2m1") {
return GGML_TYPE_MXFP4_E2M1; return GGML_TYPE_MXFP4;
} }
if (s == "mxfp8" || s == "mxfp8_e4m3") { if (s == "mxfp8" || s == "mxfp8_e4m3") {
return GGML_TYPE_MXFP8_E4M3; return GGML_TYPE_MXFP8;
} }
if (s == "mxfp6" || s == "mxfp6_e2m3") { if (s == "mxfp6" || s == "mxfp6_e2m3") {
return GGML_TYPE_MXFP6_E2M3; return GGML_TYPE_MXFP6;
} }
return GGML_TYPE_COUNT; return GGML_TYPE_COUNT;
} }