diff --git a/common/arg.cpp b/common/arg.cpp index 26c1904a2a..5e3b40d899 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -404,15 +404,6 @@ const std::vector kv_cache_types = { }; static ggml_type kv_cache_type_from_str(const std::string & s) { - if (s == "mxfp4") { - return GGML_TYPE_MXFP4; - } - if (s == "mxfp6") { - return GGML_TYPE_MXFP6; - } - if (s == "mxfp8") { - return GGML_TYPE_MXFP8; - } for (const auto & type : kv_cache_types) { if (ggml_type_name(type) == s) { return type; diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 271de1943c..b60794717c 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -199,13 +199,12 @@ typedef struct { static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding"); // E8M0 shared exponent constants (OCP MX v1.0 SS5.3). -// EMAX_OFFSET = ceil(log2(max_finite)), MSE_RANGE = search radius for optimal scale. -#define MXFP_E8M0_MSE_RANGE 2 -#define MXFP4_E2M1_EMAX_OFFSET 2 // ceil(log2(6.0)) -#define MXFP6_E2M3_EMAX_OFFSET 3 // ceil(log2(7.5)) -#define MXFP6_E3M2_EMAX_OFFSET 5 // ceil(log2(28.0)) -#define MXFP8_E4M3_EMAX_OFFSET 8 // ceil(log2(448)) -#define MXFP8_E5M2_EMAX_OFFSET 16 // ceil(log2(57344)) +// EMAX_OFFSET ≈ log2(max_finite), used by round(log2(amax)) base estimate. +#define MXFP4_E2M1_EMAX_OFFSET 2 // floor(log2(6.0)) = 2 +#define MXFP6_E2M3_EMAX_OFFSET 3 // ceil(log2(7.5)) = 3 +#define MXFP6_E3M2_EMAX_OFFSET 5 // ceil(log2(28.0)) = 5 +#define MXFP8_E4M3_EMAX_OFFSET 8 // floor(log2(448)) = 8 +#define MXFP8_E5M2_EMAX_OFFSET 16 // ceil(log2(57344)) = 16 // MXFP type properties -- shared across all backends. #define MXFP_BITS_PER_ELEM_E2M1 4 @@ -1635,13 +1634,17 @@ GGML_MXFP_FUNC void ggml_mxfp_unpack_fp6x4(const uint8_t in[3], uint8_t v[4]) { // E8M0 shared exponent → float conversion. // E8M0 encoding: value = 2^(x - 127) for x > 0, 2^(-127) for x == 0. +// E8M0 = 255 is NaN per MX spec, but we clamp to 254 (max finite) to match +// the encode path which also clamps to 254, preventing Inf * 0 = NaN in dequant. GGML_MXFP_FUNC float ggml_mxfp_e8m0_to_fp32(uint8_t x) { + if (x == 255) { x = 254; } uint32_t bits = (x == 0) ? 0x00400000u : ((uint32_t)x << 23); return GGML_MXFP_U32_AS_F32(bits); } // E8M0 → float/2. Used with MXFP4 since E2M1 values are doubled in kvalues_mxfp4. GGML_MXFP_FUNC float ggml_mxfp_e8m0_to_fp32_half(uint8_t x) { + if (x == 255) { x = 254; } uint32_t bits = (x < 2) ? (0x00200000u << x) : ((uint32_t)(x - 1) << 23); return GGML_MXFP_U32_AS_F32(bits); } diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index eac031e68e..03f7bc0efe 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -14,6 +14,8 @@ #define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1 #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 +#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0 +#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu #define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu @@ -72,6 +74,9 @@ #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) +// quants.c +#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0 +#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0 // repack.cpp #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 @@ -83,6 +88,8 @@ #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // quants.c +#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0 +#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 @@ -164,6 +171,8 @@ #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 +#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0 +#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu #define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu @@ -319,6 +328,8 @@ #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0 #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 +#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0 +#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 782a54392f..e2720ea3a2 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -280,13 +280,19 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .nrows = 1, }, [GGML_TYPE_MXFP8] = { + .from_float = (ggml_from_float_t)quantize_row_mxfp8_ref, .from_float_soa = quantize_row_mxfp8_soa, .to_float_soa = dequantize_row_mxfp8_soa_cpu, + .vec_dot = ggml_vec_dot_mxfp8_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, }, [GGML_TYPE_MXFP6] = { + .from_float = (ggml_from_float_t)quantize_row_mxfp6_ref, .from_float_soa = quantize_row_mxfp6_soa, .to_float_soa = dequantize_row_mxfp6_soa_cpu, + .vec_dot = ggml_vec_dot_mxfp6_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, }, [GGML_TYPE_Q2_K] = { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 15424d40c4..fcdd7b045d 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4909,8 +4909,106 @@ void ggml_compute_forward_get_rows( //} } -// NEON-optimized Hadamard; scalar fallback below -#if defined(__ARM_NEON) +// SIMD-optimized Hadamard; scalar fallback below +#if defined(__AVX2__) || defined(__AVX__) +static void hadamard_32_inplace(float vals[32]) { + // 32 floats = 4 × __m256 + __m256 v0 = _mm256_loadu_ps(vals + 0); + __m256 v1 = _mm256_loadu_ps(vals + 8); + __m256 v2 = _mm256_loadu_ps(vals + 16); + __m256 v3 = _mm256_loadu_ps(vals + 24); + + // Stride 1: butterfly on adjacent pairs within each 256-bit register + { + // Interleave even/odd elements, add/sub + __m256 a, b, s, d; + a = _mm256_shuffle_ps(v0, v0, _MM_SHUFFLE(2, 2, 0, 0)); + b = _mm256_shuffle_ps(v0, v0, _MM_SHUFFLE(3, 3, 1, 1)); + s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b); + v0 = _mm256_shuffle_ps(s, d, _MM_SHUFFLE(2, 0, 2, 0)); + v0 = _mm256_shuffle_ps(v0, v0, _MM_SHUFFLE(3, 1, 2, 0)); + + a = _mm256_shuffle_ps(v1, v1, _MM_SHUFFLE(2, 2, 0, 0)); + b = _mm256_shuffle_ps(v1, v1, _MM_SHUFFLE(3, 3, 1, 1)); + s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b); + v1 = _mm256_shuffle_ps(s, d, _MM_SHUFFLE(2, 0, 2, 0)); + v1 = _mm256_shuffle_ps(v1, v1, _MM_SHUFFLE(3, 1, 2, 0)); + + a = _mm256_shuffle_ps(v2, v2, _MM_SHUFFLE(2, 2, 0, 0)); + b = _mm256_shuffle_ps(v2, v2, _MM_SHUFFLE(3, 3, 1, 1)); + s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b); + v2 = _mm256_shuffle_ps(s, d, _MM_SHUFFLE(2, 0, 2, 0)); + v2 = _mm256_shuffle_ps(v2, v2, _MM_SHUFFLE(3, 1, 2, 0)); + + a = _mm256_shuffle_ps(v3, v3, _MM_SHUFFLE(2, 2, 0, 0)); + b = _mm256_shuffle_ps(v3, v3, _MM_SHUFFLE(3, 3, 1, 1)); + s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b); + v3 = _mm256_shuffle_ps(s, d, _MM_SHUFFLE(2, 0, 2, 0)); + v3 = _mm256_shuffle_ps(v3, v3, _MM_SHUFFLE(3, 1, 2, 0)); + } + + // Stride 2: butterfly on pairs separated by 2 within 128-bit lanes + { + __m256 a, b, s, d; + a = _mm256_permute_ps(v0, _MM_SHUFFLE(1, 0, 1, 0)); + b = _mm256_permute_ps(v0, _MM_SHUFFLE(3, 2, 3, 2)); + s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b); + v0 = _mm256_blend_ps(s, d, 0xCC); // 0b11001100 + + a = _mm256_permute_ps(v1, _MM_SHUFFLE(1, 0, 1, 0)); + b = _mm256_permute_ps(v1, _MM_SHUFFLE(3, 2, 3, 2)); + s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b); + v1 = _mm256_blend_ps(s, d, 0xCC); + + a = _mm256_permute_ps(v2, _MM_SHUFFLE(1, 0, 1, 0)); + b = _mm256_permute_ps(v2, _MM_SHUFFLE(3, 2, 3, 2)); + s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b); + v2 = _mm256_blend_ps(s, d, 0xCC); + + a = _mm256_permute_ps(v3, _MM_SHUFFLE(1, 0, 1, 0)); + b = _mm256_permute_ps(v3, _MM_SHUFFLE(3, 2, 3, 2)); + s = _mm256_add_ps(a, b); d = _mm256_sub_ps(a, b); + v3 = _mm256_blend_ps(s, d, 0xCC); + } + + // Stride 4: butterfly between 128-bit lanes within each 256-bit register + { + __m128 lo, hi; + lo = _mm256_castps256_ps128(v0); hi = _mm256_extractf128_ps(v0, 1); + v0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_add_ps(lo, hi)), _mm_sub_ps(lo, hi), 1); + + lo = _mm256_castps256_ps128(v1); hi = _mm256_extractf128_ps(v1, 1); + v1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_add_ps(lo, hi)), _mm_sub_ps(lo, hi), 1); + + lo = _mm256_castps256_ps128(v2); hi = _mm256_extractf128_ps(v2, 1); + v2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_add_ps(lo, hi)), _mm_sub_ps(lo, hi), 1); + + lo = _mm256_castps256_ps128(v3); hi = _mm256_extractf128_ps(v3, 1); + v3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_add_ps(lo, hi)), _mm_sub_ps(lo, hi), 1); + } + + // Stride 8: butterfly between registers + { + __m256 s, d; + s = _mm256_add_ps(v0, v1); d = _mm256_sub_ps(v0, v1); v0 = s; v1 = d; + s = _mm256_add_ps(v2, v3); d = _mm256_sub_ps(v2, v3); v2 = s; v3 = d; + } + + // Stride 16: butterfly between register pairs + { + __m256 s, d; + s = _mm256_add_ps(v0, v2); d = _mm256_sub_ps(v0, v2); v0 = s; v2 = d; + s = _mm256_add_ps(v1, v3); d = _mm256_sub_ps(v1, v3); v1 = s; v3 = d; + } + + // Normalize by 1/sqrt(32) + const __m256 norm = _mm256_set1_ps(MXFP_HADAMARD_32_NORM); + _mm256_storeu_ps(vals + 0, _mm256_mul_ps(v0, norm)); + _mm256_storeu_ps(vals + 8, _mm256_mul_ps(v1, norm)); + _mm256_storeu_ps(vals + 16, _mm256_mul_ps(v2, norm)); + _mm256_storeu_ps(vals + 24, _mm256_mul_ps(v3, norm)); +} +#elif defined(__ARM_NEON) static void hadamard_32_inplace(float vals[32]) { float32x4_t v0 = vld1q_f32(vals + 0); float32x4_t v1 = vld1q_f32(vals + 4); @@ -5032,9 +5130,15 @@ static void ggml_compute_forward_set_rows_f32( ggml_from_float_t mxfp_soa_quantize = dst_traits->from_float_soa; ggml_from_float_t from_float = mxfp_soa_quantize ? nullptr : dst_traits->from_float; - std::vector had_tmp; - if (apply_hadamard) { - had_tmp.resize(nc); + // Fused Hadamard+quantize: one pass per block, 32-float stack buffer, no heap allocation. + ggml_from_float_t mxfp_soa_hadamard_quantize = nullptr; + if (apply_hadamard && mxfp_soa_quantize) { + switch (dst->type) { + case GGML_TYPE_MXFP4: mxfp_soa_hadamard_quantize = (ggml_from_float_t)quantize_row_mxfp4_soa_hadamard; break; + case GGML_TYPE_MXFP8: mxfp_soa_hadamard_quantize = (ggml_from_float_t)quantize_row_mxfp8_soa_hadamard; break; + case GGML_TYPE_MXFP6: mxfp_soa_hadamard_quantize = (ggml_from_float_t)quantize_row_mxfp6_soa_hadamard; break; + default: break; + } } for (int64_t i03 = 0; i03 < ne03; ++i03) { @@ -5051,20 +5155,12 @@ static void ggml_compute_forward_set_rows_f32( const float * src_row = (const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03); char * dst_row = ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3); - if (apply_hadamard) { - memcpy(had_tmp.data(), src_row, nc * sizeof(float)); - ggml_apply_hadamard_blocks(had_tmp.data(), nc); - if (mxfp_soa_quantize) { - mxfp_soa_quantize(had_tmp.data(), dst_row, nc); - } else { - from_float(had_tmp.data(), dst_row, nc); - } + if (mxfp_soa_hadamard_quantize) { + mxfp_soa_hadamard_quantize(src_row, dst_row, nc); + } else if (mxfp_soa_quantize) { + mxfp_soa_quantize(src_row, dst_row, nc); } else { - if (mxfp_soa_quantize) { - mxfp_soa_quantize(src_row, dst_row, nc); - } else { - from_float(src_row, dst_row, nc); - } + from_float(src_row, dst_row, nc); } } } @@ -8268,7 +8364,6 @@ typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t); struct mxfp_kv_params { mxfp_soa_dequantize_fn dequantize; bool multihead; - int64_t soa_elems; int qs_per_block; int head_qs_bytes; int64_t head_e8m0_offset; @@ -8277,12 +8372,28 @@ struct mxfp_kv_params { // MXFP dispatch parameters for flash attention. struct mxfp_fa_params { - mxfp_soa_quantize_fn q_quantize; + mxfp_soa_quantize_fn q_quantize; // SoA quantize for Q (used only when Hadamard is off AND non-MXFP K path) + // Fused Q round-trip: Hadamard + quantize + dequant in one pass, no SoA buffer. + void (*q_roundtrip)(const float *, float *, int64_t); mxfp_kv_params k; mxfp_kv_params v; bool apply_hadamard; }; +// Compute the SoA row base pointer for a given KV position and head. +// In multihead mode, the SoA region spans all heads at one KV position, +// so the row base must NOT include the per-head offset (head_idx * nb2). +// mxfp_dequant_head handles per-head indexing within the SoA region. +// In per-head mode, each head has its own SoA region, so the base includes nb2. +static inline const char * mxfp_row_ptr( + const mxfp_kv_params & kv, const char * data, + int64_t kv_pos, size_t nb1, int head_idx, size_t nb2, int batch_idx, size_t nb3) { + if (kv.multihead) { + return data + kv_pos*nb1 + batch_idx*nb3; + } + return data + kv_pos*nb1 + head_idx*nb2 + batch_idx*nb3; +} + // Extract one head's SoA data from a multihead row and dequantize. static inline void mxfp_dequant_head( const mxfp_kv_params & kv, const char * row, int head_idx, @@ -8305,7 +8416,6 @@ static mxfp_kv_params mxfp_kv_params_init(ggml_type type, int64_t D, size_t nb2, mxfp_kv_params kv = {}; kv.dequantize = ggml_get_type_traits_cpu(type)->to_float_soa; kv.multihead = (nb2 == (size_t)ggml_row_size(type, D)); - kv.soa_elems = kv.multihead ? ne2 * D : D; kv.qs_per_block = ggml_mxfp_qs_per_block(type); kv.blocks_per_head = (int)(D / 32); kv.head_qs_bytes = kv.blocks_per_head * kv.qs_per_block; @@ -8328,6 +8438,17 @@ static mxfp_fa_params mxfp_fa_params_init( p.q_quantize = ggml_get_type_traits_cpu(k->type)->from_float_soa; p.k = mxfp_kv_params_init(k->type, DK, nbk2, nek2); } + + // Select fused Q round-trip (Hadamard + quantize error, no SoA buffer). + if (is_mxfp_k) { + const bool had = is_mxfp_k && (DK == DV) && ggml_mxfp_use_hadamard(k->type); + switch (k->type) { + case GGML_TYPE_MXFP4: p.q_roundtrip = had ? mxfp4_hadamard_roundtrip : mxfp4_roundtrip; break; + case GGML_TYPE_MXFP8: p.q_roundtrip = had ? mxfp8_hadamard_roundtrip : mxfp8_roundtrip; break; + case GGML_TYPE_MXFP6: p.q_roundtrip = had ? mxfp6_hadamard_roundtrip : mxfp6_roundtrip; break; + default: break; + } + } if (is_mxfp_v) { p.v = mxfp_kv_params_init(v->type, DV, nbv2, nev2); } @@ -8486,22 +8607,14 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( const char * k_base = (const char *) k->data + k_base_offset; const char * v_base = (const char *) v->data + v_base_offset; - const char * k_row_base = mxfp.k.multihead ? ((const char *) k->data + ik3*nbk3) : nullptr; - const char * v_row_base = mxfp.v.multihead ? ((const char *) v->data + iv3*nbv3) : nullptr; + const char * k_data_base = (const char *) k->data; + const char * v_data_base = (const char *) v->data; const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); float Q_f32[MXFP_FA_MAX_D]; - if (is_mxfp_k) { - // Q preprocessing: Hadamard + SoA round-trip captures same quantization loss as K. - if (mxfp.apply_hadamard) { - float q_tmp[MXFP_FA_MAX_D]; - memcpy(q_tmp, pq, DK * sizeof(float)); - ggml_apply_hadamard_blocks(q_tmp, DK); - mxfp.q_quantize(q_tmp, Q_q, DK); - } else { - mxfp.q_quantize(pq, Q_q, DK); - } - mxfp.k.dequantize(Q_q, Q_f32, DK); + if (mxfp.q_roundtrip) { + // Q preprocessing: fused Hadamard + quantize round-trip, no SoA buffer. + mxfp.q_roundtrip(pq, Q_f32, DK); } else { if (mxfp.apply_hadamard) { float q_tmp[MXFP_FA_MAX_D]; @@ -8526,7 +8639,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( float s; // KQ value if (is_mxfp_k) { - const char * k_row = mxfp.k.multihead ? k_row_base + ic*nbk1 : k_base + ic*nbk1; + const char * k_row = mxfp_row_ptr(mxfp.k, k_data_base, + ic, nbk1, ik2, nbk2, ik3, nbk3); mxfp_dequant_head(mxfp.k, k_row, ik2, k_head_soa, k_dequant_buf, DK); ggml_vec_dot_f32(DK, &s, 0, k_dequant_buf, 0, Q_f32, 0, 1); } else { @@ -8572,7 +8686,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( // V += v*expf(s - M) if (mxfp.v.dequantize) { - const char * v_row = mxfp.v.multihead ? v_row_base + ic*nbv1 : v_base + ic*nbv1; + const char * v_row = mxfp_row_ptr(mxfp.v, v_data_base, + ic, nbv1, iv2, nbv2, iv3, nbv3); mxfp_dequant_head(mxfp.v, v_row, iv2, v_head_soa, v_dequant_buf, DV); ggml_vec_mad_f32(DV, VKQ32, v_dequant_buf, vs); } else if (v_to_float) { @@ -8723,7 +8838,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled( if (is_mxfp_v) { GGML_ASSERT(DV <= MXFP_FA_MAX_D); } float k_dequant_buf[MXFP_FA_MAX_D]; - float v_dequant_buf[MXFP_FA_MAX_D]; char k_head_soa[MXFP_FA_SOA_BUF]; char v_head_soa[MXFP_FA_SOA_BUF]; @@ -8786,13 +8900,9 @@ static void ggml_compute_forward_flash_attn_ext_tiled( const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3)); memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float)); - if (is_mxfp_k) { - if (mxfp.apply_hadamard) { - ggml_apply_hadamard_blocks(Q_f32 + tq * DK, DK); - } - uint8_t q_mxfp_buf[MXFP_FA_SOA_BUF]; - mxfp.q_quantize(Q_f32 + tq * DK, q_mxfp_buf, DK); - mxfp.k.dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK); + if (mxfp.q_roundtrip) { + // In-place: Q_f32 is already populated by memcpy above, roundtrip overwrites. + mxfp.q_roundtrip(Q_f32 + tq * DK, Q_f32 + tq * DK, DK); } } for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) { @@ -8843,7 +8953,9 @@ static void ggml_compute_forward_flash_attn_ext_tiled( K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk]; } } else if (mxfp.k.dequantize) { - mxfp_dequant_head(mxfp.k, k_data, ik2, k_head_soa, k_dequant_buf, DK); + const char * k_row = mxfp_row_ptr(mxfp.k, (const char *)k->data, + ic + tk, nbk1, ik2, nbk2, ik3, nbk3); + mxfp_dequant_head(mxfp.k, k_row, ik2, k_head_soa, k_dequant_buf, DK); for (int64_t dk = 0; dk < DK; dk++) { K_f32[dk * KV_TILE_SZ + tk] = k_dequant_buf[dk]; } @@ -8913,8 +9025,9 @@ static void ggml_compute_forward_flash_attn_ext_tiled( } else if (v_type == GGML_TYPE_F32) { memcpy(V32 + tk * DV, v_data, DV * sizeof(float)); } else if (mxfp.v.dequantize) { - mxfp_dequant_head(mxfp.v, v_data, iv2, v_head_soa, v_dequant_buf, DV); - memcpy(V32 + tk * DV, v_dequant_buf, DV * sizeof(float)); + const char * v_row = mxfp_row_ptr(mxfp.v, (const char *)v->data, + ic + tk, nbv1, iv2, nbv2, iv3, nbv3); + mxfp_dequant_head(mxfp.v, v_row, iv2, v_head_soa, V32 + tk * DV, DV); } else { v_to_float(v_data, V32 + tk * DV, DV); } @@ -9087,10 +9200,10 @@ static void ggml_compute_forward_flash_attn_ext_f16( // Split-KV: parallelize across KV chunks for single-query decode (token generation). // Only for types whose tiled/one_chunk paths produce identical results (f32, f16, MXFP). // Standard quant types (q8_0, q4_0) must use the scalar path to preserve vec_dot semantics. - const bool kv_is_f32_f16_or_mxfp = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16 + const bool k_is_f32_f16_or_mxfp = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16 || ggml_is_type_mxfp(k->type)); const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) - && kv_is_f32_f16_or_mxfp + && k_is_f32_f16_or_mxfp && q->type == GGML_TYPE_F32 && nek1 >= 512; if (use_split_kv_path) { @@ -9151,7 +9264,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( // Tiled path: f32, f16, and MXFP only (quant types use one_chunk) bool use_tiled = !use_ref && (q->type == GGML_TYPE_F32 && - kv_is_f32_f16_or_mxfp && + k_is_f32_f16_or_mxfp && (k->type == v->type || ggml_is_type_mxfp(k->type)) && neq1 >= Q_TILE_SZ); #ifdef GGML_SIMD diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index 5cbd177234..0c4faa4fc1 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -189,6 +189,54 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c *s = sumf; } +void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + assert(n % QK_MXFP8 == 0); + static_assert(QK_MXFP8 == QK8_0, "QK_MXFP8 and QK8_0 must be the same"); + + const block_mxfp8 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + const int nb = n / QK_MXFP8; + + float sumf = 0; + for (int ib = 0; ib < nb; ++ib) { + const float d = GGML_CPU_FP16_TO_FP32(y[ib].d) * GGML_E8M0_TO_FP32(x[ib].e); + float sumi = 0; + for (int j = 0; j < QK_MXFP8; ++j) { + sumi += y[ib].qs[j] * ggml_mxfp_fp8_e4m3_to_float(x[ib].qs[j]); + } + sumf += d * sumi; + } + *s = sumf; +} + +void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + assert(n % QK_MXFP6 == 0); + static_assert(QK_MXFP6 == QK8_0, "QK_MXFP6 and QK8_0 must be the same"); + + const block_mxfp6 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + const int nb = n / QK_MXFP6; + + float sumf = 0; + for (int ib = 0; ib < nb; ++ib) { + const float d = GGML_CPU_FP16_TO_FP32(y[ib].d) * GGML_E8M0_TO_FP32(x[ib].e); + float sumi = 0; + for (int j = 0; j < QK_MXFP6; j += 4) { + uint8_t vals[4]; + ggml_mxfp_unpack_fp6x4(&x[ib].qs[j * 3 / 4], vals); + for (int jj = 0; jj < 4; jj++) { + sumi += y[ib].qs[j + jj] * ggml_mxfp_fp6_e2m3_to_float(vals[jj]); + } + } + sumf += d * sumi; + } + *s = sumf; +} + void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index 4a4dd264fe..4c75f9b0cd 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -42,6 +42,8 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index f9358b0432..4d98113139 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -431,13 +431,15 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) // E8M0 shared exponent to float: returns 2^(x - 127). +// Canonical implementation is ggml_mxfp_e8m0_to_fp32 in ggml-common.h. +// This thin wrapper exists because not all callers include ggml-common.h. +// MUST stay in sync — if you change the logic, change ggml-common.h too. +// +// E8M0 = 255 is NaN per MX spec; clamped to 254 (max finite) to match +// the encode path which also clamps to 254, preventing Inf * 0 = NaN. static inline float ggml_e8m0_to_fp32(uint8_t x) { - uint32_t bits; - if (x == 0) { - bits = 0x00400000; // denorm: 0.5 * 2^(-126) = 2^(-127) - } else { - bits = (uint32_t) x << 23; - } + if (x == 255) { x = 254; } + uint32_t bits = (x == 0) ? 0x00400000u : ((uint32_t)x << 23); float result; memcpy(&result, &bits, sizeof(float)); return result; @@ -445,14 +447,8 @@ static inline float ggml_e8m0_to_fp32(uint8_t x) { // E8M0 to float/2: returns 2^(x - 128). static inline float ggml_e8m0_to_fp32_half(uint8_t x) { - uint32_t bits; - if (x < 2) { - // x=0 → 2^(-128), x=1 → 2^(-127): denormal bit patterns - bits = 0x00200000 << x; - } else { - // 0.5 * 2^(x-127) = 2^(x-128): normalized with exponent (x-1) - bits = (uint32_t)(x - 1) << 23; - } + if (x == 255) { x = 254; } + uint32_t bits = (x < 2) ? (0x00200000u << x) : ((uint32_t)(x - 1) << 23); float result; memcpy(&result, &bits, sizeof(float)); return result; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 5c8eb97806..1afd82b6c5 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -263,8 +263,6 @@ float fp8_e4m3_to_float(uint8_t v) { return ggml_mxfp_fp8_e4m3_to_float(v); } uint8_t float_to_fp8_e4m3_rn(float x) { return ggml_mxfp_float_to_fp8_e4m3(x); } // ====================== MXFP quantization infrastructure -// -// MSE-optimal E8M0: tests candidates around round(log2(amax)), picks lowest quantization error. typedef struct { int emax_offset; // type-specific offset to max representable exponent @@ -272,33 +270,13 @@ typedef struct { int bits_per_elem; // 8 = byte-aligned, 6 = packed via fp6x4 uint8_t (*to_elem)(float); float (*to_float)(uint8_t); - float (*mse_error)(float val, float inv_scale, float scale); // NULL = use generic round-trip via to_elem/to_float } mxfp_elem_traits_t; static inline int best_index_mxfp4(float x, float e); -// MSE error for MXFP4 (kvalues are doubled, so scale is halved) -static float mse_error_mxfp4(float val, float inv_scale, float scale) { - const float d = scale * 0.5f; - const float inv_d = (d > 0.0f) ? 1.0f / d : 0.0f; - const float normalized = fabsf(val) * inv_d; - (void)inv_scale; - float qval; - if (normalized < 0.5f) qval = 0.0f; - else if (normalized < 1.5f) qval = 1.0f; - else if (normalized < 2.5f) qval = 2.0f; - else if (normalized < 3.5f) qval = 3.0f; - else if (normalized < 5.0f) qval = 4.0f; - else if (normalized < 7.0f) qval = 6.0f; - else if (normalized < 10.0f) qval = 8.0f; - else qval = 12.0f; - const float err = fabsf(val) - qval * d; - return err * err; -} -static const mxfp_elem_traits_t mxfp4_traits = { MXFP4_E2M1_EMAX_OFFSET, MXFP4_SOA_QS_PER_BLOCK, 4, NULL, NULL, mse_error_mxfp4 }; - -static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_elem_traits_t * traits) { +// E8M0 shared exponent: round(log2(amax)) — no MSE search needed. +static inline uint8_t mxfp_compute_e8m0(const float * x, int qk, int emax_offset) { float amax = 0.0f; for (int j = 0; j < qk; j++) { const float a = fabsf(x[j]); @@ -306,36 +284,8 @@ static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_ } if (amax == 0.0f) return 0; - const int e_base = ggml_mxfp_e8m0_base_estimate(amax, traits->emax_offset); - - int e_lo = e_base - MXFP_E8M0_MSE_RANGE; - int e_hi = e_base + MXFP_E8M0_MSE_RANGE; - if (e_lo < 1) e_lo = 1; - if (e_hi < 1) e_hi = 1; - if (e_hi > 254) e_hi = 254; - int best_e = e_base < 0 ? 0 : (e_base > 254 ? 254 : e_base); - float best_mse = 1e30f; - - for (int test_e = e_lo; test_e <= e_hi; ++test_e) { - const float test_scale = GGML_E8M0_TO_FP32((uint8_t)test_e); - const float test_inv = 1.0f / test_scale; - float mse = 0.0f; - for (int j = 0; j < qk; ++j) { - if (traits->mse_error) { - mse += traits->mse_error(x[j], test_inv, test_scale); - } else { - const float recon = traits->to_float(traits->to_elem(x[j] * test_inv)) * test_scale; - const float err = x[j] - recon; - mse += err * err; - } - } - if (mse < best_mse) { - best_mse = mse; - best_e = test_e; - } - } - - return (uint8_t)best_e; + const int e = ggml_mxfp_e8m0_base_estimate(amax, emax_offset); + return (uint8_t)(e < 0 ? 0 : (e > 254 ? 254 : e)); } static inline int best_index_mxfp4(float x, float e) { @@ -353,26 +303,112 @@ static inline int best_index_mxfp4(float x, float e) { return (x < 0.0f) ? (idx + 8) : idx; } -void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) { - static const int qk = QK_MXFP4; +// Per-block MXFP4 quantize: shared between AoS and SoA paths. +static inline void quantize_block_mxfp4(const float * GGML_RESTRICT src, uint8_t * GGML_RESTRICT qs, uint8_t * e_out) { + const uint8_t e = mxfp_compute_e8m0(src, QK_MXFP4, MXFP4_E2M1_EMAX_OFFSET); + const float d = GGML_E8M0_TO_FP32_HALF(e); + *e_out = e; + for (int j = 0; j < QK_MXFP4/2; ++j) { + const uint8_t x0 = best_index_mxfp4(src[0 + j], d); + const uint8_t x1 = best_index_mxfp4(src[QK_MXFP4/2 + j], d); + qs[j] = x0 | (x1 << 4); + } +} - assert(k % qk == 0); +// Per-block MXFP4 quantize round-trip: apply quantization error without materializing bytes. +// Used for Q preprocessing in flash attention — matches K's error pattern. +static inline void roundtrip_block_mxfp4(float * GGML_RESTRICT vals) { + const uint8_t e = mxfp_compute_e8m0(vals, QK_MXFP4, MXFP4_E2M1_EMAX_OFFSET); + const float d = GGML_E8M0_TO_FP32_HALF(e); + for (int j = 0; j < QK_MXFP4; ++j) { + const int idx = best_index_mxfp4(vals[j], d); + vals[j] = kvalues_mxfp4[idx] * d; // kvalues are doubled, d is halved — matches dequant + } +} - const int nb = k / qk; +// Per-block generic MXFP quantize round-trip (MXFP8/MXFP6). +static inline void roundtrip_block_mxfp(float * GGML_RESTRICT vals, const mxfp_elem_traits_t * traits) { + const uint8_t e = mxfp_compute_e8m0(vals, 32, traits->emax_offset); + const float d = GGML_E8M0_TO_FP32(e); + const float inv_d = d > 0.0f ? 1.0f / d : 0.0f; + for (int j = 0; j < 32; ++j) { + vals[j] = traits->to_float(traits->to_elem(vals[j] * inv_d)) * d; + } +} - for (int i = 0; i < nb; i++) { - const uint8_t e = mxfp_compute_e8m0_mse(&x[i*qk], qk, &mxfp4_traits); - const float d = GGML_E8M0_TO_FP32_HALF(e); +// Fused Hadamard + quantize round-trip: one pass, output is float with quantization error. +void mxfp4_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) { + assert(k % 32 == 0); + for (int64_t i = 0; i < k; i += 32) { + memcpy(dst + i, src + i, 32 * sizeof(float)); + ggml_mxfp_hadamard_32_inplace(dst + i); + roundtrip_block_mxfp4(dst + i); + } +} - y[i].e = e; +// Non-Hadamard round-trip for MXFP4 (Hadamard disabled or V cache). +void mxfp4_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) { + assert(k % 32 == 0); + for (int64_t i = 0; i < k; i += 32) { + memcpy(dst + i, src + i, 32 * sizeof(float)); + roundtrip_block_mxfp4(dst + i); + } +} - for (int j = 0; j < qk/2; ++j) { - const uint8_t x0 = best_index_mxfp4(x[i*qk + 0 + j], d); - const uint8_t x1 = best_index_mxfp4(x[i*qk + qk/2 + j], d); +// Per-block MXFP4 dequant: shared between AoS and SoA paths. +static inline void dequantize_block_mxfp4(const uint8_t * GGML_RESTRICT qs, uint8_t e, float * GGML_RESTRICT dst) { + const float d = GGML_E8M0_TO_FP32_HALF(e); + for (int j = 0; j < QK_MXFP4/2; ++j) { + dst[0 + j] = kvalues_mxfp4[qs[j] & 0x0F] * d; + dst[QK_MXFP4/2 + j] = kvalues_mxfp4[qs[j] >> 4] * d; + } +} - y[i].qs[j] = x0; - y[i].qs[j] |= x1 << 4; +// Per-block generic MXFP quantize/dequant: shared between AoS and SoA for MXFP8/MXFP6. +static inline void quantize_block_mxfp(const float * GGML_RESTRICT src, uint8_t * GGML_RESTRICT qs, + uint8_t * e_out, const mxfp_elem_traits_t * traits) { + const uint8_t e = mxfp_compute_e8m0(src, 32, traits->emax_offset); + const float d = GGML_E8M0_TO_FP32(e); + const float inv_d = d > 0.0f ? 1.0f / d : 0.0f; + *e_out = e; + if (traits->bits_per_elem == 8) { + for (int j = 0; j < 32; ++j) { + qs[j] = traits->to_elem(src[j] * inv_d); } + } else { + for (int j = 0; j < 32; j += 4) { + uint8_t vals[4]; + for (int jj = 0; jj < 4; jj++) { + vals[jj] = traits->to_elem(src[j + jj] * inv_d); + } + pack_fp6x4(vals, &qs[j * 3 / 4]); + } + } +} + +static inline void dequantize_block_mxfp(const uint8_t * GGML_RESTRICT qs, uint8_t e, + float * GGML_RESTRICT dst, const mxfp_elem_traits_t * traits) { + const float d = GGML_E8M0_TO_FP32(e); + if (traits->bits_per_elem == 8) { + for (int j = 0; j < 32; ++j) { + dst[j] = traits->to_float(qs[j]) * d; + } + } else { + for (int j = 0; j < 32; j += 4) { + uint8_t vals[4]; + unpack_fp6x4(&qs[j * 3 / 4], vals); + for (int jj = 0; jj < 4; jj++) { + dst[j + jj] = traits->to_float(vals[jj]) * d; + } + } + } +} + +void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) { + assert(k % QK_MXFP4 == 0); + const int nb = k / QK_MXFP4; + for (int i = 0; i < nb; i++) { + quantize_block_mxfp4(&x[i*QK_MXFP4], y[i].qs, &y[i].e); } } @@ -522,22 +558,10 @@ void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRI } void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { - static const int qk = QK_MXFP4; - - assert(k % qk == 0); - - const int nb = k / qk; - + assert(k % QK_MXFP4 == 0); + const int nb = k / QK_MXFP4; for (int i = 0; i < nb; i++) { - const float d = GGML_E8M0_TO_FP32_HALF(x[i].e); - - for (int j = 0; j < qk/2; ++j) { - const int8_t x0 = kvalues_mxfp4[x[i].qs[j] & 0x0F]; - const int8_t x1 = kvalues_mxfp4[x[i].qs[j] >> 4]; - - y[i*qk + j + 0 ] = x0*d; - y[i*qk + j + qk/2] = x1*d; - } + dequantize_block_mxfp4(x[i].qs, x[i].e, &y[i*QK_MXFP4]); } } @@ -582,112 +606,95 @@ uint8_t float_to_fp8_e5m2_rn(float x) { return ggml_mxfp_float_to_fp8_e5m2(x); } void pack_fp6x4(const uint8_t v[4], uint8_t out[3]) { ggml_mxfp_pack_fp6x4(v, out); } void unpack_fp6x4(const uint8_t in[3], uint8_t v[4]) { ggml_mxfp_unpack_fp6x4(in, v); } -static const mxfp_elem_traits_t mxfp8_e4m3_traits = { MXFP8_E4M3_EMAX_OFFSET, MXFP8_SOA_QS_PER_BLOCK, 8, float_to_fp8_e4m3_rn, fp8_e4m3_to_float, NULL }; -static const mxfp_elem_traits_t mxfp6_e2m3_traits = { MXFP6_E2M3_EMAX_OFFSET, MXFP6_SOA_QS_PER_BLOCK, 6, float_to_fp6_e2m3_rn, fp6_e2m3_to_float, NULL }; +static const mxfp_elem_traits_t mxfp8_e4m3_traits = { MXFP8_E4M3_EMAX_OFFSET, MXFP8_SOA_QS_PER_BLOCK, 8, float_to_fp8_e4m3_rn, fp8_e4m3_to_float }; +static const mxfp_elem_traits_t mxfp6_e2m3_traits = { MXFP6_E2M3_EMAX_OFFSET, MXFP6_SOA_QS_PER_BLOCK, 6, float_to_fp6_e2m3_rn, fp6_e2m3_to_float }; + +// MXFP8 AoS quantize/dequant — uses shared per-block helpers. +void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k) { + assert(k % QK_MXFP8 == 0); + const int nb = k / QK_MXFP8; + for (int i = 0; i < nb; i++) { + quantize_block_mxfp(&x[i*QK_MXFP8], y[i].qs, &y[i].e, &mxfp8_e4m3_traits); + } +} + +void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_MXFP8 == 0); + const int nb = k / QK_MXFP8; + for (int i = 0; i < nb; i++) { + dequantize_block_mxfp(x[i].qs, x[i].e, &y[i*QK_MXFP8], &mxfp8_e4m3_traits); + } +} + +// MXFP6 AoS quantize/dequant — uses shared per-block helpers. +void quantize_row_mxfp6_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k) { + assert(k % QK_MXFP6 == 0); + const int nb = k / QK_MXFP6; + for (int i = 0; i < nb; i++) { + quantize_block_mxfp(&x[i*QK_MXFP6], y[i].qs, &y[i].e, &mxfp6_e2m3_traits); + } +} + +void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_MXFP6 == 0); + const int nb = k / QK_MXFP6; + for (int i = 0; i < nb; i++) { + dequantize_block_mxfp(x[i].qs, x[i].e, &y[i*QK_MXFP6], &mxfp6_e2m3_traits); + } +} // ====================== SoA (Struct-of-Arrays) quantize/dequantize for flash attention void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) { assert(k % QK_MXFP4 == 0); const int nb = k / QK_MXFP4; - char * row = (char *)dst; - char * qs_base = row; - char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK); + char * qs_base = (char *)dst; + char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK); for (int i = 0; i < nb; i++) { - const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP4], QK_MXFP4, &mxfp4_traits); - const float d = GGML_E8M0_TO_FP32_HALF(e); - - e8m0_base[i] = (char)e; - uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP4_SOA_QS_PER_BLOCK)); - for (int j = 0; j < QK_MXFP4/2; ++j) { - const uint8_t x0 = best_index_mxfp4(x[i*QK_MXFP4 + 0 + j], d); - const uint8_t x1 = best_index_mxfp4(x[i*QK_MXFP4 + QK_MXFP4/2 + j], d); - qs[j] = x0 | (x1 << 4); - } + quantize_block_mxfp4(&x[i*QK_MXFP4], qs, (uint8_t *)&e8m0_base[i]); } } void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) { assert(k % QK_MXFP4 == 0); const int nb = k / QK_MXFP4; - const char * row = (const char *)src; - const char * qs_base = row; - const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK); + const char * qs_base = (const char *)src; + const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK); for (int i = 0; i < nb; i++) { - const float d = GGML_E8M0_TO_FP32_HALF((uint8_t)e8m0_base[i]); const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP4_SOA_QS_PER_BLOCK)); - - for (int j = 0; j < QK_MXFP4/2; ++j) { - const int8_t x0 = kvalues_mxfp4[qs[j] & 0x0F]; - const int8_t x1 = kvalues_mxfp4[qs[j] >> 4]; - y[i*QK_MXFP4 + j + 0 ] = x0*d; - y[i*QK_MXFP4 + j + QK_MXFP4/2] = x1*d; - } + dequantize_block_mxfp4(qs, (uint8_t)e8m0_base[i], &y[i*QK_MXFP4]); } } -// Unified SoA quantize for byte-aligned (FP8) and 6-bit packed (FP6) MXFP formats. +// Unified SoA quantize/dequantize — delegates to shared per-block helpers. static void quantize_row_mxfp_soa_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k, const mxfp_elem_traits_t * traits) { - const int qk = 32; - assert(k % qk == 0); - const int nb = k / qk; + assert(k % 32 == 0); + const int nb = k / 32; const int qpb = traits->qs_per_block; char * qs_base = (char *)dst; char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, qpb); for (int i = 0; i < nb; i++) { - const uint8_t e = mxfp_compute_e8m0_mse(&x[i*qk], qk, traits); - const float d = GGML_E8M0_TO_FP32(e); - const float inv_d = d > 0.0f ? 1.0f / d : 0.0f; - e8m0_base[i] = (char)e; - uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, qpb)); - if (traits->bits_per_elem == 8) { - for (int j = 0; j < qk; ++j) { - qs[j] = traits->to_elem(x[i*qk + j] * inv_d); - } - } else { - for (int j = 0; j < qk; j += 4) { - uint8_t vals[4]; - for (int jj = 0; jj < 4; jj++) { - vals[jj] = traits->to_elem(x[i*qk + j + jj] * inv_d); - } - pack_fp6x4(vals, &qs[j * 3 / 4]); - } - } + quantize_block_mxfp(&x[i*32], qs, (uint8_t *)&e8m0_base[i], traits); } } -// Unified SoA dequantize for byte-aligned (FP8) and 6-bit packed (FP6) MXFP formats. static void dequantize_row_mxfp_soa_impl(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, const mxfp_elem_traits_t * traits) { - const int qk = 32; - assert(k % qk == 0); - const int nb = k / qk; + assert(k % 32 == 0); + const int nb = k / 32; const int qpb = traits->qs_per_block; const char * qs_base = (const char *)src; const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, qpb); for (int i = 0; i < nb; i++) { - const float d = GGML_E8M0_TO_FP32((uint8_t)e8m0_base[i]); const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, qpb)); - if (traits->bits_per_elem == 8) { - for (int j = 0; j < qk; ++j) { - y[i*qk + j] = traits->to_float(qs[j]) * d; - } - } else { - for (int j = 0; j < qk; j += 4) { - uint8_t vals[4]; - unpack_fp6x4(&qs[j * 3 / 4], vals); - for (int jj = 0; jj < 4; jj++) { - y[i*qk + j + jj] = traits->to_float(vals[jj]) * d; - } - } - } + dequantize_block_mxfp(qs, (uint8_t)e8m0_base[i], &y[i*32], traits); } } @@ -703,6 +710,83 @@ void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) { dequantize_row_mxfp_soa_impl(src, y, k, &mxfp6_e2m3_traits); } + +// Fused Hadamard + SoA quantize: one read, one write, 32-float stack buffer per block. +// Eliminates the full-row temp buffer and extra memory pass. +void quantize_row_mxfp4_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) { + assert(k % QK_MXFP4 == 0); + const int nb = k / QK_MXFP4; + char * qs_base = (char *)dst; + char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK); + + for (int i = 0; i < nb; i++) { + float tmp[32]; + memcpy(tmp, &x[i*QK_MXFP4], QK_MXFP4 * sizeof(float)); + ggml_mxfp_hadamard_32_inplace(tmp); + uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP4_SOA_QS_PER_BLOCK)); + quantize_block_mxfp4(tmp, qs, (uint8_t *)&e8m0_base[i]); + } +} + +static void quantize_row_mxfp_soa_hadamard_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, + int64_t k, const mxfp_elem_traits_t * traits) { + assert(k % 32 == 0); + const int nb = k / 32; + const int qpb = traits->qs_per_block; + char * qs_base = (char *)dst; + char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, qpb); + + for (int i = 0; i < nb; i++) { + float tmp[32]; + memcpy(tmp, &x[i*32], 32 * sizeof(float)); + ggml_mxfp_hadamard_32_inplace(tmp); + uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, qpb)); + quantize_block_mxfp(tmp, qs, (uint8_t *)&e8m0_base[i], traits); + } +} + +void quantize_row_mxfp8_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) { + quantize_row_mxfp_soa_hadamard_impl(x, dst, k, &mxfp8_e4m3_traits); +} +void quantize_row_mxfp6_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) { + quantize_row_mxfp_soa_hadamard_impl(x, dst, k, &mxfp6_e2m3_traits); +} + +// MXFP8/6 quantize round-trips (with and without Hadamard). +void mxfp8_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) { + assert(k % 32 == 0); + for (int64_t i = 0; i < k; i += 32) { + memcpy(dst + i, src + i, 32 * sizeof(float)); + ggml_mxfp_hadamard_32_inplace(dst + i); + roundtrip_block_mxfp(dst + i, &mxfp8_e4m3_traits); + } +} + +void mxfp6_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) { + assert(k % 32 == 0); + for (int64_t i = 0; i < k; i += 32) { + memcpy(dst + i, src + i, 32 * sizeof(float)); + ggml_mxfp_hadamard_32_inplace(dst + i); + roundtrip_block_mxfp(dst + i, &mxfp6_e2m3_traits); + } +} + +void mxfp8_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) { + assert(k % 32 == 0); + for (int64_t i = 0; i < k; i += 32) { + memcpy(dst + i, src + i, 32 * sizeof(float)); + roundtrip_block_mxfp(dst + i, &mxfp8_e4m3_traits); + } +} + +void mxfp6_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k) { + assert(k % 32 == 0); + for (int64_t i = 0; i < k; i += 32) { + memcpy(dst + i, src + i, 32 * sizeof(float)); + roundtrip_block_mxfp(dst + i, &mxfp6_e2m3_traits); + } +} + // // 2-6 bit quantization in super-blocks // @@ -2373,6 +2457,18 @@ size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, return nrow * ggml_row_size(GGML_TYPE_NVFP4, n_per_row); } +size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_UNUSED(quant_weights); + quantize_row_mxfp8_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_MXFP8, n_per_row); +} + +size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_UNUSED(quant_weights); + quantize_row_mxfp6_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_MXFP6, n_per_row); +} + // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs) void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) { diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 4dec9ad351..d1cc8d4c85 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -22,6 +22,8 @@ GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_mxfp6_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); @@ -48,6 +50,8 @@ GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GG //GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); // SoA quantize/dequantize for flash attention GGML_API void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); @@ -56,6 +60,18 @@ GGML_API void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_ GGML_API void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); GGML_API void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); +// Fused Hadamard + SoA quantize (one pass, no temp buffer) +GGML_API void quantize_row_mxfp4_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); +GGML_API void quantize_row_mxfp8_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); +GGML_API void quantize_row_mxfp6_soa_hadamard(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); +// Quantize round-trip: apply quantization error to floats without materializing bytes. +// Hadamard variants include the rotation. Used for Q preprocessing in flash attention. +GGML_API void mxfp4_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k); +GGML_API void mxfp8_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k); +GGML_API void mxfp6_hadamard_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k); +GGML_API void mxfp4_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k); +GGML_API void mxfp8_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k); +GGML_API void mxfp6_roundtrip(const float * GGML_RESTRICT src, float * GGML_RESTRICT dst, int64_t k); GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -103,6 +119,8 @@ GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTR GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); // MXFP element converters GGML_API float fp8_e4m3_to_float(uint8_t v); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 21b9a81eae..470b68c4bc 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -727,16 +727,20 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .from_float_ref = (ggml_from_float_t)quantize_row_nvfp4_ref, }, [GGML_TYPE_MXFP8] = { - .type_name = "mxfp8_e4m3", + .type_name = "mxfp8", .blck_size = QK_MXFP8, .type_size = sizeof(block_mxfp8), .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_mxfp8, + .from_float_ref = (ggml_from_float_t)quantize_row_mxfp8_ref, }, [GGML_TYPE_MXFP6] = { - .type_name = "mxfp6_e2m3", + .type_name = "mxfp6", .blck_size = QK_MXFP6, .type_size = sizeof(block_mxfp6), .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_mxfp6, + .from_float_ref = (ggml_from_float_t)quantize_row_mxfp6_ref, }, [GGML_TYPE_Q2_K] = { .type_name = "q2_K", @@ -7693,8 +7697,8 @@ size_t ggml_quantize_chunk( case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_MXFP8: GGML_ABORT("MXFP8 is KV-cache-only (SoA layout) — use from_float_soa"); - case GGML_TYPE_MXFP6: GGML_ABORT("MXFP6 is KV-cache-only (SoA layout) — use from_float_soa"); + case GGML_TYPE_MXFP8: result = quantize_mxfp8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_MXFP6: result = quantize_mxfp6(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index a890510edf..4cf4ce69d1 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1121,8 +1121,9 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm ggml_tensor * result = ggml_set_rows(ctx, k_dst, k_cur, k_idxs); // enable Hadamard rotation for MXFP K cache (QuaRot arXiv:2404.00456, BRQ arXiv:2511.04214) - // skipped for MLA (V is a view of K) and E5M2/E3M2 (2-bit mantissa, no benefit) - if (is_mxfp && !hparams.is_mla() && ggml_mxfp_use_hadamard(k->type)) { + // skipped when DK != DV (MLA) and for E5M2/E3M2 (2-bit mantissa, no benefit). + // condition must match flash attention read path (ops.cpp: DK == DV). + if (is_mxfp && hparams.n_embd_head_k(il) == hparams.n_embd_head_v(il) && ggml_mxfp_use_hadamard(k->type)) { ((int32_t *)result->op_params)[0] = 1; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d102c5676c..281a9b65f4 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6390,7 +6390,7 @@ struct test_flash_attn_ext : public test_case { init_tensor_uniform(t, -10.0f, 10.0f); } else if (strcmp(t->name, "m") == 0) { init_tensor_kq_mask(t); - } else if ((strcmp(t->name, "k") == 0 || strcmp(t->name, "v") == 0) && ggml_is_type_mxfp(t->type)) { + } else if (ggml_is_type_mxfp(t->type)) { init_tensor_mxfp_soa(t); } else { init_tensor_uniform(t); @@ -7398,7 +7398,7 @@ static const ggml_type all_types[] = { GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0, - GGML_TYPE_MXFP4, + GGML_TYPE_MXFP4, GGML_TYPE_MXFP8, GGML_TYPE_MXFP6, GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K, @@ -7533,11 +7533,14 @@ static std::vector> make_test_cases_eval() { } // SET_ROWS with Hadamard rotation (exercises the op_params[0] flag used by MXFP KV cache) - for (ggml_type type : {GGML_TYPE_MXFP4, GGML_TYPE_MXFP8, - GGML_TYPE_MXFP6}) { + for (ggml_type type : {GGML_TYPE_MXFP4, GGML_TYPE_MXFP8, GGML_TYPE_MXFP6}) { // ne[0] must be divisible by 32 (Hadamard block size) test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 128, 5, 1, 1 }, { 1, 1 }, 1, false, true)); test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 5, 1, 3 }, { 1, 1 }, 1, false, true)); + // multi-row, broadcast, views + test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 128, 5, 1, 1 }, { 1, 1 }, 1, true, true)); + test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 11, 1, 1 }, { 2, 3 }, 7, false, true)); + test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 512, 5, 3, 1 }, { 1, 1 }, 1, false, true)); } for (int mode : { GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_VISION }) { diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index 8f1dcf10f0..babc9f58e1 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -33,7 +33,7 @@ constexpr float MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 = 0.0020f; // These represent actual RMSE through the full KV cache write/read path. constexpr float MAX_MXFP_PIPELINE_ERROR_MXFP4 = 0.40f; constexpr float MAX_MXFP_PIPELINE_ERROR_MXFP8 = 0.08f; -constexpr float MAX_MXFP_PIPELINE_ERROR_MXFP6 = 0.10f; +constexpr float MAX_MXFP_PIPELINE_ERROR_MXFP6 = 0.30f; constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f; constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f;