Merge b8e5c583fb into 2634ed207a
This commit is contained in:
commit
a7f0b8cc39
|
|
@ -5,7 +5,6 @@
|
|||
#include "ggml-backend.h"
|
||||
#include "traits.h"
|
||||
#include "ggml-cpu-impl.h"
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "quants.h"
|
||||
#include "ggml-threading.h"
|
||||
|
|
@ -2867,12 +2866,20 @@ struct ggml_cplan ggml_graph_plan(
|
|||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
const int64_t neq2 = node->src[0]->ne[2]; // number of query heads
|
||||
const int64_t DK = node->src[1]->ne[0];
|
||||
const int64_t DV = node->src[2]->ne[0];
|
||||
|
||||
// Tiled flash attention scratch (tile sizes defined in common.h)
|
||||
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding
|
||||
cur = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
|
||||
size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
|
||||
|
||||
// Decode path: n_kv_chunks = n_tasks (one chunk per thread)
|
||||
// Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ
|
||||
size_t n_chunks = n_tasks;
|
||||
size_t decode = sizeof(float)*(neq2*n_chunks*(2+DV) + n_tasks*(DK + 2*DV));
|
||||
|
||||
cur += MAX(prefill, decode);
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
{
|
||||
|
|
|
|||
|
|
@ -8042,12 +8042,14 @@ void ggml_compute_forward_top_k(
|
|||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_flash_attn_ext
|
||||
|
||||
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst,
|
||||
int ir0, int ir1) {
|
||||
int ir0, int ir1,
|
||||
int64_t ic_start, int64_t ic_end,
|
||||
float * partials, int64_t partial_stride) {
|
||||
|
||||
const bool write_partials = (partials != nullptr);
|
||||
const ggml_tensor * q = dst->src[0];
|
||||
const ggml_tensor * k = dst->src[1];
|
||||
const ggml_tensor * v = dst->src[2];
|
||||
|
|
@ -8124,7 +8126,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|||
|
||||
int ith = params->ith;
|
||||
|
||||
// loop over n_batch and n_head
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// q indices
|
||||
const int iq3 = ir/(neq2*neq1);
|
||||
|
|
@ -8165,7 +8166,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|||
// loop over n_kv and n_head_kv
|
||||
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
||||
|
||||
for (int64_t ic = 0; ic < nek1; ++ic) {
|
||||
for (int64_t ic = ic_start; ic < ic_end; ++ic) {
|
||||
const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
|
||||
if (mv == -INFINITY) {
|
||||
continue;
|
||||
|
|
@ -8238,8 +8239,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|||
}
|
||||
}
|
||||
|
||||
// sinks
|
||||
if (sinks) {
|
||||
// sinks - skip when writing partials, reduce function will apply once
|
||||
if (sinks && !write_partials) {
|
||||
const float s = ((float *)((char *) sinks->data))[h];
|
||||
|
||||
float ms = 1.0f;
|
||||
|
|
@ -8255,20 +8256,26 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|||
S = S*ms + vs;
|
||||
}
|
||||
|
||||
// V /= S
|
||||
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
|
||||
ggml_vec_scale_f32(DV, VKQ32, S_inv);
|
||||
if (write_partials) {
|
||||
// Write M, S, VKQ to partials for later reduction
|
||||
// partials layout: [M, S, VKQ[DV]] per query head
|
||||
float * partial = partials + ir * partial_stride;
|
||||
partial[0] = M;
|
||||
partial[1] = S;
|
||||
memcpy(partial + 2, VKQ32, DV * sizeof(float));
|
||||
} else {
|
||||
// V /= S
|
||||
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
|
||||
ggml_vec_scale_f32(DV, VKQ32, S_inv);
|
||||
|
||||
// dst indices
|
||||
const int i1 = iq1;
|
||||
const int i2 = iq2;
|
||||
const int i3 = iq3;
|
||||
// dst indices
|
||||
const int i1 = iq1;
|
||||
const int i2 = iq2;
|
||||
const int i3 = iq3;
|
||||
|
||||
// original
|
||||
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
|
||||
|
||||
// permute(0, 2, 1, 3)
|
||||
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
|
||||
// permute(0, 2, 1, 3)
|
||||
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -8546,6 +8553,93 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
|||
}
|
||||
}
|
||||
|
||||
// Reduction function: combines partial results across KV chunks
|
||||
// Partials layout in wdata: [n_q_heads][n_chunks][2 + DV]
|
||||
static void ggml_flash_attn_ext_reduce_partials(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst,
|
||||
const int64_t n_chunks,
|
||||
const int64_t chunk_size) {
|
||||
|
||||
const ggml_tensor * q = dst->src[0];
|
||||
const ggml_tensor * k = dst->src[1];
|
||||
const ggml_tensor * v = dst->src[2];
|
||||
const ggml_tensor * sinks = dst->src[4];
|
||||
|
||||
const int64_t DK = k->ne[0];
|
||||
const int64_t DV = v->ne[0];
|
||||
const int64_t nek1 = k->ne[1];
|
||||
const int64_t n_q_heads = q->ne[2];
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32;
|
||||
float * thread_wdata = (float *) params->wdata + ith * wdata_per_thread;
|
||||
|
||||
const int64_t partials_offset = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
|
||||
const int64_t partial_size = 2 + DV;
|
||||
const float * partials_base = (const float *) params->wdata + partials_offset;
|
||||
|
||||
// Output layout
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
const size_t nb1 = dst->nb[1];
|
||||
|
||||
// Each thread reduces a subset of query heads
|
||||
for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) {
|
||||
float M_final = -INFINITY;
|
||||
float S_final = 0.0f;
|
||||
float * VKQ_final = thread_wdata;
|
||||
memset(VKQ_final, 0, DV * sizeof(float));
|
||||
|
||||
// Combine partials from all chunks
|
||||
for (int64_t chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) {
|
||||
const int64_t ic_start = chunk_idx * chunk_size;
|
||||
if (ic_start >= nek1) continue;
|
||||
|
||||
const float * partial = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;
|
||||
const float M_chunk = partial[0];
|
||||
const float S_chunk = partial[1];
|
||||
const float * VKQ_chunk = partial + 2;
|
||||
|
||||
if (S_chunk == 0.0f) continue;
|
||||
|
||||
const float M_new = fmaxf(M_final, M_chunk);
|
||||
const float scale_old = expf(M_final - M_new);
|
||||
const float scale_new = expf(M_chunk - M_new);
|
||||
|
||||
for (int64_t d = 0; d < DV; ++d) {
|
||||
VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new;
|
||||
}
|
||||
S_final = S_final * scale_old + S_chunk * scale_new;
|
||||
M_final = M_new;
|
||||
}
|
||||
|
||||
// Apply sinks once after combining all chunks
|
||||
if (sinks) {
|
||||
const float s = ((float *) sinks->data)[q_head];
|
||||
|
||||
if (s > M_final) {
|
||||
const float ms = expf(M_final - s);
|
||||
ggml_vec_scale_f32(DV, VKQ_final, ms);
|
||||
S_final = S_final * ms + 1.0f;
|
||||
M_final = s;
|
||||
} else {
|
||||
S_final = S_final + expf(s - M_final);
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize and write to output
|
||||
if (S_final != 0.0f) {
|
||||
const float S_inv = 1.0f / S_final;
|
||||
ggml_vec_scale_f32(DV, VKQ_final, S_inv);
|
||||
}
|
||||
// iq1=0, iq3=0 for decode
|
||||
memcpy((char *) dst->data + (0*ne2*ne1 + q_head + 0*ne1)*nb1, VKQ_final, nb1);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
|
@ -8567,6 +8661,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||
const int64_t DV = nev0;
|
||||
const int64_t N = neq1;
|
||||
|
||||
|
||||
GGML_ASSERT(ne0 == DV);
|
||||
GGML_ASSERT(ne2 == N);
|
||||
|
||||
|
|
@ -8587,60 +8682,88 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
// parallelize by q rows using ggml_vec_dot_f32
|
||||
|
||||
// total rows in q
|
||||
const int64_t nr = neq1*neq2*neq3;
|
||||
|
||||
// rows per thread
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
// disable for NUMA
|
||||
const bool disable_chunking = ggml_is_numa();
|
||||
|
||||
// 4x chunks per thread
|
||||
int nth_scaled = nth * 4;
|
||||
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
||||
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
||||
|
||||
if (nth == 1 || nchunk < nth || disable_chunking) {
|
||||
nchunk = nth;
|
||||
}
|
||||
|
||||
if (ith == 0) {
|
||||
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
||||
ggml_threadpool_chunk_set(params->threadpool, nth);
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
// The number of elements in each chunk
|
||||
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
||||
|
||||
static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
|
||||
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
|
||||
const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
|
||||
const bool use_tiled = (q->type == GGML_TYPE_F32 &&
|
||||
kv_is_f32_or_f16 &&
|
||||
k->type == v->type &&
|
||||
nek1 % KV_TILE_SZ == 0 &&
|
||||
neq1 >= Q_TILE_SZ); // Only use tiled for batch >= tile size
|
||||
const bool use_split_kv_path = (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512;
|
||||
|
||||
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
||||
int current_chunk = ith;
|
||||
if (use_split_kv_path) {
|
||||
const int64_t chunk_size = (nek1 + nth - 1) / nth;
|
||||
|
||||
while (current_chunk < nchunk) {
|
||||
const int64_t ir0 = dr * current_chunk;
|
||||
const int64_t ir1 = MIN(ir0 + dr, nr);
|
||||
// Partials buffer layout: [q_head][kv_chunk][M, S, VKQ]
|
||||
const int64_t partial_size = 2 + DV;
|
||||
float * partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
|
||||
|
||||
if (use_tiled) {
|
||||
ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
|
||||
const int64_t ic_start = ith * chunk_size;
|
||||
const int64_t ic_end = std::min(ic_start + chunk_size, nek1);
|
||||
|
||||
const int64_t partial_stride = nth * partial_size;
|
||||
float * chunk_partials = partials_base + ith * partial_size;
|
||||
|
||||
if (ic_start < nek1) {
|
||||
for (int64_t q_head = 0; q_head < neq2; q_head++) {
|
||||
ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
||||
params, dst, q_head, q_head + 1, ic_start, ic_end,
|
||||
chunk_partials, partial_stride);
|
||||
}
|
||||
} else {
|
||||
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
|
||||
for (int64_t q_head = 0; q_head < neq2; q_head++) {
|
||||
float * q_partials = chunk_partials + q_head * partial_stride;
|
||||
q_partials[0] = -INFINITY; // M
|
||||
q_partials[1] = 0.0f; // S
|
||||
}
|
||||
}
|
||||
|
||||
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
||||
ggml_barrier(params->threadpool);
|
||||
ggml_flash_attn_ext_reduce_partials(params, dst, nth, chunk_size);
|
||||
} else {
|
||||
|
||||
// total rows in q
|
||||
const int64_t nr = neq1*neq2*neq3;
|
||||
|
||||
// disable for NUMA
|
||||
const bool disable_chunking = ggml_is_numa();
|
||||
|
||||
// 4x chunks per thread
|
||||
int nth_scaled = nth * 4;
|
||||
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
||||
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
||||
|
||||
if (nth == 1 || nchunk < nth || disable_chunking) {
|
||||
nchunk = nth;
|
||||
}
|
||||
|
||||
if (ith == 0) {
|
||||
ggml_threadpool_chunk_set(params->threadpool, nth);
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
||||
|
||||
static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
|
||||
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
|
||||
const bool use_tiled = (q->type == GGML_TYPE_F32 &&
|
||||
kv_is_f32_or_f16 &&
|
||||
k->type == v->type &&
|
||||
nek1 % KV_TILE_SZ == 0 &&
|
||||
neq1 >= Q_TILE_SZ);
|
||||
|
||||
int current_chunk = ith;
|
||||
|
||||
while (current_chunk < nchunk) {
|
||||
const int64_t ir0 = dr * current_chunk;
|
||||
const int64_t ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
if (use_tiled) {
|
||||
ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
|
||||
} else {
|
||||
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1, 0, nek1, nullptr, 0);
|
||||
}
|
||||
|
||||
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue