ggml: optimize CPU MXFP flash attention hot loop
- Per-head dequant: multihead MXFP now extracts only the needed head's SoA blocks (e.g. 20 bytes for mxfp4 DK=128) into a stack buffer and dequants DK elements, instead of dequanting all heads (nek2*DK). For 8 KV heads this is 8x less dequant work per KV position. - Hoist loop invariants: base pointer offsets (k_base, v_base), per-head SoA byte offsets, and multihead row bases are computed once per query row instead of per KV position in the inner loop. - Precompute SoA addressing in mxfp_fa_params_init: qs_per_block, blocks_per_head, head_qs_bytes, and head_e8m0_offset are calculated once at init rather than derived per iteration. - Move thread-local buffer pointers (VKQ32, V32, VKQ16, Q_q) and v_is_f16 check outside the ir loop.
This commit is contained in:
parent
a51ff77fae
commit
c2f2ff7814
|
|
@ -8271,6 +8271,18 @@ struct mxfp_fa_params {
|
|||
int64_t k_soa_elems;
|
||||
int64_t v_soa_elems;
|
||||
bool apply_hadamard;
|
||||
// Per-head SoA addressing (avoids dequanting all heads in multihead mode).
|
||||
// qs_per_block: bytes of quantized data per 32-element block.
|
||||
// head_qs_bytes: total qs bytes for one head (blocks_per_head * qs_per_block).
|
||||
// head_e8m0_offset: byte offset from row start to e8m0 region.
|
||||
int k_qs_per_block;
|
||||
int v_qs_per_block;
|
||||
int k_head_qs_bytes;
|
||||
int v_head_qs_bytes;
|
||||
int64_t k_head_e8m0_offset;
|
||||
int64_t v_head_e8m0_offset;
|
||||
int k_blocks_per_head;
|
||||
int v_blocks_per_head;
|
||||
};
|
||||
|
||||
static mxfp_fa_params mxfp_fa_params_init(
|
||||
|
|
@ -8314,6 +8326,34 @@ static mxfp_fa_params mxfp_fa_params_init(
|
|||
p.v_multihead = is_mxfp_v && (nbv2 == (size_t)ggml_row_size(v->type, DV));
|
||||
p.v_soa_elems = is_mxfp_v ? (p.v_multihead ? nev2 * DV : DV) : 0;
|
||||
|
||||
// Per-head SoA addressing for multihead mode.
|
||||
// Precompute byte offsets so the hot loop can skip per-head pointer math.
|
||||
auto mxfp_qs_per_block = [](ggml_type type) -> int {
|
||||
switch (type) {
|
||||
case GGML_TYPE_MXFP4_E2M1: return MXFP4_SOA_QS_PER_BLOCK;
|
||||
case GGML_TYPE_MXFP8_E4M3: return MXFP8_SOA_QS_PER_BLOCK;
|
||||
case GGML_TYPE_MXFP6_E2M3: return MXFP6_SOA_QS_PER_BLOCK;
|
||||
default: return 0;
|
||||
}
|
||||
};
|
||||
|
||||
if (is_mxfp_k) {
|
||||
p.k_qs_per_block = mxfp_qs_per_block(k->type);
|
||||
p.k_blocks_per_head = (int)(DK / 32);
|
||||
p.k_head_qs_bytes = p.k_blocks_per_head * p.k_qs_per_block;
|
||||
// e8m0 offset from row start = total_blocks * qs_per_block
|
||||
const int64_t k_total_blocks = p.k_multihead ? nek2 * p.k_blocks_per_head : p.k_blocks_per_head;
|
||||
p.k_head_e8m0_offset = k_total_blocks * p.k_qs_per_block;
|
||||
}
|
||||
|
||||
if (is_mxfp_v) {
|
||||
p.v_qs_per_block = mxfp_qs_per_block(v->type);
|
||||
p.v_blocks_per_head = (int)(DV / 32);
|
||||
p.v_head_qs_bytes = p.v_blocks_per_head * p.v_qs_per_block;
|
||||
const int64_t v_total_blocks = p.v_multihead ? nev2 * p.v_blocks_per_head : p.v_blocks_per_head;
|
||||
p.v_head_e8m0_offset = v_total_blocks * p.v_qs_per_block;
|
||||
}
|
||||
|
||||
return p;
|
||||
}
|
||||
|
||||
|
|
@ -8417,15 +8457,33 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|||
|
||||
int ith = params->ith;
|
||||
|
||||
// Pre-allocate dequant buffers for MXFP SoA (avoids per-iteration allocation)
|
||||
std::vector<float> k_dequant_buf(is_mxfp_k ? mxfp.k_soa_elems : 0);
|
||||
std::vector<float> v_dequant_buf(is_mxfp_v ? mxfp.v_soa_elems : 0);
|
||||
// Dequant buffers for MXFP SoA — stack-allocated, no heap allocation in the hot path.
|
||||
// In multihead mode, only dequant one head (DK or DV elements) instead of all heads.
|
||||
// DK/DV are bounded by 1024 (asserted below for MXFP).
|
||||
float k_dequant_buf[1024];
|
||||
float v_dequant_buf[1024];
|
||||
|
||||
// Per-head SoA temp buffer: holds [qs | e8m0] for one head in multihead mode.
|
||||
// Max size: 32 bytes qs (mxfp8, DK=128) + 4 bytes e8m0 = 36 bytes per head.
|
||||
// For DK up to 1024: 256 + 32 = 288 bytes. Use fixed-size stack buffer.
|
||||
alignas(16) char k_head_soa[320]; // enough for DK up to 1024 with any MXFP type
|
||||
alignas(16) char v_head_soa[320];
|
||||
|
||||
// Thread-local work buffers (constant across ir loop)
|
||||
float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32);
|
||||
float * V32 = (VKQ32 + 1*DV);
|
||||
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV);
|
||||
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV);
|
||||
|
||||
const bool v_is_f16 = (v->type == GGML_TYPE_F16);
|
||||
const bool use_softcap = (logit_softcap != 0.0f);
|
||||
const int64_t neq2_x_neq1 = neq2 * neq1;
|
||||
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// q indices
|
||||
const int iq3 = ir/(neq2*neq1);
|
||||
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
||||
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
||||
const int iq3 = ir / neq2_x_neq1;
|
||||
const int iq2 = (ir - iq3*neq2_x_neq1) / neq1;
|
||||
const int iq1 = (ir - iq3*neq2_x_neq1 - iq2*neq1);
|
||||
|
||||
const uint32_t h = iq2; // head index
|
||||
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
||||
|
|
@ -8433,12 +8491,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|||
float S = 0.0f; // sum
|
||||
float M = -INFINITY; // maximum KQ value
|
||||
|
||||
float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
|
||||
float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer
|
||||
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
|
||||
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16
|
||||
|
||||
if (v->type == GGML_TYPE_F16) {
|
||||
if (v_is_f16) {
|
||||
memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
|
||||
} else {
|
||||
memset(VKQ32, 0, DV*sizeof(float));
|
||||
|
|
@ -8446,14 +8499,31 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|||
|
||||
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
|
||||
|
||||
// k indices
|
||||
// k/v head indices — constant for this query row
|
||||
const int ik3 = iq3 / rk3;
|
||||
const int ik2 = iq2 / rk2;
|
||||
|
||||
// v indices
|
||||
const int iv3 = iq3 / rv3;
|
||||
const int iv2 = iq2 / rv2;
|
||||
|
||||
// Precompute loop-invariant base pointer offsets for K and V.
|
||||
// Only ic varies in the inner loop; head/batch offsets are constant.
|
||||
const size_t k_base_offset = ik2*nbk2 + ik3*nbk3;
|
||||
const size_t v_base_offset = iv2*nbv2 + iv3*nbv3;
|
||||
const char * k_base = (const char *) k->data + k_base_offset;
|
||||
const char * v_base = (const char *) v->data + v_base_offset;
|
||||
|
||||
// For multihead MXFP: precompute per-head SoA byte offsets (constant per query row).
|
||||
// head_qs_start: byte offset to this head's qs blocks within the SoA row.
|
||||
// head_e8m0_start: byte offset to this head's e8m0 scales within the SoA row.
|
||||
const int k_head_qs_start = mxfp.k_multihead ? ik2 * mxfp.k_head_qs_bytes : 0;
|
||||
const int k_head_e8m0_start = mxfp.k_multihead ? (int)mxfp.k_head_e8m0_offset + ik2 * mxfp.k_blocks_per_head : 0;
|
||||
const int v_head_qs_start = mxfp.v_multihead ? iv2 * mxfp.v_head_qs_bytes : 0;
|
||||
const int v_head_e8m0_start = mxfp.v_multihead ? (int)mxfp.v_head_e8m0_offset + iv2 * mxfp.v_blocks_per_head : 0;
|
||||
|
||||
// Multihead MXFP row base (without head offset) — only ic varies.
|
||||
const char * k_row_base = mxfp.k_multihead ? ((const char *) k->data + ik3*nbk3) : nullptr;
|
||||
const char * v_row_base = mxfp.v_multihead ? ((const char *) v->data + iv3*nbv3) : nullptr;
|
||||
|
||||
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
||||
float Q_f32[1024];
|
||||
if (is_mxfp_k) {
|
||||
|
|
@ -8493,23 +8563,25 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|||
|
||||
float s; // KQ value
|
||||
|
||||
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
if (is_mxfp_k) {
|
||||
// Dequant SoA data. Multi-head: full row base, extract head portion.
|
||||
// Per-head: use k_data directly.
|
||||
const char * k_soa_base = mxfp.k_multihead
|
||||
? ((const char *) k->data + ic*nbk1 + ik3*nbk3)
|
||||
: k_data;
|
||||
mxfp.k_dequantize(k_soa_base, k_dequant_buf.data(), mxfp.k_soa_elems);
|
||||
const float * k_head = k_dequant_buf.data() + (mxfp.k_multihead ? ik2 * DK : 0);
|
||||
ggml_vec_dot_f32(DK, &s, 0, k_head, 0, Q_f32, 0, 1);
|
||||
if (mxfp.k_multihead) {
|
||||
// Multihead: extract this head's SoA blocks into temp buffer, dequant only DK elements.
|
||||
// Copy qs blocks then e8m0 scales for this head into contiguous [qs|e8m0] layout.
|
||||
const char * row = k_row_base + ic*nbk1;
|
||||
memcpy(k_head_soa, row + k_head_qs_start, mxfp.k_head_qs_bytes);
|
||||
memcpy(k_head_soa + mxfp.k_head_qs_bytes, row + k_head_e8m0_start, mxfp.k_blocks_per_head);
|
||||
mxfp.k_dequantize(k_head_soa, k_dequant_buf, DK);
|
||||
} else {
|
||||
mxfp.k_dequantize(k_base + ic*nbk1, k_dequant_buf, DK);
|
||||
}
|
||||
ggml_vec_dot_f32(DK, &s, 0, k_dequant_buf, 0, Q_f32, 0, 1);
|
||||
} else {
|
||||
kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
|
||||
kq_vec_dot(DK, &s, 0, k_base + ic*nbk1, 0, Q_q, 0, 1);
|
||||
}
|
||||
|
||||
s = s*scale; // scale KQ value
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
if (use_softcap) {
|
||||
s = logit_softcap*tanhf(s);
|
||||
}
|
||||
|
||||
|
|
@ -8520,49 +8592,42 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|||
float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
|
||||
float vs = 1.0f; // post-softmax KQ value, expf(s - M)
|
||||
|
||||
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
|
||||
if (v->type == GGML_TYPE_F16) {
|
||||
if (v_is_f16) {
|
||||
if (s > M) {
|
||||
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
||||
M = s;
|
||||
ms = expf(Mold - M);
|
||||
|
||||
// V = V*expf(Mold - M)
|
||||
ggml_vec_scale_f16(DV, VKQ16, ms);
|
||||
} else {
|
||||
// no new maximum, ms == 1.0f, vs != 1.0f
|
||||
vs = expf(s - M);
|
||||
}
|
||||
|
||||
// V += v*expf(s - M)
|
||||
ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
|
||||
ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) (v_base + ic*nbv1), vs);
|
||||
} else {
|
||||
if (s > M) {
|
||||
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
||||
M = s;
|
||||
ms = expf(Mold - M);
|
||||
|
||||
// V = V*expf(Mold - M)
|
||||
ggml_vec_scale_f32(DV, VKQ32, ms);
|
||||
} else {
|
||||
// no new maximum, ms == 1.0f, vs != 1.0f
|
||||
vs = expf(s - M);
|
||||
}
|
||||
|
||||
// V += v*expf(s - M)
|
||||
if (mxfp.v_dequantize) {
|
||||
const char * v_soa_base = mxfp.v_multihead
|
||||
? ((const char *) v->data + ic*nbv1 + iv3*nbv3)
|
||||
: v_data;
|
||||
mxfp.v_dequantize(v_soa_base, v_dequant_buf.data(), mxfp.v_soa_elems);
|
||||
ggml_vec_mad_f32(DV, VKQ32, v_dequant_buf.data() + (mxfp.v_multihead ? iv2 * DV : 0), vs);
|
||||
if (mxfp.v_multihead) {
|
||||
const char * row = v_row_base + ic*nbv1;
|
||||
memcpy(v_head_soa, row + v_head_qs_start, mxfp.v_head_qs_bytes);
|
||||
memcpy(v_head_soa + mxfp.v_head_qs_bytes, row + v_head_e8m0_start, mxfp.v_blocks_per_head);
|
||||
mxfp.v_dequantize(v_head_soa, v_dequant_buf, DV);
|
||||
} else {
|
||||
mxfp.v_dequantize(v_base + ic*nbv1, v_dequant_buf, DV);
|
||||
}
|
||||
ggml_vec_mad_f32(DV, VKQ32, v_dequant_buf, vs);
|
||||
} else if (v_to_float) {
|
||||
v_to_float(v_data, V32, DV);
|
||||
v_to_float(v_base + ic*nbv1, V32, DV);
|
||||
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
|
||||
} else {
|
||||
// V is F32
|
||||
ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
|
||||
ggml_vec_mad_f32(DV, VKQ32, (const float *) (v_base + ic*nbv1), vs);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue