ggml-cpu: FA split across kv for faster TG (#19209)

* ggml-cpu: split across kv for faster TG

* simplify sinks application

* add ref impl
This commit is contained in:
Aman Gupta 2026-02-03 01:19:55 +08:00 committed by GitHub
parent a3fa035822
commit 9f682fb640
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 220 additions and 69 deletions

View File

@ -19,6 +19,9 @@ extern "C" {
// abort ggml_graph_compute when true
ggml_abort_callback abort_callback;
void * abort_callback_data;
// use only reference implementations
bool use_ref;
};
// numa strategies
@ -132,6 +135,8 @@ extern "C" {
GGML_BACKEND_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
GGML_BACKEND_API void ggml_backend_cpu_set_use_ref(ggml_backend_t backend_cpu, bool use_ref);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);

View File

@ -24,6 +24,9 @@ struct ggml_compute_params {
void * wdata;
struct ggml_threadpool * threadpool;
// use reference implementation
bool use_ref;
};

View File

@ -5,7 +5,6 @@
#include "ggml-backend.h"
#include "traits.h"
#include "ggml-cpu-impl.h"
#include "ggml-cpu.h"
#include "ggml-impl.h"
#include "quants.h"
#include "ggml-threading.h"
@ -2867,12 +2866,20 @@ struct ggml_cplan ggml_graph_plan(
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
const int64_t neq2 = node->src[0]->ne[2]; // number of query heads
const int64_t DK = node->src[1]->ne[0];
const int64_t DV = node->src[2]->ne[0];
// Tiled flash attention scratch (tile sizes defined in common.h)
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding
cur = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
// Decode path: n_kv_chunks = n_tasks (one chunk per thread)
// Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ
size_t n_chunks = n_tasks;
size_t decode = sizeof(float)*(neq2*n_chunks*(2+DV) + n_tasks*(DK + 2*DV));
cur += MAX(prefill, decode);
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
@ -2929,11 +2936,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
set_numa_thread_affinity(state->ith);
struct ggml_compute_params params = {
/*.ith =*/ state->ith,
/*.nth =*/ atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK,
/*.wsize =*/ cplan->work_size,
/*.wdata =*/ cplan->work_data,
/*.threadpool=*/ tp,
/*.ith =*/ state->ith,
/*.nth =*/ atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK,
/*.wsize =*/ cplan->work_size,
/*.wdata =*/ cplan->work_data,
/*.threadpool =*/ tp,
/*.use_ref =*/ cplan->use_ref,
};
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d \n", state->ith, cplan, state->last_graph);

View File

@ -105,6 +105,8 @@ struct ggml_backend_cpu_context {
ggml_abort_callback abort_callback;
void * abort_callback_data;
bool use_ref; // use reference implementation
};
static const char * ggml_backend_cpu_get_name(ggml_backend_t backend) {
@ -143,6 +145,7 @@ static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend
cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback;
cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data;
cpu_plan->cplan.use_ref = cpu_ctx->use_ref;
return cpu_plan;
}
@ -182,6 +185,7 @@ static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, s
cplan.abort_callback = cpu_ctx->abort_callback;
cplan.abort_callback_data = cpu_ctx->abort_callback_data;
cplan.use_ref = cpu_ctx->use_ref;
return ggml_graph_compute(cgraph, &cplan);
}
@ -223,6 +227,7 @@ ggml_backend_t ggml_backend_cpu_init(void) {
ctx->work_size = 0;
ctx->abort_callback = NULL;
ctx->abort_callback_data = NULL;
ctx->use_ref = false;
ggml_backend_t cpu_backend = new ggml_backend {
/* .guid = */ ggml_backend_cpu_guid(),
@ -270,6 +275,13 @@ void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_
ctx->abort_callback_data = abort_callback_data;
}
void ggml_backend_cpu_set_use_ref(ggml_backend_t backend_cpu, bool use_ref) {
GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
ctx->use_ref = use_ref;
}
// CPU backend - device
struct ggml_backend_cpu_device_context {
@ -646,6 +658,9 @@ static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const ch
if (strcmp(name, "ggml_backend_cpu_is_numa") == 0) {
return (void *)ggml_is_numa;
}
if (strcmp(name, "ggml_backend_cpu_set_use_ref") == 0) {
return (void *)ggml_backend_cpu_set_use_ref;
}
// threadpool - TODO: move to ggml-base
if (strcmp(name, "ggml_threadpool_new") == 0) {

View File

@ -8042,12 +8042,14 @@ void ggml_compute_forward_top_k(
}
}
// ggml_compute_forward_flash_attn_ext
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
const ggml_compute_params * params,
ggml_tensor * dst,
int ir0, int ir1) {
int ir0, int ir1,
int64_t ic_start, int64_t ic_end,
float * partials, int64_t partial_stride) {
const bool write_partials = (partials != nullptr);
const ggml_tensor * q = dst->src[0];
const ggml_tensor * k = dst->src[1];
const ggml_tensor * v = dst->src[2];
@ -8124,7 +8126,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
int ith = params->ith;
// loop over n_batch and n_head
for (int ir = ir0; ir < ir1; ++ir) {
// q indices
const int iq3 = ir/(neq2*neq1);
@ -8165,7 +8166,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
// loop over n_kv and n_head_kv
// ref: https://arxiv.org/pdf/2112.05682.pdf
for (int64_t ic = 0; ic < nek1; ++ic) {
for (int64_t ic = ic_start; ic < ic_end; ++ic) {
const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
if (mv == -INFINITY) {
continue;
@ -8238,8 +8239,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
}
}
// sinks
if (sinks) {
// sinks - apply only on the first kv-chunk
if (sinks && ic_start == 0) {
const float s = ((float *)((char *) sinks->data))[h];
float ms = 1.0f;
@ -8247,6 +8248,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
if (s > M) {
ms = expf(M - s);
M = s;
ggml_vec_scale_f32(DV, VKQ32, ms);
} else {
vs = expf(s - M);
@ -8255,20 +8257,26 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
S = S*ms + vs;
}
// V /= S
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
ggml_vec_scale_f32(DV, VKQ32, S_inv);
if (write_partials) {
// Write M, S, VKQ to partials for later reduction
// partials layout: [M, S, VKQ[DV]] per query head
float * partial = partials + ir * partial_stride;
partial[0] = M;
partial[1] = S;
memcpy(partial + 2, VKQ32, DV * sizeof(float));
} else {
// V /= S
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
ggml_vec_scale_f32(DV, VKQ32, S_inv);
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
// original
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
// permute(0, 2, 1, 3)
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
// permute(0, 2, 1, 3)
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
}
}
}
@ -8546,6 +8554,78 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
}
}
// Reduction function: combines partial results across KV chunks
// Partials layout in wdata: [n_q_heads][n_chunks][2 + DV]
static void ggml_flash_attn_ext_reduce_partials(
const ggml_compute_params * params,
ggml_tensor * dst,
const int64_t n_chunks,
const int64_t chunk_size) {
const ggml_tensor * q = dst->src[0];
const ggml_tensor * k = dst->src[1];
const ggml_tensor * v = dst->src[2];
const int64_t DK = k->ne[0];
const int64_t DV = v->ne[0];
const int64_t nek1 = k->ne[1];
const int64_t n_q_heads = q->ne[2];
const int ith = params->ith;
const int nth = params->nth;
const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32;
float * thread_wdata = (float *) params->wdata + ith * wdata_per_thread;
const int64_t partials_offset = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
const int64_t partial_size = 2 + DV;
const float * partials_base = (const float *) params->wdata + partials_offset;
// Output layout
const int64_t ne1 = dst->ne[1];
const int64_t ne2 = dst->ne[2];
const size_t nb1 = dst->nb[1];
// Each thread reduces a subset of query heads
for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) {
float M_final = -INFINITY;
float S_final = 0.0f;
float * VKQ_final = thread_wdata;
memset(VKQ_final, 0, DV * sizeof(float));
// Combine partials from all chunks
for (int64_t chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) {
const int64_t ic_start = chunk_idx * chunk_size;
if (ic_start >= nek1) continue;
const float * partial = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;
const float M_chunk = partial[0];
const float S_chunk = partial[1];
const float * VKQ_chunk = partial + 2;
if (S_chunk == 0.0f) continue;
const float M_new = fmaxf(M_final, M_chunk);
const float scale_old = expf(M_final - M_new);
const float scale_new = expf(M_chunk - M_new);
for (int64_t d = 0; d < DV; ++d) {
VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new;
}
S_final = S_final * scale_old + S_chunk * scale_new;
M_final = M_new;
}
// Normalize and write to output
if (S_final != 0.0f) {
const float S_inv = 1.0f / S_final;
ggml_vec_scale_f32(DV, VKQ_final, S_inv);
}
// iq1=0, iq3=0 for decode
memcpy((char *) dst->data + (0*ne2*ne1 + q_head + 0*ne1)*nb1, VKQ_final, nb1);
}
}
static void ggml_compute_forward_flash_attn_ext_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
@ -8567,6 +8647,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int64_t DV = nev0;
const int64_t N = neq1;
GGML_ASSERT(ne0 == DV);
GGML_ASSERT(ne2 == N);
@ -8587,60 +8668,92 @@ static void ggml_compute_forward_flash_attn_ext_f16(
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
// parallelize by q rows using ggml_vec_dot_f32
// total rows in q
const int64_t nr = neq1*neq2*neq3;
// rows per thread
const int ith = params->ith;
const int nth = params->nth;
// disable for NUMA
const bool disable_chunking = ggml_is_numa();
// When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking)
const bool use_ref = params->use_ref;
// 4x chunks per thread
int nth_scaled = nth * 4;
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
if (nth == 1 || nchunk < nth || disable_chunking) {
nchunk = nth;
}
if (ith == 0) {
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
ggml_threadpool_chunk_set(params->threadpool, nth);
}
ggml_barrier(params->threadpool);
// The number of elements in each chunk
const int64_t dr = (nr + nchunk - 1) / nchunk;
static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
const bool use_tiled = (q->type == GGML_TYPE_F32 &&
kv_is_f32_or_f16 &&
k->type == v->type &&
nek1 % KV_TILE_SZ == 0 &&
neq1 >= Q_TILE_SZ); // Only use tiled for batch >= tile size
const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512;
// The first chunk comes from our thread_id, the rest will get auto-assigned.
int current_chunk = ith;
if (use_split_kv_path) {
const int64_t chunk_size = (nek1 + nth - 1) / nth;
while (current_chunk < nchunk) {
const int64_t ir0 = dr * current_chunk;
const int64_t ir1 = MIN(ir0 + dr, nr);
// Partials buffer layout: [q_head][kv_chunk][M, S, VKQ]
const int64_t partial_size = 2 + DV;
float * partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
if (use_tiled) {
ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
const int64_t ic_start = ith * chunk_size;
const int64_t ic_end = std::min(ic_start + chunk_size, nek1);
const int64_t partial_stride = nth * partial_size;
float * chunk_partials = partials_base + ith * partial_size;
if (ic_start < nek1) {
for (int64_t q_head = 0; q_head < neq2; q_head++) {
ggml_compute_forward_flash_attn_ext_f16_one_chunk(
params, dst, q_head, q_head + 1, ic_start, ic_end,
chunk_partials, partial_stride);
}
} else {
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
for (int64_t q_head = 0; q_head < neq2; q_head++) {
float * q_partials = chunk_partials + q_head * partial_stride;
q_partials[0] = -INFINITY; // M
q_partials[1] = 0.0f; // S
}
}
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
ggml_barrier(params->threadpool);
ggml_flash_attn_ext_reduce_partials(params, dst, nth, chunk_size);
} else {
// total rows in q
const int64_t nr = neq1*neq2*neq3;
// disable for NUMA
const bool disable_chunking = ggml_is_numa();
// 4x chunks per thread
int nth_scaled = nth * 4;
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
if (nth == 1 || nchunk < nth || disable_chunking) {
nchunk = nth;
}
if (ith == 0) {
ggml_threadpool_chunk_set(params->threadpool, nth);
}
ggml_barrier(params->threadpool);
const int64_t dr = (nr + nchunk - 1) / nchunk;
static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
const bool use_tiled = !use_ref &&
(q->type == GGML_TYPE_F32 &&
kv_is_f32_or_f16 &&
k->type == v->type &&
nek1 % KV_TILE_SZ == 0 &&
neq1 >= Q_TILE_SZ);
int current_chunk = ith;
while (current_chunk < nchunk) {
const int64_t ir0 = dr * current_chunk;
const int64_t ir1 = MIN(ir0 + dr, nr);
if (use_tiled) {
ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
} else {
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1, 0, nek1, nullptr, 0);
}
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
}
}
}

View File

@ -8591,6 +8591,13 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
output_printer->print_operation(info);
return false;
}
// Use reference implementation on the CPU backend for comparison
using ggml_backend_cpu_set_use_ref_t = void (*)(ggml_backend_t, bool);
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu));
auto * set_use_ref = (ggml_backend_cpu_set_use_ref_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_use_ref");
if (set_use_ref) {
set_use_ref(backend_cpu, true);
}
size_t n_ok = 0;
size_t tests_run = 0;