ggml-cpu: Use tiled FA for prompt-processing (#19012)
* ggml-cpu: Use tiled FA for prompt-processing the FA performance is gimped on CPU on long contexts because it essentially uses a vector kernel. This PR adds a tiled FA for PP. Perf tuning for tile sizes done on a AMD EPYC single-socket 64-c machine. * fix out of bounds for mask * skip rows where there are all masks * skip tile if mask is inf * store mask in worksize * check inf tile earlier
This commit is contained in:
parent
d9c6ce46f7
commit
bcb43163ae
|
|
@ -6,6 +6,9 @@
|
|||
#include "ggml-impl.h"
|
||||
#include "simd-mappings.h"
|
||||
|
||||
#define GGML_FA_TILE_Q 32
|
||||
#define GGML_FA_TILE_KV 16
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
#include <utility>
|
||||
|
|
@ -84,4 +87,9 @@ static std::pair<int64_t, int64_t> get_thread_range(const struct ggml_compute_pa
|
|||
return {ir0, ir1};
|
||||
}
|
||||
|
||||
struct ggml_fa_tile_config {
|
||||
static constexpr size_t Q = GGML_FA_TILE_Q;
|
||||
static constexpr size_t KV = GGML_FA_TILE_KV;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
#include "vec.h"
|
||||
#include "ops.h"
|
||||
#include "ggml.h"
|
||||
#include "common.h"
|
||||
|
||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||
#include <malloc.h> // using malloc.h with MSC/MINGW
|
||||
|
|
@ -2866,10 +2867,12 @@ struct ggml_cplan ggml_graph_plan(
|
|||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
const int64_t ne10 = node->src[1]->ne[0]; // DK
|
||||
const int64_t ne20 = node->src[2]->ne[0]; // DV
|
||||
const int64_t DK = node->src[1]->ne[0];
|
||||
const int64_t DV = node->src[2]->ne[0];
|
||||
|
||||
cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
|
||||
// Tiled flash attention scratch (tile sizes defined in common.h)
|
||||
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding
|
||||
cur = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
{
|
||||
|
|
|
|||
|
|
@ -8164,6 +8164,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|||
// online softmax / attention
|
||||
// loop over n_kv and n_head_kv
|
||||
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
||||
|
||||
for (int64_t ic = 0; ic < nek1; ++ic) {
|
||||
const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
|
||||
if (mv == -INFINITY) {
|
||||
|
|
@ -8271,6 +8272,280 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst,
|
||||
int ir0, int ir1) {
|
||||
const ggml_tensor * q = dst->src[0];
|
||||
const ggml_tensor * k = dst->src[1];
|
||||
const ggml_tensor * v = dst->src[2];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * sinks = dst->src[4];
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int64_t DK = nek0;
|
||||
const int64_t DV = nev0;
|
||||
const int64_t N = neq1;
|
||||
|
||||
GGML_ASSERT(ne0 == DV);
|
||||
GGML_ASSERT(ne2 == N);
|
||||
|
||||
// input tensor rows must be contiguous
|
||||
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
|
||||
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
||||
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
||||
|
||||
GGML_ASSERT(neq0 == DK);
|
||||
GGML_ASSERT(nek0 == DK);
|
||||
GGML_ASSERT(nev0 == DV);
|
||||
|
||||
GGML_ASSERT(neq1 == N);
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb0 <= nb1);
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
GGML_ASSERT(k->type == v->type);
|
||||
const ggml_type kv_type = k->type;
|
||||
|
||||
const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type);
|
||||
const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float;
|
||||
const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot;
|
||||
const size_t kv_type_size = ggml_type_size(kv_type);
|
||||
|
||||
// broadcast factors
|
||||
const int64_t rk2 = neq2/nek2;
|
||||
const int64_t rk3 = neq3/nek3;
|
||||
|
||||
const int64_t rv2 = neq2/nev2;
|
||||
const int64_t rv3 = neq3/nev3;
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
|
||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||
memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
|
||||
|
||||
if (logit_softcap != 0) {
|
||||
scale /= logit_softcap;
|
||||
}
|
||||
|
||||
const uint32_t n_head = neq2;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
int ith = params->ith;
|
||||
|
||||
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
|
||||
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
|
||||
|
||||
GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ");
|
||||
|
||||
int ir = ir0;
|
||||
while (ir < ir1) {
|
||||
// q indices for the start of this tile
|
||||
const int iq3 = ir/(neq2*neq1);
|
||||
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
||||
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
||||
|
||||
// Number of valid rows in this tile:
|
||||
// - limited by tile size (Q_TILE_SZ)
|
||||
// - limited by chunk boundary (ir1 - ir)
|
||||
// - limited by head boundary (neq1 - iq1) to avoid crossing into next head
|
||||
const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1)));
|
||||
GGML_ASSERT(tile_rows > 0);
|
||||
|
||||
const uint32_t h = iq2; // head index
|
||||
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
||||
|
||||
float S[Q_TILE_SZ];
|
||||
float M[Q_TILE_SZ];
|
||||
|
||||
for (int i = 0 ; i < Q_TILE_SZ; ++i) {
|
||||
S[i] = 0.;
|
||||
M[i] = -INFINITY;
|
||||
}
|
||||
|
||||
// Per-thread scratch layout:
|
||||
// Q_q: Q_TILE_SZ * DK (converted Q tile in KV type)
|
||||
// KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
|
||||
// mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
|
||||
// VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
|
||||
// V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
|
||||
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32);
|
||||
|
||||
void * Q_q = base;
|
||||
float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
|
||||
float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
|
||||
float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
|
||||
float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile
|
||||
|
||||
memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
|
||||
memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
|
||||
|
||||
// k indices
|
||||
const int ik3 = iq3 / rk3;
|
||||
const int ik2 = iq2 / rk2;
|
||||
|
||||
// v indices
|
||||
const int iv3 = iq3 / rv3;
|
||||
const int iv2 = iq2 / rv2;
|
||||
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
|
||||
kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK);
|
||||
}
|
||||
// Zero-pad remaining rows
|
||||
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
|
||||
memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size);
|
||||
}
|
||||
|
||||
for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
|
||||
|
||||
// skip the tile entirely if all the masks are -inf
|
||||
if (mask) {
|
||||
bool can_skip = true;
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
|
||||
if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
|
||||
can_skip = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (can_skip) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
float s;
|
||||
kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1);
|
||||
KQ[tq * KV_TILE_SZ + tk] = s * scale;
|
||||
}
|
||||
}
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ);
|
||||
ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap);
|
||||
}
|
||||
|
||||
if (mask) {
|
||||
ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32);
|
||||
}
|
||||
|
||||
bool skip[Q_TILE_SZ] = {};
|
||||
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
float * kq_row = KQ + tq * KV_TILE_SZ;
|
||||
|
||||
float tile_max;
|
||||
ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row);
|
||||
|
||||
if (tile_max == -INFINITY) {
|
||||
skip[tq] = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
const float Mold = M[tq];
|
||||
const float Mnew = fmaxf(Mold, tile_max);
|
||||
|
||||
if (Mnew > Mold) {
|
||||
const float ms = expf(Mold - Mnew);
|
||||
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
|
||||
S[tq] *= ms;
|
||||
}
|
||||
M[tq] = Mnew;
|
||||
|
||||
|
||||
S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
|
||||
}
|
||||
|
||||
// Convert V tile to F32 first (if F16), then do MAD
|
||||
// On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
|
||||
// TODO: on ARM, native f16 should be faster
|
||||
if (kv_type == GGML_TYPE_F16) {
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV);
|
||||
}
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
if (skip[tq]) continue;
|
||||
float * vkq_row = VKQ32 + tq * DV;
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const float p = KQ[tq * KV_TILE_SZ + tk];
|
||||
ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
if (skip[tq]) continue;
|
||||
float * vkq_row = VKQ32 + tq * DV;
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const float p = KQ[tq * KV_TILE_SZ + tk];
|
||||
const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
ggml_vec_mad_f32(DV, vkq_row, v_row, p);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sinks (apply only to valid rows in the tile)
|
||||
if (sinks) {
|
||||
const float s = ((float *)((char *) sinks->data))[h];
|
||||
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
||||
if (s > M[tq]) {
|
||||
ms = expf(M[tq] - s);
|
||||
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
|
||||
} else {
|
||||
vs = expf(s - M[tq]);
|
||||
}
|
||||
|
||||
S[tq] = S[tq] * ms + vs;
|
||||
}
|
||||
}
|
||||
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
// V /= S
|
||||
const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq];
|
||||
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv);
|
||||
|
||||
// dst indices
|
||||
const int i1 = iq1 + tq;
|
||||
const int i2 = iq2;
|
||||
const int i3 = iq3;
|
||||
|
||||
// permute(0, 2, 1, 3)
|
||||
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1);
|
||||
}
|
||||
|
||||
ir += tile_rows;
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
|
@ -8343,6 +8618,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||
// The number of elements in each chunk
|
||||
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
||||
|
||||
static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
|
||||
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
|
||||
const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
|
||||
const bool use_tiled = (q->type == GGML_TYPE_F32 &&
|
||||
kv_is_f32_or_f16 &&
|
||||
k->type == v->type &&
|
||||
nek1 % KV_TILE_SZ == 0 &&
|
||||
neq1 >= Q_TILE_SZ); // Only use tiled for batch >= tile size
|
||||
|
||||
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
||||
int current_chunk = ith;
|
||||
|
||||
|
|
@ -8350,7 +8634,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||
const int64_t ir0 = dr * current_chunk;
|
||||
const int64_t ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
|
||||
if (use_tiled) {
|
||||
ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
|
||||
} else {
|
||||
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
|
||||
}
|
||||
|
||||
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue