ggml: address PR review — fix buffer overflows, add assertions, normalize MXFP6 naming
Fix potential buffer overflows flagged in PR #20609 review: - set_rows: replace fixed float tmp[1024] with std::vector for large n_embd_k_gqa - tiled FA: size q_mxfp_buf with ggml_row_size guard instead of fixed 1024 - one_chunk FA: pre-allocate k/v dequant buffers from mxfp.{k,v}_soa_elems instead of hard-coded float[4096] stack arrays - kv-cache: assert n_embd_k_gqa % qk == 0 before integer division - test init: assert soa_bytes % block_size == 0 Normalize MXFP6 function naming to match MXFP8 convention (short form without element format suffix): mxfp6_e2m3 → mxfp6 in all function identifiers across 14 files. Format-specific items (type enums, traits, lookup tables, constants) retain their _e2m3 suffix.
This commit is contained in:
parent
5c3a9523ef
commit
a51ff77fae
|
|
@ -17,7 +17,7 @@
|
|||
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
||||
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
|
||||
#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0
|
||||
#define ggml_vec_dot_mxfp6_e2m3_q8_0_generic ggml_vec_dot_mxfp6_e2m3_q8_0
|
||||
#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0
|
||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
|
||||
|
|
@ -349,7 +349,7 @@
|
|||
#if !defined(__aarch64__) && !defined(__arm__) && !defined(_M_ARM) && !defined(_M_ARM64) && \
|
||||
!defined(__x86_64__) && !defined(__i386__) && !defined(_M_IX86) && !defined(_M_X64)
|
||||
#define dequantize_row_mxfp8_cpu_generic dequantize_row_mxfp8_cpu
|
||||
#define dequantize_row_mxfp6_e2m3_cpu_generic dequantize_row_mxfp6_e2m3_cpu
|
||||
#define dequantize_row_mxfp6_cpu_generic dequantize_row_mxfp6_cpu
|
||||
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
|
||||
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
|
||||
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
|
||||
|
|
|
|||
|
|
@ -4333,7 +4333,7 @@ static inline void ggml_vec_dot_mxfp6_q8_0_neon(
|
|||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
|
||||
#if defined(__ARM_NEON)
|
||||
|
|
@ -4342,7 +4342,7 @@ void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, con
|
|||
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
|
||||
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE);
|
||||
#else
|
||||
ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -4471,13 +4471,13 @@ void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRIC
|
|||
#endif
|
||||
}
|
||||
|
||||
void dequantize_row_mxfp6_e2m3_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
#if defined(__ARM_NEON)
|
||||
dequantize_row_mxfp6_neon(x, y, k, sizeof(block_mxfp6),
|
||||
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
|
||||
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE);
|
||||
#else
|
||||
dequantize_row_mxfp6_e2m3_cpu_generic(x, y, k);
|
||||
dequantize_row_mxfp6_cpu_generic(x, y, k);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2165,6 +2165,6 @@ void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2307,6 +2307,6 @@ void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3612,6 +3612,6 @@ void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1468,6 +1468,6 @@ void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1227,6 +1227,6 @@ void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3995,7 +3995,7 @@ static inline void ggml_vec_dot_mxfp6_q8_0_avx2(
|
|||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
|
||||
#if defined(__AVX2__)
|
||||
|
|
@ -4004,7 +4004,7 @@ void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, con
|
|||
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
|
||||
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE);
|
||||
#else
|
||||
ggml_vec_dot_mxfp6_e2m3_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -4130,13 +4130,13 @@ void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRIC
|
|||
#endif
|
||||
}
|
||||
|
||||
void dequantize_row_mxfp6_e2m3_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
#if defined(__AVX2__)
|
||||
dequantize_row_mxfp6_avx2(x, y, k, sizeof(block_mxfp6),
|
||||
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
|
||||
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE);
|
||||
#else
|
||||
dequantize_row_mxfp6_e2m3_cpu_generic(x, y, k);
|
||||
dequantize_row_mxfp6_cpu_generic(x, y, k);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -284,9 +284,9 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
|||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_MXFP6_E2M3] = {
|
||||
.from_float = quantize_row_mxfp6_e2m3,
|
||||
.to_float = dequantize_row_mxfp6_e2m3_cpu,
|
||||
.vec_dot = ggml_vec_dot_mxfp6_e2m3_q8_0,
|
||||
.from_float = quantize_row_mxfp6,
|
||||
.to_float = dequantize_row_mxfp6_cpu,
|
||||
.vec_dot = ggml_vec_dot_mxfp6_q8_0,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -5061,14 +5061,13 @@ static void ggml_compute_forward_set_rows_f32(
|
|||
char * dst_row = ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
if (apply_hadamard) {
|
||||
GGML_ASSERT(nc <= 1024);
|
||||
float tmp[1024];
|
||||
memcpy(tmp, src_row, nc * sizeof(float));
|
||||
ggml_apply_hadamard_blocks(tmp, nc);
|
||||
std::vector<float> tmp(nc);
|
||||
memcpy(tmp.data(), src_row, nc * sizeof(float));
|
||||
ggml_apply_hadamard_blocks(tmp.data(), nc);
|
||||
if (mxfp_soa_quantize) {
|
||||
mxfp_soa_quantize(tmp, dst_row, nc);
|
||||
mxfp_soa_quantize(tmp.data(), dst_row, nc);
|
||||
} else {
|
||||
from_float(tmp, dst_row, nc);
|
||||
from_float(tmp.data(), dst_row, nc);
|
||||
}
|
||||
} else {
|
||||
if (mxfp_soa_quantize) {
|
||||
|
|
@ -8418,6 +8417,10 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|||
|
||||
int ith = params->ith;
|
||||
|
||||
// Pre-allocate dequant buffers for MXFP SoA (avoids per-iteration allocation)
|
||||
std::vector<float> k_dequant_buf(is_mxfp_k ? mxfp.k_soa_elems : 0);
|
||||
std::vector<float> v_dequant_buf(is_mxfp_v ? mxfp.v_soa_elems : 0);
|
||||
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// q indices
|
||||
const int iq3 = ir/(neq2*neq1);
|
||||
|
|
@ -8497,10 +8500,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|||
const char * k_soa_base = mxfp.k_multihead
|
||||
? ((const char *) k->data + ic*nbk1 + ik3*nbk3)
|
||||
: k_data;
|
||||
float k_soa_f32[4096];
|
||||
GGML_ASSERT(mxfp.k_soa_elems <= 4096);
|
||||
mxfp.k_dequantize(k_soa_base, k_soa_f32, mxfp.k_soa_elems);
|
||||
const float * k_head = k_soa_f32 + (mxfp.k_multihead ? ik2 * DK : 0);
|
||||
mxfp.k_dequantize(k_soa_base, k_dequant_buf.data(), mxfp.k_soa_elems);
|
||||
const float * k_head = k_dequant_buf.data() + (mxfp.k_multihead ? ik2 * DK : 0);
|
||||
ggml_vec_dot_f32(DK, &s, 0, k_head, 0, Q_f32, 0, 1);
|
||||
} else {
|
||||
kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
|
||||
|
|
@ -8554,10 +8555,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|||
const char * v_soa_base = mxfp.v_multihead
|
||||
? ((const char *) v->data + ic*nbv1 + iv3*nbv3)
|
||||
: v_data;
|
||||
float v_soa_f32[4096];
|
||||
GGML_ASSERT(mxfp.v_soa_elems <= 4096);
|
||||
mxfp.v_dequantize(v_soa_base, v_soa_f32, mxfp.v_soa_elems);
|
||||
ggml_vec_mad_f32(DV, VKQ32, v_soa_f32 + (mxfp.v_multihead ? iv2 * DV : 0), vs);
|
||||
mxfp.v_dequantize(v_soa_base, v_dequant_buf.data(), mxfp.v_soa_elems);
|
||||
ggml_vec_mad_f32(DV, VKQ32, v_dequant_buf.data() + (mxfp.v_multihead ? iv2 * DV : 0), vs);
|
||||
} else if (v_to_float) {
|
||||
v_to_float(v_data, V32, DV);
|
||||
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
|
||||
|
|
@ -8765,7 +8764,8 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
|||
ggml_apply_hadamard_blocks(Q_f32 + tq * DK, DK);
|
||||
}
|
||||
// SoA round-trip: quantize Q to SoA, then dequant back to float.
|
||||
uint8_t q_mxfp_buf[1024];
|
||||
uint8_t q_mxfp_buf[512]; // max: DK=256 * 33/32 = 264 bytes (MXFP8)
|
||||
GGML_ASSERT(ggml_row_size(k->type, DK) <= sizeof(q_mxfp_buf));
|
||||
mxfp.q_quantize(Q_f32 + tq * DK, q_mxfp_buf, DK);
|
||||
mxfp.k_dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -58,8 +58,8 @@ void quantize_row_mxfp8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, i
|
|||
quantize_row_mxfp8_ref(x, y, k);
|
||||
}
|
||||
|
||||
void quantize_row_mxfp6_e2m3(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
|
||||
quantize_row_mxfp6_e2m3_ref(x, y, k);
|
||||
void quantize_row_mxfp6(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
|
||||
quantize_row_mxfp6_ref(x, y, k);
|
||||
}
|
||||
|
||||
//
|
||||
|
|
@ -301,14 +301,14 @@ void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||
(ggml_to_float_t)dequantize_row_mxfp8);
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp6_e2m3_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
ggml_vec_dot_mxfp_q8_0_impl(n, s, vx, sizeof(block_mxfp6), vy,
|
||||
(ggml_to_float_t)dequantize_row_mxfp6_e2m3);
|
||||
(ggml_to_float_t)dequantize_row_mxfp6);
|
||||
}
|
||||
|
||||
// Generic (scalar) dequant wrappers — delegates to ggml-quants.c reference implementations.
|
||||
|
|
@ -316,8 +316,8 @@ void ggml_vec_dot_mxfp6_e2m3_q8_0_generic(int n, float * GGML_RESTRICT s, size_t
|
|||
void dequantize_row_mxfp8_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
dequantize_row_mxfp8(x, y, k);
|
||||
}
|
||||
void dequantize_row_mxfp6_e2m3_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
dequantize_row_mxfp6_e2m3(x, y, k);
|
||||
void dequantize_row_mxfp6_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
dequantize_row_mxfp6(x, y, k);
|
||||
}
|
||||
void dequantize_row_mxfp4_soa_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
dequantize_row_mxfp4_soa(x, y, k);
|
||||
|
|
|
|||
|
|
@ -22,11 +22,11 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
|
|||
void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_mxfp8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_mxfp6_e2m3(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_mxfp6(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
// Dequantization (SIMD-optimized, arch-dispatched)
|
||||
void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
void dequantize_row_mxfp6_e2m3_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
|
|
@ -51,7 +51,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|||
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
|
@ -85,10 +85,10 @@ void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c
|
|||
void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_mxfp8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_mxfp6_e2m3_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_mxfp6_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
void dequantize_row_mxfp8_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
void dequantize_row_mxfp6_e2m3_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
void dequantize_row_mxfp6_cpu_generic(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
// SoA dequant (SIMD-optimized for FA)
|
||||
void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
|
|
|
|||
|
|
@ -797,11 +797,11 @@ void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_REST
|
|||
dequantize_row_mxfp8_impl(x, y, k, &mxfp8_e4m3_traits);
|
||||
}
|
||||
|
||||
void quantize_row_mxfp6_e2m3_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k) {
|
||||
void quantize_row_mxfp6_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k) {
|
||||
quantize_row_mxfp6_impl(x, y, k, &mxfp6_e2m3_traits);
|
||||
}
|
||||
|
||||
void dequantize_row_mxfp6_e2m3(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
dequantize_row_mxfp6_impl(x, y, k, &mxfp6_e2m3_traits);
|
||||
}
|
||||
|
||||
|
|
@ -2627,9 +2627,9 @@ size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
|
|||
return nrow * ggml_row_size(GGML_TYPE_MXFP8_E4M3, n_per_row);
|
||||
}
|
||||
|
||||
size_t quantize_mxfp6_e2m3(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
GGML_UNUSED(quant_weights);
|
||||
quantize_row_mxfp6_e2m3_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||
quantize_row_mxfp6_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||
return nrow * ggml_row_size(GGML_TYPE_MXFP6_E2M3, n_per_row);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 *
|
|||
GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void quantize_row_mxfp6_e2m3_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void quantize_row_mxfp6_ref(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
|
||||
|
|
@ -53,7 +53,7 @@ GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GG
|
|||
GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void dequantize_row_mxfp8(const block_mxfp8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void dequantize_row_mxfp6_e2m3(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
// SoA (Struct-of-Arrays) quantize/dequantize — canonical reference for flash attention.
|
||||
// Layout: [qs contiguous][e8m0 contiguous] per row. Same total bytes as AoS.
|
||||
|
|
@ -112,7 +112,7 @@ GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTR
|
|||
GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
GGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
GGML_API size_t quantize_mxfp8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
GGML_API size_t quantize_mxfp6_e2m3(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
GGML_API size_t quantize_mxfp6(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
|
||||
//
|
||||
// MXFP element-level conversion functions (reference implementations)
|
||||
|
|
|
|||
|
|
@ -739,8 +739,8 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
|
|||
.blck_size = QK_MXFP6,
|
||||
.type_size = sizeof(block_mxfp6),
|
||||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_mxfp6_e2m3,
|
||||
.from_float_ref = (ggml_from_float_t)quantize_row_mxfp6_e2m3_ref,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_mxfp6,
|
||||
.from_float_ref = (ggml_from_float_t)quantize_row_mxfp6_ref,
|
||||
},
|
||||
[GGML_TYPE_Q2_K] = {
|
||||
.type_name = "q2_K",
|
||||
|
|
@ -7692,7 +7692,7 @@ size_t ggml_quantize_chunk(
|
|||
case GGML_TYPE_MXFP4_E2M1: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_MXFP8_E4M3: result = quantize_mxfp8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_MXFP6_E2M3: result = quantize_mxfp6_e2m3(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_MXFP6_E2M3: result = quantize_mxfp6(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
|
|
|
|||
|
|
@ -141,6 +141,7 @@ llama_kv_cache::llama_kv_cache(
|
|||
const bool is_mxfp_k = ggml_is_type_mxfp(type_k);
|
||||
if (is_mxfp_k) {
|
||||
const int qk = (int)ggml_blck_size(type_k); // 32 for all MXFP types
|
||||
GGML_ASSERT(n_embd_k_gqa % qk == 0 && "MXFP K cache requires n_embd_k_gqa divisible by block size");
|
||||
const int blocks = (int)n_embd_k_gqa / qk;
|
||||
const int blocks_aligned = (blocks + 15) & ~15; // align to 16
|
||||
n_embd_k_alloc = (uint32_t)(blocks_aligned * qk);
|
||||
|
|
|
|||
|
|
@ -181,6 +181,7 @@ static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float
|
|||
const size_t block_size = ggml_type_size(tensor->type);
|
||||
const size_t head_row_sz = ggml_row_size(tensor->type, tensor->ne[0]);
|
||||
if (soa_bytes == 0) { soa_bytes = head_row_sz; }
|
||||
GGML_ASSERT(soa_bytes % block_size == 0 && "soa_bytes must be a multiple of block_size");
|
||||
const int64_t soa_elems = (int64_t)(soa_bytes / block_size) * qk;
|
||||
|
||||
std::default_random_engine gen(42);
|
||||
|
|
|
|||
Loading…
Reference in New Issue