From b8e5c583fbdbdb06b61e4cc4d48e69b957304e32 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 29 Jan 2026 10:24:00 +0100 Subject: [PATCH] ggml-cpu: split across kv for faster TG --- ggml/src/ggml-cpu/ggml-cpu.c | 11 +- ggml/src/ggml-cpu/ops.cpp | 249 ++++++++++++++++++++++++++--------- 2 files changed, 195 insertions(+), 65 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index b1de2ae871..a8f8975d3a 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -5,7 +5,6 @@ #include "ggml-backend.h" #include "traits.h" #include "ggml-cpu-impl.h" -#include "ggml-cpu.h" #include "ggml-impl.h" #include "quants.h" #include "ggml-threading.h" @@ -2867,12 +2866,20 @@ struct ggml_cplan ggml_graph_plan( } break; case GGML_OP_FLASH_ATTN_EXT: { + const int64_t neq2 = node->src[0]->ne[2]; // number of query heads const int64_t DK = node->src[1]->ne[0]; const int64_t DV = node->src[2]->ne[0]; // Tiled flash attention scratch (tile sizes defined in common.h) // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding - cur = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks; + size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks; + + // Decode path: n_kv_chunks = n_tasks (one chunk per thread) + // Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ + size_t n_chunks = n_tasks; + size_t decode = sizeof(float)*(neq2*n_chunks*(2+DV) + n_tasks*(DK + 2*DV)); + + cur += MAX(prefill, decode); } break; case GGML_OP_FLASH_ATTN_BACK: { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 48c8964361..549d3cf389 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8042,12 +8042,14 @@ void ggml_compute_forward_top_k( } } -// ggml_compute_forward_flash_attn_ext - static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( const ggml_compute_params * params, ggml_tensor * dst, - int ir0, int ir1) { + int ir0, int ir1, + int64_t ic_start, int64_t ic_end, + float * partials, int64_t partial_stride) { + + const bool write_partials = (partials != nullptr); const ggml_tensor * q = dst->src[0]; const ggml_tensor * k = dst->src[1]; const ggml_tensor * v = dst->src[2]; @@ -8124,7 +8126,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( int ith = params->ith; - // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices const int iq3 = ir/(neq2*neq1); @@ -8165,7 +8166,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( // loop over n_kv and n_head_kv // ref: https://arxiv.org/pdf/2112.05682.pdf - for (int64_t ic = 0; ic < nek1; ++ic) { + for (int64_t ic = ic_start; ic < ic_end; ++ic) { const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f; if (mv == -INFINITY) { continue; @@ -8238,8 +8239,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( } } - // sinks - if (sinks) { + // sinks - skip when writing partials, reduce function will apply once + if (sinks && !write_partials) { const float s = ((float *)((char *) sinks->data))[h]; float ms = 1.0f; @@ -8255,20 +8256,26 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( S = S*ms + vs; } - // V /= S - const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; - ggml_vec_scale_f32(DV, VKQ32, S_inv); + if (write_partials) { + // Write M, S, VKQ to partials for later reduction + // partials layout: [M, S, VKQ[DV]] per query head + float * partial = partials + ir * partial_stride; + partial[0] = M; + partial[1] = S; + memcpy(partial + 2, VKQ32, DV * sizeof(float)); + } else { + // V /= S + const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; + ggml_vec_scale_f32(DV, VKQ32, S_inv); - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; - // original - //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); - - // permute(0, 2, 1, 3) - memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1); + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1); + } } } @@ -8546,6 +8553,93 @@ static void ggml_compute_forward_flash_attn_ext_tiled( } } +// Reduction function: combines partial results across KV chunks +// Partials layout in wdata: [n_q_heads][n_chunks][2 + DV] +static void ggml_flash_attn_ext_reduce_partials( + const ggml_compute_params * params, + ggml_tensor * dst, + const int64_t n_chunks, + const int64_t chunk_size) { + + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + const ggml_tensor * sinks = dst->src[4]; + + const int64_t DK = k->ne[0]; + const int64_t DV = v->ne[0]; + const int64_t nek1 = k->ne[1]; + const int64_t n_q_heads = q->ne[2]; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32; + float * thread_wdata = (float *) params->wdata + ith * wdata_per_thread; + + const int64_t partials_offset = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32); + const int64_t partial_size = 2 + DV; + const float * partials_base = (const float *) params->wdata + partials_offset; + + // Output layout + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const size_t nb1 = dst->nb[1]; + + // Each thread reduces a subset of query heads + for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) { + float M_final = -INFINITY; + float S_final = 0.0f; + float * VKQ_final = thread_wdata; + memset(VKQ_final, 0, DV * sizeof(float)); + + // Combine partials from all chunks + for (int64_t chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) { + const int64_t ic_start = chunk_idx * chunk_size; + if (ic_start >= nek1) continue; + + const float * partial = partials_base + (q_head * n_chunks + chunk_idx) * partial_size; + const float M_chunk = partial[0]; + const float S_chunk = partial[1]; + const float * VKQ_chunk = partial + 2; + + if (S_chunk == 0.0f) continue; + + const float M_new = fmaxf(M_final, M_chunk); + const float scale_old = expf(M_final - M_new); + const float scale_new = expf(M_chunk - M_new); + + for (int64_t d = 0; d < DV; ++d) { + VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new; + } + S_final = S_final * scale_old + S_chunk * scale_new; + M_final = M_new; + } + + // Apply sinks once after combining all chunks + if (sinks) { + const float s = ((float *) sinks->data)[q_head]; + + if (s > M_final) { + const float ms = expf(M_final - s); + ggml_vec_scale_f32(DV, VKQ_final, ms); + S_final = S_final * ms + 1.0f; + M_final = s; + } else { + S_final = S_final + expf(s - M_final); + } + } + + // Normalize and write to output + if (S_final != 0.0f) { + const float S_inv = 1.0f / S_final; + ggml_vec_scale_f32(DV, VKQ_final, S_inv); + } + // iq1=0, iq3=0 for decode + memcpy((char *) dst->data + (0*ne2*ne1 + q_head + 0*ne1)*nb1, VKQ_final, nb1); + } +} + static void ggml_compute_forward_flash_attn_ext_f16( const ggml_compute_params * params, ggml_tensor * dst) { @@ -8567,6 +8661,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t DV = nev0; const int64_t N = neq1; + GGML_ASSERT(ne0 == DV); GGML_ASSERT(ne2 == N); @@ -8587,60 +8682,88 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - // parallelize by q rows using ggml_vec_dot_f32 - - // total rows in q - const int64_t nr = neq1*neq2*neq3; - - // rows per thread const int ith = params->ith; const int nth = params->nth; - // disable for NUMA - const bool disable_chunking = ggml_is_numa(); - - // 4x chunks per thread - int nth_scaled = nth * 4; - int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; - int64_t nchunk = (nr + chunk_size - 1) / chunk_size; - - if (nth == 1 || nchunk < nth || disable_chunking) { - nchunk = nth; - } - - if (ith == 0) { - // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. - ggml_threadpool_chunk_set(params->threadpool, nth); - } - - ggml_barrier(params->threadpool); - - // The number of elements in each chunk - const int64_t dr = (nr + nchunk - 1) / nchunk; - - static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV; - static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q; const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16); - const bool use_tiled = (q->type == GGML_TYPE_F32 && - kv_is_f32_or_f16 && - k->type == v->type && - nek1 % KV_TILE_SZ == 0 && - neq1 >= Q_TILE_SZ); // Only use tiled for batch >= tile size + const bool use_split_kv_path = (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512; - // The first chunk comes from our thread_id, the rest will get auto-assigned. - int current_chunk = ith; + if (use_split_kv_path) { + const int64_t chunk_size = (nek1 + nth - 1) / nth; - while (current_chunk < nchunk) { - const int64_t ir0 = dr * current_chunk; - const int64_t ir1 = MIN(ir0 + dr, nr); + // Partials buffer layout: [q_head][kv_chunk][M, S, VKQ] + const int64_t partial_size = 2 + DV; + float * partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32); - if (use_tiled) { - ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1); + const int64_t ic_start = ith * chunk_size; + const int64_t ic_end = std::min(ic_start + chunk_size, nek1); + + const int64_t partial_stride = nth * partial_size; + float * chunk_partials = partials_base + ith * partial_size; + + if (ic_start < nek1) { + for (int64_t q_head = 0; q_head < neq2; q_head++) { + ggml_compute_forward_flash_attn_ext_f16_one_chunk( + params, dst, q_head, q_head + 1, ic_start, ic_end, + chunk_partials, partial_stride); + } } else { - ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1); + for (int64_t q_head = 0; q_head < neq2; q_head++) { + float * q_partials = chunk_partials + q_head * partial_stride; + q_partials[0] = -INFINITY; // M + q_partials[1] = 0.0f; // S + } } - current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); + ggml_barrier(params->threadpool); + ggml_flash_attn_ext_reduce_partials(params, dst, nth, chunk_size); + } else { + + // total rows in q + const int64_t nr = neq1*neq2*neq3; + + // disable for NUMA + const bool disable_chunking = ggml_is_numa(); + + // 4x chunks per thread + int nth_scaled = nth * 4; + int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; + int64_t nchunk = (nr + chunk_size - 1) / chunk_size; + + if (nth == 1 || nchunk < nth || disable_chunking) { + nchunk = nth; + } + + if (ith == 0) { + ggml_threadpool_chunk_set(params->threadpool, nth); + } + + ggml_barrier(params->threadpool); + + const int64_t dr = (nr + nchunk - 1) / nchunk; + + static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV; + static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q; + const bool use_tiled = (q->type == GGML_TYPE_F32 && + kv_is_f32_or_f16 && + k->type == v->type && + nek1 % KV_TILE_SZ == 0 && + neq1 >= Q_TILE_SZ); + + int current_chunk = ith; + + while (current_chunk < nchunk) { + const int64_t ir0 = dr * current_chunk; + const int64_t ir1 = MIN(ir0 + dr, nr); + + if (use_tiled) { + ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1); + } else { + ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1, 0, nek1, nullptr, 0); + } + + current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); + } } }