fix buffer overflows for large DK and multi-head MXFP flash attention

- Increase q_mxfp_buf from 512 to 2048 bytes (supports DK up to 1024 with MXFP8)
- Replace fixed k_soa[4096]/v_soa[4096] stack arrays with dynamically sized vectors
- Replace fixed k_head_soa[320]/v_head_soa[320] with dynamically sized vectors
- Add soa_bytes divisibility assertion in test init
This commit is contained in:
Tim Burke 2026-03-15 20:30:01 -04:00
parent f603c036ec
commit c913ab36d2
2 changed files with 20 additions and 16 deletions

View File

@ -8464,10 +8464,13 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
float v_dequant_buf[1024];
// Per-head SoA temp buffer: holds [qs | e8m0] for one head in multihead mode.
// Max size: 32 bytes qs (mxfp8, DK=128) + 4 bytes e8m0 = 36 bytes per head.
// For DK up to 1024: 256 + 32 = 288 bytes. Use fixed-size stack buffer.
alignas(16) char k_head_soa[320]; // enough for DK up to 1024 with any MXFP type
alignas(16) char v_head_soa[320];
// For DK=1024 with MXFP8: 32 blocks * 32 qs + 32 e8m0 = 1056 bytes.
const size_t k_head_soa_size = is_mxfp_k ? (size_t)(mxfp.k_head_qs_bytes + mxfp.k_blocks_per_head) : 0;
const size_t v_head_soa_size = is_mxfp_v ? (size_t)(mxfp.v_head_qs_bytes + mxfp.v_blocks_per_head) : 0;
std::vector<char> k_head_soa_vec(k_head_soa_size);
std::vector<char> v_head_soa_vec(v_head_soa_size);
char * k_head_soa = k_head_soa_vec.data();
char * v_head_soa = v_head_soa_vec.data();
// Thread-local work buffers (constant across ir loop)
float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32);
@ -8828,9 +8831,10 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
if (mxfp.apply_hadamard) {
ggml_apply_hadamard_blocks(Q_f32 + tq * DK, DK);
}
// SoA round-trip: quantize Q to SoA, then dequant back to float.
uint8_t q_mxfp_buf[512]; // max: DK=256 * 33/32 = 264 bytes (MXFP8)
GGML_ASSERT(ggml_row_size(k->type, DK) <= sizeof(q_mxfp_buf));
// SoA round-trip: quantize Q to SoA, then dequant back to float
const size_t q_soa_bytes = ggml_row_size(k->type, DK);
GGML_ASSERT(q_soa_bytes <= 2048);
uint8_t q_mxfp_buf[2048]; // max: DK=1024 * 33/32 = 1056 bytes (MXFP8)
mxfp.q_quantize(Q_f32 + tq * DK, q_mxfp_buf, DK);
mxfp.k_dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK);
}
@ -8843,6 +8847,10 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));
memset(V32, 0, KV_TILE_SZ * DV * sizeof(float));
// dequant scratch buffers for SoA MXFP — allocated once per tile, reused per KV token
std::vector<float> k_soa_buf(mxfp.k_soa_elems);
std::vector<float> v_soa_buf(mxfp.v_soa_elems);
for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);
@ -8886,10 +8894,8 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
const char * k_soa_base = mxfp.k_multihead
? ((const char *)k->data + (ic + tk)*nbk1 + ik3*nbk3)
: k_data;
float k_soa[4096];
GGML_ASSERT(mxfp.k_soa_elems <= 4096);
mxfp.k_dequantize(k_soa_base, k_soa, mxfp.k_soa_elems);
const float * k_head = k_soa + (mxfp.k_multihead ? ik2 * DK : 0);
mxfp.k_dequantize(k_soa_base, k_soa_buf.data(), mxfp.k_soa_elems);
const float * k_head = k_soa_buf.data() + (mxfp.k_multihead ? ik2 * DK : 0);
for (int64_t dk = 0; dk < DK; dk++) {
K_f32[dk * KV_TILE_SZ + tk] = k_head[dk];
}
@ -8962,10 +8968,8 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
const char * v_soa_base = mxfp.v_multihead
? ((const char *)v->data + (ic + tk)*nbv1 + iv3*nbv3)
: v_data;
float v_soa[4096];
GGML_ASSERT(mxfp.v_soa_elems <= 4096);
mxfp.v_dequantize(v_soa_base, v_soa, mxfp.v_soa_elems);
memcpy(V32 + tk * DV, v_soa + (mxfp.v_multihead ? iv2 * DV : 0), DV * sizeof(float));
mxfp.v_dequantize(v_soa_base, v_soa_buf.data(), mxfp.v_soa_elems);
memcpy(V32 + tk * DV, v_soa_buf.data() + (mxfp.v_multihead ? iv2 * DV : 0), DV * sizeof(float));
} else {
v_to_float(v_data, V32 + tk * DV, DV);
}

View File

@ -202,6 +202,7 @@ static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float
// If soa_bytes < nb1 (per-head), iterate over (ne1 * ne2 * ne3) regions with stride nb1.
// We use strides to compute offsets, handling views and permutations correctly.
const int64_t heads_per_region = (int64_t)(soa_bytes / head_row_sz);
GGML_ASSERT(soa_bytes % head_row_sz == 0 && "soa_bytes must be a multiple of head_row_sz");
// For multi-head regions, we step by nb1 (KV-position stride) between regions.
// For per-head, we step through all dimensions.
@ -210,7 +211,6 @@ static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float
if (heads_per_region > 1) {
// Multi-head SoA: iterate over (kv_positions * batches), each region = nb1 bytes
for (int64_t i3 = 0; i3 < ne3; i3++) {
// ne2/heads_per_region = number of head groups (for GQA broadcast, usually 1)
const int64_t n_groups = ne2 / heads_per_region;
for (int64_t ig = 0; ig < n_groups; ig++) {
for (int64_t i1 = 0; i1 < ne1; i1++) {