diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 2267eaa27b..2b24e87ae6 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8271,6 +8271,18 @@ struct mxfp_fa_params { int64_t k_soa_elems; int64_t v_soa_elems; bool apply_hadamard; + // Per-head SoA addressing (avoids dequanting all heads in multihead mode). + // qs_per_block: bytes of quantized data per 32-element block. + // head_qs_bytes: total qs bytes for one head (blocks_per_head * qs_per_block). + // head_e8m0_offset: byte offset from row start to e8m0 region. + int k_qs_per_block; + int v_qs_per_block; + int k_head_qs_bytes; + int v_head_qs_bytes; + int64_t k_head_e8m0_offset; + int64_t v_head_e8m0_offset; + int k_blocks_per_head; + int v_blocks_per_head; }; static mxfp_fa_params mxfp_fa_params_init( @@ -8314,6 +8326,34 @@ static mxfp_fa_params mxfp_fa_params_init( p.v_multihead = is_mxfp_v && (nbv2 == (size_t)ggml_row_size(v->type, DV)); p.v_soa_elems = is_mxfp_v ? (p.v_multihead ? nev2 * DV : DV) : 0; + // Per-head SoA addressing for multihead mode. + // Precompute byte offsets so the hot loop can skip per-head pointer math. + auto mxfp_qs_per_block = [](ggml_type type) -> int { + switch (type) { + case GGML_TYPE_MXFP4_E2M1: return MXFP4_SOA_QS_PER_BLOCK; + case GGML_TYPE_MXFP8_E4M3: return MXFP8_SOA_QS_PER_BLOCK; + case GGML_TYPE_MXFP6_E2M3: return MXFP6_SOA_QS_PER_BLOCK; + default: return 0; + } + }; + + if (is_mxfp_k) { + p.k_qs_per_block = mxfp_qs_per_block(k->type); + p.k_blocks_per_head = (int)(DK / 32); + p.k_head_qs_bytes = p.k_blocks_per_head * p.k_qs_per_block; + // e8m0 offset from row start = total_blocks * qs_per_block + const int64_t k_total_blocks = p.k_multihead ? nek2 * p.k_blocks_per_head : p.k_blocks_per_head; + p.k_head_e8m0_offset = k_total_blocks * p.k_qs_per_block; + } + + if (is_mxfp_v) { + p.v_qs_per_block = mxfp_qs_per_block(v->type); + p.v_blocks_per_head = (int)(DV / 32); + p.v_head_qs_bytes = p.v_blocks_per_head * p.v_qs_per_block; + const int64_t v_total_blocks = p.v_multihead ? nev2 * p.v_blocks_per_head : p.v_blocks_per_head; + p.v_head_e8m0_offset = v_total_blocks * p.v_qs_per_block; + } + return p; } @@ -8417,15 +8457,33 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( int ith = params->ith; - // Pre-allocate dequant buffers for MXFP SoA (avoids per-iteration allocation) - std::vector k_dequant_buf(is_mxfp_k ? mxfp.k_soa_elems : 0); - std::vector v_dequant_buf(is_mxfp_v ? mxfp.v_soa_elems : 0); + // Dequant buffers for MXFP SoA — stack-allocated, no heap allocation in the hot path. + // In multihead mode, only dequant one head (DK or DV elements) instead of all heads. + // DK/DV are bounded by 1024 (asserted below for MXFP). + float k_dequant_buf[1024]; + float v_dequant_buf[1024]; + + // Per-head SoA temp buffer: holds [qs | e8m0] for one head in multihead mode. + // Max size: 32 bytes qs (mxfp8, DK=128) + 4 bytes e8m0 = 36 bytes per head. + // For DK up to 1024: 256 + 32 = 288 bytes. Use fixed-size stack buffer. + alignas(16) char k_head_soa[320]; // enough for DK up to 1024 with any MXFP type + alignas(16) char v_head_soa[320]; + + // Thread-local work buffers (constant across ir loop) + float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); + float * V32 = (VKQ32 + 1*DV); + ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); + ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); + + const bool v_is_f16 = (v->type == GGML_TYPE_F16); + const bool use_softcap = (logit_softcap != 0.0f); + const int64_t neq2_x_neq1 = neq2 * neq1; for (int ir = ir0; ir < ir1; ++ir) { // q indices - const int iq3 = ir/(neq2*neq1); - const int iq2 = (ir - iq3*neq2*neq1)/neq1; - const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + const int iq3 = ir / neq2_x_neq1; + const int iq2 = (ir - iq3*neq2_x_neq1) / neq1; + const int iq1 = (ir - iq3*neq2_x_neq1 - iq2*neq1); const uint32_t h = iq2; // head index const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; @@ -8433,12 +8491,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( float S = 0.0f; // sum float M = -INFINITY; // maximum KQ value - float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator - float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer - ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator - ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16 - - if (v->type == GGML_TYPE_F16) { + if (v_is_f16) { memset(VKQ16, 0, DV*sizeof(ggml_fp16_t)); } else { memset(VKQ32, 0, DV*sizeof(float)); @@ -8446,14 +8499,31 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL; - // k indices + // k/v head indices — constant for this query row const int ik3 = iq3 / rk3; const int ik2 = iq2 / rk2; - - // v indices const int iv3 = iq3 / rv3; const int iv2 = iq2 / rv2; + // Precompute loop-invariant base pointer offsets for K and V. + // Only ic varies in the inner loop; head/batch offsets are constant. + const size_t k_base_offset = ik2*nbk2 + ik3*nbk3; + const size_t v_base_offset = iv2*nbv2 + iv3*nbv3; + const char * k_base = (const char *) k->data + k_base_offset; + const char * v_base = (const char *) v->data + v_base_offset; + + // For multihead MXFP: precompute per-head SoA byte offsets (constant per query row). + // head_qs_start: byte offset to this head's qs blocks within the SoA row. + // head_e8m0_start: byte offset to this head's e8m0 scales within the SoA row. + const int k_head_qs_start = mxfp.k_multihead ? ik2 * mxfp.k_head_qs_bytes : 0; + const int k_head_e8m0_start = mxfp.k_multihead ? (int)mxfp.k_head_e8m0_offset + ik2 * mxfp.k_blocks_per_head : 0; + const int v_head_qs_start = mxfp.v_multihead ? iv2 * mxfp.v_head_qs_bytes : 0; + const int v_head_e8m0_start = mxfp.v_multihead ? (int)mxfp.v_head_e8m0_offset + iv2 * mxfp.v_blocks_per_head : 0; + + // Multihead MXFP row base (without head offset) — only ic varies. + const char * k_row_base = mxfp.k_multihead ? ((const char *) k->data + ik3*nbk3) : nullptr; + const char * v_row_base = mxfp.v_multihead ? ((const char *) v->data + iv3*nbv3) : nullptr; + const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); float Q_f32[1024]; if (is_mxfp_k) { @@ -8493,23 +8563,25 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( float s; // KQ value - const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); if (is_mxfp_k) { - // Dequant SoA data. Multi-head: full row base, extract head portion. - // Per-head: use k_data directly. - const char * k_soa_base = mxfp.k_multihead - ? ((const char *) k->data + ic*nbk1 + ik3*nbk3) - : k_data; - mxfp.k_dequantize(k_soa_base, k_dequant_buf.data(), mxfp.k_soa_elems); - const float * k_head = k_dequant_buf.data() + (mxfp.k_multihead ? ik2 * DK : 0); - ggml_vec_dot_f32(DK, &s, 0, k_head, 0, Q_f32, 0, 1); + if (mxfp.k_multihead) { + // Multihead: extract this head's SoA blocks into temp buffer, dequant only DK elements. + // Copy qs blocks then e8m0 scales for this head into contiguous [qs|e8m0] layout. + const char * row = k_row_base + ic*nbk1; + memcpy(k_head_soa, row + k_head_qs_start, mxfp.k_head_qs_bytes); + memcpy(k_head_soa + mxfp.k_head_qs_bytes, row + k_head_e8m0_start, mxfp.k_blocks_per_head); + mxfp.k_dequantize(k_head_soa, k_dequant_buf, DK); + } else { + mxfp.k_dequantize(k_base + ic*nbk1, k_dequant_buf, DK); + } + ggml_vec_dot_f32(DK, &s, 0, k_dequant_buf, 0, Q_f32, 0, 1); } else { - kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1); + kq_vec_dot(DK, &s, 0, k_base + ic*nbk1, 0, Q_q, 0, 1); } s = s*scale; // scale KQ value - if (logit_softcap != 0.0f) { + if (use_softcap) { s = logit_softcap*tanhf(s); } @@ -8520,49 +8592,42 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value float vs = 1.0f; // post-softmax KQ value, expf(s - M) - const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); - - if (v->type == GGML_TYPE_F16) { + if (v_is_f16) { if (s > M) { - // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f M = s; ms = expf(Mold - M); - - // V = V*expf(Mold - M) ggml_vec_scale_f16(DV, VKQ16, ms); } else { - // no new maximum, ms == 1.0f, vs != 1.0f vs = expf(s - M); } // V += v*expf(s - M) - ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs); + ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) (v_base + ic*nbv1), vs); } else { if (s > M) { - // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f M = s; ms = expf(Mold - M); - - // V = V*expf(Mold - M) ggml_vec_scale_f32(DV, VKQ32, ms); } else { - // no new maximum, ms == 1.0f, vs != 1.0f vs = expf(s - M); } // V += v*expf(s - M) if (mxfp.v_dequantize) { - const char * v_soa_base = mxfp.v_multihead - ? ((const char *) v->data + ic*nbv1 + iv3*nbv3) - : v_data; - mxfp.v_dequantize(v_soa_base, v_dequant_buf.data(), mxfp.v_soa_elems); - ggml_vec_mad_f32(DV, VKQ32, v_dequant_buf.data() + (mxfp.v_multihead ? iv2 * DV : 0), vs); + if (mxfp.v_multihead) { + const char * row = v_row_base + ic*nbv1; + memcpy(v_head_soa, row + v_head_qs_start, mxfp.v_head_qs_bytes); + memcpy(v_head_soa + mxfp.v_head_qs_bytes, row + v_head_e8m0_start, mxfp.v_blocks_per_head); + mxfp.v_dequantize(v_head_soa, v_dequant_buf, DV); + } else { + mxfp.v_dequantize(v_base + ic*nbv1, v_dequant_buf, DV); + } + ggml_vec_mad_f32(DV, VKQ32, v_dequant_buf, vs); } else if (v_to_float) { - v_to_float(v_data, V32, DV); + v_to_float(v_base + ic*nbv1, V32, DV); ggml_vec_mad_f32(DV, VKQ32, V32, vs); } else { - // V is F32 - ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs); + ggml_vec_mad_f32(DV, VKQ32, (const float *) (v_base + ic*nbv1), vs); } }