diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h index 2e13dd58ba..19c06a033d 100644 --- a/ggml/include/ggml-cpu.h +++ b/ggml/include/ggml-cpu.h @@ -115,10 +115,12 @@ extern "C" { struct ggml_type_traits_cpu { ggml_from_float_t from_float; - ggml_to_float_t to_float; // SIMD-optimized dequant (NULL = use global to_float) + ggml_to_float_t to_float; + ggml_from_float_t from_float_soa; // SoA quantize (MXFP flash attention layout) + ggml_to_float_t to_float_soa; // SoA dequant (MXFP flash attention layout) ggml_vec_dot_t vec_dot; enum ggml_type vec_dot_type; - int64_t nrows; // number of rows to process simultaneously + int64_t nrows; }; GGML_BACKEND_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 81b552ec78..6edf9909cf 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -427,6 +427,7 @@ extern "C" { // GGML_TYPE_IQ4_NL_4_8 = 37, // GGML_TYPE_IQ4_NL_8_8 = 38, GGML_TYPE_MXFP4_E2M1 = 39, // MX FP4 E2M1 + GGML_TYPE_MXFP4 = GGML_TYPE_MXFP4_E2M1, // compat alias GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale) GGML_TYPE_MXFP8_E4M3 = 41, // MX FP8 E4M3 GGML_TYPE_MXFP6_E2M3 = 42, // MX FP6 E2M3 diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 7b7fb1e5ce..9b8618423c 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -266,6 +266,8 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { }, [GGML_TYPE_MXFP4_E2M1] = { .from_float = quantize_row_mxfp4, + .from_float_soa = quantize_row_mxfp4_soa, + .to_float_soa = dequantize_row_mxfp4_soa_cpu, .vec_dot = ggml_vec_dot_mxfp4_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, @@ -279,6 +281,8 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { [GGML_TYPE_MXFP8_E4M3] = { .from_float = quantize_row_mxfp8, .to_float = dequantize_row_mxfp8_cpu, + .from_float_soa = quantize_row_mxfp8_soa, + .to_float_soa = dequantize_row_mxfp8_soa_cpu, .vec_dot = ggml_vec_dot_mxfp8_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, @@ -286,6 +290,8 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { [GGML_TYPE_MXFP6_E2M3] = { .from_float = quantize_row_mxfp6, .to_float = dequantize_row_mxfp6_cpu, + .from_float_soa = quantize_row_mxfp6_soa, + .to_float_soa = dequantize_row_mxfp6_soa_cpu, .vec_dot = ggml_vec_dot_mxfp6_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index a8f55efbed..cb1f881391 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -5028,18 +5028,9 @@ static void ggml_compute_forward_set_rows_f32( const int32_t apply_hadamard = ((const int32_t *)dst->op_params)[0]; - typedef void (*quantize_soa_fn)(const float *, void *, int64_t); - quantize_soa_fn mxfp_soa_quantize = nullptr; - ggml_from_float_t from_float = nullptr; - - switch (dst->type) { - case GGML_TYPE_MXFP4_E2M1: mxfp_soa_quantize = quantize_row_mxfp4_soa; break; - case GGML_TYPE_MXFP8_E4M3: mxfp_soa_quantize = quantize_row_mxfp8_soa; break; - case GGML_TYPE_MXFP6_E2M3: mxfp_soa_quantize = quantize_row_mxfp6_soa; break; - default: - from_float = ggml_get_type_traits_cpu(dst->type)->from_float; - break; - } + const struct ggml_type_traits_cpu * dst_traits = ggml_get_type_traits_cpu(dst->type); + ggml_from_float_t mxfp_soa_quantize = dst_traits->from_float_soa; + ggml_from_float_t from_float = mxfp_soa_quantize ? nullptr : dst_traits->from_float; std::vector had_tmp; if (apply_hadamard) { @@ -8300,21 +8291,13 @@ static mxfp_fa_params mxfp_fa_params_init( const bool is_mxfp_v = ggml_is_type_mxfp(v->type); if (is_mxfp_k) { - switch (k->type) { - case GGML_TYPE_MXFP4_E2M1: p.q_quantize = quantize_row_mxfp4_soa; p.k_dequantize = dequantize_row_mxfp4_soa_cpu; break; - case GGML_TYPE_MXFP8_E4M3: p.q_quantize = quantize_row_mxfp8_soa; p.k_dequantize = dequantize_row_mxfp8_soa_cpu; break; - case GGML_TYPE_MXFP6_E2M3: p.q_quantize = quantize_row_mxfp6_soa; p.k_dequantize = dequantize_row_mxfp6_soa_cpu; break; - default: GGML_ABORT("unsupported MXFP K type"); - } + const struct ggml_type_traits_cpu * k_traits = ggml_get_type_traits_cpu(k->type); + p.q_quantize = k_traits->from_float_soa; + p.k_dequantize = k_traits->to_float_soa; } if (is_mxfp_v) { - switch (v->type) { - case GGML_TYPE_MXFP4_E2M1: p.v_dequantize = dequantize_row_mxfp4_soa_cpu; break; - case GGML_TYPE_MXFP8_E4M3: p.v_dequantize = dequantize_row_mxfp8_soa_cpu; break; - case GGML_TYPE_MXFP6_E2M3: p.v_dequantize = dequantize_row_mxfp6_soa_cpu; break; - default: GGML_ABORT("unsupported MXFP V type"); - } + p.v_dequantize = ggml_get_type_traits_cpu(v->type)->to_float_soa; } // Hadamard rotation must match K rotation. diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index 0a7ea64135..c16e87a2e9 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -90,7 +90,10 @@ void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, void dequantize_row_mxfp8_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_mxfp6_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -// SoA dequant (SIMD-optimized for FA) +// SoA quantize/dequant for MXFP flash attention +void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); +void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); +void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_mxfp6_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 1a99711401..cca3d99c82 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -5639,6 +5639,14 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb); } break; + case GGML_TYPE_MXFP8_E4M3: + { + VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp8, data, nb); + } break; + case GGML_TYPE_MXFP6_E2M3: + { + VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp6, data, nb); + } break; case GGML_TYPE_NVFP4: { // UE4M3 scales are uint8_t — all byte values are valid diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0e48e9e354..7f47079835 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -150,29 +151,15 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m } } -// MXFP SoA functions (internal to ggml, not in test include path) typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t); -extern "C" { - void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); - void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); - void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); - void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); - void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); - void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); -} // Initialize an MXFP tensor with SoA layout (soa_bytes = region width, 0 = one row). static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f, size_t soa_bytes = 0) { GGML_ASSERT(ggml_is_type_mxfp(tensor->type)); - typedef void (*soa_quantize_fn)(const float *, void *, int64_t); - soa_quantize_fn quantize_soa = nullptr; - switch (tensor->type) { - case GGML_TYPE_MXFP4_E2M1: quantize_soa = quantize_row_mxfp4_soa; break; - case GGML_TYPE_MXFP8_E4M3: quantize_soa = quantize_row_mxfp8_soa; break; - case GGML_TYPE_MXFP6_E2M3: quantize_soa = quantize_row_mxfp6_soa; break; - default: GGML_ABORT("unsupported MXFP type for SoA init"); - } + const auto * traits = ggml_get_type_traits_cpu(tensor->type); + GGML_ASSERT(traits->from_float_soa && "MXFP type missing SoA quantize in traits"); + auto quantize_soa = traits->from_float_soa; const int qk = (int)ggml_blck_size(tensor->type); const size_t block_size = ggml_type_size(tensor->type); @@ -318,12 +305,8 @@ static std::vector tensor_to_float(const ggml_tensor * t) { mxfp_soa_dequantize_fn mxfp_dequant_soa = nullptr; if (is_mxfp) { - switch (t->type) { - case GGML_TYPE_MXFP4_E2M1: mxfp_dequant_soa = dequantize_row_mxfp4_soa; break; - case GGML_TYPE_MXFP8_E4M3: mxfp_dequant_soa = dequantize_row_mxfp8_soa; break; - case GGML_TYPE_MXFP6_E2M3: mxfp_dequant_soa = dequantize_row_mxfp6_soa; break; - default: GGML_ABORT("unsupported MXFP type in tensor_to_float"); - } + mxfp_dequant_soa = (mxfp_soa_dequantize_fn) ggml_get_type_traits_cpu(t->type)->to_float_soa; + GGML_ASSERT(mxfp_dequant_soa && "MXFP type missing SoA dequant in traits"); } // access elements by index to avoid gaps in views diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index a8fb192623..ca2f4a2994 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -21,9 +21,13 @@ constexpr float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS = 0.0050f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR_FP4 = 0.0030f; +constexpr float MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 = 0.0070f; +constexpr float MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 = 0.0040f; +constexpr float MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 = 0.0020f; constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f; constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f; constexpr float MAX_DOT_PRODUCT_ERROR_FP4 = 0.03f; +constexpr float MAX_DOT_PRODUCT_ERROR_MXFP = 0.04f; constexpr float MAX_DOT_PRODUCT_ERROR_TERNARY = 0.15f; static const char* RESULT_STR[] = {"ok", "FAILED"}; @@ -152,7 +156,10 @@ int main(int argc, char * argv[]) { type == GGML_TYPE_Q3_K ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS : type == GGML_TYPE_IQ3_S ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS : type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS : - type == GGML_TYPE_NVFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_FP4 : MAX_QUANTIZATION_TOTAL_ERROR; + type == GGML_TYPE_NVFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_FP4 : + type == GGML_TYPE_MXFP4_E2M1 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 : + type == GGML_TYPE_MXFP6_E2M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 : + type == GGML_TYPE_MXFP8_E4M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR; failed = !(total_error < max_quantization_error); num_failed += failed; if (failed || verbose) { @@ -174,6 +181,8 @@ int main(int argc, char * argv[]) { ? MAX_DOT_PRODUCT_ERROR_TERNARY : type == GGML_TYPE_NVFP4 ? MAX_DOT_PRODUCT_ERROR_FP4 + : type == GGML_TYPE_MXFP4_E2M1 || type == GGML_TYPE_MXFP6_E2M3 || type == GGML_TYPE_MXFP8_E4M3 + ? MAX_DOT_PRODUCT_ERROR_MXFP : MAX_DOT_PRODUCT_ERROR; failed = !(vec_dot_error < max_allowed_error); num_failed += failed; @@ -183,6 +192,34 @@ int main(int argc, char * argv[]) { } } + // MXFP SoA roundtrip: test from_float_soa → to_float_soa through the traits system + for (int i = 0; i < GGML_TYPE_COUNT; i++) { + ggml_type type = (ggml_type) i; + const auto * qfns_cpu = ggml_get_type_traits_cpu(type); + + if (!qfns_cpu->from_float_soa || !qfns_cpu->to_float_soa) { + continue; + } + + const size_t buf_size = ggml_row_size(type, test_size); + std::vector tmp_q(buf_size); + std::vector tmp_out(test_size); + + qfns_cpu->from_float_soa(test_data.data(), tmp_q.data(), test_size); + qfns_cpu->to_float_soa(tmp_q.data(), tmp_out.data(), test_size); + + const float soa_error = array_rmse(test_data.data(), tmp_out.data(), test_size); + const float max_soa_error = + type == GGML_TYPE_MXFP4_E2M1 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 : + type == GGML_TYPE_MXFP6_E2M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 : + type == GGML_TYPE_MXFP8_E4M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR; + failed = !(soa_error < max_soa_error); + num_failed += failed; + if (failed || verbose) { + printf("%5s SoA quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], soa_error); + } + } + if (num_failed || verbose) { printf("%d tests failed\n", num_failed); }