ggml-cpu: Use tiled FA for prompt-processing (#19012)

* ggml-cpu: Use tiled FA for prompt-processing

the FA performance is gimped on CPU on long contexts because it essentially uses a vector kernel. This PR adds a tiled FA for PP. Perf tuning for tile sizes done on a AMD EPYC single-socket 64-c machine.

* fix out of bounds for mask

* skip rows where there are all masks

* skip tile if mask is inf

* store mask in worksize

* check inf tile earlier
This commit is contained in:
Aman Gupta 2026-01-25 23:25:58 +08:00 committed by GitHub
parent d9c6ce46f7
commit bcb43163ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 303 additions and 4 deletions

View File

@ -6,6 +6,9 @@
#include "ggml-impl.h"
#include "simd-mappings.h"
#define GGML_FA_TILE_Q 32
#define GGML_FA_TILE_KV 16
#ifdef __cplusplus
#include <utility>
@ -84,4 +87,9 @@ static std::pair<int64_t, int64_t> get_thread_range(const struct ggml_compute_pa
return {ir0, ir1};
}
struct ggml_fa_tile_config {
static constexpr size_t Q = GGML_FA_TILE_Q;
static constexpr size_t KV = GGML_FA_TILE_KV;
};
#endif

View File

@ -14,6 +14,7 @@
#include "vec.h"
#include "ops.h"
#include "ggml.h"
#include "common.h"
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
@ -2866,10 +2867,12 @@ struct ggml_cplan ggml_graph_plan(
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
const int64_t ne10 = node->src[1]->ne[0]; // DK
const int64_t ne20 = node->src[2]->ne[0]; // DV
const int64_t DK = node->src[1]->ne[0];
const int64_t DV = node->src[2]->ne[0];
cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
// Tiled flash attention scratch (tile sizes defined in common.h)
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding
cur = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
} break;
case GGML_OP_FLASH_ATTN_BACK:
{

View File

@ -8164,6 +8164,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
// online softmax / attention
// loop over n_kv and n_head_kv
// ref: https://arxiv.org/pdf/2112.05682.pdf
for (int64_t ic = 0; ic < nek1; ++ic) {
const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
if (mv == -INFINITY) {
@ -8271,6 +8272,280 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
}
}
static void ggml_compute_forward_flash_attn_ext_tiled(
const ggml_compute_params * params,
ggml_tensor * dst,
int ir0, int ir1) {
const ggml_tensor * q = dst->src[0];
const ggml_tensor * k = dst->src[1];
const ggml_tensor * v = dst->src[2];
const ggml_tensor * mask = dst->src[3];
const ggml_tensor * sinks = dst->src[4];
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int64_t DK = nek0;
const int64_t DV = nev0;
const int64_t N = neq1;
GGML_ASSERT(ne0 == DV);
GGML_ASSERT(ne2 == N);
// input tensor rows must be contiguous
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
GGML_ASSERT(neq0 == DK);
GGML_ASSERT(nek0 == DK);
GGML_ASSERT(nev0 == DV);
GGML_ASSERT(neq1 == N);
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(k->type == v->type);
const ggml_type kv_type = k->type;
const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type);
const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float;
const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot;
const size_t kv_type_size = ggml_type_size(kv_type);
// broadcast factors
const int64_t rk2 = neq2/nek2;
const int64_t rk3 = neq3/nek3;
const int64_t rv2 = neq2/nev2;
const int64_t rv3 = neq3/nev3;
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
if (logit_softcap != 0) {
scale /= logit_softcap;
}
const uint32_t n_head = neq2;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
int ith = params->ith;
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ");
int ir = ir0;
while (ir < ir1) {
// q indices for the start of this tile
const int iq3 = ir/(neq2*neq1);
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
// Number of valid rows in this tile:
// - limited by tile size (Q_TILE_SZ)
// - limited by chunk boundary (ir1 - ir)
// - limited by head boundary (neq1 - iq1) to avoid crossing into next head
const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1)));
GGML_ASSERT(tile_rows > 0);
const uint32_t h = iq2; // head index
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
float S[Q_TILE_SZ];
float M[Q_TILE_SZ];
for (int i = 0 ; i < Q_TILE_SZ; ++i) {
S[i] = 0.;
M[i] = -INFINITY;
}
// Per-thread scratch layout:
// Q_q: Q_TILE_SZ * DK (converted Q tile in KV type)
// KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
// mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
// VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
// V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32);
void * Q_q = base;
float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile
memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
// k indices
const int ik3 = iq3 / rk3;
const int ik2 = iq2 / rk2;
// v indices
const int iv3 = iq3 / rv3;
const int iv2 = iq2 / rv2;
for (int tq = 0; tq < tile_rows; tq++) {
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK);
}
// Zero-pad remaining rows
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size);
}
for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
// skip the tile entirely if all the masks are -inf
if (mask) {
bool can_skip = true;
for (int tq = 0; tq < tile_rows; tq++) {
const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
can_skip = false;
}
}
}
if (can_skip) {
continue;
}
}
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
float s;
kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1);
KQ[tq * KV_TILE_SZ + tk] = s * scale;
}
}
if (logit_softcap != 0.0f) {
ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ);
ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap);
}
if (mask) {
ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32);
}
bool skip[Q_TILE_SZ] = {};
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
float * kq_row = KQ + tq * KV_TILE_SZ;
float tile_max;
ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row);
if (tile_max == -INFINITY) {
skip[tq] = true;
continue;
}
const float Mold = M[tq];
const float Mnew = fmaxf(Mold, tile_max);
if (Mnew > Mold) {
const float ms = expf(Mold - Mnew);
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
S[tq] *= ms;
}
M[tq] = Mnew;
S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
}
// Convert V tile to F32 first (if F16), then do MAD
// On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
// TODO: on ARM, native f16 should be faster
if (kv_type == GGML_TYPE_F16) {
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV);
}
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
if (skip[tq]) continue;
float * vkq_row = VKQ32 + tq * DV;
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const float p = KQ[tq * KV_TILE_SZ + tk];
ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p);
}
}
} else {
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
if (skip[tq]) continue;
float * vkq_row = VKQ32 + tq * DV;
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const float p = KQ[tq * KV_TILE_SZ + tk];
const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
ggml_vec_mad_f32(DV, vkq_row, v_row, p);
}
}
}
}
// sinks (apply only to valid rows in the tile)
if (sinks) {
const float s = ((float *)((char *) sinks->data))[h];
for (int tq = 0; tq < tile_rows; tq++) {
float ms = 1.0f;
float vs = 1.0f;
if (s > M[tq]) {
ms = expf(M[tq] - s);
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
} else {
vs = expf(s - M[tq]);
}
S[tq] = S[tq] * ms + vs;
}
}
for (int tq = 0; tq < tile_rows; tq++) {
// V /= S
const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq];
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv);
// dst indices
const int i1 = iq1 + tq;
const int i2 = iq2;
const int i3 = iq3;
// permute(0, 2, 1, 3)
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1);
}
ir += tile_rows;
}
}
static void ggml_compute_forward_flash_attn_ext_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
@ -8343,6 +8618,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
// The number of elements in each chunk
const int64_t dr = (nr + nchunk - 1) / nchunk;
static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
const bool use_tiled = (q->type == GGML_TYPE_F32 &&
kv_is_f32_or_f16 &&
k->type == v->type &&
nek1 % KV_TILE_SZ == 0 &&
neq1 >= Q_TILE_SZ); // Only use tiled for batch >= tile size
// The first chunk comes from our thread_id, the rest will get auto-assigned.
int current_chunk = ith;
@ -8350,7 +8634,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int64_t ir0 = dr * current_chunk;
const int64_t ir1 = MIN(ir0 + dr, nr);
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
if (use_tiled) {
ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
} else {
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
}
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
}