mxfp traits : ensure mxfp soa quant and dequant functions are tested
This commit is contained in:
parent
5bb05ed21c
commit
dd263ff567
|
|
@ -115,10 +115,12 @@ extern "C" {
|
|||
|
||||
struct ggml_type_traits_cpu {
|
||||
ggml_from_float_t from_float;
|
||||
ggml_to_float_t to_float; // SIMD-optimized dequant (NULL = use global to_float)
|
||||
ggml_to_float_t to_float;
|
||||
ggml_from_float_t from_float_soa; // SoA quantize (MXFP flash attention layout)
|
||||
ggml_to_float_t to_float_soa; // SoA dequant (MXFP flash attention layout)
|
||||
ggml_vec_dot_t vec_dot;
|
||||
enum ggml_type vec_dot_type;
|
||||
int64_t nrows; // number of rows to process simultaneously
|
||||
int64_t nrows;
|
||||
};
|
||||
|
||||
GGML_BACKEND_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type);
|
||||
|
|
|
|||
|
|
@ -427,6 +427,7 @@ extern "C" {
|
|||
// GGML_TYPE_IQ4_NL_4_8 = 37,
|
||||
// GGML_TYPE_IQ4_NL_8_8 = 38,
|
||||
GGML_TYPE_MXFP4_E2M1 = 39, // MX FP4 E2M1
|
||||
GGML_TYPE_MXFP4 = GGML_TYPE_MXFP4_E2M1, // compat alias
|
||||
GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale)
|
||||
GGML_TYPE_MXFP8_E4M3 = 41, // MX FP8 E4M3
|
||||
GGML_TYPE_MXFP6_E2M3 = 42, // MX FP6 E2M3
|
||||
|
|
|
|||
|
|
@ -266,6 +266,8 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
|||
},
|
||||
[GGML_TYPE_MXFP4_E2M1] = {
|
||||
.from_float = quantize_row_mxfp4,
|
||||
.from_float_soa = quantize_row_mxfp4_soa,
|
||||
.to_float_soa = dequantize_row_mxfp4_soa_cpu,
|
||||
.vec_dot = ggml_vec_dot_mxfp4_q8_0,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
|
|
@ -279,6 +281,8 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
|||
[GGML_TYPE_MXFP8_E4M3] = {
|
||||
.from_float = quantize_row_mxfp8,
|
||||
.to_float = dequantize_row_mxfp8_cpu,
|
||||
.from_float_soa = quantize_row_mxfp8_soa,
|
||||
.to_float_soa = dequantize_row_mxfp8_soa_cpu,
|
||||
.vec_dot = ggml_vec_dot_mxfp8_q8_0,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
|
|
@ -286,6 +290,8 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
|||
[GGML_TYPE_MXFP6_E2M3] = {
|
||||
.from_float = quantize_row_mxfp6,
|
||||
.to_float = dequantize_row_mxfp6_cpu,
|
||||
.from_float_soa = quantize_row_mxfp6_soa,
|
||||
.to_float_soa = dequantize_row_mxfp6_soa_cpu,
|
||||
.vec_dot = ggml_vec_dot_mxfp6_q8_0,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
|
|
|
|||
|
|
@ -5028,18 +5028,9 @@ static void ggml_compute_forward_set_rows_f32(
|
|||
|
||||
const int32_t apply_hadamard = ((const int32_t *)dst->op_params)[0];
|
||||
|
||||
typedef void (*quantize_soa_fn)(const float *, void *, int64_t);
|
||||
quantize_soa_fn mxfp_soa_quantize = nullptr;
|
||||
ggml_from_float_t from_float = nullptr;
|
||||
|
||||
switch (dst->type) {
|
||||
case GGML_TYPE_MXFP4_E2M1: mxfp_soa_quantize = quantize_row_mxfp4_soa; break;
|
||||
case GGML_TYPE_MXFP8_E4M3: mxfp_soa_quantize = quantize_row_mxfp8_soa; break;
|
||||
case GGML_TYPE_MXFP6_E2M3: mxfp_soa_quantize = quantize_row_mxfp6_soa; break;
|
||||
default:
|
||||
from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
|
||||
break;
|
||||
}
|
||||
const struct ggml_type_traits_cpu * dst_traits = ggml_get_type_traits_cpu(dst->type);
|
||||
ggml_from_float_t mxfp_soa_quantize = dst_traits->from_float_soa;
|
||||
ggml_from_float_t from_float = mxfp_soa_quantize ? nullptr : dst_traits->from_float;
|
||||
|
||||
std::vector<float> had_tmp;
|
||||
if (apply_hadamard) {
|
||||
|
|
@ -8300,21 +8291,13 @@ static mxfp_fa_params mxfp_fa_params_init(
|
|||
const bool is_mxfp_v = ggml_is_type_mxfp(v->type);
|
||||
|
||||
if (is_mxfp_k) {
|
||||
switch (k->type) {
|
||||
case GGML_TYPE_MXFP4_E2M1: p.q_quantize = quantize_row_mxfp4_soa; p.k_dequantize = dequantize_row_mxfp4_soa_cpu; break;
|
||||
case GGML_TYPE_MXFP8_E4M3: p.q_quantize = quantize_row_mxfp8_soa; p.k_dequantize = dequantize_row_mxfp8_soa_cpu; break;
|
||||
case GGML_TYPE_MXFP6_E2M3: p.q_quantize = quantize_row_mxfp6_soa; p.k_dequantize = dequantize_row_mxfp6_soa_cpu; break;
|
||||
default: GGML_ABORT("unsupported MXFP K type");
|
||||
}
|
||||
const struct ggml_type_traits_cpu * k_traits = ggml_get_type_traits_cpu(k->type);
|
||||
p.q_quantize = k_traits->from_float_soa;
|
||||
p.k_dequantize = k_traits->to_float_soa;
|
||||
}
|
||||
|
||||
if (is_mxfp_v) {
|
||||
switch (v->type) {
|
||||
case GGML_TYPE_MXFP4_E2M1: p.v_dequantize = dequantize_row_mxfp4_soa_cpu; break;
|
||||
case GGML_TYPE_MXFP8_E4M3: p.v_dequantize = dequantize_row_mxfp8_soa_cpu; break;
|
||||
case GGML_TYPE_MXFP6_E2M3: p.v_dequantize = dequantize_row_mxfp6_soa_cpu; break;
|
||||
default: GGML_ABORT("unsupported MXFP V type");
|
||||
}
|
||||
p.v_dequantize = ggml_get_type_traits_cpu(v->type)->to_float_soa;
|
||||
}
|
||||
|
||||
// Hadamard rotation must match K rotation.
|
||||
|
|
|
|||
|
|
@ -90,7 +90,10 @@ void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||
void dequantize_row_mxfp8_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
void dequantize_row_mxfp6_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
// SoA dequant (SIMD-optimized for FA)
|
||||
// SoA quantize/dequant for MXFP flash attention
|
||||
void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
|
||||
void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
|
||||
void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
|
||||
void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
void dequantize_row_mxfp6_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
|
|
|
|||
|
|
@ -5639,6 +5639,14 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
|
|||
{
|
||||
VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb);
|
||||
} break;
|
||||
case GGML_TYPE_MXFP8_E4M3:
|
||||
{
|
||||
VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp8, data, nb);
|
||||
} break;
|
||||
case GGML_TYPE_MXFP6_E2M3:
|
||||
{
|
||||
VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp6, data, nb);
|
||||
} break;
|
||||
case GGML_TYPE_NVFP4:
|
||||
{
|
||||
// UE4M3 scales are uint8_t — all byte values are valid
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@
|
|||
#include <ggml.h>
|
||||
#include <ggml-alloc.h>
|
||||
#include <ggml-backend.h>
|
||||
#include <ggml-cpu.h>
|
||||
#include <ggml-cpp.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
|
@ -150,29 +151,15 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
|
|||
}
|
||||
}
|
||||
|
||||
// MXFP SoA functions (internal to ggml, not in test include path)
|
||||
typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t);
|
||||
extern "C" {
|
||||
void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
|
||||
void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
|
||||
void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
|
||||
void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
|
||||
void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
|
||||
void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
|
||||
}
|
||||
|
||||
// Initialize an MXFP tensor with SoA layout (soa_bytes = region width, 0 = one row).
|
||||
static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f, size_t soa_bytes = 0) {
|
||||
GGML_ASSERT(ggml_is_type_mxfp(tensor->type));
|
||||
|
||||
typedef void (*soa_quantize_fn)(const float *, void *, int64_t);
|
||||
soa_quantize_fn quantize_soa = nullptr;
|
||||
switch (tensor->type) {
|
||||
case GGML_TYPE_MXFP4_E2M1: quantize_soa = quantize_row_mxfp4_soa; break;
|
||||
case GGML_TYPE_MXFP8_E4M3: quantize_soa = quantize_row_mxfp8_soa; break;
|
||||
case GGML_TYPE_MXFP6_E2M3: quantize_soa = quantize_row_mxfp6_soa; break;
|
||||
default: GGML_ABORT("unsupported MXFP type for SoA init");
|
||||
}
|
||||
const auto * traits = ggml_get_type_traits_cpu(tensor->type);
|
||||
GGML_ASSERT(traits->from_float_soa && "MXFP type missing SoA quantize in traits");
|
||||
auto quantize_soa = traits->from_float_soa;
|
||||
|
||||
const int qk = (int)ggml_blck_size(tensor->type);
|
||||
const size_t block_size = ggml_type_size(tensor->type);
|
||||
|
|
@ -318,12 +305,8 @@ static std::vector<float> tensor_to_float(const ggml_tensor * t) {
|
|||
|
||||
mxfp_soa_dequantize_fn mxfp_dequant_soa = nullptr;
|
||||
if (is_mxfp) {
|
||||
switch (t->type) {
|
||||
case GGML_TYPE_MXFP4_E2M1: mxfp_dequant_soa = dequantize_row_mxfp4_soa; break;
|
||||
case GGML_TYPE_MXFP8_E4M3: mxfp_dequant_soa = dequantize_row_mxfp8_soa; break;
|
||||
case GGML_TYPE_MXFP6_E2M3: mxfp_dequant_soa = dequantize_row_mxfp6_soa; break;
|
||||
default: GGML_ABORT("unsupported MXFP type in tensor_to_float");
|
||||
}
|
||||
mxfp_dequant_soa = (mxfp_soa_dequantize_fn) ggml_get_type_traits_cpu(t->type)->to_float_soa;
|
||||
GGML_ASSERT(mxfp_dequant_soa && "MXFP type missing SoA dequant in traits");
|
||||
}
|
||||
|
||||
// access elements by index to avoid gaps in views
|
||||
|
|
|
|||
|
|
@ -21,9 +21,13 @@ constexpr float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f;
|
|||
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f;
|
||||
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS = 0.0050f;
|
||||
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_FP4 = 0.0030f;
|
||||
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 = 0.0070f;
|
||||
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 = 0.0040f;
|
||||
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 = 0.0020f;
|
||||
constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f;
|
||||
constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f;
|
||||
constexpr float MAX_DOT_PRODUCT_ERROR_FP4 = 0.03f;
|
||||
constexpr float MAX_DOT_PRODUCT_ERROR_MXFP = 0.04f;
|
||||
constexpr float MAX_DOT_PRODUCT_ERROR_TERNARY = 0.15f;
|
||||
|
||||
static const char* RESULT_STR[] = {"ok", "FAILED"};
|
||||
|
|
@ -152,7 +156,10 @@ int main(int argc, char * argv[]) {
|
|||
type == GGML_TYPE_Q3_K ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
|
||||
type == GGML_TYPE_IQ3_S ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
|
||||
type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS :
|
||||
type == GGML_TYPE_NVFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_FP4 : MAX_QUANTIZATION_TOTAL_ERROR;
|
||||
type == GGML_TYPE_NVFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_FP4 :
|
||||
type == GGML_TYPE_MXFP4_E2M1 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 :
|
||||
type == GGML_TYPE_MXFP6_E2M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 :
|
||||
type == GGML_TYPE_MXFP8_E4M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR;
|
||||
failed = !(total_error < max_quantization_error);
|
||||
num_failed += failed;
|
||||
if (failed || verbose) {
|
||||
|
|
@ -174,6 +181,8 @@ int main(int argc, char * argv[]) {
|
|||
? MAX_DOT_PRODUCT_ERROR_TERNARY
|
||||
: type == GGML_TYPE_NVFP4
|
||||
? MAX_DOT_PRODUCT_ERROR_FP4
|
||||
: type == GGML_TYPE_MXFP4_E2M1 || type == GGML_TYPE_MXFP6_E2M3 || type == GGML_TYPE_MXFP8_E4M3
|
||||
? MAX_DOT_PRODUCT_ERROR_MXFP
|
||||
: MAX_DOT_PRODUCT_ERROR;
|
||||
failed = !(vec_dot_error < max_allowed_error);
|
||||
num_failed += failed;
|
||||
|
|
@ -183,6 +192,34 @@ int main(int argc, char * argv[]) {
|
|||
}
|
||||
}
|
||||
|
||||
// MXFP SoA roundtrip: test from_float_soa → to_float_soa through the traits system
|
||||
for (int i = 0; i < GGML_TYPE_COUNT; i++) {
|
||||
ggml_type type = (ggml_type) i;
|
||||
const auto * qfns_cpu = ggml_get_type_traits_cpu(type);
|
||||
|
||||
if (!qfns_cpu->from_float_soa || !qfns_cpu->to_float_soa) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const size_t buf_size = ggml_row_size(type, test_size);
|
||||
std::vector<uint8_t> tmp_q(buf_size);
|
||||
std::vector<float> tmp_out(test_size);
|
||||
|
||||
qfns_cpu->from_float_soa(test_data.data(), tmp_q.data(), test_size);
|
||||
qfns_cpu->to_float_soa(tmp_q.data(), tmp_out.data(), test_size);
|
||||
|
||||
const float soa_error = array_rmse(test_data.data(), tmp_out.data(), test_size);
|
||||
const float max_soa_error =
|
||||
type == GGML_TYPE_MXFP4_E2M1 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 :
|
||||
type == GGML_TYPE_MXFP6_E2M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 :
|
||||
type == GGML_TYPE_MXFP8_E4M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR;
|
||||
failed = !(soa_error < max_soa_error);
|
||||
num_failed += failed;
|
||||
if (failed || verbose) {
|
||||
printf("%5s SoA quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], soa_error);
|
||||
}
|
||||
}
|
||||
|
||||
if (num_failed || verbose) {
|
||||
printf("%d tests failed\n", num_failed);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue