ggml: eliminate hot-path heap allocations and fix tiled MXFP multihead dequant

Replace per-row/per-tile std::vector heap allocations with stack buffers
in set_rows, one_chunk, and tiled flash attention paths. Fix tiled path
to use per-head SoA extraction (matching one_chunk) instead of dequanting
the full multihead region per token.
This commit is contained in:
Tim Burke 2026-03-15 22:55:27 -04:00
parent b8e8d291d1
commit 8036edc99a
4 changed files with 61 additions and 38 deletions

View File

@ -16,8 +16,8 @@
#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0
#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0
#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0
#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
@ -348,9 +348,9 @@
// All other targets use the scalar generic as the public cpu function.
#if !defined(__aarch64__) && !defined(__arm__) && !defined(_M_ARM) && !defined(_M_ARM64) && \
!defined(__x86_64__) && !defined(__i386__) && !defined(_M_IX86) && !defined(_M_X64)
#define dequantize_row_mxfp8_cpu_generic dequantize_row_mxfp8_cpu
#define dequantize_row_mxfp6_cpu_generic dequantize_row_mxfp6_cpu
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
#define dequantize_row_mxfp8_cpu_generic dequantize_row_mxfp8_cpu
#define dequantize_row_mxfp6_cpu_generic dequantize_row_mxfp6_cpu
#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu
#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu
#define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu
#endif

View File

@ -5046,6 +5046,13 @@ static void ggml_compute_forward_set_rows_f32(
break;
}
// Pre-allocate Hadamard temp buffer once outside the hot loop (nc is constant).
// nc == n_embd_k_gqa which is bounded by model architecture (typically <= 8192).
std::vector<float> had_tmp;
if (apply_hadamard) {
had_tmp.resize(nc);
}
for (int64_t i03 = 0; i03 < ne03; ++i03) {
for (int64_t i02 = 0; i02 < ne02; ++i02) {
for (int64_t i = ir0; i < ir1; ++i) {
@ -5061,13 +5068,12 @@ static void ggml_compute_forward_set_rows_f32(
char * dst_row = ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3);
if (apply_hadamard) {
std::vector<float> tmp(nc);
memcpy(tmp.data(), src_row, nc * sizeof(float));
ggml_apply_hadamard_blocks(tmp.data(), nc);
memcpy(had_tmp.data(), src_row, nc * sizeof(float));
ggml_apply_hadamard_blocks(had_tmp.data(), nc);
if (mxfp_soa_quantize) {
mxfp_soa_quantize(tmp.data(), dst_row, nc);
mxfp_soa_quantize(had_tmp.data(), dst_row, nc);
} else {
from_float(tmp.data(), dst_row, nc);
from_float(had_tmp.data(), dst_row, nc);
}
} else {
if (mxfp_soa_quantize) {
@ -8465,12 +8471,10 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
// Per-head SoA temp buffer: holds [qs | e8m0] for one head in multihead mode.
// For DK=1024 with MXFP8: 32 blocks * 32 qs + 32 e8m0 = 1056 bytes.
const size_t k_head_soa_size = is_mxfp_k ? (size_t)(mxfp.k_head_qs_bytes + mxfp.k_blocks_per_head) : 0;
const size_t v_head_soa_size = is_mxfp_v ? (size_t)(mxfp.v_head_qs_bytes + mxfp.v_blocks_per_head) : 0;
std::vector<char> k_head_soa_vec(k_head_soa_size);
std::vector<char> v_head_soa_vec(v_head_soa_size);
char * k_head_soa = k_head_soa_vec.data();
char * v_head_soa = v_head_soa_vec.data();
// Stack-allocated since sizes are bounded by DK/DV <= 1024.
// Max: 1024/32 * 32(qs) + 1024/32 = 1056 bytes (MXFP8).
char k_head_soa[1088]; // 1056 rounded up for alignment
char v_head_soa[1088];
// Thread-local work buffers (constant across ir loop)
float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32);
@ -8769,6 +8773,15 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
// MXFP dequant scratch buffers — allocated once per thread, reused across all tiles.
// DK/DV bounded by 1024, so per-head dequant fits in stack buffers.
float k_dequant_buf[1024];
float v_dequant_buf[1024];
// Per-head SoA temp buffers for multihead extraction (same as one_chunk path).
char k_head_soa[1088];
char v_head_soa[1088];
int ir = ir0;
while (ir < ir1) {
// q indices for the start of this tile
@ -8847,10 +8860,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));
memset(V32, 0, KV_TILE_SZ * DV * sizeof(float));
// dequant scratch buffers for SoA MXFP — allocated once per tile, reused per KV token
std::vector<float> k_soa_buf(mxfp.k_soa_elems);
std::vector<float> v_soa_buf(mxfp.v_soa_elems);
for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);
@ -8891,13 +8900,19 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
}
} else if (mxfp.k_dequantize) {
const char * k_soa_base = mxfp.k_multihead
? ((const char *)k->data + (ic + tk)*nbk1 + ik3*nbk3)
: k_data;
mxfp.k_dequantize(k_soa_base, k_soa_buf.data(), mxfp.k_soa_elems);
const float * k_head = k_soa_buf.data() + (mxfp.k_multihead ? ik2 * DK : 0);
if (mxfp.k_multihead) {
// Per-head extraction: copy only this head's SoA blocks + e8m0, dequant DK elements.
const char * row = (const char *)k->data + (ic + tk)*nbk1 + ik3*nbk3;
const int kqs = ik2 * mxfp.k_head_qs_bytes;
const int ke8 = (int)mxfp.k_head_e8m0_offset + ik2 * mxfp.k_blocks_per_head;
memcpy(k_head_soa, row + kqs, mxfp.k_head_qs_bytes);
memcpy(k_head_soa + mxfp.k_head_qs_bytes, row + ke8, mxfp.k_blocks_per_head);
mxfp.k_dequantize(k_head_soa, k_dequant_buf, DK);
} else {
mxfp.k_dequantize(k_data, k_dequant_buf, DK);
}
for (int64_t dk = 0; dk < DK; dk++) {
K_f32[dk * KV_TILE_SZ + tk] = k_head[dk];
K_f32[dk * KV_TILE_SZ + tk] = k_dequant_buf[dk];
}
} else {
float k_tmp[1024];
@ -8965,11 +8980,18 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
} else if (v_type == GGML_TYPE_F32) {
memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
} else if (mxfp.v_dequantize) {
const char * v_soa_base = mxfp.v_multihead
? ((const char *)v->data + (ic + tk)*nbv1 + iv3*nbv3)
: v_data;
mxfp.v_dequantize(v_soa_base, v_soa_buf.data(), mxfp.v_soa_elems);
memcpy(V32 + tk * DV, v_soa_buf.data() + (mxfp.v_multihead ? iv2 * DV : 0), DV * sizeof(float));
if (mxfp.v_multihead) {
// Per-head extraction: copy only this head's SoA blocks + e8m0, dequant DV elements.
const char * row = (const char *)v->data + (ic + tk)*nbv1 + iv3*nbv3;
const int vqs = iv2 * mxfp.v_head_qs_bytes;
const int ve8 = (int)mxfp.v_head_e8m0_offset + iv2 * mxfp.v_blocks_per_head;
memcpy(v_head_soa, row + vqs, mxfp.v_head_qs_bytes);
memcpy(v_head_soa + mxfp.v_head_qs_bytes, row + ve8, mxfp.v_blocks_per_head);
mxfp.v_dequantize(v_head_soa, v_dequant_buf, DV);
} else {
mxfp.v_dequantize(v_data, v_dequant_buf, DV);
}
memcpy(V32 + tk * DV, v_dequant_buf, DV * sizeof(float));
} else {
v_to_float(v_data, V32 + tk * DV, DV);
}

View File

@ -57,11 +57,11 @@ GGML_API void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float *
// SoA (Struct-of-Arrays) quantize/dequantize — canonical reference for flash attention.
// Layout: [qs contiguous][e8m0 contiguous] per row. Same total bytes as AoS.
GGML_API void quantize_row_mxfp4_soa (const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
GGML_API void dequantize_row_mxfp4_soa (const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp8_soa (const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
GGML_API void dequantize_row_mxfp8_soa (const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp6_soa (const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
GGML_API void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
GGML_API void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
GGML_API void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
GGML_API void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);

View File

@ -150,7 +150,8 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
}
}
// SoA quantize/dequantize functions — declared here because ggml-quants.h is not in the test include path.
// MXFP SoA quantize/dequantize (from ggml-quants.h, which is internal to ggml
// and not in the test include path). Signatures must match ggml-quants.h exactly.
typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t);
extern "C" {
void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);