From a51ff77fae4f698e22d1d8ee9a88e8aa195b4be3 Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Sun, 15 Mar 2026 18:57:12 -0400 Subject: [PATCH] =?UTF-8?q?ggml:=20address=20PR=20review=20=E2=80=94=20fix?= =?UTF-8?q?=20buffer=20overflows,=20add=20assertions,=20normalize=20MXFP6?= =?UTF-8?q?=20naming?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix potential buffer overflows flagged in PR #20609 review: - set_rows: replace fixed float tmp[1024] with std::vector for large n_embd_k_gqa - tiled FA: size q_mxfp_buf with ggml_row_size guard instead of fixed 1024 - one_chunk FA: pre-allocate k/v dequant buffers from mxfp.{k,v}_soa_elems instead of hard-coded float[4096] stack arrays - kv-cache: assert n_embd_k_gqa % qk == 0 before integer division - test init: assert soa_bytes % block_size == 0 Normalize MXFP6 function naming to match MXFP8 convention (short form without element format suffix): mxfp6_e2m3 → mxfp6 in all function identifiers across 14 files. Format-specific items (type enums, traits, lookup tables, constants) retain their _e2m3 suffix. --- ggml/src/ggml-cpu/arch-fallback.h | 4 +-- ggml/src/ggml-cpu/arch/arm/quants.c | 8 +++--- ggml/src/ggml-cpu/arch/loongarch/quants.c | 4 +-- ggml/src/ggml-cpu/arch/powerpc/quants.c | 4 +-- ggml/src/ggml-cpu/arch/riscv/quants.c | 4 +-- ggml/src/ggml-cpu/arch/s390/quants.c | 4 +-- ggml/src/ggml-cpu/arch/wasm/quants.c | 4 +-- ggml/src/ggml-cpu/arch/x86/quants.c | 8 +++--- ggml/src/ggml-cpu/ggml-cpu.c | 6 ++--- ggml/src/ggml-cpu/ops.cpp | 30 +++++++++++------------ ggml/src/ggml-cpu/quants.c | 12 ++++----- ggml/src/ggml-cpu/quants.h | 10 ++++---- ggml/src/ggml-quants.c | 8 +++--- ggml/src/ggml-quants.h | 6 ++--- ggml/src/ggml.c | 6 ++--- src/llama-kv-cache.cpp | 1 + tests/test-backend-ops.cpp | 1 + 17 files changed, 61 insertions(+), 59 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 42647e14e1..612786e941 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -17,7 +17,7 @@ #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0 -#define ggml_vec_dot_mxfp6_e2m3_q8_0_generic ggml_vec_dot_mxfp6_e2m3_q8_0 +#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -349,7 +349,7 @@ #if !defined(__aarch64__) && !defined(__arm__) && !defined(_M_ARM) && !defined(_M_ARM64) && \ !defined(__x86_64__) && !defined(__i386__) && !defined(_M_IX86) && !defined(_M_X64) #define dequantize_row_mxfp8_cpu_generic dequantize_row_mxfp8_cpu -#define dequantize_row_mxfp6_e2m3_cpu_generic dequantize_row_mxfp6_e2m3_cpu +#define dequantize_row_mxfp6_cpu_generic dequantize_row_mxfp6_cpu #define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu #define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu #define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index 0f0ba86518..b18a276640 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -4333,7 +4333,7 @@ static inline void ggml_vec_dot_mxfp6_q8_0_neon( } #endif -void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); #if defined(__ARM_NEON) @@ -4342,7 +4342,7 @@ void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, con MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); #else - ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); + ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } @@ -4471,13 +4471,13 @@ void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRIC #endif } -void dequantize_row_mxfp6_e2m3_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { #if defined(__ARM_NEON) dequantize_row_mxfp6_neon(x, y, k, sizeof(block_mxfp6), MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); #else - dequantize_row_mxfp6_e2m3_cpu_generic(x, y, k); + dequantize_row_mxfp6_cpu_generic(x, y, k); #endif } diff --git a/ggml/src/ggml-cpu/arch/loongarch/quants.c b/ggml/src/ggml-cpu/arch/loongarch/quants.c index a75dac8b15..fa05e49c5d 100644 --- a/ggml/src/ggml-cpu/arch/loongarch/quants.c +++ b/ggml/src/ggml-cpu/arch/loongarch/quants.c @@ -2165,6 +2165,6 @@ void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); } -void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); } diff --git a/ggml/src/ggml-cpu/arch/powerpc/quants.c b/ggml/src/ggml-cpu/arch/powerpc/quants.c index 82ca1f9df9..efb669da09 100644 --- a/ggml/src/ggml-cpu/arch/powerpc/quants.c +++ b/ggml/src/ggml-cpu/arch/powerpc/quants.c @@ -2307,6 +2307,6 @@ void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); } -void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); } diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index dcb97756c6..beef1885da 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -3612,6 +3612,6 @@ void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); } -void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); } diff --git a/ggml/src/ggml-cpu/arch/s390/quants.c b/ggml/src/ggml-cpu/arch/s390/quants.c index 234488f25c..e696fd4570 100644 --- a/ggml/src/ggml-cpu/arch/s390/quants.c +++ b/ggml/src/ggml-cpu/arch/s390/quants.c @@ -1468,6 +1468,6 @@ void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); } -void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); } diff --git a/ggml/src/ggml-cpu/arch/wasm/quants.c b/ggml/src/ggml-cpu/arch/wasm/quants.c index 88bc6ad778..a3ae8e8885 100644 --- a/ggml/src/ggml-cpu/arch/wasm/quants.c +++ b/ggml/src/ggml-cpu/arch/wasm/quants.c @@ -1227,6 +1227,6 @@ void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); } -void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); } diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index 29d5a28759..0c6f6ed49a 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -3995,7 +3995,7 @@ static inline void ggml_vec_dot_mxfp6_q8_0_avx2( } #endif -void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); #if defined(__AVX2__) @@ -4004,7 +4004,7 @@ void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, con MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); #else - ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); + ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } @@ -4130,13 +4130,13 @@ void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRIC #endif } -void dequantize_row_mxfp6_e2m3_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { #if defined(__AVX2__) dequantize_row_mxfp6_avx2(x, y, k, sizeof(block_mxfp6), MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); #else - dequantize_row_mxfp6_e2m3_cpu_generic(x, y, k); + dequantize_row_mxfp6_cpu_generic(x, y, k); #endif } diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index a87f808c95..7b7fb1e5ce 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -284,9 +284,9 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .nrows = 1, }, [GGML_TYPE_MXFP6_E2M3] = { - .from_float = quantize_row_mxfp6_e2m3, - .to_float = dequantize_row_mxfp6_e2m3_cpu, - .vec_dot = ggml_vec_dot_mxfp6_e2m3_q8_0, + .from_float = quantize_row_mxfp6, + .to_float = dequantize_row_mxfp6_cpu, + .vec_dot = ggml_vec_dot_mxfp6_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, }, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 02cd1abb8d..2267eaa27b 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -5061,14 +5061,13 @@ static void ggml_compute_forward_set_rows_f32( char * dst_row = ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3); if (apply_hadamard) { - GGML_ASSERT(nc <= 1024); - float tmp[1024]; - memcpy(tmp, src_row, nc * sizeof(float)); - ggml_apply_hadamard_blocks(tmp, nc); + std::vector tmp(nc); + memcpy(tmp.data(), src_row, nc * sizeof(float)); + ggml_apply_hadamard_blocks(tmp.data(), nc); if (mxfp_soa_quantize) { - mxfp_soa_quantize(tmp, dst_row, nc); + mxfp_soa_quantize(tmp.data(), dst_row, nc); } else { - from_float(tmp, dst_row, nc); + from_float(tmp.data(), dst_row, nc); } } else { if (mxfp_soa_quantize) { @@ -8418,6 +8417,10 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( int ith = params->ith; + // Pre-allocate dequant buffers for MXFP SoA (avoids per-iteration allocation) + std::vector k_dequant_buf(is_mxfp_k ? mxfp.k_soa_elems : 0); + std::vector v_dequant_buf(is_mxfp_v ? mxfp.v_soa_elems : 0); + for (int ir = ir0; ir < ir1; ++ir) { // q indices const int iq3 = ir/(neq2*neq1); @@ -8497,10 +8500,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( const char * k_soa_base = mxfp.k_multihead ? ((const char *) k->data + ic*nbk1 + ik3*nbk3) : k_data; - float k_soa_f32[4096]; - GGML_ASSERT(mxfp.k_soa_elems <= 4096); - mxfp.k_dequantize(k_soa_base, k_soa_f32, mxfp.k_soa_elems); - const float * k_head = k_soa_f32 + (mxfp.k_multihead ? ik2 * DK : 0); + mxfp.k_dequantize(k_soa_base, k_dequant_buf.data(), mxfp.k_soa_elems); + const float * k_head = k_dequant_buf.data() + (mxfp.k_multihead ? ik2 * DK : 0); ggml_vec_dot_f32(DK, &s, 0, k_head, 0, Q_f32, 0, 1); } else { kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1); @@ -8554,10 +8555,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( const char * v_soa_base = mxfp.v_multihead ? ((const char *) v->data + ic*nbv1 + iv3*nbv3) : v_data; - float v_soa_f32[4096]; - GGML_ASSERT(mxfp.v_soa_elems <= 4096); - mxfp.v_dequantize(v_soa_base, v_soa_f32, mxfp.v_soa_elems); - ggml_vec_mad_f32(DV, VKQ32, v_soa_f32 + (mxfp.v_multihead ? iv2 * DV : 0), vs); + mxfp.v_dequantize(v_soa_base, v_dequant_buf.data(), mxfp.v_soa_elems); + ggml_vec_mad_f32(DV, VKQ32, v_dequant_buf.data() + (mxfp.v_multihead ? iv2 * DV : 0), vs); } else if (v_to_float) { v_to_float(v_data, V32, DV); ggml_vec_mad_f32(DV, VKQ32, V32, vs); @@ -8765,7 +8764,8 @@ static void ggml_compute_forward_flash_attn_ext_tiled( ggml_apply_hadamard_blocks(Q_f32 + tq * DK, DK); } // SoA round-trip: quantize Q to SoA, then dequant back to float. - uint8_t q_mxfp_buf[1024]; + uint8_t q_mxfp_buf[512]; // max: DK=256 * 33/32 = 264 bytes (MXFP8) + GGML_ASSERT(ggml_row_size(k->type, DK) <= sizeof(q_mxfp_buf)); mxfp.q_quantize(Q_f32 + tq * DK, q_mxfp_buf, DK); mxfp.k_dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK); } diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index 9152755010..7303638c81 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -58,8 +58,8 @@ void quantize_row_mxfp8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, i quantize_row_mxfp8_ref(x, y, k); } -void quantize_row_mxfp6_e2m3(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { - quantize_row_mxfp6_e2m3_ref(x, y, k); +void quantize_row_mxfp6(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_mxfp6_ref(x, y, k); } // @@ -301,14 +301,14 @@ void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, (ggml_to_float_t)dequantize_row_mxfp8); } -void ggml_vec_dot_mxfp6_e2m3_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); ggml_vec_dot_mxfp_q8_0_impl(n, s, vx, sizeof(block_mxfp6), vy, - (ggml_to_float_t)dequantize_row_mxfp6_e2m3); + (ggml_to_float_t)dequantize_row_mxfp6); } // Generic (scalar) dequant wrappers — delegates to ggml-quants.c reference implementations. @@ -316,8 +316,8 @@ void ggml_vec_dot_mxfp6_e2m3_q8_0_generic(int n, float * GGML_RESTRICT s, size_t void dequantize_row_mxfp8_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { dequantize_row_mxfp8(x, y, k); } -void dequantize_row_mxfp6_e2m3_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { - dequantize_row_mxfp6_e2m3(x, y, k); +void dequantize_row_mxfp6_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + dequantize_row_mxfp6(x, y, k); } void dequantize_row_mxfp4_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { dequantize_row_mxfp4_soa(x, y, k); diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index 7d8c32762a..0a7ea64135 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -22,11 +22,11 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_mxfp8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_mxfp6_e2m3(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_mxfp6(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); // Dequantization (SIMD-optimized, arch-dispatched) void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_mxfp6_e2m3_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -51,7 +51,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -85,10 +85,10 @@ void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_mxfp6_e2m3_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void dequantize_row_mxfp8_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_mxfp6_e2m3_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_mxfp6_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); // SoA dequant (SIMD-optimized for FA) void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index b2692c45f6..188c7e68b6 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -797,11 +797,11 @@ void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_REST dequantize_row_mxfp8_impl(x, y, k, &mxfp8_e4m3_traits); } -void quantize_row_mxfp6_e2m3_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k) { +void quantize_row_mxfp6_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k) { quantize_row_mxfp6_impl(x, y, k, &mxfp6_e2m3_traits); } -void dequantize_row_mxfp6_e2m3(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { dequantize_row_mxfp6_impl(x, y, k, &mxfp6_e2m3_traits); } @@ -2627,9 +2627,9 @@ size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, return nrow * ggml_row_size(GGML_TYPE_MXFP8_E4M3, n_per_row); } -size_t quantize_mxfp6_e2m3(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { +size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { GGML_UNUSED(quant_weights); - quantize_row_mxfp6_e2m3_ref(src, dst, (int64_t)nrow*n_per_row); + quantize_row_mxfp6_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_MXFP6_E2M3, n_per_row); } diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 33401f2843..a0f6928e10 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -24,7 +24,7 @@ GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k); -GGML_API void quantize_row_mxfp6_e2m3_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_mxfp6_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); @@ -53,7 +53,7 @@ GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GG GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -GGML_API void dequantize_row_mxfp6_e2m3(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); // SoA (Struct-of-Arrays) quantize/dequantize — canonical reference for flash attention. // Layout: [qs contiguous][e8m0 contiguous] per row. Same total bytes as AoS. @@ -112,7 +112,7 @@ GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTR GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -GGML_API size_t quantize_mxfp6_e2m3(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); // // MXFP element-level conversion functions (reference implementations) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 329f2b93b3..37b99e844e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -739,8 +739,8 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .blck_size = QK_MXFP6, .type_size = sizeof(block_mxfp6), .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_mxfp6_e2m3, - .from_float_ref = (ggml_from_float_t)quantize_row_mxfp6_e2m3_ref, + .to_float = (ggml_to_float_t) dequantize_row_mxfp6, + .from_float_ref = (ggml_from_float_t)quantize_row_mxfp6_ref, }, [GGML_TYPE_Q2_K] = { .type_name = "q2_K", @@ -7692,7 +7692,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_MXFP4_E2M1: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_MXFP8_E4M3: result = quantize_mxfp8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_MXFP6_E2M3: result = quantize_mxfp6_e2m3(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_MXFP6_E2M3: result = quantize_mxfp6(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index fcd784e79d..29fdeb3f33 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -141,6 +141,7 @@ llama_kv_cache::llama_kv_cache( const bool is_mxfp_k = ggml_is_type_mxfp(type_k); if (is_mxfp_k) { const int qk = (int)ggml_blck_size(type_k); // 32 for all MXFP types + GGML_ASSERT(n_embd_k_gqa % qk == 0 && "MXFP K cache requires n_embd_k_gqa divisible by block size"); const int blocks = (int)n_embd_k_gqa / qk; const int blocks_aligned = (blocks + 15) & ~15; // align to 16 n_embd_k_alloc = (uint32_t)(blocks_aligned * qk); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6ddef63336..3915ab4db6 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -181,6 +181,7 @@ static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float const size_t block_size = ggml_type_size(tensor->type); const size_t head_row_sz = ggml_row_size(tensor->type, tensor->ne[0]); if (soa_bytes == 0) { soa_bytes = head_row_sz; } + GGML_ASSERT(soa_bytes % block_size == 0 && "soa_bytes must be a multiple of block_size"); const int64_t soa_elems = (int64_t)(soa_bytes / block_size) * qk; std::default_random_engine gen(42);