From c913ab36d2e1d2a68125300f6cdae968fc2f2a83 Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Sun, 15 Mar 2026 20:30:01 -0400 Subject: [PATCH] fix buffer overflows for large DK and multi-head MXFP flash attention - Increase q_mxfp_buf from 512 to 2048 bytes (supports DK up to 1024 with MXFP8) - Replace fixed k_soa[4096]/v_soa[4096] stack arrays with dynamically sized vectors - Replace fixed k_head_soa[320]/v_head_soa[320] with dynamically sized vectors - Add soa_bytes divisibility assertion in test init --- ggml/src/ggml-cpu/ops.cpp | 34 +++++++++++++++++++--------------- tests/test-backend-ops.cpp | 2 +- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 2b24e87ae6..27275ca1e1 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8464,10 +8464,13 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( float v_dequant_buf[1024]; // Per-head SoA temp buffer: holds [qs | e8m0] for one head in multihead mode. - // Max size: 32 bytes qs (mxfp8, DK=128) + 4 bytes e8m0 = 36 bytes per head. - // For DK up to 1024: 256 + 32 = 288 bytes. Use fixed-size stack buffer. - alignas(16) char k_head_soa[320]; // enough for DK up to 1024 with any MXFP type - alignas(16) char v_head_soa[320]; + // For DK=1024 with MXFP8: 32 blocks * 32 qs + 32 e8m0 = 1056 bytes. + const size_t k_head_soa_size = is_mxfp_k ? (size_t)(mxfp.k_head_qs_bytes + mxfp.k_blocks_per_head) : 0; + const size_t v_head_soa_size = is_mxfp_v ? (size_t)(mxfp.v_head_qs_bytes + mxfp.v_blocks_per_head) : 0; + std::vector k_head_soa_vec(k_head_soa_size); + std::vector v_head_soa_vec(v_head_soa_size); + char * k_head_soa = k_head_soa_vec.data(); + char * v_head_soa = v_head_soa_vec.data(); // Thread-local work buffers (constant across ir loop) float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); @@ -8828,9 +8831,10 @@ static void ggml_compute_forward_flash_attn_ext_tiled( if (mxfp.apply_hadamard) { ggml_apply_hadamard_blocks(Q_f32 + tq * DK, DK); } - // SoA round-trip: quantize Q to SoA, then dequant back to float. - uint8_t q_mxfp_buf[512]; // max: DK=256 * 33/32 = 264 bytes (MXFP8) - GGML_ASSERT(ggml_row_size(k->type, DK) <= sizeof(q_mxfp_buf)); + // SoA round-trip: quantize Q to SoA, then dequant back to float + const size_t q_soa_bytes = ggml_row_size(k->type, DK); + GGML_ASSERT(q_soa_bytes <= 2048); + uint8_t q_mxfp_buf[2048]; // max: DK=1024 * 33/32 = 1056 bytes (MXFP8) mxfp.q_quantize(Q_f32 + tq * DK, q_mxfp_buf, DK); mxfp.k_dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK); } @@ -8843,6 +8847,10 @@ static void ggml_compute_forward_flash_attn_ext_tiled( memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float)); memset(V32, 0, KV_TILE_SZ * DV * sizeof(float)); + // dequant scratch buffers for SoA MXFP — allocated once per tile, reused per KV token + std::vector k_soa_buf(mxfp.k_soa_elems); + std::vector v_soa_buf(mxfp.v_soa_elems); + for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) { const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic); @@ -8886,10 +8894,8 @@ static void ggml_compute_forward_flash_attn_ext_tiled( const char * k_soa_base = mxfp.k_multihead ? ((const char *)k->data + (ic + tk)*nbk1 + ik3*nbk3) : k_data; - float k_soa[4096]; - GGML_ASSERT(mxfp.k_soa_elems <= 4096); - mxfp.k_dequantize(k_soa_base, k_soa, mxfp.k_soa_elems); - const float * k_head = k_soa + (mxfp.k_multihead ? ik2 * DK : 0); + mxfp.k_dequantize(k_soa_base, k_soa_buf.data(), mxfp.k_soa_elems); + const float * k_head = k_soa_buf.data() + (mxfp.k_multihead ? ik2 * DK : 0); for (int64_t dk = 0; dk < DK; dk++) { K_f32[dk * KV_TILE_SZ + tk] = k_head[dk]; } @@ -8962,10 +8968,8 @@ static void ggml_compute_forward_flash_attn_ext_tiled( const char * v_soa_base = mxfp.v_multihead ? ((const char *)v->data + (ic + tk)*nbv1 + iv3*nbv3) : v_data; - float v_soa[4096]; - GGML_ASSERT(mxfp.v_soa_elems <= 4096); - mxfp.v_dequantize(v_soa_base, v_soa, mxfp.v_soa_elems); - memcpy(V32 + tk * DV, v_soa + (mxfp.v_multihead ? iv2 * DV : 0), DV * sizeof(float)); + mxfp.v_dequantize(v_soa_base, v_soa_buf.data(), mxfp.v_soa_elems); + memcpy(V32 + tk * DV, v_soa_buf.data() + (mxfp.v_multihead ? iv2 * DV : 0), DV * sizeof(float)); } else { v_to_float(v_data, V32 + tk * DV, DV); } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 3915ab4db6..6dd245aeab 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -202,6 +202,7 @@ static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float // If soa_bytes < nb1 (per-head), iterate over (ne1 * ne2 * ne3) regions with stride nb1. // We use strides to compute offsets, handling views and permutations correctly. const int64_t heads_per_region = (int64_t)(soa_bytes / head_row_sz); + GGML_ASSERT(soa_bytes % head_row_sz == 0 && "soa_bytes must be a multiple of head_row_sz"); // For multi-head regions, we step by nb1 (KV-position stride) between regions. // For per-head, we step through all dimensions. @@ -210,7 +211,6 @@ static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float if (heads_per_region > 1) { // Multi-head SoA: iterate over (kv_positions * batches), each region = nb1 bytes for (int64_t i3 = 0; i3 < ne3; i3++) { - // ne2/heads_per_region = number of head groups (for GQA broadcast, usually 1) const int64_t n_groups = ne2 / heads_per_region; for (int64_t ig = 0; ig < n_groups; ig++) { for (int64_t i1 = 0; i1 < ne1; i1++) {