From 8036edc99aa8b7d6c7d6bbe2b89ca480c3ad9006 Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Sun, 15 Mar 2026 22:55:27 -0400 Subject: [PATCH] ggml: eliminate hot-path heap allocations and fix tiled MXFP multihead dequant Replace per-row/per-tile std::vector heap allocations with stack buffers in set_rows, one_chunk, and tiled flash attention paths. Fix tiled path to use per-head SoA extraction (matching one_chunk) instead of dequanting the full multihead region per token. --- ggml/src/ggml-cpu/arch-fallback.h | 12 ++--- ggml/src/ggml-cpu/ops.cpp | 74 ++++++++++++++++++++----------- ggml/src/ggml-quants.h | 10 ++--- tests/test-backend-ops.cpp | 3 +- 4 files changed, 61 insertions(+), 38 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 612786e941..3f01c0b1c7 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -16,8 +16,8 @@ #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 -#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0 -#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0 +#define ggml_vec_dot_mxfp8_q8_0_generic ggml_vec_dot_mxfp8_q8_0 +#define ggml_vec_dot_mxfp6_q8_0_generic ggml_vec_dot_mxfp6_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -348,9 +348,9 @@ // All other targets use the scalar generic as the public cpu function. #if !defined(__aarch64__) && !defined(__arm__) && !defined(_M_ARM) && !defined(_M_ARM64) && \ !defined(__x86_64__) && !defined(__i386__) && !defined(_M_IX86) && !defined(_M_X64) -#define dequantize_row_mxfp8_cpu_generic dequantize_row_mxfp8_cpu -#define dequantize_row_mxfp6_cpu_generic dequantize_row_mxfp6_cpu -#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu -#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu +#define dequantize_row_mxfp8_cpu_generic dequantize_row_mxfp8_cpu +#define dequantize_row_mxfp6_cpu_generic dequantize_row_mxfp6_cpu +#define dequantize_row_mxfp4_soa_cpu_generic dequantize_row_mxfp4_soa_cpu +#define dequantize_row_mxfp8_soa_cpu_generic dequantize_row_mxfp8_soa_cpu #define dequantize_row_mxfp6_soa_cpu_generic dequantize_row_mxfp6_soa_cpu #endif diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 12f04905af..81cbbdb4bf 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -5046,6 +5046,13 @@ static void ggml_compute_forward_set_rows_f32( break; } + // Pre-allocate Hadamard temp buffer once outside the hot loop (nc is constant). + // nc == n_embd_k_gqa which is bounded by model architecture (typically <= 8192). + std::vector had_tmp; + if (apply_hadamard) { + had_tmp.resize(nc); + } + for (int64_t i03 = 0; i03 < ne03; ++i03) { for (int64_t i02 = 0; i02 < ne02; ++i02) { for (int64_t i = ir0; i < ir1; ++i) { @@ -5061,13 +5068,12 @@ static void ggml_compute_forward_set_rows_f32( char * dst_row = ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3); if (apply_hadamard) { - std::vector tmp(nc); - memcpy(tmp.data(), src_row, nc * sizeof(float)); - ggml_apply_hadamard_blocks(tmp.data(), nc); + memcpy(had_tmp.data(), src_row, nc * sizeof(float)); + ggml_apply_hadamard_blocks(had_tmp.data(), nc); if (mxfp_soa_quantize) { - mxfp_soa_quantize(tmp.data(), dst_row, nc); + mxfp_soa_quantize(had_tmp.data(), dst_row, nc); } else { - from_float(tmp.data(), dst_row, nc); + from_float(had_tmp.data(), dst_row, nc); } } else { if (mxfp_soa_quantize) { @@ -8465,12 +8471,10 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( // Per-head SoA temp buffer: holds [qs | e8m0] for one head in multihead mode. // For DK=1024 with MXFP8: 32 blocks * 32 qs + 32 e8m0 = 1056 bytes. - const size_t k_head_soa_size = is_mxfp_k ? (size_t)(mxfp.k_head_qs_bytes + mxfp.k_blocks_per_head) : 0; - const size_t v_head_soa_size = is_mxfp_v ? (size_t)(mxfp.v_head_qs_bytes + mxfp.v_blocks_per_head) : 0; - std::vector k_head_soa_vec(k_head_soa_size); - std::vector v_head_soa_vec(v_head_soa_size); - char * k_head_soa = k_head_soa_vec.data(); - char * v_head_soa = v_head_soa_vec.data(); + // Stack-allocated since sizes are bounded by DK/DV <= 1024. + // Max: 1024/32 * 32(qs) + 1024/32 = 1056 bytes (MXFP8). + char k_head_soa[1088]; // 1056 rounded up for alignment + char v_head_soa[1088]; // Thread-local work buffers (constant across ir loop) float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); @@ -8769,6 +8773,15 @@ static void ggml_compute_forward_flash_attn_ext_tiled( static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q; static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV; + // MXFP dequant scratch buffers — allocated once per thread, reused across all tiles. + // DK/DV bounded by 1024, so per-head dequant fits in stack buffers. + float k_dequant_buf[1024]; + float v_dequant_buf[1024]; + + // Per-head SoA temp buffers for multihead extraction (same as one_chunk path). + char k_head_soa[1088]; + char v_head_soa[1088]; + int ir = ir0; while (ir < ir1) { // q indices for the start of this tile @@ -8847,10 +8860,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled( memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float)); memset(V32, 0, KV_TILE_SZ * DV * sizeof(float)); - // dequant scratch buffers for SoA MXFP — allocated once per tile, reused per KV token - std::vector k_soa_buf(mxfp.k_soa_elems); - std::vector v_soa_buf(mxfp.v_soa_elems); - for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) { const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic); @@ -8891,13 +8900,19 @@ static void ggml_compute_forward_flash_attn_ext_tiled( K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk]; } } else if (mxfp.k_dequantize) { - const char * k_soa_base = mxfp.k_multihead - ? ((const char *)k->data + (ic + tk)*nbk1 + ik3*nbk3) - : k_data; - mxfp.k_dequantize(k_soa_base, k_soa_buf.data(), mxfp.k_soa_elems); - const float * k_head = k_soa_buf.data() + (mxfp.k_multihead ? ik2 * DK : 0); + if (mxfp.k_multihead) { + // Per-head extraction: copy only this head's SoA blocks + e8m0, dequant DK elements. + const char * row = (const char *)k->data + (ic + tk)*nbk1 + ik3*nbk3; + const int kqs = ik2 * mxfp.k_head_qs_bytes; + const int ke8 = (int)mxfp.k_head_e8m0_offset + ik2 * mxfp.k_blocks_per_head; + memcpy(k_head_soa, row + kqs, mxfp.k_head_qs_bytes); + memcpy(k_head_soa + mxfp.k_head_qs_bytes, row + ke8, mxfp.k_blocks_per_head); + mxfp.k_dequantize(k_head_soa, k_dequant_buf, DK); + } else { + mxfp.k_dequantize(k_data, k_dequant_buf, DK); + } for (int64_t dk = 0; dk < DK; dk++) { - K_f32[dk * KV_TILE_SZ + tk] = k_head[dk]; + K_f32[dk * KV_TILE_SZ + tk] = k_dequant_buf[dk]; } } else { float k_tmp[1024]; @@ -8965,11 +8980,18 @@ static void ggml_compute_forward_flash_attn_ext_tiled( } else if (v_type == GGML_TYPE_F32) { memcpy(V32 + tk * DV, v_data, DV * sizeof(float)); } else if (mxfp.v_dequantize) { - const char * v_soa_base = mxfp.v_multihead - ? ((const char *)v->data + (ic + tk)*nbv1 + iv3*nbv3) - : v_data; - mxfp.v_dequantize(v_soa_base, v_soa_buf.data(), mxfp.v_soa_elems); - memcpy(V32 + tk * DV, v_soa_buf.data() + (mxfp.v_multihead ? iv2 * DV : 0), DV * sizeof(float)); + if (mxfp.v_multihead) { + // Per-head extraction: copy only this head's SoA blocks + e8m0, dequant DV elements. + const char * row = (const char *)v->data + (ic + tk)*nbv1 + iv3*nbv3; + const int vqs = iv2 * mxfp.v_head_qs_bytes; + const int ve8 = (int)mxfp.v_head_e8m0_offset + iv2 * mxfp.v_blocks_per_head; + memcpy(v_head_soa, row + vqs, mxfp.v_head_qs_bytes); + memcpy(v_head_soa + mxfp.v_head_qs_bytes, row + ve8, mxfp.v_blocks_per_head); + mxfp.v_dequantize(v_head_soa, v_dequant_buf, DV); + } else { + mxfp.v_dequantize(v_data, v_dequant_buf, DV); + } + memcpy(V32 + tk * DV, v_dequant_buf, DV * sizeof(float)); } else { v_to_float(v_data, V32 + tk * DV, DV); } diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index b386446035..c4d2ae86d3 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -57,11 +57,11 @@ GGML_API void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * // SoA (Struct-of-Arrays) quantize/dequantize — canonical reference for flash attention. // Layout: [qs contiguous][e8m0 contiguous] per row. Same total bytes as AoS. -GGML_API void quantize_row_mxfp4_soa (const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); -GGML_API void dequantize_row_mxfp4_soa (const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); -GGML_API void quantize_row_mxfp8_soa (const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); -GGML_API void dequantize_row_mxfp8_soa (const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); -GGML_API void quantize_row_mxfp6_soa (const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); +GGML_API void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); +GGML_API void dequantize_row_mxfp4_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_mxfp8_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); +GGML_API void dequantize_row_mxfp8_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); GGML_API void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6dd245aeab..d123211505 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -150,7 +150,8 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m } } -// SoA quantize/dequantize functions — declared here because ggml-quants.h is not in the test include path. +// MXFP SoA quantize/dequantize (from ggml-quants.h, which is internal to ggml +// and not in the test include path). Signatures must match ggml-quants.h exactly. typedef void (*mxfp_soa_dequantize_fn)(const void *, float *, int64_t); extern "C" { void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);