cleanup : remove unused untested code and improve consistency
* cleanup: consolidate MXFP type aliases, fix SoA linker bug on 5 platforms
- Add GGML_TYPE_MXFP8 and GGML_TYPE_MXFP6 short aliases (matching
existing GGML_TYPE_MXFP4 pattern) and use short names consistently
throughout the codebase instead of mixing long/short forms.
- Fix missing SoA dequant symbols (dequantize_row_mxfp{4,8,6}_soa_cpu)
on loongarch, powerpc, riscv, s390, and wasm by adding proper aliases
to each arch section in arch-fallback.h. Previously these were only
defined under GGML_CPU_GENERIC, causing linker failures on those
platforms when using MXFP flash attention.
- Remove 10 files from the PR diff:
- 5 arch stub files replaced by arch-fallback.h aliases
- 5 rename-only files (sycl, opencl, repack, llama-quant) reverted
since the GGML_TYPE_MXFP4 compat alias handles them
* cleanup: DRY FP6 unpack, extract mxfp_kv_params + mxfp_dequant_head helper
- FP6 unpack: x86 and ARM SIMD versions now call ggml_mxfp_unpack_fp6x4()
from ggml-common.h instead of duplicating the scalar bit manipulation.
- Extract mxfp_kv_params sub-struct from mxfp_fa_params: the 7 symmetric
K/V fields (dequantize, multihead, soa_elems, qs_per_block,
head_qs_bytes, head_e8m0_offset, blocks_per_head) are now in a reusable
struct accessed as mxfp.k and mxfp.v.
- Add mxfp_dequant_head() helper: replaces 4 instances of the multihead
SoA extraction pattern (2x memcpy + dequant, with multihead/single-head
branching) with a single function call. Future backends get the pattern
for free.
* cleanup: extract mxfp_kv_params_init to DRY the K/V init blocks
The K and V initialization in mxfp_fa_params_init were structurally
identical 10-line blocks differing only by tensor/dimension. Extract
into mxfp_kv_params_init(type, D, nb2, ne2) so future MXFP formats
get the multihead SoA addressing logic automatically.
* cleanup: generic MSE round-trip, replace magic buffer sizes with constants
- Remove mse_error_fp8_e4m3 and mse_error_fp6_e2m3: these were identical
round-trip functions differing only by converter. mxfp_compute_e8m0_mse
now uses to_elem/to_float directly when mse_error is NULL (FP8/FP6).
MXFP4 keeps its custom decision-tree MSE. New formats get MSE for free
by just setting to_elem/to_float in their traits.
- Replace magic 1024/1088 buffer sizes in flash attention with named
constants MXFP_FA_MAX_D and MXFP_FA_SOA_BUF. One place to change if
max head dimension grows.
* cleanup: remove dead AoS vec_dot for MXFP8/MXFP6, unify SoA impls
MXFP8 and MXFP6 are KV-cache-only types that use SoA layout for flash
attention. The AoS vec_dot functions (scalar generic, AVX2, NEON) were
dead code — no matmul path uses them.
Removed:
- ggml_vec_dot_mxfp{8,6}_q8_0 from scalar, x86, ARM, quants.h
- ggml_vec_dot_mxfp_q8_0_impl shared helper
- arch-fallback.h aliases for vec_dot mxfp8/mxfp6 (12 lines)
- vec_dot/vec_dot_type registration in ggml-cpu.c
Also unified SoA quantize/dequant: the separate mxfp8_soa_impl and
mxfp6_soa_impl functions (4 functions, ~80 lines) are replaced by two
generic functions (quantize_row_mxfp_soa_impl, dequantize_row_mxfp_soa_impl)
that use traits->bits_per_elem and traits->qs_per_block to handle both
byte-aligned (FP8) and 6-bit packed (FP6) formats. New MXFP formats
get SoA for free by setting these trait fields.
* cleanup: remove all AoS MXFP8/MXFP6 quantize/dequant — SoA only
MXFP8 and MXFP6 are KV-cache-only types. All quantization and
dequantization goes through the SoA (Struct-of-Arrays) path for flash
attention. The AoS (block_mxfp8/block_mxfp6 struct) implementations
were dead code that should never have been added.
Removed:
- quantize_row_mxfp{8,6}_impl, dequantize_row_mxfp{8,6}_impl
- quantize_row_mxfp{8,6}_ref, dequantize_row_mxfp{8,6}
- quantize_mxfp{8,6} (ggml_quantize_chunk wrappers)
- All declarations from ggml-quants.h and quants.h
- to_float/from_float_ref registrations from ggml.c type traits
- from_float registration from ggml-cpu.c CPU traits
Block struct definitions (block_mxfp8, block_mxfp6) are retained for
sizeof() in type traits and validate_row_data.
* cleanup: fail fast in ggml_quantize_chunk for KV-cache-only types
Add explicit GGML_ABORT for MXFP8/MXFP6 in ggml_quantize_chunk —
these are KV-cache-only types that use SoA layout via from_float_soa.
Attempting AoS quantization through this entry point is a bug.
This commit is contained in:
parent
0e3304fbca
commit
c919bc471b
|
|
@ -398,20 +398,20 @@ const std::vector<ggml_type> kv_cache_types = {
|
|||
GGML_TYPE_IQ4_NL,
|
||||
GGML_TYPE_Q5_0,
|
||||
GGML_TYPE_Q5_1,
|
||||
GGML_TYPE_MXFP4_E2M1,
|
||||
GGML_TYPE_MXFP8_E4M3,
|
||||
GGML_TYPE_MXFP6_E2M3,
|
||||
GGML_TYPE_MXFP4,
|
||||
GGML_TYPE_MXFP8,
|
||||
GGML_TYPE_MXFP6,
|
||||
};
|
||||
|
||||
static ggml_type kv_cache_type_from_str(const std::string & s) {
|
||||
if (s == "mxfp4") {
|
||||
return GGML_TYPE_MXFP4_E2M1;
|
||||
return GGML_TYPE_MXFP4;
|
||||
}
|
||||
if (s == "mxfp6") {
|
||||
return GGML_TYPE_MXFP6_E2M3;
|
||||
return GGML_TYPE_MXFP6;
|
||||
}
|
||||
if (s == "mxfp8") {
|
||||
return GGML_TYPE_MXFP8_E4M3;
|
||||
return GGML_TYPE_MXFP8;
|
||||
}
|
||||
for (const auto & type : kv_cache_types) {
|
||||
if (ggml_type_name(type) == s) {
|
||||
|
|
|
|||
|
|
@ -430,7 +430,9 @@ extern "C" {
|
|||
GGML_TYPE_MXFP4 = GGML_TYPE_MXFP4_E2M1, // compat alias
|
||||
GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale)
|
||||
GGML_TYPE_MXFP8_E4M3 = 41, // MX FP8 E4M3
|
||||
GGML_TYPE_MXFP8 = GGML_TYPE_MXFP8_E4M3, // compat alias
|
||||
GGML_TYPE_MXFP6_E2M3 = 42, // MX FP6 E2M3
|
||||
GGML_TYPE_MXFP6 = GGML_TYPE_MXFP6_E2M3, // compat alias
|
||||
GGML_TYPE_COUNT = 43,
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -15,8 +15,9 @@
|
|||
#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
|
||||
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
||||
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
|
||||
#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0
|
||||
#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0
|
||||
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
|
||||
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
|
||||
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
|
||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
|
||||
|
|
@ -113,6 +114,9 @@
|
|||
// quants.c
|
||||
#define quantize_row_q8_K_generic quantize_row_q8_K
|
||||
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
|
||||
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
|
||||
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
|
||||
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
|
||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
||||
|
|
@ -161,6 +165,9 @@
|
|||
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
||||
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
||||
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
|
||||
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
|
||||
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
|
||||
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||
|
|
@ -201,6 +208,9 @@
|
|||
#elif defined(__riscv)
|
||||
// quants.c
|
||||
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
|
||||
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
|
||||
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
|
||||
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
|
|
@ -241,6 +251,9 @@
|
|||
// quants.c
|
||||
#define quantize_row_q8_K_generic quantize_row_q8_K
|
||||
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
|
||||
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
|
||||
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
|
||||
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
|
||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
|
||||
|
|
@ -291,6 +304,9 @@
|
|||
#elif defined(__wasm__)
|
||||
// quants.c
|
||||
#define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1
|
||||
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
|
||||
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
|
||||
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
|
||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||
#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K
|
||||
|
|
@ -342,10 +358,3 @@
|
|||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||
#endif
|
||||
|
||||
// MXFP dequantize fallbacks (same GGML_CPU_GENERIC guard as above)
|
||||
#if defined(GGML_CPU_GENERIC)
|
||||
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
|
||||
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
|
||||
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -4191,12 +4191,8 @@ static inline float32x4_t mxfp6_dequant_neon(
|
|||
|
||||
// Unpack 4 tightly-packed 6-bit values from 3 bytes, widen to uint32x4_t.
|
||||
static inline uint32x4_t unpack_fp6x4_neon(const uint8_t * p) {
|
||||
const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
|
||||
uint8_t u[4];
|
||||
u[0] = (pk >> 0) & 0x3F;
|
||||
u[1] = (pk >> 6) & 0x3F;
|
||||
u[2] = (pk >> 12) & 0x3F;
|
||||
u[3] = (pk >> 18) & 0x3F;
|
||||
ggml_mxfp_unpack_fp6x4(p, u);
|
||||
const uint8x8_t raw8 = vcreate_u8(
|
||||
(uint64_t)u[0] | ((uint64_t)u[1] << 8) |
|
||||
((uint64_t)u[2] << 16) | ((uint64_t)u[3] << 24));
|
||||
|
|
@ -4221,96 +4217,6 @@ static inline void widen_s8x8_to_f32x4x2(const int8_t * src,
|
|||
*hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16)));
|
||||
}
|
||||
|
||||
// MXFP FP8/FP6 vec_dot
|
||||
|
||||
static void ggml_vec_dot_mxfp8_q8_0_neon(
|
||||
int n, float * GGML_RESTRICT s,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
const mxfp_neon_traits_t * t) {
|
||||
assert(n % QK_MXFP8 == 0);
|
||||
const int nb = n / QK_MXFP8;
|
||||
const block_mxfp8 * GGML_RESTRICT x = vx;
|
||||
const block_q8_0 * GGML_RESTRICT y = vy;
|
||||
|
||||
const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask);
|
||||
const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask);
|
||||
const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off);
|
||||
const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale);
|
||||
const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift);
|
||||
const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift);
|
||||
|
||||
float32x4_t acc0 = vdupq_n_f32(0.0f);
|
||||
float32x4_t acc1 = vdupq_n_f32(0.0f);
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const float32x4_t v_scale = vdupq_n_f32(
|
||||
GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d));
|
||||
|
||||
for (int j = 0; j < 32; j += 8) {
|
||||
uint32x4_t v_lo, v_hi;
|
||||
widen_u8x8_to_u32x4x2(x[ib].qs + j, &v_lo, &v_hi);
|
||||
|
||||
float32x4_t qf_lo, qf_hi;
|
||||
widen_s8x8_to_f32x4x2(y[ib].qs + j, &qf_lo, &qf_hi);
|
||||
|
||||
const float32x4_t val_lo = mxfp8_dequant_neon(v_lo,
|
||||
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
|
||||
const float32x4_t val_hi = mxfp8_dequant_neon(v_hi,
|
||||
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
|
||||
|
||||
acc0 = vfmaq_f32(acc0, vmulq_f32(val_lo, v_scale), qf_lo);
|
||||
acc1 = vfmaq_f32(acc1, vmulq_f32(val_hi, v_scale), qf_hi);
|
||||
}
|
||||
}
|
||||
|
||||
*s = vaddvq_f32(vaddq_f32(acc0, acc1));
|
||||
}
|
||||
|
||||
static void ggml_vec_dot_mxfp6_q8_0_neon(
|
||||
int n, float * GGML_RESTRICT s,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
const mxfp_neon_traits_t * t) {
|
||||
assert(n % QK_MXFP6 == 0);
|
||||
const int nb = n / QK_MXFP6;
|
||||
const block_q8_0 * GGML_RESTRICT y = vy;
|
||||
|
||||
const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask);
|
||||
const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask);
|
||||
const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off);
|
||||
const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale);
|
||||
const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift);
|
||||
const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift);
|
||||
|
||||
float32x4_t acc0 = vdupq_n_f32(0.0f);
|
||||
float32x4_t acc1 = vdupq_n_f32(0.0f);
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib;
|
||||
const float32x4_t v_scale = vdupq_n_f32(
|
||||
GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d));
|
||||
|
||||
for (int j = 0; j < 32; j += 8) {
|
||||
const uint32x4_t v_lo = unpack_fp6x4_neon(xb->qs + (j * 3 / 4));
|
||||
const uint32x4_t v_hi = unpack_fp6x4_neon(xb->qs + ((j + 4) * 3 / 4));
|
||||
|
||||
float32x4_t qf_lo, qf_hi;
|
||||
widen_s8x8_to_f32x4x2(y[ib].qs + j, &qf_lo, &qf_hi);
|
||||
|
||||
const float32x4_t val_lo = mxfp6_dequant_neon(v_lo,
|
||||
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
|
||||
const float32x4_t val_hi = mxfp6_dequant_neon(v_hi,
|
||||
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
|
||||
|
||||
acc0 = vfmaq_f32(acc0, vmulq_f32(val_lo, v_scale), qf_lo);
|
||||
acc1 = vfmaq_f32(acc1, vmulq_f32(val_hi, v_scale), qf_hi);
|
||||
}
|
||||
}
|
||||
|
||||
*s = vaddvq_f32(vaddq_f32(acc0, acc1));
|
||||
}
|
||||
|
||||
// MXFP SoA dequant (flash attention)
|
||||
|
||||
static void dequantize_row_mxfp8_soa_neon(
|
||||
|
|
@ -4424,26 +4330,6 @@ static void dequantize_row_mxfp4_soa_neon(
|
|||
|
||||
// Public dispatch functions
|
||||
|
||||
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
|
||||
#if defined(__ARM_NEON)
|
||||
ggml_vec_dot_mxfp8_q8_0_neon(n, s, vx, vy, &MXFP_TRAITS_E4M3);
|
||||
#else
|
||||
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
|
||||
#if defined(__ARM_NEON)
|
||||
ggml_vec_dot_mxfp6_q8_0_neon(n, s, vx, vy, &MXFP_TRAITS_E2M3);
|
||||
#else
|
||||
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
#if defined(__ARM_NEON)
|
||||
dequantize_row_mxfp4_soa_neon(x, y, k);
|
||||
|
|
|
|||
|
|
@ -2157,14 +2157,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2303,10 +2303,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3621,11 +3621,3 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1464,10 +1464,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1219,14 +1219,3 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3850,106 +3850,14 @@ static inline __m256 mxfp_dequant_avx2(
|
|||
return _mm256_blendv_ps(normal, sub_val, is_sub);
|
||||
}
|
||||
|
||||
// Unpack 4 tightly-packed 6-bit values from 3 bytes into separate bytes.
|
||||
static inline void unpack_fp6x4_avx2(const uint8_t * p, uint8_t out[4]) {
|
||||
const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
|
||||
out[0] = (pk >> 0) & 0x3F;
|
||||
out[1] = (pk >> 6) & 0x3F;
|
||||
out[2] = (pk >> 12) & 0x3F;
|
||||
out[3] = (pk >> 18) & 0x3F;
|
||||
}
|
||||
|
||||
// Unpack 8 FP6 values (two groups of 4) from packed qs data at offset j.
|
||||
static inline __m256i unpack_fp6x8_avx2(const uint8_t * qs, int j) {
|
||||
uint8_t unpacked[8];
|
||||
unpack_fp6x4_avx2(qs + (j * 3 / 4), unpacked);
|
||||
unpack_fp6x4_avx2(qs + ((j + 4) * 3 / 4), unpacked + 4);
|
||||
ggml_mxfp_unpack_fp6x4(qs + (j * 3 / 4), unpacked);
|
||||
ggml_mxfp_unpack_fp6x4(qs + ((j + 4) * 3 / 4), unpacked + 4);
|
||||
return _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)unpacked));
|
||||
}
|
||||
|
||||
// MXFP FP8/FP6 vec_dot
|
||||
|
||||
// FP8 x Q8_0 dot product (E4M3/E5M2).
|
||||
static void ggml_vec_dot_mxfp8_q8_0_avx2(
|
||||
int n, float * GGML_RESTRICT s,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
const mxfp_avx2_traits_t * t) {
|
||||
assert(n % QK_MXFP8 == 0);
|
||||
const int nb = n / QK_MXFP8;
|
||||
const block_mxfp8 * GGML_RESTRICT x = vx;
|
||||
const block_q8_0 * GGML_RESTRICT y = vy;
|
||||
|
||||
const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask);
|
||||
const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask);
|
||||
const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off);
|
||||
const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale);
|
||||
const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask);
|
||||
const __m256i v_zero = _mm256_setzero_si256();
|
||||
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const __m256 v_scale = _mm256_set1_ps(
|
||||
GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d));
|
||||
|
||||
for (int j = 0; j < 32; j += 8) {
|
||||
const __m256i v_raw = _mm256_cvtepu8_epi32(
|
||||
_mm_loadl_epi64((const __m128i *)(x[ib].qs + j)));
|
||||
const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
|
||||
_mm_loadl_epi64((const __m128i *)(y[ib].qs + j))));
|
||||
|
||||
const __m256 val = mxfp_dequant_avx2(v_raw,
|
||||
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc,
|
||||
v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift);
|
||||
|
||||
acc = _mm256_fmadd_ps(_mm256_mul_ps(val, v_scale), qf, acc);
|
||||
}
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc);
|
||||
}
|
||||
|
||||
// FP6 x Q8_0 dot product (E2M3/E3M2).
|
||||
static void ggml_vec_dot_mxfp6_q8_0_avx2(
|
||||
int n, float * GGML_RESTRICT s,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
const mxfp_avx2_traits_t * t) {
|
||||
assert(n % QK_MXFP6 == 0);
|
||||
const int nb = n / QK_MXFP6;
|
||||
const block_q8_0 * GGML_RESTRICT y = vy;
|
||||
|
||||
const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask);
|
||||
const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask);
|
||||
const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off);
|
||||
const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale);
|
||||
const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask);
|
||||
const __m256i v_zero = _mm256_setzero_si256();
|
||||
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib;
|
||||
const __m256 v_scale = _mm256_set1_ps(
|
||||
GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d));
|
||||
|
||||
for (int j = 0; j < 32; j += 8) {
|
||||
const __m256i v_raw = unpack_fp6x8_avx2(xb->qs, j);
|
||||
const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
|
||||
_mm_loadl_epi64((const __m128i *)(y[ib].qs + j))));
|
||||
|
||||
const __m256 val = mxfp_dequant_avx2(v_raw,
|
||||
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc,
|
||||
v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift);
|
||||
|
||||
acc = _mm256_fmadd_ps(_mm256_mul_ps(val, v_scale), qf, acc);
|
||||
}
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc);
|
||||
}
|
||||
|
||||
// MXFP SoA dequant (flash attention)
|
||||
|
||||
static void dequantize_row_mxfp8_soa_avx2(
|
||||
|
|
@ -4052,26 +3960,6 @@ static void dequantize_row_mxfp4_soa_avx2(
|
|||
|
||||
// Public dispatch functions
|
||||
|
||||
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
|
||||
#if defined(__AVX2__)
|
||||
ggml_vec_dot_mxfp8_q8_0_avx2(n, s, vx, vy, &MXFP_TRAITS_E4M3);
|
||||
#else
|
||||
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
|
||||
#if defined(__AVX2__)
|
||||
ggml_vec_dot_mxfp6_q8_0_avx2(n, s, vx, vy, &MXFP_TRAITS_E2M3);
|
||||
#else
|
||||
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
#if defined(__AVX2__)
|
||||
dequantize_row_mxfp4_soa_avx2(x, y, k);
|
||||
|
|
|
|||
|
|
@ -265,7 +265,7 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
|||
.vec_dot_type = GGML_TYPE_Q8_1,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_MXFP4_E2M1] = {
|
||||
[GGML_TYPE_MXFP4] = {
|
||||
.from_float = quantize_row_mxfp4,
|
||||
.from_float_soa = quantize_row_mxfp4_soa,
|
||||
.to_float_soa = dequantize_row_mxfp4_soa_cpu,
|
||||
|
|
@ -279,20 +279,14 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
|||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_MXFP8_E4M3] = {
|
||||
.from_float = quantize_row_mxfp8,
|
||||
[GGML_TYPE_MXFP8] = {
|
||||
.from_float_soa = quantize_row_mxfp8_soa,
|
||||
.to_float_soa = dequantize_row_mxfp8_soa_cpu,
|
||||
.vec_dot = ggml_vec_dot_mxfp8_q8_0,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_MXFP6_E2M3] = {
|
||||
.from_float = quantize_row_mxfp6,
|
||||
[GGML_TYPE_MXFP6] = {
|
||||
.from_float_soa = quantize_row_mxfp6_soa,
|
||||
.to_float_soa = dequantize_row_mxfp6_soa_cpu,
|
||||
.vec_dot = ggml_vec_dot_mxfp6_q8_0,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_Q2_K] = {
|
||||
|
|
|
|||
|
|
@ -672,10 +672,10 @@ void ggml_compute_forward_add(
|
|||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_MXFP4_E2M1:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
case GGML_TYPE_MXFP8_E4M3:
|
||||
case GGML_TYPE_MXFP6_E2M3:
|
||||
case GGML_TYPE_MXFP8:
|
||||
case GGML_TYPE_MXFP6:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
|
|
@ -1124,10 +1124,10 @@ void ggml_compute_forward_add1(
|
|||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q8_1:
|
||||
case GGML_TYPE_MXFP4_E2M1:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
case GGML_TYPE_MXFP8_E4M3:
|
||||
case GGML_TYPE_MXFP6_E2M3:
|
||||
case GGML_TYPE_MXFP8:
|
||||
case GGML_TYPE_MXFP6:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
|
|
@ -1255,10 +1255,10 @@ void ggml_compute_forward_acc(
|
|||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q8_1:
|
||||
case GGML_TYPE_MXFP4_E2M1:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
case GGML_TYPE_MXFP8_E4M3:
|
||||
case GGML_TYPE_MXFP6_E2M3:
|
||||
case GGML_TYPE_MXFP8:
|
||||
case GGML_TYPE_MXFP6:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
|
|
@ -4345,10 +4345,10 @@ void ggml_compute_forward_out_prod(
|
|||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_MXFP4_E2M1:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
case GGML_TYPE_MXFP8_E4M3:
|
||||
case GGML_TYPE_MXFP6_E2M3:
|
||||
case GGML_TYPE_MXFP8:
|
||||
case GGML_TYPE_MXFP6:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
|
|
@ -4623,10 +4623,10 @@ void ggml_compute_forward_set(
|
|||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q8_1:
|
||||
case GGML_TYPE_MXFP4_E2M1:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
case GGML_TYPE_MXFP8_E4M3:
|
||||
case GGML_TYPE_MXFP6_E2M3:
|
||||
case GGML_TYPE_MXFP8:
|
||||
case GGML_TYPE_MXFP6:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
|
|
@ -4848,10 +4848,10 @@ void ggml_compute_forward_get_rows(
|
|||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q8_1:
|
||||
case GGML_TYPE_MXFP4_E2M1:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
case GGML_TYPE_MXFP8_E4M3:
|
||||
case GGML_TYPE_MXFP6_E2M3:
|
||||
case GGML_TYPE_MXFP8:
|
||||
case GGML_TYPE_MXFP6:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
|
|
@ -5686,10 +5686,10 @@ void ggml_compute_forward_clamp(
|
|||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q8_1:
|
||||
case GGML_TYPE_MXFP4_E2M1:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
case GGML_TYPE_MXFP8_E4M3:
|
||||
case GGML_TYPE_MXFP6_E2M3:
|
||||
case GGML_TYPE_MXFP8:
|
||||
case GGML_TYPE_MXFP6:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
|
|
@ -8255,31 +8255,65 @@ void ggml_compute_forward_top_k(
|
|||
}
|
||||
}
|
||||
|
||||
// Max head dimension for stack-allocated MXFP buffers.
|
||||
static constexpr int64_t MXFP_FA_MAX_D = 1024;
|
||||
// SoA buffer size for MXFP_FA_MAX_D with MXFP8 (worst case: 1024 + 32 e8m0 = 1056, rounded up).
|
||||
static constexpr int MXFP_FA_SOA_BUF = 1088;
|
||||
|
||||
// SoA function pointer types for MXFP flash attention paths.
|
||||
typedef void (*mxfp_soa_quantize_fn)(const float *, void *, int64_t);
|
||||
typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t);
|
||||
|
||||
// Per-KV-type MXFP parameters (shared between K and V).
|
||||
struct mxfp_kv_params {
|
||||
mxfp_soa_dequantize_fn dequantize;
|
||||
bool multihead;
|
||||
int64_t soa_elems;
|
||||
int qs_per_block;
|
||||
int head_qs_bytes;
|
||||
int64_t head_e8m0_offset;
|
||||
int blocks_per_head;
|
||||
};
|
||||
|
||||
// MXFP dispatch parameters for flash attention.
|
||||
struct mxfp_fa_params {
|
||||
mxfp_soa_quantize_fn q_quantize;
|
||||
mxfp_soa_dequantize_fn k_dequantize;
|
||||
mxfp_soa_dequantize_fn v_dequantize;
|
||||
bool k_multihead;
|
||||
bool v_multihead;
|
||||
int64_t k_soa_elems;
|
||||
int64_t v_soa_elems;
|
||||
bool apply_hadamard;
|
||||
// Per-head SoA addressing (avoids dequanting all heads in multihead mode).
|
||||
int k_qs_per_block;
|
||||
int v_qs_per_block;
|
||||
int k_head_qs_bytes;
|
||||
int v_head_qs_bytes;
|
||||
int64_t k_head_e8m0_offset;
|
||||
int64_t v_head_e8m0_offset;
|
||||
int k_blocks_per_head;
|
||||
int v_blocks_per_head;
|
||||
mxfp_soa_quantize_fn q_quantize;
|
||||
mxfp_kv_params k;
|
||||
mxfp_kv_params v;
|
||||
bool apply_hadamard;
|
||||
};
|
||||
|
||||
// Extract one head's SoA data from a multihead row and dequantize.
|
||||
static inline void mxfp_dequant_head(
|
||||
const mxfp_kv_params & kv, const char * row, int head_idx,
|
||||
char * soa_buf, float * out, int64_t D) {
|
||||
if (kv.multihead) {
|
||||
const int qs_off = head_idx * kv.head_qs_bytes;
|
||||
const int e8m0_off = (int)kv.head_e8m0_offset + head_idx * kv.blocks_per_head;
|
||||
memcpy(soa_buf, row + qs_off, kv.head_qs_bytes);
|
||||
memcpy(soa_buf + kv.head_qs_bytes, row + e8m0_off, kv.blocks_per_head);
|
||||
kv.dequantize(soa_buf, out, D);
|
||||
} else {
|
||||
kv.dequantize(row, out, D);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize per-KV-type params from tensor metadata.
|
||||
// Multihead detection: nb2 == row_size(D) means heads are contiguous within
|
||||
// one KV-position stride, so SoA spans all heads. Otherwise SoA is per-head.
|
||||
static mxfp_kv_params mxfp_kv_params_init(ggml_type type, int64_t D, size_t nb2, int64_t ne2) {
|
||||
mxfp_kv_params kv = {};
|
||||
kv.dequantize = ggml_get_type_traits_cpu(type)->to_float_soa;
|
||||
kv.multihead = (nb2 == (size_t)ggml_row_size(type, D));
|
||||
kv.soa_elems = kv.multihead ? ne2 * D : D;
|
||||
kv.qs_per_block = ggml_mxfp_qs_per_block(type);
|
||||
kv.blocks_per_head = (int)(D / 32);
|
||||
kv.head_qs_bytes = kv.blocks_per_head * kv.qs_per_block;
|
||||
const int64_t total_blocks = kv.multihead ? ne2 * kv.blocks_per_head : kv.blocks_per_head;
|
||||
kv.head_e8m0_offset = total_blocks * kv.qs_per_block;
|
||||
return kv;
|
||||
}
|
||||
|
||||
static mxfp_fa_params mxfp_fa_params_init(
|
||||
const ggml_tensor * k, const ggml_tensor * v,
|
||||
int64_t DK, int64_t DV,
|
||||
|
|
@ -8291,44 +8325,17 @@ static mxfp_fa_params mxfp_fa_params_init(
|
|||
const bool is_mxfp_v = ggml_is_type_mxfp(v->type);
|
||||
|
||||
if (is_mxfp_k) {
|
||||
const struct ggml_type_traits_cpu * k_traits = ggml_get_type_traits_cpu(k->type);
|
||||
p.q_quantize = k_traits->from_float_soa;
|
||||
p.k_dequantize = k_traits->to_float_soa;
|
||||
p.q_quantize = ggml_get_type_traits_cpu(k->type)->from_float_soa;
|
||||
p.k = mxfp_kv_params_init(k->type, DK, nbk2, nek2);
|
||||
}
|
||||
|
||||
if (is_mxfp_v) {
|
||||
p.v_dequantize = ggml_get_type_traits_cpu(v->type)->to_float_soa;
|
||||
p.v = mxfp_kv_params_init(v->type, DV, nbv2, nev2);
|
||||
}
|
||||
|
||||
// Hadamard rotation must match K rotation.
|
||||
// Skipped for: MLA (DK != DV, V is a view of K).
|
||||
// Skipped for MLA (DK != DV, V is a view of K).
|
||||
p.apply_hadamard = is_mxfp_k && (DK == DV) && ggml_mxfp_use_hadamard(k->type);
|
||||
|
||||
// SoA layout detection: in the real KV cache, heads are contiguous within
|
||||
// one KV-position stride (nb[2] == row_size(DK)), so SoA spans all heads.
|
||||
// In test tensors, heads may be at distant offsets (nb[2] >> row_size(DK)),
|
||||
// so SoA is per-head. Detect which case and set dequant parameters accordingly.
|
||||
p.k_multihead = is_mxfp_k && (nbk2 == (size_t)ggml_row_size(k->type, DK));
|
||||
p.k_soa_elems = is_mxfp_k ? (p.k_multihead ? nek2 * DK : DK) : 0;
|
||||
p.v_multihead = is_mxfp_v && (nbv2 == (size_t)ggml_row_size(v->type, DV));
|
||||
p.v_soa_elems = is_mxfp_v ? (p.v_multihead ? nev2 * DV : DV) : 0;
|
||||
|
||||
if (is_mxfp_k) {
|
||||
p.k_qs_per_block = ggml_mxfp_qs_per_block(k->type);
|
||||
p.k_blocks_per_head = (int)(DK / 32);
|
||||
p.k_head_qs_bytes = p.k_blocks_per_head * p.k_qs_per_block;
|
||||
const int64_t k_total_blocks = p.k_multihead ? nek2 * p.k_blocks_per_head : p.k_blocks_per_head;
|
||||
p.k_head_e8m0_offset = k_total_blocks * p.k_qs_per_block;
|
||||
}
|
||||
|
||||
if (is_mxfp_v) {
|
||||
p.v_qs_per_block = ggml_mxfp_qs_per_block(v->type);
|
||||
p.v_blocks_per_head = (int)(DV / 32);
|
||||
p.v_head_qs_bytes = p.v_blocks_per_head * p.v_qs_per_block;
|
||||
const int64_t v_total_blocks = p.v_multihead ? nev2 * p.v_blocks_per_head : p.v_blocks_per_head;
|
||||
p.v_head_e8m0_offset = v_total_blocks * p.v_qs_per_block;
|
||||
}
|
||||
|
||||
return p;
|
||||
}
|
||||
|
||||
|
|
@ -8430,14 +8437,14 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|||
|
||||
int ith = params->ith;
|
||||
|
||||
if (is_mxfp_k) { GGML_ASSERT(DK <= 1024); }
|
||||
if (is_mxfp_v) { GGML_ASSERT(DV <= 1024); }
|
||||
if (is_mxfp_k) { GGML_ASSERT(DK <= MXFP_FA_MAX_D); }
|
||||
if (is_mxfp_v) { GGML_ASSERT(DV <= MXFP_FA_MAX_D); }
|
||||
|
||||
float k_dequant_buf[1024];
|
||||
float v_dequant_buf[1024];
|
||||
float k_dequant_buf[MXFP_FA_MAX_D];
|
||||
float v_dequant_buf[MXFP_FA_MAX_D];
|
||||
|
||||
char k_head_soa[1088]; // max: DK=1024 MXFP8 -> 1056 bytes, rounded up
|
||||
char v_head_soa[1088];
|
||||
char k_head_soa[MXFP_FA_SOA_BUF]; // max: DK=1024 MXFP8 -> 1056 bytes, rounded up
|
||||
char v_head_soa[MXFP_FA_SOA_BUF];
|
||||
|
||||
float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32);
|
||||
float * V32 = (VKQ32 + 1*DV);
|
||||
|
|
@ -8479,31 +8486,25 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|||
const char * k_base = (const char *) k->data + k_base_offset;
|
||||
const char * v_base = (const char *) v->data + v_base_offset;
|
||||
|
||||
// Per-head SoA byte offsets
|
||||
const int k_head_qs_start = mxfp.k_multihead ? ik2 * mxfp.k_head_qs_bytes : 0;
|
||||
const int k_head_e8m0_start = mxfp.k_multihead ? (int)mxfp.k_head_e8m0_offset + ik2 * mxfp.k_blocks_per_head : 0;
|
||||
const int v_head_qs_start = mxfp.v_multihead ? iv2 * mxfp.v_head_qs_bytes : 0;
|
||||
const int v_head_e8m0_start = mxfp.v_multihead ? (int)mxfp.v_head_e8m0_offset + iv2 * mxfp.v_blocks_per_head : 0;
|
||||
|
||||
const char * k_row_base = mxfp.k_multihead ? ((const char *) k->data + ik3*nbk3) : nullptr;
|
||||
const char * v_row_base = mxfp.v_multihead ? ((const char *) v->data + iv3*nbv3) : nullptr;
|
||||
const char * k_row_base = mxfp.k.multihead ? ((const char *) k->data + ik3*nbk3) : nullptr;
|
||||
const char * v_row_base = mxfp.v.multihead ? ((const char *) v->data + iv3*nbv3) : nullptr;
|
||||
|
||||
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
||||
float Q_f32[1024];
|
||||
float Q_f32[MXFP_FA_MAX_D];
|
||||
if (is_mxfp_k) {
|
||||
// Q preprocessing: Hadamard + SoA round-trip captures same quantization loss as K.
|
||||
if (mxfp.apply_hadamard) {
|
||||
float q_tmp[1024];
|
||||
float q_tmp[MXFP_FA_MAX_D];
|
||||
memcpy(q_tmp, pq, DK * sizeof(float));
|
||||
ggml_apply_hadamard_blocks(q_tmp, DK);
|
||||
mxfp.q_quantize(q_tmp, Q_q, DK);
|
||||
} else {
|
||||
mxfp.q_quantize(pq, Q_q, DK);
|
||||
}
|
||||
mxfp.k_dequantize(Q_q, Q_f32, DK);
|
||||
mxfp.k.dequantize(Q_q, Q_f32, DK);
|
||||
} else {
|
||||
if (mxfp.apply_hadamard) {
|
||||
float q_tmp[1024];
|
||||
float q_tmp[MXFP_FA_MAX_D];
|
||||
memcpy(q_tmp, pq, DK * sizeof(float));
|
||||
ggml_apply_hadamard_blocks(q_tmp, DK);
|
||||
q_to_vec_dot(q_tmp, Q_q, DK);
|
||||
|
|
@ -8525,15 +8526,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|||
float s; // KQ value
|
||||
|
||||
if (is_mxfp_k) {
|
||||
if (mxfp.k_multihead) {
|
||||
// Extract this head's SoA blocks
|
||||
const char * row = k_row_base + ic*nbk1;
|
||||
memcpy(k_head_soa, row + k_head_qs_start, mxfp.k_head_qs_bytes);
|
||||
memcpy(k_head_soa + mxfp.k_head_qs_bytes, row + k_head_e8m0_start, mxfp.k_blocks_per_head);
|
||||
mxfp.k_dequantize(k_head_soa, k_dequant_buf, DK);
|
||||
} else {
|
||||
mxfp.k_dequantize(k_base + ic*nbk1, k_dequant_buf, DK);
|
||||
}
|
||||
const char * k_row = mxfp.k.multihead ? k_row_base + ic*nbk1 : k_base + ic*nbk1;
|
||||
mxfp_dequant_head(mxfp.k, k_row, ik2, k_head_soa, k_dequant_buf, DK);
|
||||
ggml_vec_dot_f32(DK, &s, 0, k_dequant_buf, 0, Q_f32, 0, 1);
|
||||
} else {
|
||||
kq_vec_dot(DK, &s, 0, k_base + ic*nbk1, 0, Q_q, 0, 1);
|
||||
|
|
@ -8577,15 +8571,9 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|||
}
|
||||
|
||||
// V += v*expf(s - M)
|
||||
if (mxfp.v_dequantize) {
|
||||
if (mxfp.v_multihead) {
|
||||
const char * row = v_row_base + ic*nbv1;
|
||||
memcpy(v_head_soa, row + v_head_qs_start, mxfp.v_head_qs_bytes);
|
||||
memcpy(v_head_soa + mxfp.v_head_qs_bytes, row + v_head_e8m0_start, mxfp.v_blocks_per_head);
|
||||
mxfp.v_dequantize(v_head_soa, v_dequant_buf, DV);
|
||||
} else {
|
||||
mxfp.v_dequantize(v_base + ic*nbv1, v_dequant_buf, DV);
|
||||
}
|
||||
if (mxfp.v.dequantize) {
|
||||
const char * v_row = mxfp.v.multihead ? v_row_base + ic*nbv1 : v_base + ic*nbv1;
|
||||
mxfp_dequant_head(mxfp.v, v_row, iv2, v_head_soa, v_dequant_buf, DV);
|
||||
ggml_vec_mad_f32(DV, VKQ32, v_dequant_buf, vs);
|
||||
} else if (v_to_float) {
|
||||
v_to_float(v_base + ic*nbv1, V32, DV);
|
||||
|
|
@ -8731,14 +8719,14 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
|||
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
|
||||
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
|
||||
|
||||
if (is_mxfp_k) { GGML_ASSERT(DK <= 1024); }
|
||||
if (is_mxfp_v) { GGML_ASSERT(DV <= 1024); }
|
||||
if (is_mxfp_k) { GGML_ASSERT(DK <= MXFP_FA_MAX_D); }
|
||||
if (is_mxfp_v) { GGML_ASSERT(DV <= MXFP_FA_MAX_D); }
|
||||
|
||||
float k_dequant_buf[1024];
|
||||
float v_dequant_buf[1024];
|
||||
float k_dequant_buf[MXFP_FA_MAX_D];
|
||||
float v_dequant_buf[MXFP_FA_MAX_D];
|
||||
|
||||
char k_head_soa[1088];
|
||||
char v_head_soa[1088];
|
||||
char k_head_soa[MXFP_FA_SOA_BUF];
|
||||
char v_head_soa[MXFP_FA_SOA_BUF];
|
||||
|
||||
int ir = ir0;
|
||||
while (ir < ir1) {
|
||||
|
|
@ -8802,9 +8790,9 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
|||
if (mxfp.apply_hadamard) {
|
||||
ggml_apply_hadamard_blocks(Q_f32 + tq * DK, DK);
|
||||
}
|
||||
uint8_t q_mxfp_buf[1088]; // max: DK=1024 MXFP8 -> 1056 bytes
|
||||
uint8_t q_mxfp_buf[MXFP_FA_SOA_BUF];
|
||||
mxfp.q_quantize(Q_f32 + tq * DK, q_mxfp_buf, DK);
|
||||
mxfp.k_dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK);
|
||||
mxfp.k.dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK);
|
||||
}
|
||||
}
|
||||
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
|
||||
|
|
@ -8854,23 +8842,13 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
|||
for (int64_t dk = 0; dk < DK; dk++) {
|
||||
K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
|
||||
}
|
||||
} else if (mxfp.k_dequantize) {
|
||||
if (mxfp.k_multihead) {
|
||||
// Per-head extraction: copy only this head's SoA blocks + e8m0, dequant DK elements.
|
||||
const char * row = (const char *)k->data + (ic + tk)*nbk1 + ik3*nbk3;
|
||||
const int kqs = ik2 * mxfp.k_head_qs_bytes;
|
||||
const int ke8 = (int)mxfp.k_head_e8m0_offset + ik2 * mxfp.k_blocks_per_head;
|
||||
memcpy(k_head_soa, row + kqs, mxfp.k_head_qs_bytes);
|
||||
memcpy(k_head_soa + mxfp.k_head_qs_bytes, row + ke8, mxfp.k_blocks_per_head);
|
||||
mxfp.k_dequantize(k_head_soa, k_dequant_buf, DK);
|
||||
} else {
|
||||
mxfp.k_dequantize(k_data, k_dequant_buf, DK);
|
||||
}
|
||||
} else if (mxfp.k.dequantize) {
|
||||
mxfp_dequant_head(mxfp.k, k_data, ik2, k_head_soa, k_dequant_buf, DK);
|
||||
for (int64_t dk = 0; dk < DK; dk++) {
|
||||
K_f32[dk * KV_TILE_SZ + tk] = k_dequant_buf[dk];
|
||||
}
|
||||
} else {
|
||||
float k_tmp[1024];
|
||||
float k_tmp[MXFP_FA_MAX_D];
|
||||
k_to_float(k_data, k_tmp, DK);
|
||||
for (int64_t dk = 0; dk < DK; dk++) {
|
||||
K_f32[dk * KV_TILE_SZ + tk] = k_tmp[dk];
|
||||
|
|
@ -8934,18 +8912,8 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
|||
ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
|
||||
} else if (v_type == GGML_TYPE_F32) {
|
||||
memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
|
||||
} else if (mxfp.v_dequantize) {
|
||||
if (mxfp.v_multihead) {
|
||||
// Per-head extraction: copy only this head's SoA blocks + e8m0, dequant DV elements.
|
||||
const char * row = (const char *)v->data + (ic + tk)*nbv1 + iv3*nbv3;
|
||||
const int vqs = iv2 * mxfp.v_head_qs_bytes;
|
||||
const int ve8 = (int)mxfp.v_head_e8m0_offset + iv2 * mxfp.v_blocks_per_head;
|
||||
memcpy(v_head_soa, row + vqs, mxfp.v_head_qs_bytes);
|
||||
memcpy(v_head_soa + mxfp.v_head_qs_bytes, row + ve8, mxfp.v_blocks_per_head);
|
||||
mxfp.v_dequantize(v_head_soa, v_dequant_buf, DV);
|
||||
} else {
|
||||
mxfp.v_dequantize(v_data, v_dequant_buf, DV);
|
||||
}
|
||||
} else if (mxfp.v.dequantize) {
|
||||
mxfp_dequant_head(mxfp.v, v_data, iv2, v_head_soa, v_dequant_buf, DV);
|
||||
memcpy(V32 + tk * DV, v_dequant_buf, DV * sizeof(float));
|
||||
} else {
|
||||
v_to_float(v_data, V32 + tk * DV, DV);
|
||||
|
|
|
|||
|
|
@ -54,14 +54,6 @@ void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, i
|
|||
quantize_row_nvfp4_ref(x, y, k);
|
||||
}
|
||||
|
||||
void quantize_row_mxfp8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
|
||||
quantize_row_mxfp8_ref(x, y, k);
|
||||
}
|
||||
|
||||
void quantize_row_mxfp6(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
|
||||
quantize_row_mxfp6_ref(x, y, k);
|
||||
}
|
||||
|
||||
//
|
||||
// 2-6 bit quantization in super-blocks
|
||||
//
|
||||
|
|
@ -264,51 +256,6 @@ void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||
*s = sumf;
|
||||
}
|
||||
|
||||
// Generic MXFP x Q8_0 dot product (scalar, not SIMD-optimized)
|
||||
static void ggml_vec_dot_mxfp_q8_0_impl(
|
||||
int n, float * GGML_RESTRICT s,
|
||||
const void * GGML_RESTRICT vx, size_t block_size,
|
||||
const void * GGML_RESTRICT vy,
|
||||
ggml_to_float_t dequant) {
|
||||
assert(n % QK8_0 == 0);
|
||||
const int nb = n / QK8_0;
|
||||
const block_q8_0 * GGML_RESTRICT y = vy;
|
||||
float sumf = 0;
|
||||
|
||||
for (int ib = 0; ib < nb; ib++) {
|
||||
float tmp[QK8_0];
|
||||
dequant((const char *)vx + ib * block_size, tmp, QK8_0);
|
||||
|
||||
const float y_d = GGML_CPU_FP16_TO_FP32(y[ib].d);
|
||||
float block_sum = 0;
|
||||
for (int j = 0; j < QK8_0; j++) {
|
||||
block_sum += tmp[j] * (float)y[ib].qs[j];
|
||||
}
|
||||
sumf += block_sum * y_d;
|
||||
}
|
||||
*s = sumf;
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
ggml_vec_dot_mxfp_q8_0_impl(n, s, vx, sizeof(block_mxfp8), vy,
|
||||
(ggml_to_float_t)dequantize_row_mxfp8);
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
ggml_vec_dot_mxfp_q8_0_impl(n, s, vx, sizeof(block_mxfp6), vy,
|
||||
(ggml_to_float_t)dequantize_row_mxfp6);
|
||||
}
|
||||
|
||||
// Generic SoA dequant wrappers — arch-specific SIMD versions override via fallback.h.
|
||||
void dequantize_row_mxfp4_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
dequantize_row_mxfp4_soa(x, y, k);
|
||||
|
|
|
|||
|
|
@ -21,9 +21,6 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
|
|||
|
||||
void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_mxfp8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_mxfp6(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
|
|
@ -46,9 +43,6 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|||
|
||||
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
|
@ -80,9 +74,6 @@ void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c
|
|||
|
||||
void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
// SoA dequant (SIMD-dispatched, CPU backend)
|
||||
void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
|
|
|
|||
|
|
@ -3770,7 +3770,7 @@ static block_mxfp4x4 make_block_mxfp4x4(block_mxfp4 * in, unsigned int blck_size
|
|||
}
|
||||
|
||||
static int repack_mxfp4_to_mxfp4_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_MXFP4_E2M1);
|
||||
GGML_ASSERT(t->type == GGML_TYPE_MXFP4);
|
||||
GGML_ASSERT(interleave_block == 4);
|
||||
|
||||
const block_mxfp4 * src = (const block_mxfp4 *)data;
|
||||
|
|
@ -3827,7 +3827,7 @@ static block_mxfp4x8 make_block_mxfp4x8(block_mxfp4 * in, unsigned int blck_size
|
|||
}
|
||||
|
||||
static int repack_mxfp4_to_mxfp4_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_MXFP4_E2M1);
|
||||
GGML_ASSERT(t->type == GGML_TYPE_MXFP4);
|
||||
GGML_ASSERT(interleave_block == 8);
|
||||
|
||||
const block_mxfp4 * src = (const block_mxfp4 *)data;
|
||||
|
|
@ -4685,7 +4685,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
|||
}
|
||||
#endif
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_MXFP4_E2M1) {
|
||||
} else if (cur->type == GGML_TYPE_MXFP4) {
|
||||
if (ggml_cpu_has_avx2()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
return &mxfp4_8x8_q8_0;
|
||||
|
|
|
|||
|
|
@ -1014,7 +1014,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
// MXFP4: has AoS shaders (MUL_MAT, GET_ROWS) but no SoA/flash attention support yet.
|
||||
for (size_t i = 0, n = 3; i < n; ++i) {
|
||||
if (op->src[i] != NULL && ggml_is_type_mxfp(op->src[i]->type)) {
|
||||
if (op->src[i]->type != GGML_TYPE_MXFP4_E2M1) {
|
||||
if (op->src[i]->type != GGML_TYPE_MXFP4) {
|
||||
return false;
|
||||
}
|
||||
if (op->op == GGML_OP_FLASH_ATTN_EXT || op->op == GGML_OP_SET_ROWS) {
|
||||
|
|
|
|||
|
|
@ -3760,7 +3760,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
|||
} else if (op->src[0]->type == GGML_TYPE_F32) {
|
||||
return op->src[1]->type == GGML_TYPE_F32;
|
||||
} else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 ||
|
||||
op->src[0]->type == GGML_TYPE_MXFP4_E2M1 ||
|
||||
op->src[0]->type == GGML_TYPE_MXFP4 ||
|
||||
op->src[0]->type == GGML_TYPE_Q4_K ||
|
||||
op->src[0]->type == GGML_TYPE_Q6_K) {
|
||||
return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
|
||||
|
|
@ -3771,7 +3771,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
|||
case GGML_OP_MUL_MAT_ID:
|
||||
if (op->src[0]->type == GGML_TYPE_Q4_0 ||
|
||||
op->src[0]->type == GGML_TYPE_Q8_0 ||
|
||||
op->src[0]->type == GGML_TYPE_MXFP4_E2M1) {
|
||||
op->src[0]->type == GGML_TYPE_MXFP4) {
|
||||
if (op->src[1]->type == GGML_TYPE_F32) {
|
||||
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
|
||||
}
|
||||
|
|
@ -4559,7 +4559,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
return;
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_MXFP4_E2M1) {
|
||||
if (tensor->type == GGML_TYPE_MXFP4) {
|
||||
ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
|
||||
GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
|
||||
|
||||
|
|
@ -5136,7 +5136,7 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
|||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
return;
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_MXFP4_E2M1) {
|
||||
if (tensor->type == GGML_TYPE_MXFP4) {
|
||||
ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra;
|
||||
|
||||
cl_int err;
|
||||
|
|
@ -5585,7 +5585,7 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso
|
|||
CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL));
|
||||
CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_d, buf_d, 0, NULL, NULL));
|
||||
CL_CHECK(clFinish(queue));
|
||||
} else if (tensor->type == GGML_TYPE_MXFP4_E2M1) {
|
||||
} else if (tensor->type == GGML_TYPE_MXFP4) {
|
||||
ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *) tensor->extra;
|
||||
GGML_ASSERT(extra);
|
||||
|
||||
|
|
@ -10550,7 +10550,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
|||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3));
|
||||
#endif // GGML_OPENCL_SOA_Q
|
||||
break;
|
||||
case GGML_TYPE_MXFP4_E2M1: {
|
||||
case GGML_TYPE_MXFP4: {
|
||||
#ifdef GGML_OPENCL_SOA_Q
|
||||
kernel = backend_ctx->kernel_mul_mv_mxfp4_f32_flat;
|
||||
|
||||
|
|
@ -10630,7 +10630,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
|||
GGML_ASSERT(false && "not implemented");
|
||||
}
|
||||
|
||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4_E2M1 ||
|
||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4 ||
|
||||
src0t == GGML_TYPE_Q4_1 ||
|
||||
src0t == GGML_TYPE_Q8_0 ||
|
||||
src0t == GGML_TYPE_Q2_K) {
|
||||
|
|
@ -10864,7 +10864,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
|
|||
#endif // GGML_OPENCL_SOA_Q
|
||||
break;
|
||||
}
|
||||
case GGML_TYPE_MXFP4_E2M1: {
|
||||
case GGML_TYPE_MXFP4: {
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
if (use_adreno_moe_kernels(backend_ctx, src0)) {
|
||||
cl_int status;
|
||||
|
|
|
|||
|
|
@ -267,10 +267,12 @@ uint8_t float_to_fp8_e4m3_rn(float x) { return ggml_mxfp_float_to_fp8_e4m3(x); }
|
|||
// MSE-optimal E8M0: tests candidates around round(log2(amax)), picks lowest quantization error.
|
||||
|
||||
typedef struct {
|
||||
int emax_offset; // type-specific offset to max representable exponent
|
||||
int emax_offset; // type-specific offset to max representable exponent
|
||||
int qs_per_block; // quantized scalar bytes per 32-element block
|
||||
int bits_per_elem; // 8 = byte-aligned, 6 = packed via fp6x4
|
||||
uint8_t (*to_elem)(float);
|
||||
float (*to_float)(uint8_t);
|
||||
float (*mse_error)(float val, float inv_scale, float scale);
|
||||
float (*mse_error)(float val, float inv_scale, float scale); // NULL = use generic round-trip via to_elem/to_float
|
||||
} mxfp_elem_traits_t;
|
||||
|
||||
static inline int best_index_mxfp4(float x, float e);
|
||||
|
|
@ -294,7 +296,7 @@ static float mse_error_mxfp4(float val, float inv_scale, float scale) {
|
|||
return err * err;
|
||||
}
|
||||
|
||||
static const mxfp_elem_traits_t mxfp4_traits = { MXFP4_E2M1_EMAX_OFFSET, NULL, NULL, mse_error_mxfp4 };
|
||||
static const mxfp_elem_traits_t mxfp4_traits = { MXFP4_E2M1_EMAX_OFFSET, MXFP4_SOA_QS_PER_BLOCK, 4, NULL, NULL, mse_error_mxfp4 };
|
||||
|
||||
static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_elem_traits_t * traits) {
|
||||
float amax = 0.0f;
|
||||
|
|
@ -319,7 +321,13 @@ static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_
|
|||
const float test_inv = 1.0f / test_scale;
|
||||
float mse = 0.0f;
|
||||
for (int j = 0; j < qk; ++j) {
|
||||
mse += traits->mse_error(x[j], test_inv, test_scale);
|
||||
if (traits->mse_error) {
|
||||
mse += traits->mse_error(x[j], test_inv, test_scale);
|
||||
} else {
|
||||
const float recon = traits->to_float(traits->to_elem(x[j] * test_inv)) * test_scale;
|
||||
const float err = x[j] - recon;
|
||||
mse += err * err;
|
||||
}
|
||||
}
|
||||
if (mse < best_mse) {
|
||||
best_mse = mse;
|
||||
|
|
@ -574,102 +582,8 @@ uint8_t float_to_fp8_e5m2_rn(float x) { return ggml_mxfp_float_to_fp8_e5m2(x); }
|
|||
void pack_fp6x4(const uint8_t v[4], uint8_t out[3]) { ggml_mxfp_pack_fp6x4(v, out); }
|
||||
void unpack_fp6x4(const uint8_t in[3], uint8_t v[4]) { ggml_mxfp_unpack_fp6x4(in, v); }
|
||||
|
||||
static float mse_error_fp8_e4m3(float val, float inv_scale, float scale) {
|
||||
const float recon = fp8_e4m3_to_float(float_to_fp8_e4m3_rn(val * inv_scale)) * scale;
|
||||
const float err = val - recon;
|
||||
return err * err;
|
||||
}
|
||||
static float mse_error_fp6_e2m3(float val, float inv_scale, float scale) {
|
||||
const float recon = fp6_e2m3_to_float(float_to_fp6_e2m3_rn(val * inv_scale)) * scale;
|
||||
const float err = val - recon;
|
||||
return err * err;
|
||||
}
|
||||
static const mxfp_elem_traits_t mxfp8_e4m3_traits = { MXFP8_E4M3_EMAX_OFFSET, float_to_fp8_e4m3_rn, fp8_e4m3_to_float, mse_error_fp8_e4m3 };
|
||||
static const mxfp_elem_traits_t mxfp6_e2m3_traits = { MXFP6_E2M3_EMAX_OFFSET, float_to_fp6_e2m3_rn, fp6_e2m3_to_float, mse_error_fp6_e2m3 };
|
||||
|
||||
static void quantize_row_mxfp8_impl(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y,
|
||||
int64_t k, const mxfp_elem_traits_t * traits) {
|
||||
assert(k % QK_MXFP8 == 0);
|
||||
const int nb = k / QK_MXFP8;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP8], QK_MXFP8, traits);
|
||||
const float d = GGML_E8M0_TO_FP32(e);
|
||||
const float inv_d = d > 0.0f ? 1.0f / d : 0.0f;
|
||||
y[i].e = e;
|
||||
|
||||
for (int j = 0; j < QK_MXFP8; ++j) {
|
||||
y[i].qs[j] = traits->to_elem(x[i*QK_MXFP8 + j] * inv_d);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void dequantize_row_mxfp8_impl(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y,
|
||||
int64_t k, const mxfp_elem_traits_t * traits) {
|
||||
assert(k % QK_MXFP8 == 0);
|
||||
const int nb = k / QK_MXFP8;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const float d = GGML_E8M0_TO_FP32(x[i].e);
|
||||
for (int j = 0; j < QK_MXFP8; ++j) {
|
||||
y[i*QK_MXFP8 + j] = traits->to_float(x[i].qs[j]) * d;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void quantize_row_mxfp6_impl(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y,
|
||||
int64_t k, const mxfp_elem_traits_t * traits) {
|
||||
assert(k % QK_MXFP6 == 0);
|
||||
const int nb = k / QK_MXFP6;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP6], QK_MXFP6, traits);
|
||||
const float d = GGML_E8M0_TO_FP32(e);
|
||||
const float inv_d = d > 0.0f ? 1.0f / d : 0.0f;
|
||||
y[i].e = e;
|
||||
|
||||
for (int j = 0; j < QK_MXFP6; j += 4) {
|
||||
uint8_t vals[4];
|
||||
for (int jj = 0; jj < 4; jj++) {
|
||||
vals[jj] = traits->to_elem(x[i*QK_MXFP6 + j + jj] * inv_d);
|
||||
}
|
||||
pack_fp6x4(vals, &y[i].qs[j * 3 / 4]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void dequantize_row_mxfp6_impl(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y,
|
||||
int64_t k, const mxfp_elem_traits_t * traits) {
|
||||
assert(k % QK_MXFP6 == 0);
|
||||
const int nb = k / QK_MXFP6;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const float d = GGML_E8M0_TO_FP32(x[i].e);
|
||||
for (int j = 0; j < QK_MXFP6; j += 4) {
|
||||
uint8_t vals[4];
|
||||
unpack_fp6x4(&x[i].qs[j * 3 / 4], vals);
|
||||
for (int jj = 0; jj < 4; jj++) {
|
||||
y[i*QK_MXFP6 + j + jj] = traits->to_float(vals[jj]) * d;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k) {
|
||||
quantize_row_mxfp8_impl(x, y, k, &mxfp8_e4m3_traits);
|
||||
}
|
||||
|
||||
void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
dequantize_row_mxfp8_impl(x, y, k, &mxfp8_e4m3_traits);
|
||||
}
|
||||
|
||||
void quantize_row_mxfp6_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k) {
|
||||
quantize_row_mxfp6_impl(x, y, k, &mxfp6_e2m3_traits);
|
||||
}
|
||||
|
||||
void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
dequantize_row_mxfp6_impl(x, y, k, &mxfp6_e2m3_traits);
|
||||
}
|
||||
static const mxfp_elem_traits_t mxfp8_e4m3_traits = { MXFP8_E4M3_EMAX_OFFSET, MXFP8_SOA_QS_PER_BLOCK, 8, float_to_fp8_e4m3_rn, fp8_e4m3_to_float, NULL };
|
||||
static const mxfp_elem_traits_t mxfp6_e2m3_traits = { MXFP6_E2M3_EMAX_OFFSET, MXFP6_SOA_QS_PER_BLOCK, 6, float_to_fp6_e2m3_rn, fp6_e2m3_to_float, NULL };
|
||||
|
||||
// ====================== SoA (Struct-of-Arrays) quantize/dequantize for flash attention
|
||||
|
||||
|
|
@ -715,101 +629,79 @@ void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTR
|
|||
}
|
||||
}
|
||||
|
||||
static void quantize_row_mxfp8_soa_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst,
|
||||
int64_t k, const mxfp_elem_traits_t * traits) {
|
||||
assert(k % QK_MXFP8 == 0);
|
||||
const int nb = k / QK_MXFP8;
|
||||
char * row = (char *)dst;
|
||||
char * qs_base = row;
|
||||
char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK);
|
||||
// Unified SoA quantize for byte-aligned (FP8) and 6-bit packed (FP6) MXFP formats.
|
||||
static void quantize_row_mxfp_soa_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst,
|
||||
int64_t k, const mxfp_elem_traits_t * traits) {
|
||||
const int qk = 32;
|
||||
assert(k % qk == 0);
|
||||
const int nb = k / qk;
|
||||
const int qpb = traits->qs_per_block;
|
||||
char * qs_base = (char *)dst;
|
||||
char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, qpb);
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP8], QK_MXFP8, traits);
|
||||
const uint8_t e = mxfp_compute_e8m0_mse(&x[i*qk], qk, traits);
|
||||
const float d = GGML_E8M0_TO_FP32(e);
|
||||
const float inv_d = d > 0.0f ? 1.0f / d : 0.0f;
|
||||
e8m0_base[i] = (char)e;
|
||||
|
||||
uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP8_SOA_QS_PER_BLOCK));
|
||||
for (int j = 0; j < QK_MXFP8; ++j) {
|
||||
qs[j] = traits->to_elem(x[i*QK_MXFP8 + j] * inv_d);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void dequantize_row_mxfp8_soa_impl(const void * GGML_RESTRICT src, float * GGML_RESTRICT y,
|
||||
int64_t k, const mxfp_elem_traits_t * traits) {
|
||||
assert(k % QK_MXFP8 == 0);
|
||||
const int nb = k / QK_MXFP8;
|
||||
const char * row = (const char *)src;
|
||||
const char * qs_base = row;
|
||||
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK);
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const float d = GGML_E8M0_TO_FP32((uint8_t)e8m0_base[i]);
|
||||
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP8_SOA_QS_PER_BLOCK));
|
||||
for (int j = 0; j < QK_MXFP8; ++j) {
|
||||
y[i*QK_MXFP8 + j] = traits->to_float(qs[j]) * d;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void quantize_row_mxfp6_soa_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst,
|
||||
int64_t k, const mxfp_elem_traits_t * traits) {
|
||||
assert(k % QK_MXFP6 == 0);
|
||||
const int nb = k / QK_MXFP6;
|
||||
char * row = (char *)dst;
|
||||
char * qs_base = row;
|
||||
char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK);
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const uint8_t e = mxfp_compute_e8m0_mse(&x[i*QK_MXFP6], QK_MXFP6, traits);
|
||||
const float d = GGML_E8M0_TO_FP32(e);
|
||||
const float inv_d = d > 0.0f ? 1.0f / d : 0.0f;
|
||||
e8m0_base[i] = (char)e;
|
||||
|
||||
uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP6_SOA_QS_PER_BLOCK));
|
||||
for (int j = 0; j < QK_MXFP6; j += 4) {
|
||||
uint8_t vals[4];
|
||||
for (int jj = 0; jj < 4; jj++) {
|
||||
vals[jj] = traits->to_elem(x[i*QK_MXFP6 + j + jj] * inv_d);
|
||||
uint8_t * qs = (uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, qpb));
|
||||
if (traits->bits_per_elem == 8) {
|
||||
for (int j = 0; j < qk; ++j) {
|
||||
qs[j] = traits->to_elem(x[i*qk + j] * inv_d);
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < qk; j += 4) {
|
||||
uint8_t vals[4];
|
||||
for (int jj = 0; jj < 4; jj++) {
|
||||
vals[jj] = traits->to_elem(x[i*qk + j + jj] * inv_d);
|
||||
}
|
||||
pack_fp6x4(vals, &qs[j * 3 / 4]);
|
||||
}
|
||||
pack_fp6x4(vals, &qs[j * 3 / 4]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void dequantize_row_mxfp6_soa_impl(const void * GGML_RESTRICT src, float * GGML_RESTRICT y,
|
||||
int64_t k, const mxfp_elem_traits_t * traits) {
|
||||
assert(k % QK_MXFP6 == 0);
|
||||
const int nb = k / QK_MXFP6;
|
||||
const char * row = (const char *)src;
|
||||
const char * qs_base = row;
|
||||
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK);
|
||||
// Unified SoA dequantize for byte-aligned (FP8) and 6-bit packed (FP6) MXFP formats.
|
||||
static void dequantize_row_mxfp_soa_impl(const void * GGML_RESTRICT src, float * GGML_RESTRICT y,
|
||||
int64_t k, const mxfp_elem_traits_t * traits) {
|
||||
const int qk = 32;
|
||||
assert(k % qk == 0);
|
||||
const int nb = k / qk;
|
||||
const int qpb = traits->qs_per_block;
|
||||
const char * qs_base = (const char *)src;
|
||||
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, qpb);
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const float d = GGML_E8M0_TO_FP32((uint8_t)e8m0_base[i]);
|
||||
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, MXFP6_SOA_QS_PER_BLOCK));
|
||||
for (int j = 0; j < QK_MXFP6; j += 4) {
|
||||
uint8_t vals[4];
|
||||
unpack_fp6x4(&qs[j * 3 / 4], vals);
|
||||
for (int jj = 0; jj < 4; jj++) {
|
||||
y[i*QK_MXFP6 + j + jj] = traits->to_float(vals[jj]) * d;
|
||||
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(i, qpb));
|
||||
if (traits->bits_per_elem == 8) {
|
||||
for (int j = 0; j < qk; ++j) {
|
||||
y[i*qk + j] = traits->to_float(qs[j]) * d;
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < qk; j += 4) {
|
||||
uint8_t vals[4];
|
||||
unpack_fp6x4(&qs[j * 3 / 4], vals);
|
||||
for (int jj = 0; jj < 4; jj++) {
|
||||
y[i*qk + j + jj] = traits->to_float(vals[jj]) * d;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) {
|
||||
quantize_row_mxfp8_soa_impl(x, dst, k, &mxfp8_e4m3_traits);
|
||||
quantize_row_mxfp_soa_impl(x, dst, k, &mxfp8_e4m3_traits);
|
||||
}
|
||||
void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) {
|
||||
dequantize_row_mxfp8_soa_impl(src, y, k, &mxfp8_e4m3_traits);
|
||||
dequantize_row_mxfp_soa_impl(src, y, k, &mxfp8_e4m3_traits);
|
||||
}
|
||||
void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) {
|
||||
quantize_row_mxfp6_soa_impl(x, dst, k, &mxfp6_e2m3_traits);
|
||||
quantize_row_mxfp_soa_impl(x, dst, k, &mxfp6_e2m3_traits);
|
||||
}
|
||||
void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) {
|
||||
dequantize_row_mxfp6_soa_impl(src, y, k, &mxfp6_e2m3_traits);
|
||||
dequantize_row_mxfp_soa_impl(src, y, k, &mxfp6_e2m3_traits);
|
||||
}
|
||||
//
|
||||
// 2-6 bit quantization in super-blocks
|
||||
|
|
@ -2472,7 +2364,7 @@ size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
|
|||
size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
GGML_UNUSED(quant_weights);
|
||||
quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||
return nrow * ggml_row_size(GGML_TYPE_MXFP4_E2M1, n_per_row);
|
||||
return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
|
||||
}
|
||||
|
||||
size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
|
|
@ -2481,18 +2373,6 @@ size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
|
|||
return nrow * ggml_row_size(GGML_TYPE_NVFP4, n_per_row);
|
||||
}
|
||||
|
||||
size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
GGML_UNUSED(quant_weights);
|
||||
quantize_row_mxfp8_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||
return nrow * ggml_row_size(GGML_TYPE_MXFP8_E4M3, n_per_row);
|
||||
}
|
||||
|
||||
size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
GGML_UNUSED(quant_weights);
|
||||
quantize_row_mxfp6_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||
return nrow * ggml_row_size(GGML_TYPE_MXFP6_E2M3, n_per_row);
|
||||
}
|
||||
|
||||
// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
|
||||
|
||||
void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {
|
||||
|
|
@ -5635,15 +5515,15 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
|
|||
{
|
||||
VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
|
||||
} break;
|
||||
case GGML_TYPE_MXFP4_E2M1:
|
||||
case GGML_TYPE_MXFP4:
|
||||
{
|
||||
VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb);
|
||||
} break;
|
||||
case GGML_TYPE_MXFP8_E4M3:
|
||||
case GGML_TYPE_MXFP8:
|
||||
{
|
||||
VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp8, data, nb);
|
||||
} break;
|
||||
case GGML_TYPE_MXFP6_E2M3:
|
||||
case GGML_TYPE_MXFP6:
|
||||
{
|
||||
VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp6, data, nb);
|
||||
} break;
|
||||
|
|
|
|||
|
|
@ -23,9 +23,6 @@ GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 *
|
|||
|
||||
GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void quantize_row_mxfp6_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
|
||||
|
|
@ -52,9 +49,6 @@ GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GG
|
|||
|
||||
GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
// SoA quantize/dequantize for flash attention
|
||||
GGML_API void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
|
||||
GGML_API void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
|
||||
|
|
@ -110,9 +104,6 @@ GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTR
|
|||
|
||||
GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
GGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
GGML_API size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
GGML_API size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
|
||||
// MXFP element converters
|
||||
GGML_API float fp8_e4m3_to_float(uint8_t v);
|
||||
GGML_API uint8_t float_to_fp8_e4m3_rn(float x);
|
||||
|
|
|
|||
|
|
@ -639,7 +639,7 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
|
|||
return dequantize_row_iq4_xs_sycl;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
return dequantize_row_iq4_nl_sycl;
|
||||
case GGML_TYPE_MXFP4_E2M1:
|
||||
case GGML_TYPE_MXFP4:
|
||||
return dequantize_row_mxfp4_sycl;
|
||||
case GGML_TYPE_F32:
|
||||
return convert_unary_sycl<float>;
|
||||
|
|
@ -706,7 +706,7 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
|
|||
return dequantize_row_iq4_xs_sycl;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
return dequantize_row_iq4_nl_sycl;
|
||||
case GGML_TYPE_MXFP4_E2M1:
|
||||
case GGML_TYPE_MXFP4:
|
||||
return dequantize_row_mxfp4_sycl;
|
||||
case GGML_TYPE_F16:
|
||||
return convert_unary_sycl<sycl::half>;
|
||||
|
|
|
|||
|
|
@ -1142,7 +1142,7 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
|
|||
case GGML_TYPE_IQ4_XS:
|
||||
mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_MXFP4_E2M1:
|
||||
case GGML_TYPE_MXFP4:
|
||||
mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -710,7 +710,7 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = true,
|
||||
.from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref,
|
||||
},
|
||||
[GGML_TYPE_MXFP4_E2M1] = {
|
||||
[GGML_TYPE_MXFP4] = {
|
||||
.type_name = "mxfp4",
|
||||
.blck_size = QK_MXFP4,
|
||||
.type_size = sizeof(block_mxfp4),
|
||||
|
|
@ -726,21 +726,17 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
|
|||
.to_float = (ggml_to_float_t) dequantize_row_nvfp4,
|
||||
.from_float_ref = (ggml_from_float_t)quantize_row_nvfp4_ref,
|
||||
},
|
||||
[GGML_TYPE_MXFP8_E4M3] = {
|
||||
[GGML_TYPE_MXFP8] = {
|
||||
.type_name = "mxfp8_e4m3",
|
||||
.blck_size = QK_MXFP8,
|
||||
.type_size = sizeof(block_mxfp8),
|
||||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_mxfp8,
|
||||
.from_float_ref = (ggml_from_float_t)quantize_row_mxfp8_ref,
|
||||
},
|
||||
[GGML_TYPE_MXFP6_E2M3] = {
|
||||
[GGML_TYPE_MXFP6] = {
|
||||
.type_name = "mxfp6_e2m3",
|
||||
.blck_size = QK_MXFP6,
|
||||
.type_size = sizeof(block_mxfp6),
|
||||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_mxfp6,
|
||||
.from_float_ref = (ggml_from_float_t)quantize_row_mxfp6_ref,
|
||||
},
|
||||
[GGML_TYPE_Q2_K] = {
|
||||
.type_name = "q2_K",
|
||||
|
|
@ -1329,25 +1325,25 @@ bool ggml_is_quantized(enum ggml_type type) {
|
|||
}
|
||||
|
||||
bool ggml_is_type_mxfp(enum ggml_type type) {
|
||||
return type == GGML_TYPE_MXFP4_E2M1 ||
|
||||
type == GGML_TYPE_MXFP8_E4M3 ||
|
||||
type == GGML_TYPE_MXFP6_E2M3;
|
||||
return type == GGML_TYPE_MXFP4 ||
|
||||
type == GGML_TYPE_MXFP8 ||
|
||||
type == GGML_TYPE_MXFP6;
|
||||
}
|
||||
|
||||
bool ggml_mxfp_use_hadamard(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_MXFP4_E2M1: return MXFP_USE_HADAMARD_E2M1;
|
||||
case GGML_TYPE_MXFP8_E4M3: return MXFP_USE_HADAMARD_E4M3;
|
||||
case GGML_TYPE_MXFP6_E2M3: return MXFP_USE_HADAMARD_E2M3;
|
||||
case GGML_TYPE_MXFP4: return MXFP_USE_HADAMARD_E2M1;
|
||||
case GGML_TYPE_MXFP8: return MXFP_USE_HADAMARD_E4M3;
|
||||
case GGML_TYPE_MXFP6: return MXFP_USE_HADAMARD_E2M3;
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
|
||||
int ggml_mxfp_qs_per_block(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_MXFP4_E2M1: return MXFP_QS_PER_BLOCK_E2M1;
|
||||
case GGML_TYPE_MXFP8_E4M3: return MXFP_QS_PER_BLOCK_E4M3;
|
||||
case GGML_TYPE_MXFP6_E2M3: return MXFP_QS_PER_BLOCK_E2M3;
|
||||
case GGML_TYPE_MXFP4: return MXFP_QS_PER_BLOCK_E2M1;
|
||||
case GGML_TYPE_MXFP8: return MXFP_QS_PER_BLOCK_E4M3;
|
||||
case GGML_TYPE_MXFP6: return MXFP_QS_PER_BLOCK_E2M3;
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
|
|
@ -7695,10 +7691,10 @@ size_t ggml_quantize_chunk(
|
|||
case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_MXFP4_E2M1: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_MXFP8_E4M3: result = quantize_mxfp8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_MXFP6_E2M3: result = quantize_mxfp6(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_MXFP8: GGML_ABORT("MXFP8 is KV-cache-only (SoA layout) — use from_float_soa");
|
||||
case GGML_TYPE_MXFP6: GGML_ABORT("MXFP6 is KV-cache-only (SoA layout) — use from_float_soa");
|
||||
case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
|
|
|
|||
|
|
@ -457,7 +457,7 @@ static ggml_type llama_tensor_get_type_impl(quantize_state_impl & qs, ggml_type
|
|||
// MoE tensors -> MXFP4
|
||||
// other tensors -> Q8_0
|
||||
if (tensor->ne[2] > 1) {
|
||||
new_type = GGML_TYPE_MXFP4_E2M1;
|
||||
new_type = GGML_TYPE_MXFP4;
|
||||
} else {
|
||||
new_type = GGML_TYPE_Q8_0;
|
||||
}
|
||||
|
|
@ -795,7 +795,7 @@ static ggml_type llama_ftype_get_default_type(llama_ftype ftype) {
|
|||
case LLAMA_FTYPE_MOSTLY_BF16: return GGML_TYPE_BF16;
|
||||
case LLAMA_FTYPE_ALL_F32: return GGML_TYPE_F32;
|
||||
|
||||
case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return GGML_TYPE_MXFP4_E2M1;
|
||||
case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return GGML_TYPE_MXFP4;
|
||||
|
||||
// K-quants
|
||||
case LLAMA_FTYPE_MOSTLY_Q2_K_S:
|
||||
|
|
|
|||
|
|
@ -170,9 +170,9 @@ struct mxfp_soa_fns {
|
|||
};
|
||||
|
||||
static const mxfp_soa_fns mxfp_soa_table[] = {
|
||||
{ GGML_TYPE_MXFP4_E2M1, quantize_row_mxfp4_soa, dequantize_row_mxfp4_soa },
|
||||
{ GGML_TYPE_MXFP8_E4M3, quantize_row_mxfp8_soa, dequantize_row_mxfp8_soa },
|
||||
{ GGML_TYPE_MXFP6_E2M3, quantize_row_mxfp6_soa, dequantize_row_mxfp6_soa },
|
||||
{ GGML_TYPE_MXFP4, quantize_row_mxfp4_soa, dequantize_row_mxfp4_soa },
|
||||
{ GGML_TYPE_MXFP8, quantize_row_mxfp8_soa, dequantize_row_mxfp8_soa },
|
||||
{ GGML_TYPE_MXFP6, quantize_row_mxfp6_soa, dequantize_row_mxfp6_soa },
|
||||
};
|
||||
|
||||
static const mxfp_soa_fns * get_mxfp_soa(ggml_type type) {
|
||||
|
|
@ -3908,7 +3908,7 @@ struct test_mul_mat : public test_case {
|
|||
|
||||
double max_nmse_err(ggml_backend_t backend) override {
|
||||
// for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance
|
||||
if (type_a == GGML_TYPE_MXFP4_E2M1 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
|
||||
if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
|
||||
return 2e-2;
|
||||
}
|
||||
return max_nmse_err();
|
||||
|
|
@ -4044,7 +4044,7 @@ struct test_mul_mat_id : public test_case {
|
|||
|
||||
double max_nmse_err(ggml_backend_t backend) override {
|
||||
// for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance
|
||||
if (type_a == GGML_TYPE_MXFP4_E2M1 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
|
||||
if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
|
||||
return 2e-2;
|
||||
}
|
||||
return max_nmse_err();
|
||||
|
|
@ -7398,7 +7398,7 @@ static const ggml_type all_types[] = {
|
|||
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
|
||||
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
|
||||
GGML_TYPE_Q8_0,
|
||||
GGML_TYPE_MXFP4_E2M1,
|
||||
GGML_TYPE_MXFP4,
|
||||
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
|
||||
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
|
||||
GGML_TYPE_Q6_K,
|
||||
|
|
@ -7414,7 +7414,7 @@ static const ggml_type base_types[] = {
|
|||
GGML_TYPE_Q4_0,
|
||||
GGML_TYPE_Q4_1, // for I8MM tests
|
||||
GGML_TYPE_Q4_K,
|
||||
GGML_TYPE_MXFP4_E2M1,
|
||||
GGML_TYPE_MXFP4,
|
||||
GGML_TYPE_IQ2_XXS
|
||||
};
|
||||
|
||||
|
|
@ -7533,8 +7533,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
}
|
||||
|
||||
// SET_ROWS with Hadamard rotation (exercises the op_params[0] flag used by MXFP KV cache)
|
||||
for (ggml_type type : {GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP8_E4M3,
|
||||
GGML_TYPE_MXFP6_E2M3}) {
|
||||
for (ggml_type type : {GGML_TYPE_MXFP4, GGML_TYPE_MXFP8,
|
||||
GGML_TYPE_MXFP6}) {
|
||||
// ne[0] must be divisible by 32 (Hadamard block size)
|
||||
test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 128, 5, 1, 1 }, { 1, 1 }, 1, false, true));
|
||||
test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 5, 1, 3 }, { 1, 1 }, 1, false, true));
|
||||
|
|
@ -8270,7 +8270,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_mul_mat_id_fusion(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
|
||||
|
||||
// gpt-oss issue with Vulkan mmq_id
|
||||
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4_E2M1, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
|
||||
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
|
||||
|
||||
for (ggml_type type_a : base_types) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
|
||||
|
|
@ -8731,7 +8731,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
|
||||
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
|
||||
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0,
|
||||
GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3,
|
||||
GGML_TYPE_MXFP4, GGML_TYPE_MXFP8, GGML_TYPE_MXFP6,
|
||||
}) {
|
||||
// Non-F16 types: test at D=64, D=72, and D=128.
|
||||
if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72 && hsk != 128) continue;
|
||||
|
|
@ -8760,8 +8760,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
|
||||
// MXFP-specific K/V type combinations (mixed and same-type)
|
||||
// Mixed: mxfp8 K + mxfp4 V, mxfp6 K + mxfp4 V (our recommended configs)
|
||||
for (ggml_type type_K : {GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3}) {
|
||||
for (ggml_type type_V : {GGML_TYPE_MXFP4_E2M1}) {
|
||||
for (ggml_type type_K : {GGML_TYPE_MXFP8, GGML_TYPE_MXFP6}) {
|
||||
for (ggml_type type_V : {GGML_TYPE_MXFP4}) {
|
||||
if (type_K == type_V) continue;
|
||||
for (int nb : {1, 3, 32}) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(
|
||||
|
|
@ -8770,7 +8770,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
}
|
||||
}
|
||||
// Same-type: mxfp8/mxfp8, mxfp6/mxfp6
|
||||
for (ggml_type type_KV : {GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3}) {
|
||||
for (ggml_type type_KV : {GGML_TYPE_MXFP8, GGML_TYPE_MXFP6}) {
|
||||
for (int nb : {1, 3, 32}) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(
|
||||
128, 128, 4, {1, 1}, 512, nb, true, false, 0.0f, 0.0f, GGML_PREC_F32, type_KV, {0, 1, 2, 3}, type_KV));
|
||||
|
|
@ -9000,7 +9000,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||
|
||||
// gpt-oss-20b
|
||||
for (int bs : {1, 4, 8, 512}) {
|
||||
for (ggml_type type_a : {GGML_TYPE_MXFP4_E2M1}) {
|
||||
for (ggml_type type_a : {GGML_TYPE_MXFP4}) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 2880, bs, 2880));
|
||||
test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1));
|
||||
|
|
|
|||
|
|
@ -178,9 +178,9 @@ int main(int argc, char * argv[]) {
|
|||
type == GGML_TYPE_IQ3_S ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
|
||||
type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS :
|
||||
type == GGML_TYPE_NVFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_FP4 :
|
||||
type == GGML_TYPE_MXFP4_E2M1 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 :
|
||||
type == GGML_TYPE_MXFP6_E2M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 :
|
||||
type == GGML_TYPE_MXFP8_E4M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR;
|
||||
type == GGML_TYPE_MXFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 :
|
||||
type == GGML_TYPE_MXFP6 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 :
|
||||
type == GGML_TYPE_MXFP8 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR;
|
||||
failed = !(total_error < max_quantization_error);
|
||||
num_failed += failed;
|
||||
if (failed || verbose) {
|
||||
|
|
@ -202,7 +202,7 @@ int main(int argc, char * argv[]) {
|
|||
? MAX_DOT_PRODUCT_ERROR_TERNARY
|
||||
: type == GGML_TYPE_NVFP4
|
||||
? MAX_DOT_PRODUCT_ERROR_FP4
|
||||
: type == GGML_TYPE_MXFP4_E2M1 || type == GGML_TYPE_MXFP6_E2M3 || type == GGML_TYPE_MXFP8_E4M3
|
||||
: type == GGML_TYPE_MXFP4 || type == GGML_TYPE_MXFP6 || type == GGML_TYPE_MXFP8
|
||||
? MAX_DOT_PRODUCT_ERROR_MXFP
|
||||
: MAX_DOT_PRODUCT_ERROR;
|
||||
failed = !(vec_dot_error < max_allowed_error);
|
||||
|
|
@ -231,9 +231,9 @@ int main(int argc, char * argv[]) {
|
|||
|
||||
const float soa_error = array_rmse(test_data.data(), tmp_out.data(), test_size);
|
||||
const float max_soa_error =
|
||||
type == GGML_TYPE_MXFP4_E2M1 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 :
|
||||
type == GGML_TYPE_MXFP6_E2M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 :
|
||||
type == GGML_TYPE_MXFP8_E4M3 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR;
|
||||
type == GGML_TYPE_MXFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP4 :
|
||||
type == GGML_TYPE_MXFP6 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP6 :
|
||||
type == GGML_TYPE_MXFP8 ? MAX_QUANTIZATION_TOTAL_ERROR_MXFP8 : MAX_QUANTIZATION_TOTAL_ERROR;
|
||||
failed = !(soa_error < max_soa_error);
|
||||
num_failed += failed;
|
||||
if (failed || verbose) {
|
||||
|
|
@ -243,7 +243,7 @@ int main(int argc, char * argv[]) {
|
|||
|
||||
// MXFP traits: SoA required, MXFP6/MXFP8 are KV-cache-only (no AoS dequant)
|
||||
{
|
||||
const ggml_type all_mxfp_types[] = { GGML_TYPE_MXFP4_E2M1, GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3 };
|
||||
const ggml_type all_mxfp_types[] = { GGML_TYPE_MXFP4, GGML_TYPE_MXFP8, GGML_TYPE_MXFP6 };
|
||||
for (ggml_type type : all_mxfp_types) {
|
||||
const auto * cpu = ggml_get_type_traits_cpu(type);
|
||||
|
||||
|
|
@ -255,7 +255,7 @@ int main(int argc, char * argv[]) {
|
|||
}
|
||||
|
||||
// KV-cache-only types: no AoS dequant
|
||||
const ggml_type kv_only_types[] = { GGML_TYPE_MXFP8_E4M3, GGML_TYPE_MXFP6_E2M3 };
|
||||
const ggml_type kv_only_types[] = { GGML_TYPE_MXFP8, GGML_TYPE_MXFP6 };
|
||||
for (ggml_type type : kv_only_types) {
|
||||
const auto * cpu = ggml_get_type_traits_cpu(type);
|
||||
failed = (cpu->to_float != nullptr);
|
||||
|
|
@ -297,9 +297,9 @@ int main(int argc, char * argv[]) {
|
|||
};
|
||||
|
||||
const soa_cross_check checks[] = {
|
||||
{ GGML_TYPE_MXFP4_E2M1, dequantize_row_mxfp4_soa },
|
||||
{ GGML_TYPE_MXFP8_E4M3, dequantize_row_mxfp8_soa },
|
||||
{ GGML_TYPE_MXFP6_E2M3, dequantize_row_mxfp6_soa },
|
||||
{ GGML_TYPE_MXFP4, dequantize_row_mxfp4_soa },
|
||||
{ GGML_TYPE_MXFP8, dequantize_row_mxfp8_soa },
|
||||
{ GGML_TYPE_MXFP6, dequantize_row_mxfp6_soa },
|
||||
};
|
||||
|
||||
for (const auto & c : checks) {
|
||||
|
|
@ -774,9 +774,9 @@ int main(int argc, char * argv[]) {
|
|||
// SoA layout: verify offset macros produce correct byte positions
|
||||
{
|
||||
const struct { ggml_type type; int qs_per_block; } soa_types[] = {
|
||||
{ GGML_TYPE_MXFP4_E2M1, MXFP4_SOA_QS_PER_BLOCK },
|
||||
{ GGML_TYPE_MXFP8_E4M3, MXFP8_SOA_QS_PER_BLOCK },
|
||||
{ GGML_TYPE_MXFP6_E2M3, MXFP6_SOA_QS_PER_BLOCK },
|
||||
{ GGML_TYPE_MXFP4, MXFP4_SOA_QS_PER_BLOCK },
|
||||
{ GGML_TYPE_MXFP8, MXFP8_SOA_QS_PER_BLOCK },
|
||||
{ GGML_TYPE_MXFP6, MXFP6_SOA_QS_PER_BLOCK },
|
||||
};
|
||||
|
||||
for (const auto & st : soa_types) {
|
||||
|
|
@ -864,7 +864,7 @@ int main(int argc, char * argv[]) {
|
|||
dequantize_row_mxfp4(aos_q.data(), aos_out.data(), nelems);
|
||||
|
||||
// Quantize and dequant via SoA
|
||||
const size_t soa_buf_size = ggml_row_size(GGML_TYPE_MXFP4_E2M1, nelems);
|
||||
const size_t soa_buf_size = ggml_row_size(GGML_TYPE_MXFP4, nelems);
|
||||
std::vector<uint8_t> soa_q(soa_buf_size);
|
||||
std::vector<float> soa_out(nelems);
|
||||
quantize_row_mxfp4_soa(input, soa_q.data(), nelems);
|
||||
|
|
@ -901,9 +901,9 @@ int main(int argc, char * argv[]) {
|
|||
};
|
||||
|
||||
const hadamard_pipeline_check pipeline_checks[] = {
|
||||
{ "mxfp4", GGML_TYPE_MXFP4_E2M1, MAX_MXFP_PIPELINE_ERROR_MXFP4 },
|
||||
{ "mxfp8", GGML_TYPE_MXFP8_E4M3, MAX_MXFP_PIPELINE_ERROR_MXFP8 },
|
||||
{ "mxfp6", GGML_TYPE_MXFP6_E2M3, MAX_MXFP_PIPELINE_ERROR_MXFP6 },
|
||||
{ "mxfp4", GGML_TYPE_MXFP4, MAX_MXFP_PIPELINE_ERROR_MXFP4 },
|
||||
{ "mxfp8", GGML_TYPE_MXFP8, MAX_MXFP_PIPELINE_ERROR_MXFP8 },
|
||||
{ "mxfp6", GGML_TYPE_MXFP6, MAX_MXFP_PIPELINE_ERROR_MXFP6 },
|
||||
};
|
||||
|
||||
for (const auto & p : pipeline_checks) {
|
||||
|
|
@ -963,7 +963,7 @@ int main(int argc, char * argv[]) {
|
|||
// zero block produces E8M0=0
|
||||
{
|
||||
float zeros[32] = {};
|
||||
const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP8_E4M3, 32);
|
||||
const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP8, 32);
|
||||
std::vector<uint8_t> buf(buf_size, 0xFF); // fill with 0xFF to detect non-writes
|
||||
|
||||
quantize_row_mxfp8_soa(zeros, buf.data(), 32);
|
||||
|
|
@ -991,7 +991,7 @@ int main(int argc, char * argv[]) {
|
|||
|
||||
// MXFP4
|
||||
{
|
||||
const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP4_E2M1, nelems);
|
||||
const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP4, nelems);
|
||||
std::vector<uint8_t> buf(buf_size);
|
||||
std::vector<float> ref_out(nelems);
|
||||
std::vector<float> manual_out(nelems);
|
||||
|
|
@ -1032,7 +1032,7 @@ int main(int argc, char * argv[]) {
|
|||
|
||||
// MXFP8
|
||||
{
|
||||
const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP8_E4M3, nelems);
|
||||
const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP8, nelems);
|
||||
std::vector<uint8_t> buf(buf_size);
|
||||
std::vector<float> ref_out(nelems);
|
||||
std::vector<float> manual_out(nelems);
|
||||
|
|
@ -1069,7 +1069,7 @@ int main(int argc, char * argv[]) {
|
|||
|
||||
// MXFP6
|
||||
{
|
||||
const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP6_E2M3, nelems);
|
||||
const size_t buf_size = ggml_row_size(GGML_TYPE_MXFP6, nelems);
|
||||
std::vector<uint8_t> buf(buf_size);
|
||||
std::vector<float> ref_out(nelems);
|
||||
std::vector<float> manual_out(nelems);
|
||||
|
|
|
|||
|
|
@ -484,13 +484,13 @@ static ggml_type ggml_type_from_name(const std::string & s) {
|
|||
return GGML_TYPE_IQ4_NL;
|
||||
}
|
||||
if (s == "mxfp4" || s == "mxfp4_e2m1") {
|
||||
return GGML_TYPE_MXFP4_E2M1;
|
||||
return GGML_TYPE_MXFP4;
|
||||
}
|
||||
if (s == "mxfp8" || s == "mxfp8_e4m3") {
|
||||
return GGML_TYPE_MXFP8_E4M3;
|
||||
return GGML_TYPE_MXFP8;
|
||||
}
|
||||
if (s == "mxfp6" || s == "mxfp6_e2m3") {
|
||||
return GGML_TYPE_MXFP6_E2M3;
|
||||
return GGML_TYPE_MXFP6;
|
||||
}
|
||||
return GGML_TYPE_COUNT;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue