fix buffer overflows for large DK and multi-head MXFP flash attention
- Increase q_mxfp_buf from 512 to 2048 bytes (supports DK up to 1024 with MXFP8) - Replace fixed k_soa[4096]/v_soa[4096] stack arrays with dynamically sized vectors - Replace fixed k_head_soa[320]/v_head_soa[320] with dynamically sized vectors - Add soa_bytes divisibility assertion in test init
This commit is contained in:
parent
f603c036ec
commit
c913ab36d2
|
|
@ -8464,10 +8464,13 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|||
float v_dequant_buf[1024];
|
||||
|
||||
// Per-head SoA temp buffer: holds [qs | e8m0] for one head in multihead mode.
|
||||
// Max size: 32 bytes qs (mxfp8, DK=128) + 4 bytes e8m0 = 36 bytes per head.
|
||||
// For DK up to 1024: 256 + 32 = 288 bytes. Use fixed-size stack buffer.
|
||||
alignas(16) char k_head_soa[320]; // enough for DK up to 1024 with any MXFP type
|
||||
alignas(16) char v_head_soa[320];
|
||||
// For DK=1024 with MXFP8: 32 blocks * 32 qs + 32 e8m0 = 1056 bytes.
|
||||
const size_t k_head_soa_size = is_mxfp_k ? (size_t)(mxfp.k_head_qs_bytes + mxfp.k_blocks_per_head) : 0;
|
||||
const size_t v_head_soa_size = is_mxfp_v ? (size_t)(mxfp.v_head_qs_bytes + mxfp.v_blocks_per_head) : 0;
|
||||
std::vector<char> k_head_soa_vec(k_head_soa_size);
|
||||
std::vector<char> v_head_soa_vec(v_head_soa_size);
|
||||
char * k_head_soa = k_head_soa_vec.data();
|
||||
char * v_head_soa = v_head_soa_vec.data();
|
||||
|
||||
// Thread-local work buffers (constant across ir loop)
|
||||
float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32);
|
||||
|
|
@ -8828,9 +8831,10 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
|||
if (mxfp.apply_hadamard) {
|
||||
ggml_apply_hadamard_blocks(Q_f32 + tq * DK, DK);
|
||||
}
|
||||
// SoA round-trip: quantize Q to SoA, then dequant back to float.
|
||||
uint8_t q_mxfp_buf[512]; // max: DK=256 * 33/32 = 264 bytes (MXFP8)
|
||||
GGML_ASSERT(ggml_row_size(k->type, DK) <= sizeof(q_mxfp_buf));
|
||||
// SoA round-trip: quantize Q to SoA, then dequant back to float
|
||||
const size_t q_soa_bytes = ggml_row_size(k->type, DK);
|
||||
GGML_ASSERT(q_soa_bytes <= 2048);
|
||||
uint8_t q_mxfp_buf[2048]; // max: DK=1024 * 33/32 = 1056 bytes (MXFP8)
|
||||
mxfp.q_quantize(Q_f32 + tq * DK, q_mxfp_buf, DK);
|
||||
mxfp.k_dequantize(q_mxfp_buf, Q_f32 + tq * DK, DK);
|
||||
}
|
||||
|
|
@ -8843,6 +8847,10 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
|||
memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));
|
||||
memset(V32, 0, KV_TILE_SZ * DV * sizeof(float));
|
||||
|
||||
// dequant scratch buffers for SoA MXFP — allocated once per tile, reused per KV token
|
||||
std::vector<float> k_soa_buf(mxfp.k_soa_elems);
|
||||
std::vector<float> v_soa_buf(mxfp.v_soa_elems);
|
||||
|
||||
for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
|
||||
const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);
|
||||
|
||||
|
|
@ -8886,10 +8894,8 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
|||
const char * k_soa_base = mxfp.k_multihead
|
||||
? ((const char *)k->data + (ic + tk)*nbk1 + ik3*nbk3)
|
||||
: k_data;
|
||||
float k_soa[4096];
|
||||
GGML_ASSERT(mxfp.k_soa_elems <= 4096);
|
||||
mxfp.k_dequantize(k_soa_base, k_soa, mxfp.k_soa_elems);
|
||||
const float * k_head = k_soa + (mxfp.k_multihead ? ik2 * DK : 0);
|
||||
mxfp.k_dequantize(k_soa_base, k_soa_buf.data(), mxfp.k_soa_elems);
|
||||
const float * k_head = k_soa_buf.data() + (mxfp.k_multihead ? ik2 * DK : 0);
|
||||
for (int64_t dk = 0; dk < DK; dk++) {
|
||||
K_f32[dk * KV_TILE_SZ + tk] = k_head[dk];
|
||||
}
|
||||
|
|
@ -8962,10 +8968,8 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
|||
const char * v_soa_base = mxfp.v_multihead
|
||||
? ((const char *)v->data + (ic + tk)*nbv1 + iv3*nbv3)
|
||||
: v_data;
|
||||
float v_soa[4096];
|
||||
GGML_ASSERT(mxfp.v_soa_elems <= 4096);
|
||||
mxfp.v_dequantize(v_soa_base, v_soa, mxfp.v_soa_elems);
|
||||
memcpy(V32 + tk * DV, v_soa + (mxfp.v_multihead ? iv2 * DV : 0), DV * sizeof(float));
|
||||
mxfp.v_dequantize(v_soa_base, v_soa_buf.data(), mxfp.v_soa_elems);
|
||||
memcpy(V32 + tk * DV, v_soa_buf.data() + (mxfp.v_multihead ? iv2 * DV : 0), DV * sizeof(float));
|
||||
} else {
|
||||
v_to_float(v_data, V32 + tk * DV, DV);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -202,6 +202,7 @@ static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float
|
|||
// If soa_bytes < nb1 (per-head), iterate over (ne1 * ne2 * ne3) regions with stride nb1.
|
||||
// We use strides to compute offsets, handling views and permutations correctly.
|
||||
const int64_t heads_per_region = (int64_t)(soa_bytes / head_row_sz);
|
||||
GGML_ASSERT(soa_bytes % head_row_sz == 0 && "soa_bytes must be a multiple of head_row_sz");
|
||||
|
||||
// For multi-head regions, we step by nb1 (KV-position stride) between regions.
|
||||
// For per-head, we step through all dimensions.
|
||||
|
|
@ -210,7 +211,6 @@ static void init_tensor_mxfp_soa(ggml_tensor * tensor, float min = -1.0f, float
|
|||
if (heads_per_region > 1) {
|
||||
// Multi-head SoA: iterate over (kv_positions * batches), each region = nb1 bytes
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
// ne2/heads_per_region = number of head groups (for GQA broadcast, usually 1)
|
||||
const int64_t n_groups = ne2 / heads_per_region;
|
||||
for (int64_t ig = 0; ig < n_groups; ig++) {
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue