From c919bc471bbddf0d5df1cec52a50339ce4f279c3 Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Sun, 22 Mar 2026 02:44:56 -0400 Subject: [PATCH] cleanup : remove unused untested code and improve consistency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * cleanup: consolidate MXFP type aliases, fix SoA linker bug on 5 platforms - Add GGML_TYPE_MXFP8 and GGML_TYPE_MXFP6 short aliases (matching existing GGML_TYPE_MXFP4 pattern) and use short names consistently throughout the codebase instead of mixing long/short forms. - Fix missing SoA dequant symbols (dequantize_row_mxfp{4,8,6}_soa_cpu) on loongarch, powerpc, riscv, s390, and wasm by adding proper aliases to each arch section in arch-fallback.h. Previously these were only defined under GGML_CPU_GENERIC, causing linker failures on those platforms when using MXFP flash attention. - Remove 10 files from the PR diff: - 5 arch stub files replaced by arch-fallback.h aliases - 5 rename-only files (sycl, opencl, repack, llama-quant) reverted since the GGML_TYPE_MXFP4 compat alias handles them * cleanup: DRY FP6 unpack, extract mxfp_kv_params + mxfp_dequant_head helper - FP6 unpack: x86 and ARM SIMD versions now call ggml_mxfp_unpack_fp6x4() from ggml-common.h instead of duplicating the scalar bit manipulation. - Extract mxfp_kv_params sub-struct from mxfp_fa_params: the 7 symmetric K/V fields (dequantize, multihead, soa_elems, qs_per_block, head_qs_bytes, head_e8m0_offset, blocks_per_head) are now in a reusable struct accessed as mxfp.k and mxfp.v. - Add mxfp_dequant_head() helper: replaces 4 instances of the multihead SoA extraction pattern (2x memcpy + dequant, with multihead/single-head branching) with a single function call. Future backends get the pattern for free. * cleanup: extract mxfp_kv_params_init to DRY the K/V init blocks The K and V initialization in mxfp_fa_params_init were structurally identical 10-line blocks differing only by tensor/dimension. Extract into mxfp_kv_params_init(type, D, nb2, ne2) so future MXFP formats get the multihead SoA addressing logic automatically. * cleanup: generic MSE round-trip, replace magic buffer sizes with constants - Remove mse_error_fp8_e4m3 and mse_error_fp6_e2m3: these were identical round-trip functions differing only by converter. mxfp_compute_e8m0_mse now uses to_elem/to_float directly when mse_error is NULL (FP8/FP6). MXFP4 keeps its custom decision-tree MSE. New formats get MSE for free by just setting to_elem/to_float in their traits. - Replace magic 1024/1088 buffer sizes in flash attention with named constants MXFP_FA_MAX_D and MXFP_FA_SOA_BUF. One place to change if max head dimension grows. * cleanup: remove dead AoS vec_dot for MXFP8/MXFP6, unify SoA impls MXFP8 and MXFP6 are KV-cache-only types that use SoA layout for flash attention. The AoS vec_dot functions (scalar generic, AVX2, NEON) were dead code — no matmul path uses them. Removed: - ggml_vec_dot_mxfp{8,6}_q8_0 from scalar, x86, ARM, quants.h - ggml_vec_dot_mxfp_q8_0_impl shared helper - arch-fallback.h aliases for vec_dot mxfp8/mxfp6 (12 lines) - vec_dot/vec_dot_type registration in ggml-cpu.c Also unified SoA quantize/dequant: the separate mxfp8_soa_impl and mxfp6_soa_impl functions (4 functions, ~80 lines) are replaced by two generic functions (quantize_row_mxfp_soa_impl, dequantize_row_mxfp_soa_impl) that use traits->bits_per_elem and traits->qs_per_block to handle both byte-aligned (FP8) and 6-bit packed (FP6) formats. New MXFP formats get SoA for free by setting these trait fields. * cleanup: remove all AoS MXFP8/MXFP6 quantize/dequant — SoA only MXFP8 and MXFP6 are KV-cache-only types. All quantization and dequantization goes through the SoA (Struct-of-Arrays) path for flash attention. The AoS (block_mxfp8/block_mxfp6 struct) implementations were dead code that should never have been added. Removed: - quantize_row_mxfp{8,6}_impl, dequantize_row_mxfp{8,6}_impl - quantize_row_mxfp{8,6}_ref, dequantize_row_mxfp{8,6} - quantize_mxfp{8,6} (ggml_quantize_chunk wrappers) - All declarations from ggml-quants.h and quants.h - to_float/from_float_ref registrations from ggml.c type traits - from_float registration from ggml-cpu.c CPU traits Block struct definitions (block_mxfp8, block_mxfp6) are retained for sizeof() in type traits and validate_row_data. * cleanup: fail fast in ggml_quantize_chunk for KV-cache-only types Add explicit GGML_ABORT for MXFP8/MXFP6 in ggml_quantize_chunk — these are KV-cache-only types that use SoA layout via from_float_soa. Attempting AoS quantization through this entry point is a bug. --- common/arg.cpp | 12 +- ggml/include/ggml.h | 2 + ggml/src/ggml-cpu/arch-fallback.h | 27 ++- ggml/src/ggml-cpu/arch/arm/quants.c | 116 +--------- ggml/src/ggml-cpu/arch/loongarch/quants.c | 11 - ggml/src/ggml-cpu/arch/powerpc/quants.c | 7 - ggml/src/ggml-cpu/arch/riscv/quants.c | 8 - ggml/src/ggml-cpu/arch/s390/quants.c | 7 - ggml/src/ggml-cpu/arch/wasm/quants.c | 11 - ggml/src/ggml-cpu/arch/x86/quants.c | 116 +--------- ggml/src/ggml-cpu/ggml-cpu.c | 12 +- ggml/src/ggml-cpu/ops.cpp | 244 +++++++++------------ ggml/src/ggml-cpu/quants.c | 53 ----- ggml/src/ggml-cpu/quants.h | 9 - ggml/src/ggml-cpu/repack.cpp | 6 +- ggml/src/ggml-metal/ggml-metal-device.m | 2 +- ggml/src/ggml-opencl/ggml-opencl.cpp | 16 +- ggml/src/ggml-quants.c | 250 ++++++---------------- ggml/src/ggml-quants.h | 9 - ggml/src/ggml-sycl/convert.cpp | 4 +- ggml/src/ggml-sycl/mmvq.cpp | 2 +- ggml/src/ggml.c | 36 ++-- src/llama-quant.cpp | 4 +- tests/test-backend-ops.cpp | 30 +-- tests/test-quantize-fns.cpp | 46 ++-- tools/llama-bench/llama-bench.cpp | 6 +- 26 files changed, 277 insertions(+), 769 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index ff12646a70..26c1904a2a 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -398,20 +398,20 @@ const std::vector kv_cache_types = { GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, - GGML_TYPE_MXFP4_E2M1, - GGML_TYPE_MXFP8_E4M3, - GGML_TYPE_MXFP6_E2M3, + GGML_TYPE_MXFP4, + GGML_TYPE_MXFP8, + GGML_TYPE_MXFP6, }; static ggml_type kv_cache_type_from_str(const std::string & s) { if (s == "mxfp4") { - return GGML_TYPE_MXFP4_E2M1; + return GGML_TYPE_MXFP4; } if (s == "mxfp6") { - return GGML_TYPE_MXFP6_E2M3; + return GGML_TYPE_MXFP6; } if (s == "mxfp8") { - return GGML_TYPE_MXFP8_E4M3; + return GGML_TYPE_MXFP8; } for (const auto & type : kv_cache_types) { if (ggml_type_name(type) == s) { diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 1f66550459..c068113932 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -430,7 +430,9 @@ extern "C" { GGML_TYPE_MXFP4 = GGML_TYPE_MXFP4_E2M1, // compat alias GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale) GGML_TYPE_MXFP8_E4M3 = 41, // MX FP8 E4M3 + GGML_TYPE_MXFP8 = GGML_TYPE_MXFP8_E4M3, // compat alias GGML_TYPE_MXFP6_E2M3 = 42, // MX FP6 E2M3 + GGML_TYPE_MXFP6 = GGML_TYPE_MXFP6_E2M3, // compat alias GGML_TYPE_COUNT = 43, }; diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index ddeee1fa7e..eac031e68e 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -15,8 +15,9 @@ #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 -#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0 -#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0 +#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu +#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu +#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -113,6 +114,9 @@ // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu +#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu +#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K @@ -161,6 +165,9 @@ #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu +#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu +#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 @@ -201,6 +208,9 @@ #elif defined(__riscv) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu +#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu +#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu // repack.cpp #define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 @@ -241,6 +251,9 @@ // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu +#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu +#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -291,6 +304,9 @@ #elif defined(__wasm__) // quants.c #define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1 +#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu +#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu +#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K @@ -342,10 +358,3 @@ #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #endif - -// MXFP dequantize fallbacks (same GGML_CPU_GENERIC guard as above) -#if defined(GGML_CPU_GENERIC) -#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu -#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu -#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu -#endif diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index 53507b97ce..9a4ac95aab 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -4191,12 +4191,8 @@ static inline float32x4_t mxfp6_dequant_neon( // Unpack 4 tightly-packed 6-bit values from 3 bytes, widen to uint32x4_t. static inline uint32x4_t unpack_fp6x4_neon(const uint8_t * p) { - const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); uint8_t u[4]; - u[0] = (pk >> 0) & 0x3F; - u[1] = (pk >> 6) & 0x3F; - u[2] = (pk >> 12) & 0x3F; - u[3] = (pk >> 18) & 0x3F; + ggml_mxfp_unpack_fp6x4(p, u); const uint8x8_t raw8 = vcreate_u8( (uint64_t)u[0] | ((uint64_t)u[1] << 8) | ((uint64_t)u[2] << 16) | ((uint64_t)u[3] << 24)); @@ -4221,96 +4217,6 @@ static inline void widen_s8x8_to_f32x4x2(const int8_t * src, *hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16))); } -// MXFP FP8/FP6 vec_dot - -static void ggml_vec_dot_mxfp8_q8_0_neon( - int n, float * GGML_RESTRICT s, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - const mxfp_neon_traits_t * t) { - assert(n % QK_MXFP8 == 0); - const int nb = n / QK_MXFP8; - const block_mxfp8 * GGML_RESTRICT x = vx; - const block_q8_0 * GGML_RESTRICT y = vy; - - const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask); - const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask); - const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off); - const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale); - const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift); - const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift); - - float32x4_t acc0 = vdupq_n_f32(0.0f); - float32x4_t acc1 = vdupq_n_f32(0.0f); - - for (int ib = 0; ib < nb; ++ib) { - const float32x4_t v_scale = vdupq_n_f32( - GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d)); - - for (int j = 0; j < 32; j += 8) { - uint32x4_t v_lo, v_hi; - widen_u8x8_to_u32x4x2(x[ib].qs + j, &v_lo, &v_hi); - - float32x4_t qf_lo, qf_hi; - widen_s8x8_to_f32x4x2(y[ib].qs + j, &qf_lo, &qf_hi); - - const float32x4_t val_lo = mxfp8_dequant_neon(v_lo, - v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); - const float32x4_t val_hi = mxfp8_dequant_neon(v_hi, - v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); - - acc0 = vfmaq_f32(acc0, vmulq_f32(val_lo, v_scale), qf_lo); - acc1 = vfmaq_f32(acc1, vmulq_f32(val_hi, v_scale), qf_hi); - } - } - - *s = vaddvq_f32(vaddq_f32(acc0, acc1)); -} - -static void ggml_vec_dot_mxfp6_q8_0_neon( - int n, float * GGML_RESTRICT s, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - const mxfp_neon_traits_t * t) { - assert(n % QK_MXFP6 == 0); - const int nb = n / QK_MXFP6; - const block_q8_0 * GGML_RESTRICT y = vy; - - const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask); - const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask); - const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off); - const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale); - const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift); - const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift); - - float32x4_t acc0 = vdupq_n_f32(0.0f); - float32x4_t acc1 = vdupq_n_f32(0.0f); - - for (int ib = 0; ib < nb; ++ib) { - const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib; - const float32x4_t v_scale = vdupq_n_f32( - GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d)); - - for (int j = 0; j < 32; j += 8) { - const uint32x4_t v_lo = unpack_fp6x4_neon(xb->qs + (j * 3 / 4)); - const uint32x4_t v_hi = unpack_fp6x4_neon(xb->qs + ((j + 4) * 3 / 4)); - - float32x4_t qf_lo, qf_hi; - widen_s8x8_to_f32x4x2(y[ib].qs + j, &qf_lo, &qf_hi); - - const float32x4_t val_lo = mxfp6_dequant_neon(v_lo, - v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); - const float32x4_t val_hi = mxfp6_dequant_neon(v_hi, - v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); - - acc0 = vfmaq_f32(acc0, vmulq_f32(val_lo, v_scale), qf_lo); - acc1 = vfmaq_f32(acc1, vmulq_f32(val_hi, v_scale), qf_hi); - } - } - - *s = vaddvq_f32(vaddq_f32(acc0, acc1)); -} - // MXFP SoA dequant (flash attention) static void dequantize_row_mxfp8_soa_neon( @@ -4424,26 +4330,6 @@ static void dequantize_row_mxfp4_soa_neon( // Public dispatch functions -void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); -#if defined(__ARM_NEON) - ggml_vec_dot_mxfp8_q8_0_neon(n, s, vx, vy, &MXFP_TRAITS_E4M3); -#else - ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif -} - -void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); -#if defined(__ARM_NEON) - ggml_vec_dot_mxfp6_q8_0_neon(n, s, vx, vy, &MXFP_TRAITS_E2M3); -#else - ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif -} - void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { #if defined(__ARM_NEON) dequantize_row_mxfp4_soa_neon(x, y, k); diff --git a/ggml/src/ggml-cpu/arch/loongarch/quants.c b/ggml/src/ggml-cpu/arch/loongarch/quants.c index fa05e49c5d..f531e916b9 100644 --- a/ggml/src/ggml-cpu/arch/loongarch/quants.c +++ b/ggml/src/ggml-cpu/arch/loongarch/quants.c @@ -2157,14 +2157,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} - -void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} - -void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} diff --git a/ggml/src/ggml-cpu/arch/powerpc/quants.c b/ggml/src/ggml-cpu/arch/powerpc/quants.c index efb669da09..d3dfd049ea 100644 --- a/ggml/src/ggml-cpu/arch/powerpc/quants.c +++ b/ggml/src/ggml-cpu/arch/powerpc/quants.c @@ -2303,10 +2303,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} - -void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index 10e0ff04d8..d7e9ba4634 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -3621,11 +3621,3 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - -void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} - -void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} diff --git a/ggml/src/ggml-cpu/arch/s390/quants.c b/ggml/src/ggml-cpu/arch/s390/quants.c index e696fd4570..34184ed851 100644 --- a/ggml/src/ggml-cpu/arch/s390/quants.c +++ b/ggml/src/ggml-cpu/arch/s390/quants.c @@ -1464,10 +1464,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} - -void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} diff --git a/ggml/src/ggml-cpu/arch/wasm/quants.c b/ggml/src/ggml-cpu/arch/wasm/quants.c index a3ae8e8885..74a359e6d1 100644 --- a/ggml/src/ggml-cpu/arch/wasm/quants.c +++ b/ggml/src/ggml-cpu/arch/wasm/quants.c @@ -1219,14 +1219,3 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } -void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} - -void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} - -void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -} diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index 4b8f3386fa..775a7c742f 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -3850,106 +3850,14 @@ static inline __m256 mxfp_dequant_avx2( return _mm256_blendv_ps(normal, sub_val, is_sub); } -// Unpack 4 tightly-packed 6-bit values from 3 bytes into separate bytes. -static inline void unpack_fp6x4_avx2(const uint8_t * p, uint8_t out[4]) { - const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); - out[0] = (pk >> 0) & 0x3F; - out[1] = (pk >> 6) & 0x3F; - out[2] = (pk >> 12) & 0x3F; - out[3] = (pk >> 18) & 0x3F; -} - // Unpack 8 FP6 values (two groups of 4) from packed qs data at offset j. static inline __m256i unpack_fp6x8_avx2(const uint8_t * qs, int j) { uint8_t unpacked[8]; - unpack_fp6x4_avx2(qs + (j * 3 / 4), unpacked); - unpack_fp6x4_avx2(qs + ((j + 4) * 3 / 4), unpacked + 4); + ggml_mxfp_unpack_fp6x4(qs + (j * 3 / 4), unpacked); + ggml_mxfp_unpack_fp6x4(qs + ((j + 4) * 3 / 4), unpacked + 4); return _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)unpacked)); } -// MXFP FP8/FP6 vec_dot - -// FP8 x Q8_0 dot product (E4M3/E5M2). -static void ggml_vec_dot_mxfp8_q8_0_avx2( - int n, float * GGML_RESTRICT s, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - const mxfp_avx2_traits_t * t) { - assert(n % QK_MXFP8 == 0); - const int nb = n / QK_MXFP8; - const block_mxfp8 * GGML_RESTRICT x = vx; - const block_q8_0 * GGML_RESTRICT y = vy; - - const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask); - const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask); - const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off); - const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale); - const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask); - const __m256i v_zero = _mm256_setzero_si256(); - - __m256 acc = _mm256_setzero_ps(); - - for (int ib = 0; ib < nb; ++ib) { - const __m256 v_scale = _mm256_set1_ps( - GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d)); - - for (int j = 0; j < 32; j += 8) { - const __m256i v_raw = _mm256_cvtepu8_epi32( - _mm_loadl_epi64((const __m128i *)(x[ib].qs + j))); - const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32( - _mm_loadl_epi64((const __m128i *)(y[ib].qs + j)))); - - const __m256 val = mxfp_dequant_avx2(v_raw, - v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, - v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift); - - acc = _mm256_fmadd_ps(_mm256_mul_ps(val, v_scale), qf, acc); - } - } - - *s = hsum_float_8(acc); -} - -// FP6 x Q8_0 dot product (E2M3/E3M2). -static void ggml_vec_dot_mxfp6_q8_0_avx2( - int n, float * GGML_RESTRICT s, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - const mxfp_avx2_traits_t * t) { - assert(n % QK_MXFP6 == 0); - const int nb = n / QK_MXFP6; - const block_q8_0 * GGML_RESTRICT y = vy; - - const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask); - const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask); - const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off); - const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale); - const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask); - const __m256i v_zero = _mm256_setzero_si256(); - - __m256 acc = _mm256_setzero_ps(); - - for (int ib = 0; ib < nb; ++ib) { - const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib; - const __m256 v_scale = _mm256_set1_ps( - GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d)); - - for (int j = 0; j < 32; j += 8) { - const __m256i v_raw = unpack_fp6x8_avx2(xb->qs, j); - const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32( - _mm_loadl_epi64((const __m128i *)(y[ib].qs + j)))); - - const __m256 val = mxfp_dequant_avx2(v_raw, - v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, - v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift); - - acc = _mm256_fmadd_ps(_mm256_mul_ps(val, v_scale), qf, acc); - } - } - - *s = hsum_float_8(acc); -} - // MXFP SoA dequant (flash attention) static void dequantize_row_mxfp8_soa_avx2( @@ -4052,26 +3960,6 @@ static void dequantize_row_mxfp4_soa_avx2( // Public dispatch functions -void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); -#if defined(__AVX2__) - ggml_vec_dot_mxfp8_q8_0_avx2(n, s, vx, vy, &MXFP_TRAITS_E4M3); -#else - ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif -} - -void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); -#if defined(__AVX2__) - ggml_vec_dot_mxfp6_q8_0_avx2(n, s, vx, vy, &MXFP_TRAITS_E2M3); -#else - ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif -} - void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { #if defined(__AVX2__) dequantize_row_mxfp4_soa_avx2(x, y, k); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index b84b6e0031..782a54392f 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -265,7 +265,7 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_1, .nrows = 1, }, - [GGML_TYPE_MXFP4_E2M1] = { + [GGML_TYPE_MXFP4] = { .from_float = quantize_row_mxfp4, .from_float_soa = quantize_row_mxfp4_soa, .to_float_soa = dequantize_row_mxfp4_soa_cpu, @@ -279,20 +279,14 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, }, - [GGML_TYPE_MXFP8_E4M3] = { - .from_float = quantize_row_mxfp8, + [GGML_TYPE_MXFP8] = { .from_float_soa = quantize_row_mxfp8_soa, .to_float_soa = dequantize_row_mxfp8_soa_cpu, - .vec_dot = ggml_vec_dot_mxfp8_q8_0, - .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, }, - [GGML_TYPE_MXFP6_E2M3] = { - .from_float = quantize_row_mxfp6, + [GGML_TYPE_MXFP6] = { .from_float_soa = quantize_row_mxfp6_soa, .to_float_soa = dequantize_row_mxfp6_soa_cpu, - .vec_dot = ggml_vec_dot_mxfp6_q8_0, - .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, }, [GGML_TYPE_Q2_K] = { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 9291af62dc..15424d40c4 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -672,10 +672,10 @@ void ggml_compute_forward_add( case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: - case GGML_TYPE_MXFP4_E2M1: + case GGML_TYPE_MXFP4: case GGML_TYPE_NVFP4: - case GGML_TYPE_MXFP8_E4M3: - case GGML_TYPE_MXFP6_E2M3: + case GGML_TYPE_MXFP8: + case GGML_TYPE_MXFP6: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -1124,10 +1124,10 @@ void ggml_compute_forward_add1( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: - case GGML_TYPE_MXFP4_E2M1: + case GGML_TYPE_MXFP4: case GGML_TYPE_NVFP4: - case GGML_TYPE_MXFP8_E4M3: - case GGML_TYPE_MXFP6_E2M3: + case GGML_TYPE_MXFP8: + case GGML_TYPE_MXFP6: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -1255,10 +1255,10 @@ void ggml_compute_forward_acc( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: - case GGML_TYPE_MXFP4_E2M1: + case GGML_TYPE_MXFP4: case GGML_TYPE_NVFP4: - case GGML_TYPE_MXFP8_E4M3: - case GGML_TYPE_MXFP6_E2M3: + case GGML_TYPE_MXFP8: + case GGML_TYPE_MXFP6: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4345,10 +4345,10 @@ void ggml_compute_forward_out_prod( case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: - case GGML_TYPE_MXFP4_E2M1: + case GGML_TYPE_MXFP4: case GGML_TYPE_NVFP4: - case GGML_TYPE_MXFP8_E4M3: - case GGML_TYPE_MXFP6_E2M3: + case GGML_TYPE_MXFP8: + case GGML_TYPE_MXFP6: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4623,10 +4623,10 @@ void ggml_compute_forward_set( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: - case GGML_TYPE_MXFP4_E2M1: + case GGML_TYPE_MXFP4: case GGML_TYPE_NVFP4: - case GGML_TYPE_MXFP8_E4M3: - case GGML_TYPE_MXFP6_E2M3: + case GGML_TYPE_MXFP8: + case GGML_TYPE_MXFP6: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4848,10 +4848,10 @@ void ggml_compute_forward_get_rows( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: - case GGML_TYPE_MXFP4_E2M1: + case GGML_TYPE_MXFP4: case GGML_TYPE_NVFP4: - case GGML_TYPE_MXFP8_E4M3: - case GGML_TYPE_MXFP6_E2M3: + case GGML_TYPE_MXFP8: + case GGML_TYPE_MXFP6: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -5686,10 +5686,10 @@ void ggml_compute_forward_clamp( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: - case GGML_TYPE_MXFP4_E2M1: + case GGML_TYPE_MXFP4: case GGML_TYPE_NVFP4: - case GGML_TYPE_MXFP8_E4M3: - case GGML_TYPE_MXFP6_E2M3: + case GGML_TYPE_MXFP8: + case GGML_TYPE_MXFP6: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -8255,31 +8255,65 @@ void ggml_compute_forward_top_k( } } +// Max head dimension for stack-allocated MXFP buffers. +static constexpr int64_t MXFP_FA_MAX_D = 1024; +// SoA buffer size for MXFP_FA_MAX_D with MXFP8 (worst case: 1024 + 32 e8m0 = 1056, rounded up). +static constexpr int MXFP_FA_SOA_BUF = 1088; + // SoA function pointer types for MXFP flash attention paths. typedef void (*mxfp_soa_quantize_fn)(const float *, void *, int64_t); typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t); +// Per-KV-type MXFP parameters (shared between K and V). +struct mxfp_kv_params { + mxfp_soa_dequantize_fn dequantize; + bool multihead; + int64_t soa_elems; + int qs_per_block; + int head_qs_bytes; + int64_t head_e8m0_offset; + int blocks_per_head; +}; + // MXFP dispatch parameters for flash attention. struct mxfp_fa_params { - mxfp_soa_quantize_fn q_quantize; - mxfp_soa_dequantize_fn k_dequantize; - mxfp_soa_dequantize_fn v_dequantize; - bool k_multihead; - bool v_multihead; - int64_t k_soa_elems; - int64_t v_soa_elems; - bool apply_hadamard; - // Per-head SoA addressing (avoids dequanting all heads in multihead mode). - int k_qs_per_block; - int v_qs_per_block; - int k_head_qs_bytes; - int v_head_qs_bytes; - int64_t k_head_e8m0_offset; - int64_t v_head_e8m0_offset; - int k_blocks_per_head; - int v_blocks_per_head; + mxfp_soa_quantize_fn q_quantize; + mxfp_kv_params k; + mxfp_kv_params v; + bool apply_hadamard; }; +// Extract one head's SoA data from a multihead row and dequantize. +static inline void mxfp_dequant_head( + const mxfp_kv_params & kv, const char * row, int head_idx, + char * soa_buf, float * out, int64_t D) { + if (kv.multihead) { + const int qs_off = head_idx * kv.head_qs_bytes; + const int e8m0_off = (int)kv.head_e8m0_offset + head_idx * kv.blocks_per_head; + memcpy(soa_buf, row + qs_off, kv.head_qs_bytes); + memcpy(soa_buf + kv.head_qs_bytes, row + e8m0_off, kv.blocks_per_head); + kv.dequantize(soa_buf, out, D); + } else { + kv.dequantize(row, out, D); + } +} + +// Initialize per-KV-type params from tensor metadata. +// Multihead detection: nb2 == row_size(D) means heads are contiguous within +// one KV-position stride, so SoA spans all heads. Otherwise SoA is per-head. +static mxfp_kv_params mxfp_kv_params_init(ggml_type type, int64_t D, size_t nb2, int64_t ne2) { + mxfp_kv_params kv = {}; + kv.dequantize = ggml_get_type_traits_cpu(type)->to_float_soa; + kv.multihead = (nb2 == (size_t)ggml_row_size(type, D)); + kv.soa_elems = kv.multihead ? ne2 * D : D; + kv.qs_per_block = ggml_mxfp_qs_per_block(type); + kv.blocks_per_head = (int)(D / 32); + kv.head_qs_bytes = kv.blocks_per_head * kv.qs_per_block; + const int64_t total_blocks = kv.multihead ? ne2 * kv.blocks_per_head : kv.blocks_per_head; + kv.head_e8m0_offset = total_blocks * kv.qs_per_block; + return kv; +} + static mxfp_fa_params mxfp_fa_params_init( const ggml_tensor * k, const ggml_tensor * v, int64_t DK, int64_t DV, @@ -8291,44 +8325,17 @@ static mxfp_fa_params mxfp_fa_params_init( const bool is_mxfp_v = ggml_is_type_mxfp(v->type); if (is_mxfp_k) { - const struct ggml_type_traits_cpu * k_traits = ggml_get_type_traits_cpu(k->type); - p.q_quantize = k_traits->from_float_soa; - p.k_dequantize = k_traits->to_float_soa; + p.q_quantize = ggml_get_type_traits_cpu(k->type)->from_float_soa; + p.k = mxfp_kv_params_init(k->type, DK, nbk2, nek2); } - if (is_mxfp_v) { - p.v_dequantize = ggml_get_type_traits_cpu(v->type)->to_float_soa; + p.v = mxfp_kv_params_init(v->type, DV, nbv2, nev2); } // Hadamard rotation must match K rotation. - // Skipped for: MLA (DK != DV, V is a view of K). + // Skipped for MLA (DK != DV, V is a view of K). p.apply_hadamard = is_mxfp_k && (DK == DV) && ggml_mxfp_use_hadamard(k->type); - // SoA layout detection: in the real KV cache, heads are contiguous within - // one KV-position stride (nb[2] == row_size(DK)), so SoA spans all heads. - // In test tensors, heads may be at distant offsets (nb[2] >> row_size(DK)), - // so SoA is per-head. Detect which case and set dequant parameters accordingly. - p.k_multihead = is_mxfp_k && (nbk2 == (size_t)ggml_row_size(k->type, DK)); - p.k_soa_elems = is_mxfp_k ? (p.k_multihead ? nek2 * DK : DK) : 0; - p.v_multihead = is_mxfp_v && (nbv2 == (size_t)ggml_row_size(v->type, DV)); - p.v_soa_elems = is_mxfp_v ? (p.v_multihead ? nev2 * DV : DV) : 0; - - if (is_mxfp_k) { - p.k_qs_per_block = ggml_mxfp_qs_per_block(k->type); - p.k_blocks_per_head = (int)(DK / 32); - p.k_head_qs_bytes = p.k_blocks_per_head * p.k_qs_per_block; - const int64_t k_total_blocks = p.k_multihead ? nek2 * p.k_blocks_per_head : p.k_blocks_per_head; - p.k_head_e8m0_offset = k_total_blocks * p.k_qs_per_block; - } - - if (is_mxfp_v) { - p.v_qs_per_block = ggml_mxfp_qs_per_block(v->type); - p.v_blocks_per_head = (int)(DV / 32); - p.v_head_qs_bytes = p.v_blocks_per_head * p.v_qs_per_block; - const int64_t v_total_blocks = p.v_multihead ? nev2 * p.v_blocks_per_head : p.v_blocks_per_head; - p.v_head_e8m0_offset = v_total_blocks * p.v_qs_per_block; - } - return p; } @@ -8430,14 +8437,14 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( int ith = params->ith; - if (is_mxfp_k) { GGML_ASSERT(DK <= 1024); } - if (is_mxfp_v) { GGML_ASSERT(DV <= 1024); } + if (is_mxfp_k) { GGML_ASSERT(DK <= MXFP_FA_MAX_D); } + if (is_mxfp_v) { GGML_ASSERT(DV <= MXFP_FA_MAX_D); } - float k_dequant_buf[1024]; - float v_dequant_buf[1024]; + float k_dequant_buf[MXFP_FA_MAX_D]; + float v_dequant_buf[MXFP_FA_MAX_D]; - char k_head_soa[1088]; // max: DK=1024 MXFP8 -> 1056 bytes, rounded up - char v_head_soa[1088]; + char k_head_soa[MXFP_FA_SOA_BUF]; // max: DK=1024 MXFP8 -> 1056 bytes, rounded up + char v_head_soa[MXFP_FA_SOA_BUF]; float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); float * V32 = (VKQ32 + 1*DV); @@ -8479,31 +8486,25 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( const char * k_base = (const char *) k->data + k_base_offset; const char * v_base = (const char *) v->data + v_base_offset; - // Per-head SoA byte offsets - const int k_head_qs_start = mxfp.k_multihead ? ik2 * mxfp.k_head_qs_bytes : 0; - const int k_head_e8m0_start = mxfp.k_multihead ? (int)mxfp.k_head_e8m0_offset + ik2 * mxfp.k_blocks_per_head : 0; - const int v_head_qs_start = mxfp.v_multihead ? iv2 * mxfp.v_head_qs_bytes : 0; - const int v_head_e8m0_start = mxfp.v_multihead ? (int)mxfp.v_head_e8m0_offset + iv2 * mxfp.v_blocks_per_head : 0; - - const char * k_row_base = mxfp.k_multihead ? ((const char *) k->data + ik3*nbk3) : nullptr; - const char * v_row_base = mxfp.v_multihead ? ((const char *) v->data + iv3*nbv3) : nullptr; + const char * k_row_base = mxfp.k.multihead ? ((const char *) k->data + ik3*nbk3) : nullptr; + const char * v_row_base = mxfp.v.multihead ? ((const char *) v->data + iv3*nbv3) : nullptr; const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); - float Q_f32[1024]; + float Q_f32[MXFP_FA_MAX_D]; if (is_mxfp_k) { // Q preprocessing: Hadamard + SoA round-trip captures same quantization loss as K. if (mxfp.apply_hadamard) { - float q_tmp[1024]; + float q_tmp[MXFP_FA_MAX_D]; memcpy(q_tmp, pq, DK * sizeof(float)); ggml_apply_hadamard_blocks(q_tmp, DK); mxfp.q_quantize(q_tmp, Q_q, DK); } else { mxfp.q_quantize(pq, Q_q, DK); } - mxfp.k_dequantize(Q_q, Q_f32, DK); + mxfp.k.dequantize(Q_q, Q_f32, DK); } else { if (mxfp.apply_hadamard) { - float q_tmp[1024]; + float q_tmp[MXFP_FA_MAX_D]; memcpy(q_tmp, pq, DK * sizeof(float)); ggml_apply_hadamard_blocks(q_tmp, DK); q_to_vec_dot(q_tmp, Q_q, DK); @@ -8525,15 +8526,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( float s; // KQ value if (is_mxfp_k) { - if (mxfp.k_multihead) { - // Extract this head's SoA blocks - const char * row = k_row_base + ic*nbk1; - memcpy(k_head_soa, row + k_head_qs_start, mxfp.k_head_qs_bytes); - memcpy(k_head_soa + mxfp.k_head_qs_bytes, row + k_head_e8m0_start, mxfp.k_blocks_per_head); - mxfp.k_dequantize(k_head_soa, k_dequant_buf, DK); - } else { - mxfp.k_dequantize(k_base + ic*nbk1, k_dequant_buf, DK); - } + const char * k_row = mxfp.k.multihead ? k_row_base + ic*nbk1 : k_base + ic*nbk1; + mxfp_dequant_head(mxfp.k, k_row, ik2, k_head_soa, k_dequant_buf, DK); ggml_vec_dot_f32(DK, &s, 0, k_dequant_buf, 0, Q_f32, 0, 1); } else { kq_vec_dot(DK, &s, 0, k_base + ic*nbk1, 0, Q_q, 0, 1); @@ -8577,15 +8571,9 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( } // V += v*expf(s - M) - if (mxfp.v_dequantize) { - if (mxfp.v_multihead) { - const char * row = v_row_base + ic*nbv1; - memcpy(v_head_soa, row + v_head_qs_start, mxfp.v_head_qs_bytes); - memcpy(v_head_soa + mxfp.v_head_qs_bytes, row + v_head_e8m0_start, mxfp.v_blocks_per_head); - mxfp.v_dequantize(v_head_soa, v_dequant_buf, DV); - } else { - mxfp.v_dequantize(v_base + ic*nbv1, v_dequant_buf, DV); - } + if (mxfp.v.dequantize) { + const char * v_row = mxfp.v.multihead ? v_row_base + ic*nbv1 : v_base + ic*nbv1; + mxfp_dequant_head(mxfp.v, v_row, iv2, v_head_soa, v_dequant_buf, DV); ggml_vec_mad_f32(DV, VKQ32, v_dequant_buf, vs); } else if (v_to_float) { v_to_float(v_base + ic*nbv1, V32, DV); @@ -8731,14 +8719,14 @@ static void ggml_compute_forward_flash_attn_ext_tiled( static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q; static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV; - if (is_mxfp_k) { GGML_ASSERT(DK <= 1024); } - if (is_mxfp_v) { GGML_ASSERT(DV <= 1024); } + if (is_mxfp_k) { GGML_ASSERT(DK <= MXFP_FA_MAX_D); } + if (is_mxfp_v) { GGML_ASSERT(DV <= MXFP_FA_MAX_D); } - float k_dequant_buf[1024]; - float v_dequant_buf[1024]; + float k_dequant_buf[MXFP_FA_MAX_D]; + float v_dequant_buf[MXFP_FA_MAX_D]; - char k_head_soa[1088]; - char v_head_soa[1088]; + char k_head_soa[MXFP_FA_SOA_BUF]; + char v_head_soa[MXFP_FA_SOA_BUF]; int ir = ir0; while (ir < ir1) { @@ -8802,9 +8790,9 @@ static void ggml_compute_forward_flash_attn_ext_tiled( if (mxfp.apply_hadamard) { ggml_apply_hadamard_blocks(Q_f32 + tq * DK, DK); } - uint8_t q_mxfp_buf[1088]; // max: DK=1024 MXFP8 -> 1056 bytes + uint8_t q_mxfp_buf[MXFP_FA_SOA_BUF]; mxfp.q_quantize(Q_f32 + tq * DK, q_mxfp_buf, DK); - mxfp.k_dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK); + mxfp.k.dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK); } } for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) { @@ -8854,23 +8842,13 @@ static void ggml_compute_forward_flash_attn_ext_tiled( for (int64_t dk = 0; dk < DK; dk++) { K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk]; } - } else if (mxfp.k_dequantize) { - if (mxfp.k_multihead) { - // Per-head extraction: copy only this head's SoA blocks + e8m0, dequant DK elements. - const char * row = (const char *)k->data + (ic + tk)*nbk1 + ik3*nbk3; - const int kqs = ik2 * mxfp.k_head_qs_bytes; - const int ke8 = (int)mxfp.k_head_e8m0_offset + ik2 * mxfp.k_blocks_per_head; - memcpy(k_head_soa, row + kqs, mxfp.k_head_qs_bytes); - memcpy(k_head_soa + mxfp.k_head_qs_bytes, row + ke8, mxfp.k_blocks_per_head); - mxfp.k_dequantize(k_head_soa, k_dequant_buf, DK); - } else { - mxfp.k_dequantize(k_data, k_dequant_buf, DK); - } + } else if (mxfp.k.dequantize) { + mxfp_dequant_head(mxfp.k, k_data, ik2, k_head_soa, k_dequant_buf, DK); for (int64_t dk = 0; dk < DK; dk++) { K_f32[dk * KV_TILE_SZ + tk] = k_dequant_buf[dk]; } } else { - float k_tmp[1024]; + float k_tmp[MXFP_FA_MAX_D]; k_to_float(k_data, k_tmp, DK); for (int64_t dk = 0; dk < DK; dk++) { K_f32[dk * KV_TILE_SZ + tk] = k_tmp[dk]; @@ -8934,18 +8912,8 @@ static void ggml_compute_forward_flash_attn_ext_tiled( ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV); } else if (v_type == GGML_TYPE_F32) { memcpy(V32 + tk * DV, v_data, DV * sizeof(float)); - } else if (mxfp.v_dequantize) { - if (mxfp.v_multihead) { - // Per-head extraction: copy only this head's SoA blocks + e8m0, dequant DV elements. - const char * row = (const char *)v->data + (ic + tk)*nbv1 + iv3*nbv3; - const int vqs = iv2 * mxfp.v_head_qs_bytes; - const int ve8 = (int)mxfp.v_head_e8m0_offset + iv2 * mxfp.v_blocks_per_head; - memcpy(v_head_soa, row + vqs, mxfp.v_head_qs_bytes); - memcpy(v_head_soa + mxfp.v_head_qs_bytes, row + ve8, mxfp.v_blocks_per_head); - mxfp.v_dequantize(v_head_soa, v_dequant_buf, DV); - } else { - mxfp.v_dequantize(v_data, v_dequant_buf, DV); - } + } else if (mxfp.v.dequantize) { + mxfp_dequant_head(mxfp.v, v_data, iv2, v_head_soa, v_dequant_buf, DV); memcpy(V32 + tk * DV, v_dequant_buf, DV * sizeof(float)); } else { v_to_float(v_data, V32 + tk * DV, DV); diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index eed3be90fc..5cbd177234 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -54,14 +54,6 @@ void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, i quantize_row_nvfp4_ref(x, y, k); } -void quantize_row_mxfp8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { - quantize_row_mxfp8_ref(x, y, k); -} - -void quantize_row_mxfp6(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { - quantize_row_mxfp6_ref(x, y, k); -} - // // 2-6 bit quantization in super-blocks // @@ -264,51 +256,6 @@ void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, *s = sumf; } -// Generic MXFP x Q8_0 dot product (scalar, not SIMD-optimized) -static void ggml_vec_dot_mxfp_q8_0_impl( - int n, float * GGML_RESTRICT s, - const void * GGML_RESTRICT vx, size_t block_size, - const void * GGML_RESTRICT vy, - ggml_to_float_t dequant) { - assert(n % QK8_0 == 0); - const int nb = n / QK8_0; - const block_q8_0 * GGML_RESTRICT y = vy; - float sumf = 0; - - for (int ib = 0; ib < nb; ib++) { - float tmp[QK8_0]; - dequant((const char *)vx + ib * block_size, tmp, QK8_0); - - const float y_d = GGML_CPU_FP16_TO_FP32(y[ib].d); - float block_sum = 0; - for (int j = 0; j < QK8_0; j++) { - block_sum += tmp[j] * (float)y[ib].qs[j]; - } - sumf += block_sum * y_d; - } - *s = sumf; -} - -void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - ggml_vec_dot_mxfp_q8_0_impl(n, s, vx, sizeof(block_mxfp8), vy, - (ggml_to_float_t)dequantize_row_mxfp8); -} - -void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - ggml_vec_dot_mxfp_q8_0_impl(n, s, vx, sizeof(block_mxfp6), vy, - (ggml_to_float_t)dequantize_row_mxfp6); -} - // Generic SoA dequant wrappers — arch-specific SIMD versions override via fallback.h. void dequantize_row_mxfp4_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { dequantize_row_mxfp4_soa(x, y, k); diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index 78c9984bdc..4a4dd264fe 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -21,9 +21,6 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_mxfp8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_mxfp6(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); - void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -46,9 +43,6 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); - void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -80,9 +74,6 @@ void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); - // SoA dequant (SIMD-dispatched, CPU backend) void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index e9a101ff66..f18758f16b 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -3770,7 +3770,7 @@ static block_mxfp4x4 make_block_mxfp4x4(block_mxfp4 * in, unsigned int blck_size } static int repack_mxfp4_to_mxfp4_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_MXFP4_E2M1); + GGML_ASSERT(t->type == GGML_TYPE_MXFP4); GGML_ASSERT(interleave_block == 4); const block_mxfp4 * src = (const block_mxfp4 *)data; @@ -3827,7 +3827,7 @@ static block_mxfp4x8 make_block_mxfp4x8(block_mxfp4 * in, unsigned int blck_size } static int repack_mxfp4_to_mxfp4_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_MXFP4_E2M1); + GGML_ASSERT(t->type == GGML_TYPE_MXFP4); GGML_ASSERT(interleave_block == 8); const block_mxfp4 * src = (const block_mxfp4 *)data; @@ -4685,7 +4685,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons } #endif } - } else if (cur->type == GGML_TYPE_MXFP4_E2M1) { + } else if (cur->type == GGML_TYPE_MXFP4) { if (ggml_cpu_has_avx2()) { if (cur->ne[1] % 8 == 0) { return &mxfp4_8x8_q8_0; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index a8996a2ab5..9388b1be46 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1014,7 +1014,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te // MXFP4: has AoS shaders (MUL_MAT, GET_ROWS) but no SoA/flash attention support yet. for (size_t i = 0, n = 3; i < n; ++i) { if (op->src[i] != NULL && ggml_is_type_mxfp(op->src[i]->type)) { - if (op->src[i]->type != GGML_TYPE_MXFP4_E2M1) { + if (op->src[i]->type != GGML_TYPE_MXFP4) { return false; } if (op->op == GGML_OP_FLASH_ATTN_EXT || op->op == GGML_OP_SET_ROWS) { diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index f2fcd73dba..e1dca6b4b4 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -3760,7 +3760,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te } else if (op->src[0]->type == GGML_TYPE_F32) { return op->src[1]->type == GGML_TYPE_F32; } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || - op->src[0]->type == GGML_TYPE_MXFP4_E2M1 || + op->src[0]->type == GGML_TYPE_MXFP4 || op->src[0]->type == GGML_TYPE_Q4_K || op->src[0]->type == GGML_TYPE_Q6_K) { return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); @@ -3771,7 +3771,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_MUL_MAT_ID: if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0 || - op->src[0]->type == GGML_TYPE_MXFP4_E2M1) { + op->src[0]->type == GGML_TYPE_MXFP4) { if (op->src[1]->type == GGML_TYPE_F32) { return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); } @@ -4559,7 +4559,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, #endif // GGML_OPENCL_USE_ADRENO_KERNELS return; } - if (tensor->type == GGML_TYPE_MXFP4_E2M1) { + if (tensor->type == GGML_TYPE_MXFP4) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); @@ -5136,7 +5136,7 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } - if (tensor->type == GGML_TYPE_MXFP4_E2M1) { + if (tensor->type == GGML_TYPE_MXFP4) { ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra; cl_int err; @@ -5585,7 +5585,7 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL)); CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_d, buf_d, 0, NULL, NULL)); CL_CHECK(clFinish(queue)); - } else if (tensor->type == GGML_TYPE_MXFP4_E2M1) { + } else if (tensor->type == GGML_TYPE_MXFP4) { ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *) tensor->extra; GGML_ASSERT(extra); @@ -10550,7 +10550,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); #endif // GGML_OPENCL_SOA_Q break; - case GGML_TYPE_MXFP4_E2M1: { + case GGML_TYPE_MXFP4: { #ifdef GGML_OPENCL_SOA_Q kernel = backend_ctx->kernel_mul_mv_mxfp4_f32_flat; @@ -10630,7 +10630,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co GGML_ASSERT(false && "not implemented"); } - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4_E2M1 || + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K) { @@ -10864,7 +10864,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, #endif // GGML_OPENCL_SOA_Q break; } - case GGML_TYPE_MXFP4_E2M1: { + case GGML_TYPE_MXFP4: { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (use_adreno_moe_kernels(backend_ctx, src0)) { cl_int status; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index cca3d99c82..5c8eb97806 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -267,10 +267,12 @@ uint8_t float_to_fp8_e4m3_rn(float x) { return ggml_mxfp_float_to_fp8_e4m3(x); } // MSE-optimal E8M0: tests candidates around round(log2(amax)), picks lowest quantization error. typedef struct { - int emax_offset; // type-specific offset to max representable exponent + int emax_offset; // type-specific offset to max representable exponent + int qs_per_block; // quantized scalar bytes per 32-element block + int bits_per_elem; // 8 = byte-aligned, 6 = packed via fp6x4 uint8_t (*to_elem)(float); float (*to_float)(uint8_t); - float (*mse_error)(float val, float inv_scale, float scale); + float (*mse_error)(float val, float inv_scale, float scale); // NULL = use generic round-trip via to_elem/to_float } mxfp_elem_traits_t; static inline int best_index_mxfp4(float x, float e); @@ -294,7 +296,7 @@ static float mse_error_mxfp4(float val, float inv_scale, float scale) { return err * err; } -static const mxfp_elem_traits_t mxfp4_traits = { MXFP4_E2M1_EMAX_OFFSET, NULL, NULL, mse_error_mxfp4 }; +static const mxfp_elem_traits_t mxfp4_traits = { MXFP4_E2M1_EMAX_OFFSET, MXFP4_SOA_QS_PER_BLOCK, 4, NULL, NULL, mse_error_mxfp4 }; static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_elem_traits_t * traits) { float amax = 0.0f; @@ -319,7 +321,13 @@ static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_ const float test_inv = 1.0f / test_scale; float mse = 0.0f; for (int j = 0; j < qk; ++j) { - mse += traits->mse_error(x[j], test_inv, test_scale); + if (traits->mse_error) { + mse += traits->mse_error(x[j], test_inv, test_scale); + } else { + const float recon = traits->to_float(traits->to_elem(x[j] * test_inv)) * test_scale; + const float err = x[j] - recon; + mse += err * err; + } } if (mse < best_mse) { best_mse = mse; @@ -574,102 +582,8 @@ uint8_t float_to_fp8_e5m2_rn(float x) { return ggml_mxfp_float_to_fp8_e5m2(x); } void pack_fp6x4(const uint8_t v[4], uint8_t out[3]) { ggml_mxfp_pack_fp6x4(v, out); } void unpack_fp6x4(const uint8_t in[3], uint8_t v[4]) { ggml_mxfp_unpack_fp6x4(in, v); } -static float mse_error_fp8_e4m3(float val, float inv_scale, float scale) { - const float recon = fp8_e4m3_to_float(float_to_fp8_e4m3_rn(val * inv_scale)) * scale; - const float err = val - recon; - return err * err; -} -static float mse_error_fp6_e2m3(float val, float inv_scale, float scale) { - const float recon = fp6_e2m3_to_float(float_to_fp6_e2m3_rn(val * inv_scale)) * scale; - const float err = val - recon; - return err * err; -} -static const mxfp_elem_traits_t mxfp8_e4m3_traits = { MXFP8_E4M3_EMAX_OFFSET, float_to_fp8_e4m3_rn, fp8_e4m3_to_float, mse_error_fp8_e4m3 }; -static const mxfp_elem_traits_t mxfp6_e2m3_traits = { MXFP6_E2M3_EMAX_OFFSET, float_to_fp6_e2m3_rn, fp6_e2m3_to_float, mse_error_fp6_e2m3 }; - -static void quantize_row_mxfp8_impl(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, - int64_t k, const mxfp_elem_traits_t * traits) { - assert(k % QK_MXFP8 == 0); - const int nb = k / QK_MXFP8; - - for (int i = 0; i < nb; i++) { - const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP8], QK_MXFP8, traits); - const float d = GGML_E8M0_TO_FP32(e); - const float inv_d = d > 0.0f ? 1.0f / d : 0.0f; - y[i].e = e; - - for (int j = 0; j < QK_MXFP8; ++j) { - y[i].qs[j] = traits->to_elem(x[i*QK_MXFP8 + j] * inv_d); - } - } -} - -static void dequantize_row_mxfp8_impl(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, - int64_t k, const mxfp_elem_traits_t * traits) { - assert(k % QK_MXFP8 == 0); - const int nb = k / QK_MXFP8; - - for (int i = 0; i < nb; i++) { - const float d = GGML_E8M0_TO_FP32(x[i].e); - for (int j = 0; j < QK_MXFP8; ++j) { - y[i*QK_MXFP8 + j] = traits->to_float(x[i].qs[j]) * d; - } - } -} - -static void quantize_row_mxfp6_impl(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, - int64_t k, const mxfp_elem_traits_t * traits) { - assert(k % QK_MXFP6 == 0); - const int nb = k / QK_MXFP6; - - for (int i = 0; i < nb; i++) { - const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP6], QK_MXFP6, traits); - const float d = GGML_E8M0_TO_FP32(e); - const float inv_d = d > 0.0f ? 1.0f / d : 0.0f; - y[i].e = e; - - for (int j = 0; j < QK_MXFP6; j += 4) { - uint8_t vals[4]; - for (int jj = 0; jj < 4; jj++) { - vals[jj] = traits->to_elem(x[i*QK_MXFP6 + j + jj] * inv_d); - } - pack_fp6x4(vals, &y[i].qs[j * 3 / 4]); - } - } -} - -static void dequantize_row_mxfp6_impl(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, - int64_t k, const mxfp_elem_traits_t * traits) { - assert(k % QK_MXFP6 == 0); - const int nb = k / QK_MXFP6; - - for (int i = 0; i < nb; i++) { - const float d = GGML_E8M0_TO_FP32(x[i].e); - for (int j = 0; j < QK_MXFP6; j += 4) { - uint8_t vals[4]; - unpack_fp6x4(&x[i].qs[j * 3 / 4], vals); - for (int jj = 0; jj < 4; jj++) { - y[i*QK_MXFP6 + j + jj] = traits->to_float(vals[jj]) * d; - } - } - } -} - -void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k) { - quantize_row_mxfp8_impl(x, y, k, &mxfp8_e4m3_traits); -} - -void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { - dequantize_row_mxfp8_impl(x, y, k, &mxfp8_e4m3_traits); -} - -void quantize_row_mxfp6_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k) { - quantize_row_mxfp6_impl(x, y, k, &mxfp6_e2m3_traits); -} - -void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { - dequantize_row_mxfp6_impl(x, y, k, &mxfp6_e2m3_traits); -} +static const mxfp_elem_traits_t mxfp8_e4m3_traits = { MXFP8_E4M3_EMAX_OFFSET, MXFP8_SOA_QS_PER_BLOCK, 8, float_to_fp8_e4m3_rn, fp8_e4m3_to_float, NULL }; +static const mxfp_elem_traits_t mxfp6_e2m3_traits = { MXFP6_E2M3_EMAX_OFFSET, MXFP6_SOA_QS_PER_BLOCK, 6, float_to_fp6_e2m3_rn, fp6_e2m3_to_float, NULL }; // ====================== SoA (Struct-of-Arrays) quantize/dequantize for flash attention @@ -715,101 +629,79 @@ void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTR } } -static void quantize_row_mxfp8_soa_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, - int64_t k, const mxfp_elem_traits_t * traits) { - assert(k % QK_MXFP8 == 0); - const int nb = k / QK_MXFP8; - char * row = (char *)dst; - char * qs_base = row; - char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK); +// Unified SoA quantize for byte-aligned (FP8) and 6-bit packed (FP6) MXFP formats. +static void quantize_row_mxfp_soa_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, + int64_t k, const mxfp_elem_traits_t * traits) { + const int qk = 32; + assert(k % qk == 0); + const int nb = k / qk; + const int qpb = traits->qs_per_block; + char * qs_base = (char *)dst; + char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, qpb); for (int i = 0; i < nb; i++) { - const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP8], QK_MXFP8, traits); + const uint8_t e = mxfp_compute_e8m0_mse(&x[i*qk], qk, traits); const float d = GGML_E8M0_TO_FP32(e); const float inv_d = d > 0.0f ? 1.0f / d : 0.0f; e8m0_base[i] = (char)e; - uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP8_SOA_QS_PER_BLOCK)); - for (int j = 0; j < QK_MXFP8; ++j) { - qs[j] = traits->to_elem(x[i*QK_MXFP8 + j] * inv_d); - } - } -} - -static void dequantize_row_mxfp8_soa_impl(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, - int64_t k, const mxfp_elem_traits_t * traits) { - assert(k % QK_MXFP8 == 0); - const int nb = k / QK_MXFP8; - const char * row = (const char *)src; - const char * qs_base = row; - const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK); - - for (int i = 0; i < nb; i++) { - const float d = GGML_E8M0_TO_FP32((uint8_t)e8m0_base[i]); - const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP8_SOA_QS_PER_BLOCK)); - for (int j = 0; j < QK_MXFP8; ++j) { - y[i*QK_MXFP8 + j] = traits->to_float(qs[j]) * d; - } - } -} - -static void quantize_row_mxfp6_soa_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, - int64_t k, const mxfp_elem_traits_t * traits) { - assert(k % QK_MXFP6 == 0); - const int nb = k / QK_MXFP6; - char * row = (char *)dst; - char * qs_base = row; - char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK); - - for (int i = 0; i < nb; i++) { - const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP6], QK_MXFP6, traits); - const float d = GGML_E8M0_TO_FP32(e); - const float inv_d = d > 0.0f ? 1.0f / d : 0.0f; - e8m0_base[i] = (char)e; - - uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP6_SOA_QS_PER_BLOCK)); - for (int j = 0; j < QK_MXFP6; j += 4) { - uint8_t vals[4]; - for (int jj = 0; jj < 4; jj++) { - vals[jj] = traits->to_elem(x[i*QK_MXFP6 + j + jj] * inv_d); + uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, qpb)); + if (traits->bits_per_elem == 8) { + for (int j = 0; j < qk; ++j) { + qs[j] = traits->to_elem(x[i*qk + j] * inv_d); + } + } else { + for (int j = 0; j < qk; j += 4) { + uint8_t vals[4]; + for (int jj = 0; jj < 4; jj++) { + vals[jj] = traits->to_elem(x[i*qk + j + jj] * inv_d); + } + pack_fp6x4(vals, &qs[j * 3 / 4]); } - pack_fp6x4(vals, &qs[j * 3 / 4]); } } } -static void dequantize_row_mxfp6_soa_impl(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, - int64_t k, const mxfp_elem_traits_t * traits) { - assert(k % QK_MXFP6 == 0); - const int nb = k / QK_MXFP6; - const char * row = (const char *)src; - const char * qs_base = row; - const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK); +// Unified SoA dequantize for byte-aligned (FP8) and 6-bit packed (FP6) MXFP formats. +static void dequantize_row_mxfp_soa_impl(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, + int64_t k, const mxfp_elem_traits_t * traits) { + const int qk = 32; + assert(k % qk == 0); + const int nb = k / qk; + const int qpb = traits->qs_per_block; + const char * qs_base = (const char *)src; + const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, qpb); for (int i = 0; i < nb; i++) { const float d = GGML_E8M0_TO_FP32((uint8_t)e8m0_base[i]); - const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP6_SOA_QS_PER_BLOCK)); - for (int j = 0; j < QK_MXFP6; j += 4) { - uint8_t vals[4]; - unpack_fp6x4(&qs[j * 3 / 4], vals); - for (int jj = 0; jj < 4; jj++) { - y[i*QK_MXFP6 + j + jj] = traits->to_float(vals[jj]) * d; + const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, qpb)); + if (traits->bits_per_elem == 8) { + for (int j = 0; j < qk; ++j) { + y[i*qk + j] = traits->to_float(qs[j]) * d; + } + } else { + for (int j = 0; j < qk; j += 4) { + uint8_t vals[4]; + unpack_fp6x4(&qs[j * 3 / 4], vals); + for (int jj = 0; jj < 4; jj++) { + y[i*qk + j + jj] = traits->to_float(vals[jj]) * d; + } } } } } void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) { - quantize_row_mxfp8_soa_impl(x, dst, k, &mxfp8_e4m3_traits); + quantize_row_mxfp_soa_impl(x, dst, k, &mxfp8_e4m3_traits); } void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) { - dequantize_row_mxfp8_soa_impl(src, y, k, &mxfp8_e4m3_traits); + dequantize_row_mxfp_soa_impl(src, y, k, &mxfp8_e4m3_traits); } void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) { - quantize_row_mxfp6_soa_impl(x, dst, k, &mxfp6_e2m3_traits); + quantize_row_mxfp_soa_impl(x, dst, k, &mxfp6_e2m3_traits); } void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) { - dequantize_row_mxfp6_soa_impl(src, y, k, &mxfp6_e2m3_traits); + dequantize_row_mxfp_soa_impl(src, y, k, &mxfp6_e2m3_traits); } // // 2-6 bit quantization in super-blocks @@ -2472,7 +2364,7 @@ size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { GGML_UNUSED(quant_weights); quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row); - return nrow * ggml_row_size(GGML_TYPE_MXFP4_E2M1, n_per_row); + return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row); } size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { @@ -2481,18 +2373,6 @@ size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, return nrow * ggml_row_size(GGML_TYPE_NVFP4, n_per_row); } -size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - GGML_UNUSED(quant_weights); - quantize_row_mxfp8_ref(src, dst, (int64_t)nrow*n_per_row); - return nrow * ggml_row_size(GGML_TYPE_MXFP8_E4M3, n_per_row); -} - -size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - GGML_UNUSED(quant_weights); - quantize_row_mxfp6_ref(src, dst, (int64_t)nrow*n_per_row); - return nrow * ggml_row_size(GGML_TYPE_MXFP6_E2M3, n_per_row); -} - // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs) void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) { @@ -5635,15 +5515,15 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb); } break; - case GGML_TYPE_MXFP4_E2M1: + case GGML_TYPE_MXFP4: { VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb); } break; - case GGML_TYPE_MXFP8_E4M3: + case GGML_TYPE_MXFP8: { VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp8, data, nb); } break; - case GGML_TYPE_MXFP6_E2M3: + case GGML_TYPE_MXFP6: { VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp6, data, nb); } break; diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 7d848edffe..4dec9ad351 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -23,9 +23,6 @@ GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k); -GGML_API void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k); -GGML_API void quantize_row_mxfp6_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k); - GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k); @@ -52,9 +49,6 @@ GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GG GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -GGML_API void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -GGML_API void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); - // SoA quantize/dequantize for flash attention GGML_API void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); GGML_API void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); @@ -110,9 +104,6 @@ GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTR GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -GGML_API size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -GGML_API size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); - // MXFP element converters GGML_API float fp8_e4m3_to_float(uint8_t v); GGML_API uint8_t float_to_fp8_e4m3_rn(float x); diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index 09f3a43a90..d17aca2cac 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -639,7 +639,7 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { return dequantize_row_iq4_xs_sycl; case GGML_TYPE_IQ4_NL: return dequantize_row_iq4_nl_sycl; - case GGML_TYPE_MXFP4_E2M1: + case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_sycl; case GGML_TYPE_F32: return convert_unary_sycl; @@ -706,7 +706,7 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { return dequantize_row_iq4_xs_sycl; case GGML_TYPE_IQ4_NL: return dequantize_row_iq4_nl_sycl; - case GGML_TYPE_MXFP4_E2M1: + case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_sycl; case GGML_TYPE_F16: return convert_unary_sycl; diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index f0d1472f44..316aa0d0fb 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -1142,7 +1142,7 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens case GGML_TYPE_IQ4_XS: mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; - case GGML_TYPE_MXFP4_E2M1: + case GGML_TYPE_MXFP4: mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; default: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 40a0aab62b..21b9a81eae 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -710,7 +710,7 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref, }, - [GGML_TYPE_MXFP4_E2M1] = { + [GGML_TYPE_MXFP4] = { .type_name = "mxfp4", .blck_size = QK_MXFP4, .type_size = sizeof(block_mxfp4), @@ -726,21 +726,17 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .to_float = (ggml_to_float_t) dequantize_row_nvfp4, .from_float_ref = (ggml_from_float_t)quantize_row_nvfp4_ref, }, - [GGML_TYPE_MXFP8_E4M3] = { + [GGML_TYPE_MXFP8] = { .type_name = "mxfp8_e4m3", .blck_size = QK_MXFP8, .type_size = sizeof(block_mxfp8), .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_mxfp8, - .from_float_ref = (ggml_from_float_t)quantize_row_mxfp8_ref, }, - [GGML_TYPE_MXFP6_E2M3] = { + [GGML_TYPE_MXFP6] = { .type_name = "mxfp6_e2m3", .blck_size = QK_MXFP6, .type_size = sizeof(block_mxfp6), .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_mxfp6, - .from_float_ref = (ggml_from_float_t)quantize_row_mxfp6_ref, }, [GGML_TYPE_Q2_K] = { .type_name = "q2_K", @@ -1329,25 +1325,25 @@ bool ggml_is_quantized(enum ggml_type type) { } bool ggml_is_type_mxfp(enum ggml_type type) { - return type == GGML_TYPE_MXFP4_E2M1 || - type == GGML_TYPE_MXFP8_E4M3 || - type == GGML_TYPE_MXFP6_E2M3; + return type == GGML_TYPE_MXFP4 || + type == GGML_TYPE_MXFP8 || + type == GGML_TYPE_MXFP6; } bool ggml_mxfp_use_hadamard(enum ggml_type type) { switch (type) { - case GGML_TYPE_MXFP4_E2M1: return MXFP_USE_HADAMARD_E2M1; - case GGML_TYPE_MXFP8_E4M3: return MXFP_USE_HADAMARD_E4M3; - case GGML_TYPE_MXFP6_E2M3: return MXFP_USE_HADAMARD_E2M3; + case GGML_TYPE_MXFP4: return MXFP_USE_HADAMARD_E2M1; + case GGML_TYPE_MXFP8: return MXFP_USE_HADAMARD_E4M3; + case GGML_TYPE_MXFP6: return MXFP_USE_HADAMARD_E2M3; default: return false; } } int ggml_mxfp_qs_per_block(enum ggml_type type) { switch (type) { - case GGML_TYPE_MXFP4_E2M1: return MXFP_QS_PER_BLOCK_E2M1; - case GGML_TYPE_MXFP8_E4M3: return MXFP_QS_PER_BLOCK_E4M3; - case GGML_TYPE_MXFP6_E2M3: return MXFP_QS_PER_BLOCK_E2M3; + case GGML_TYPE_MXFP4: return MXFP_QS_PER_BLOCK_E2M1; + case GGML_TYPE_MXFP8: return MXFP_QS_PER_BLOCK_E4M3; + case GGML_TYPE_MXFP6: return MXFP_QS_PER_BLOCK_E2M3; default: return 0; } } @@ -7695,10 +7691,10 @@ size_t ggml_quantize_chunk( case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_MXFP4_E2M1: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_MXFP8_E4M3: result = quantize_mxfp8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_MXFP6_E2M3: result = quantize_mxfp6(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_MXFP8: GGML_ABORT("MXFP8 is KV-cache-only (SoA layout) — use from_float_soa"); + case GGML_TYPE_MXFP6: GGML_ABORT("MXFP6 is KV-cache-only (SoA layout) — use from_float_soa"); case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 279c57e582..8e8ce23124 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -457,7 +457,7 @@ static ggml_type llama_tensor_get_type_impl(quantize_state_impl & qs, ggml_type // MoE tensors -> MXFP4 // other tensors -> Q8_0 if (tensor->ne[2] > 1) { - new_type = GGML_TYPE_MXFP4_E2M1; + new_type = GGML_TYPE_MXFP4; } else { new_type = GGML_TYPE_Q8_0; } @@ -795,7 +795,7 @@ static ggml_type llama_ftype_get_default_type(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_BF16: return GGML_TYPE_BF16; case LLAMA_FTYPE_ALL_F32: return GGML_TYPE_F32; - case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return GGML_TYPE_MXFP4_E2M1; + case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return GGML_TYPE_MXFP4; // K-quants case LLAMA_FTYPE_MOSTLY_Q2_K_S: diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 8e57cb1d1d..d102c5676c 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -170,9 +170,9 @@ struct mxfp_soa_fns { }; static const mxfp_soa_fns mxfp_soa_table[] = { - { GGML_TYPE_MXFP4_E2M1, quantize_row_mxfp4_soa, dequantize_row_mxfp4_soa }, - { GGML_TYPE_MXFP8_E4M3, quantize_row_mxfp8_soa, dequantize_row_mxfp8_soa }, - { GGML_TYPE_MXFP6_E2M3, quantize_row_mxfp6_soa, dequantize_row_mxfp6_soa }, + { GGML_TYPE_MXFP4, quantize_row_mxfp4_soa, dequantize_row_mxfp4_soa }, + { GGML_TYPE_MXFP8, quantize_row_mxfp8_soa, dequantize_row_mxfp8_soa }, + { GGML_TYPE_MXFP6, quantize_row_mxfp6_soa, dequantize_row_mxfp6_soa }, }; static const mxfp_soa_fns * get_mxfp_soa(ggml_type type) { @@ -3908,7 +3908,7 @@ struct test_mul_mat : public test_case { double max_nmse_err(ggml_backend_t backend) override { // for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance - if (type_a == GGML_TYPE_MXFP4_E2M1 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) { + if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) { return 2e-2; } return max_nmse_err(); @@ -4044,7 +4044,7 @@ struct test_mul_mat_id : public test_case { double max_nmse_err(ggml_backend_t backend) override { // for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance - if (type_a == GGML_TYPE_MXFP4_E2M1 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) { + if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) { return 2e-2; } return max_nmse_err(); @@ -7398,7 +7398,7 @@ static const ggml_type all_types[] = { GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0, - GGML_TYPE_MXFP4_E2M1, + GGML_TYPE_MXFP4, GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K, @@ -7414,7 +7414,7 @@ static const ggml_type base_types[] = { GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, // for I8MM tests GGML_TYPE_Q4_K, - GGML_TYPE_MXFP4_E2M1, + GGML_TYPE_MXFP4, GGML_TYPE_IQ2_XXS }; @@ -7533,8 +7533,8 @@ static std::vector> make_test_cases_eval() { } // SET_ROWS with Hadamard rotation (exercises the op_params[0] flag used by MXFP KV cache) - for (ggml_type type : {GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP8_E4M3, - GGML_TYPE_MXFP6_E2M3}) { + for (ggml_type type : {GGML_TYPE_MXFP4, GGML_TYPE_MXFP8, + GGML_TYPE_MXFP6}) { // ne[0] must be divisible by 32 (Hadamard block size) test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 128, 5, 1, 1 }, { 1, 1 }, 1, false, true)); test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 5, 1, 3 }, { 1, 1 }, 1, false, true)); @@ -8270,7 +8270,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat_id_fusion(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3)); // gpt-oss issue with Vulkan mmq_id - test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4_E2M1, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880)); + test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880)); for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { @@ -8731,7 +8731,7 @@ static std::vector> make_test_cases_eval() { for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) { if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue; for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, - GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3, + GGML_TYPE_MXFP4, GGML_TYPE_MXFP8, GGML_TYPE_MXFP6, }) { // Non-F16 types: test at D=64, D=72, and D=128. if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72 && hsk != 128) continue; @@ -8760,8 +8760,8 @@ static std::vector> make_test_cases_eval() { // MXFP-specific K/V type combinations (mixed and same-type) // Mixed: mxfp8 K + mxfp4 V, mxfp6 K + mxfp4 V (our recommended configs) - for (ggml_type type_K : {GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3}) { - for (ggml_type type_V : {GGML_TYPE_MXFP4_E2M1}) { + for (ggml_type type_K : {GGML_TYPE_MXFP8, GGML_TYPE_MXFP6}) { + for (ggml_type type_V : {GGML_TYPE_MXFP4}) { if (type_K == type_V) continue; for (int nb : {1, 3, 32}) { test_cases.emplace_back(new test_flash_attn_ext( @@ -8770,7 +8770,7 @@ static std::vector> make_test_cases_eval() { } } // Same-type: mxfp8/mxfp8, mxfp6/mxfp6 - for (ggml_type type_KV : {GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3}) { + for (ggml_type type_KV : {GGML_TYPE_MXFP8, GGML_TYPE_MXFP6}) { for (int nb : {1, 3, 32}) { test_cases.emplace_back(new test_flash_attn_ext( 128, 128, 4, {1, 1}, 512, nb, true, false, 0.0f, 0.0f, GGML_PREC_F32, type_KV, {0, 1, 2, 3}, type_KV)); @@ -9000,7 +9000,7 @@ static std::vector> make_test_cases_perf() { // gpt-oss-20b for (int bs : {1, 4, 8, 512}) { - for (ggml_type type_a : {GGML_TYPE_MXFP4_E2M1}) { + for (ggml_type type_a : {GGML_TYPE_MXFP4}) { for (ggml_type type_b : {GGML_TYPE_F32}) { test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 2880, bs, 2880)); test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1)); diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index 98e3d489dd..8f1dcf10f0 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -178,9 +178,9 @@ int main(int argc, char * argv[]) { type == GGML_TYPE_IQ3_S ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS : type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS : type == GGML_TYPE_NVFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_FP4 : - type == GGML_TYPE_MXFP4_E2M1 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 : - type == GGML_TYPE_MXFP6_E2M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 : - type == GGML_TYPE_MXFP8_E4M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR; + type == GGML_TYPE_MXFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 : + type == GGML_TYPE_MXFP6 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 : + type == GGML_TYPE_MXFP8 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR; failed = !(total_error < max_quantization_error); num_failed += failed; if (failed || verbose) { @@ -202,7 +202,7 @@ int main(int argc, char * argv[]) { ? MAX_DOT_PRODUCT_ERROR_TERNARY : type == GGML_TYPE_NVFP4 ? MAX_DOT_PRODUCT_ERROR_FP4 - : type == GGML_TYPE_MXFP4_E2M1 || type == GGML_TYPE_MXFP6_E2M3 || type == GGML_TYPE_MXFP8_E4M3 + : type == GGML_TYPE_MXFP4 || type == GGML_TYPE_MXFP6 || type == GGML_TYPE_MXFP8 ? MAX_DOT_PRODUCT_ERROR_MXFP : MAX_DOT_PRODUCT_ERROR; failed = !(vec_dot_error < max_allowed_error); @@ -231,9 +231,9 @@ int main(int argc, char * argv[]) { const float soa_error = array_rmse(test_data.data(), tmp_out.data(), test_size); const float max_soa_error = - type == GGML_TYPE_MXFP4_E2M1 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 : - type == GGML_TYPE_MXFP6_E2M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 : - type == GGML_TYPE_MXFP8_E4M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR; + type == GGML_TYPE_MXFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 : + type == GGML_TYPE_MXFP6 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 : + type == GGML_TYPE_MXFP8 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR; failed = !(soa_error < max_soa_error); num_failed += failed; if (failed || verbose) { @@ -243,7 +243,7 @@ int main(int argc, char * argv[]) { // MXFP traits: SoA required, MXFP6/MXFP8 are KV-cache-only (no AoS dequant) { - const ggml_type all_mxfp_types[] = { GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3 }; + const ggml_type all_mxfp_types[] = { GGML_TYPE_MXFP4, GGML_TYPE_MXFP8, GGML_TYPE_MXFP6 }; for (ggml_type type : all_mxfp_types) { const auto * cpu = ggml_get_type_traits_cpu(type); @@ -255,7 +255,7 @@ int main(int argc, char * argv[]) { } // KV-cache-only types: no AoS dequant - const ggml_type kv_only_types[] = { GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3 }; + const ggml_type kv_only_types[] = { GGML_TYPE_MXFP8, GGML_TYPE_MXFP6 }; for (ggml_type type : kv_only_types) { const auto * cpu = ggml_get_type_traits_cpu(type); failed = (cpu->to_float != nullptr); @@ -297,9 +297,9 @@ int main(int argc, char * argv[]) { }; const soa_cross_check checks[] = { - { GGML_TYPE_MXFP4_E2M1, dequantize_row_mxfp4_soa }, - { GGML_TYPE_MXFP8_E4M3, dequantize_row_mxfp8_soa }, - { GGML_TYPE_MXFP6_E2M3, dequantize_row_mxfp6_soa }, + { GGML_TYPE_MXFP4, dequantize_row_mxfp4_soa }, + { GGML_TYPE_MXFP8, dequantize_row_mxfp8_soa }, + { GGML_TYPE_MXFP6, dequantize_row_mxfp6_soa }, }; for (const auto & c : checks) { @@ -774,9 +774,9 @@ int main(int argc, char * argv[]) { // SoA layout: verify offset macros produce correct byte positions { const struct { ggml_type type; int qs_per_block; } soa_types[] = { - { GGML_TYPE_MXFP4_E2M1, MXFP4_SOA_QS_PER_BLOCK }, - { GGML_TYPE_MXFP8_E4M3, MXFP8_SOA_QS_PER_BLOCK }, - { GGML_TYPE_MXFP6_E2M3, MXFP6_SOA_QS_PER_BLOCK }, + { GGML_TYPE_MXFP4, MXFP4_SOA_QS_PER_BLOCK }, + { GGML_TYPE_MXFP8, MXFP8_SOA_QS_PER_BLOCK }, + { GGML_TYPE_MXFP6, MXFP6_SOA_QS_PER_BLOCK }, }; for (const auto & st : soa_types) { @@ -864,7 +864,7 @@ int main(int argc, char * argv[]) { dequantize_row_mxfp4(aos_q.data(), aos_out.data(), nelems); // Quantize and dequant via SoA - const size_t soa_buf_size = ggml_row_size(GGML_TYPE_MXFP4_E2M1, nelems); + const size_t soa_buf_size = ggml_row_size(GGML_TYPE_MXFP4, nelems); std::vector soa_q(soa_buf_size); std::vector soa_out(nelems); quantize_row_mxfp4_soa(input, soa_q.data(), nelems); @@ -901,9 +901,9 @@ int main(int argc, char * argv[]) { }; const hadamard_pipeline_check pipeline_checks[] = { - { "mxfp4", GGML_TYPE_MXFP4_E2M1, MAX_MXFP_PIPELINE_ERROR_MXFP4 }, - { "mxfp8", GGML_TYPE_MXFP8_E4M3, MAX_MXFP_PIPELINE_ERROR_MXFP8 }, - { "mxfp6", GGML_TYPE_MXFP6_E2M3, MAX_MXFP_PIPELINE_ERROR_MXFP6 }, + { "mxfp4", GGML_TYPE_MXFP4, MAX_MXFP_PIPELINE_ERROR_MXFP4 }, + { "mxfp8", GGML_TYPE_MXFP8, MAX_MXFP_PIPELINE_ERROR_MXFP8 }, + { "mxfp6", GGML_TYPE_MXFP6, MAX_MXFP_PIPELINE_ERROR_MXFP6 }, }; for (const auto & p : pipeline_checks) { @@ -963,7 +963,7 @@ int main(int argc, char * argv[]) { // zero block produces E8M0=0 { float zeros[32] = {}; - const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP8_E4M3, 32); + const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP8, 32); std::vector buf(buf_size, 0xFF); // fill with 0xFF to detect non-writes quantize_row_mxfp8_soa(zeros, buf.data(), 32); @@ -991,7 +991,7 @@ int main(int argc, char * argv[]) { // MXFP4 { - const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP4_E2M1, nelems); + const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP4, nelems); std::vector buf(buf_size); std::vector ref_out(nelems); std::vector manual_out(nelems); @@ -1032,7 +1032,7 @@ int main(int argc, char * argv[]) { // MXFP8 { - const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP8_E4M3, nelems); + const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP8, nelems); std::vector buf(buf_size); std::vector ref_out(nelems); std::vector manual_out(nelems); @@ -1069,7 +1069,7 @@ int main(int argc, char * argv[]) { // MXFP6 { - const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP6_E2M3, nelems); + const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP6, nelems); std::vector buf(buf_size); std::vector ref_out(nelems); std::vector manual_out(nelems); diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 00c6536589..27db9a065d 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -484,13 +484,13 @@ static ggml_type ggml_type_from_name(const std::string & s) { return GGML_TYPE_IQ4_NL; } if (s == "mxfp4" || s == "mxfp4_e2m1") { - return GGML_TYPE_MXFP4_E2M1; + return GGML_TYPE_MXFP4; } if (s == "mxfp8" || s == "mxfp8_e4m3") { - return GGML_TYPE_MXFP8_E4M3; + return GGML_TYPE_MXFP8; } if (s == "mxfp6" || s == "mxfp6_e2m3") { - return GGML_TYPE_MXFP6_E2M3; + return GGML_TYPE_MXFP6; } return GGML_TYPE_COUNT; }