Added a smaller tile size to flash attention for smaller batch sizes

PiperOrigin-RevId: 813226193
This commit is contained in:
Ray Smith 2025-09-30 05:48:50 -07:00 committed by Copybara-Service
parent 4974f24832
commit 2f6cbde8ff
6 changed files with 445 additions and 89 deletions

View File

@ -548,6 +548,7 @@ cc_library(
":gemma_args", ":gemma_args",
":kv_cache", ":kv_cache",
":mat", ":mat",
":matmul",
":matmul_env", ":matmul_env",
":model_store", ":model_store",
":ops", ":ops",

View File

@ -358,7 +358,8 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch, DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
env.ctx); env.ctx);
} else { } else {
FlashAttention(num_tokens, layer_idx, layer, activations, qbatch, env.ctx); FlashAttention(num_tokens, /*target_parallelism=*/64, layer_idx, layer,
activations, qbatch, env.ctx);
} }
SumHeads(layer, activations, env); SumHeads(layer, activations, env);
} }

View File

@ -44,6 +44,7 @@
// After highway.h // After highway.h
#include "compression/compress-inl.h" #include "compression/compress-inl.h"
#include "gemma/attention.h" #include "gemma/attention.h"
#include "ops/matmul-inl.h"
#include "ops/ops-inl.h" #include "ops/ops-inl.h"
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();
@ -114,6 +115,27 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
} }
} }
// Handles a single v row of flash attention for a single q.k dot product.
void HWY_INLINE SingleFlashAttentionStep(
float x, float cap, float& old_max, float& old_d,
const float* HWY_RESTRICT v, const size_t v_cols,
float* HWY_RESTRICT att_out, hwy::Profiler& p, const size_t worker) {
if (cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
x = cap * std::tanh(x / cap);
}
float m = std::max(x, old_max);
x = std::exp(x - m);
float scale = old_d * std::exp(old_max - m);
old_d = x + scale;
old_max = m;
float one_over_d = 1.0f / old_d;
scale *= one_over_d;
x *= one_over_d;
MulByConst(scale, att_out, v_cols, p, worker);
MulByConstAndAdd(x, v, att_out, v_cols, p, worker);
}
// Calculates the complete attention outputs for a single row of q. // Calculates the complete attention outputs for a single row of q.
void SingleFlashAttention(const size_t start_pos, const size_t last_pos, void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
const float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const float* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
@ -136,21 +158,8 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = activations.div_seq_len.Remainder(pos); const size_t pos_mod = activations.div_seq_len.Remainder(pos);
float x = Dot(q, k.Row(pos_mod), k.Cols()); float x = Dot(q, k.Row(pos_mod), k.Cols());
if (activations.config.att_cap > 0.0f) { SingleFlashAttentionStep(x, activations.config.att_cap, m, d,
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x. v.Row(pos_mod), v.Cols(), att_out, p, worker);
x = activations.config.att_cap *
std::tanh(x / activations.config.att_cap);
}
float m_new = std::max(m, x);
float scale = d * std::exp(m - m_new);
x = std::exp(x - m_new);
m = m_new;
d = scale + x;
float one_over_d = 1.0f / d;
x *= one_over_d;
scale *= one_over_d;
MulByConst(scale, att_out, v.Cols(), p, worker);
MulByConstAndAdd(x, v.Row(pos_mod), att_out, v.Cols(), p, worker);
} }
} }
@ -167,7 +176,8 @@ VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
return hn::LoadU(df, results); return hn::LoadU(df, results);
} }
// Returns an 8xNF tile of Q.K dot products, in single precision. // Returns an NF Q rows by 8 K rows tile of Q.K dot products, in single
// precision.
// This is the result of NF rows of Q against 8 K timesteps, with positions // This is the result of NF rows of Q against 8 K timesteps, with positions
// given by k_pos[0..7]. Q has been transposed so that the NF rows are read in // given by k_pos[0..7]. Q has been transposed so that the NF rows are read in
// consecutive elements, and other columns by adding q_stride. // consecutive elements, and other columns by adding q_stride.
@ -240,8 +250,9 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
return hn::Add(sum0, sum2); return hn::Add(sum0, sum2);
} }
// Sweeps a tile of 8xNF accumulators from start_pos to min_last_pos, then // Sweeps a tile of NF Q rows by 8 K timesteps accumulators from start_pos to
// sweeps the remaining timesteps in the range (min_last_pos, max_last_pos]. // min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
// max_last_pos].
void TileFlashAttention( void TileFlashAttention(
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
const StridedView<float>& qT, const MatPtrT<KV_t>& k, const StridedView<float>& qT, const MatPtrT<KV_t>& k,
@ -260,7 +271,7 @@ void TileFlashAttention(
using DI = hn::ScalableTag<uint32_t>; using DI = hn::ScalableTag<uint32_t>;
const DI di; const DI di;
using VI = hn::Vec<DI>; using VI = hn::Vec<DI>;
const int kVTileSize = hn::MaxLanes(df); const int kVTileSize = hn::Lanes(df);
for (int i = 0; i < kVTileSize; ++i) { for (int i = 0; i < kVTileSize; ++i) {
hwy::ZeroBytes(att_out.Row(0) + out_offsets[i], hwy::ZeroBytes(att_out.Row(0) + out_offsets[i],
v.Cols() * sizeof(att_out.Row(0)[0])); v.Cols() * sizeof(att_out.Row(0)[0]));
@ -348,38 +359,217 @@ void TileFlashAttention(
} }
} }
// Returns an 4 Q rows by NF K tile of Q.K dot products, in single precision.
// This is the result of 4 rows of Q against NF K timesteps, with positions
// given by k_offsets[0..NF]. Q has been transposed so that the 4 rows are read
// in consecutive elements, and other columns by adding q_stride.
template <class DF, class VF = hn::Vec<DF>>
void QDotKTilex4(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
const MatPtrT<KV_t>& k, const int32_t* HWY_RESTRICT k_offsets,
hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1,
VF& sum2, VF& sum3) {
sum0 = hn::Zero(df);
sum1 = hn::Zero(df);
sum2 = hn::Zero(df);
sum3 = hn::Zero(df);
const float* HWY_RESTRICT k_base = k.Row(0);
using DI = hn::ScalableTag<int32_t>;
const DI di;
using VI = hn::Vec<DI>;
VI k_offsets_vec = hn::LoadU(di, k_offsets);
for (size_t i = 0; i < k.Cols(); ++i) {
VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec);
VF q_0 = hn::Set(df, q[0]);
sum0 = hn::MulAdd(q_0, k_vec, sum0);
VF q_1 = hn::Set(df, q[1]);
sum1 = hn::MulAdd(q_1, k_vec, sum1);
VF q_2 = hn::Set(df, q[2]);
sum2 = hn::MulAdd(q_2, k_vec, sum2);
VF q_3 = hn::Set(df, q[3]);
sum3 = hn::MulAdd(q_3, k_vec, sum3);
q += q_stride;
}
}
// Handles NF v rows of flash attention for NF q.k dot products from one q row.
template <class DF, class VF = hn::Vec<DF>>
float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
float& old_d) {
float m = hn::ReduceMax(df, x);
m = std::max(m, old_max);
x = hn::Exp(df, x - hn::Set(df, m));
float scale = old_d * std::exp(old_max - m);
old_d = hn::ReduceSum(df, x) + scale;
old_max = m;
float one_over_d = 1.0f / old_d;
scale *= one_over_d;
x = hn::Mul(x, hn::Set(df, one_over_d));
return scale;
}
// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
// max_last_pos].
void TileFlashAttention4(
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
const StridedView<float>& qT, const MatPtrT<KV_t>& k,
const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos,
const size_t min_last_pos, const size_t max_last_pos,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
hwy::Profiler& p, const size_t worker) {
static const auto zone = p.AddZone("Gen.Attention.TileFlashAttention4");
PROFILER_ZONE3(p, worker, zone);
using DF = hn::ScalableTag<float>;
const DF df;
using VF = hn::Vec<DF>;
constexpr size_t kMaxNF = hn::MaxLanes(df);
const size_t kHTileSize = hn::Lanes(df);
HWY_DASSERT(kHTileSize <= kMaxNF);
constexpr size_t kVTileSize = 4;
float scales[kVTileSize];
for (size_t i = 0; i < kVTileSize; ++i) {
hwy::ZeroBytes(att_out.Row(0) + out_offsets[i],
v.Cols() * sizeof(att_out.Row(0)[0]));
}
float old_m0 = -std::numeric_limits<float>::max() / 2.0f;
float old_m1 = -std::numeric_limits<float>::max() / 2.0f;
float old_m2 = -std::numeric_limits<float>::max() / 2.0f;
float old_m3 = -std::numeric_limits<float>::max() / 2.0f;
float old_d0 = 0.0f;
float old_d1 = 0.0f;
float old_d2 = 0.0f;
float old_d3 = 0.0f;
const float* HWY_RESTRICT qT_row = qT.Row(0);
const size_t qT_stride = qT.Stride();
size_t position = start_pos;
while (position + kHTileSize - 1 <= min_last_pos) {
int32_t k_offsets[kMaxNF];
size_t v_pos[kMaxNF];
for (size_t i = 0; i < kHTileSize; ++i) {
v_pos[i] = activations.div_seq_len.Remainder(position + i);
k_offsets[i] = k.Row(v_pos[i]) - k.Row(0);
}
VF x0, x1, x2, x3;
QDotKTilex4(df, qT_row, qT_stride, k, k_offsets, p, worker, x0, x1, x2, x3);
if (activations.config.att_cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
VF cap = hn::Set(df, activations.config.att_cap);
VF one_over_cap = hn::Div(hn::Set(df, 1.0f), cap);
x0 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x0, one_over_cap)));
x1 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x1, one_over_cap)));
x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap)));
x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap)));
}
scales[0] = SingleFlashAttentionRowVector(df, x0, old_m0, old_d0);
scales[1] = SingleFlashAttentionRowVector(df, x1, old_m1, old_d1);
scales[2] = SingleFlashAttentionRowVector(df, x2, old_m2, old_d2);
scales[3] = SingleFlashAttentionRowVector(df, x3, old_m3, old_d3);
MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0),
out_offsets, v.Cols(), p, worker);
position += kHTileSize;
}
while (position <= max_last_pos) {
size_t k_pos = activations.div_seq_len.Remainder(position);
if (position <= last_pos[0]) {
// Past the last position, x0 doesn't count.
float x0 = Dot(q.Row(0) + q_offsets[0], k.Row(k_pos), k.Cols());
SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[0], p, worker);
}
if (position <= last_pos[1]) {
// Past the last position, x1 doesn't count.
float x1 = Dot(q.Row(0) + q_offsets[1], k.Row(k_pos), k.Cols());
SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[1], p, worker);
}
if (position <= last_pos[2]) {
// Past the last position, x2 doesn't count.
float x2 = Dot(q.Row(0) + q_offsets[2], k.Row(k_pos), k.Cols());
SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[2], p, worker);
}
if (position <= last_pos[3]) {
// Past the last position, x3 doesn't count.
float x3 = Dot(q.Row(0) + q_offsets[3], k.Row(k_pos), k.Cols());
SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[3], p, worker);
}
++position;
}
}
// Rounds n to a number that can be used as the number of Q rows in a tile
// of flash attention.
static size_t RoundToSuitablePowerOf2(size_t n) {
if (n < 4) return 1;
if (n < 8) return 4;
if (n < 16) return 8;
if (n < 32) return 16;
return 32;
}
// The nominal aim of attention is to combine 3 inputs Q[L,D], K[L,D], V[L,D] // The nominal aim of attention is to combine 3 inputs Q[L,D], K[L,D], V[L,D]
// into a single output O[L,D]. // into a single output O[L,D].
// Conventional attention first computes A[L,L] = Q . KT // Conventional attention first computes A[L,L] = Q . KT
// followed by A = softmax(A) (over invididual rows). // followed by A = softmax(A) (over invididual rows).
// Then A is multiplied by V to get O[L,D]. // Then A is multiplied by V to get O[L,D].
// For each row of O, this takes a read of one row of Q L times, all of K, // For each row of O, this takes a read of one row of Q L times, all of K,
// 3 write/reads of one row of A, read all of V, an read.write the one row of O // 3 write/reads of one row of A, read all of V, and read/write the one row of O
// L times. Ignoring the computation for now, and focusing just on memory, // L times. Ignoring the computation for now, and focusing just on memory,
// the one row of O takes L(4D+3) reads and L(D+3) writes. // the one row of O takes L(4D+3) reads and L(D+3) writes.
// For the whole of Q, this is L^2(4D+3) reads and L^2(D+3) writes. // For the whole of Q, this is L^2(4D+3) reads and L^2(D+3) writes.
// //
// Flash attention fuses these operations together, and (where possible) // Flash attention fuses these operations together, and has 3 operating modes:
// computes NF rows of the result using 8 accumulator registers and two more to // 1. NF rows of the result computed using tiles of registers of shape NFx8.
// keep running results. NF is the number of float lanes in a register, being 16 // 2. 4 rows of the result computed using tiles of registers of shape 4xNF.
// for AVX3. The softmax is converted to streaming form using the // 3. One row (of Q and the result) at a time.
// algortihm from: // In all cases the intermediate result (Q.KT) is never stored to memory.
// NF is the number of float lanes in a register, being 16 for AVX3. The softmax
// is converted to streaming form using the algorithm from:
// https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf. // https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf.
// Q is transposed to Q_T[D,L] to make the dot product computation efficient. // Q is transposed to Q_T[D,L] to make the dot product computation efficient.
// QDotKTileFloat computes 8xNF rows of Q.K dot products in one go, reducing //
// reads of Q by 8 and reads of K by NF. The streaming softmax is computed // In mode 1:
// entirely in registers, and a further NF registers to accumulate the results // QDotKTileFloat computes NF Q rows x 8 K timesteps of Q.K dot products in one
// of the product of the softmax and V, reduce the number of reads of V by NF, // go, reducing reads of Q by 8 and reads of K by NF. The streaming softmax is
// and the reads/writes of O by 8. // computed entirely in registers, and a further NF registers to accumulate the
// results of the product of the softmax and V, reduce the number of reads of V
// by NF, and the reads/writes of O by 8.
// The reads are thus reduced to 2DL^2(1/8+1/NF) and writes reduced to DL^2/8, // The reads are thus reduced to 2DL^2(1/8+1/NF) and writes reduced to DL^2/8,
// which on AVX3 is an overall reduction by about a factor of 10. // which on AVX3 is an overall reduction by about a factor of 10.
// Mode 1 can only be accessed if there is a large Qbatch size, or in multi-turn
// prefill, since in other cases, there is either a single K timestep (prefill)
// or a single num_heads set of Q rows (decode).
//
// In mode 2, the 4 rows of Q are computed against NF K timesteps in a tile,
// reducing the reads of Q by NF, and the reads of K by 4. The softmax and
// accumulation of the result is done in registers, cutting the reads of V by 4.
// The reads/writes of O are reduced by a factor of NF.
// The overall reduction is limited by the need to use gather to load K.
// Transposing K would be possible, but is complicated by the wraparound.
// Mode 2 can be used in all cases when there are at least 4 attention heads,
// but it may be prefereable to use mode 3 when the batch size is small to
// maximise parallelism.
//
// In mode 3, a single row of Q is computed against a single K timestep at a
// time, using SingleFlashAttention. In this case there is no reduction in the
// reads of Q or K, or V, or O, but the reads/writes of the intermediate A are
// still eliminated.
// //
// A further complication is that real attention is not as simple as documented // A further complication is that real attention is not as simple as documented
// in the paper and above. There are multiple query heads, differing KV, and // in the paper and above. There are multiple query heads, differing KV, and
// different sequence lengths, so a lot of the work in FlashAttention is making // different sequence lengths, so a lot of the work in FlashAttention is making
// sure that a collection of q rows can use the TileFlashAttention path. // sure that a collection of q rows with the same KV and sequence length are
void FlashAttention(const size_t num_tokens, const size_t layer_idx, // grouped together so that mode 1 or 2 can be used, and choosing which of the
const LayerWeightsPtrs& layer, // 3 modes to use for best efficiency.
void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
const size_t layer_idx, const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch, AttentionActivations& activations, QBatch& qbatch,
ThreadingContext& ctx) { ThreadingContext& ctx) {
static const auto zone = ctx.profiler.AddZone("Gen.Attention.FlashAttention"); static const auto zone = ctx.profiler.AddZone("Gen.Attention.FlashAttention");
@ -392,15 +582,28 @@ void FlashAttention(const size_t num_tokens, const size_t layer_idx,
// A "head group" in the context of GQA refers to a collection of query // A "head group" in the context of GQA refers to a collection of query
// heads that share the same key and value heads. // heads that share the same key and value heads.
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
using DF = hn::ScalableTag<float>;
const DF df;
constexpr size_t kVTileSize = hn::MaxLanes(df);
const size_t cache_layer_size = layer_config.CacheLayerSize(); const size_t cache_layer_size = layer_config.CacheLayerSize();
const size_t seq_len = const size_t seq_len =
static_cast<size_t>(activations.div_seq_len.GetDivisor()); static_cast<size_t>(activations.div_seq_len.GetDivisor());
const size_t token_batch = num_tokens * div_qbatch.GetDivisor(); const size_t token_batch = num_tokens * div_qbatch.GetDivisor();
const size_t total_tasks = token_batch * layer_config.heads; const size_t total_tasks = token_batch * layer_config.heads;
using DF = hn::ScalableTag<float>;
const DF df;
const size_t kNF = hn::Lanes(df);
constexpr size_t kMaxNF = hn::MaxLanes(df);
HWY_DASSERT(kNF <= kMaxNF);
// The vertical tile size is determined by the ability to use tiling and the
// target_parallelism. In practice the possible tile sizes in order of
// preference for efficiency are kNF, 4, 1, where kNF is likely to be 4 8 or
// 16. The final tile size is chosen to be the largest possible that allows
// for target_parallelism parallel tasks.
const size_t kMaxEqualK = RoundToSuitablePowerOf2(kHeadGroups * num_tokens);
const size_t kMinTileSize = (total_tasks / 4 >= target_parallelism) ? 4 : 1;
const size_t kVTileSize =
(kNF <= kMaxEqualK && total_tasks / kNF >= target_parallelism)
? kNF
: std::min(kMinTileSize, kMaxEqualK);
// q has shape [batch, qbatch][head, qkv_dim]. // q has shape [batch, qbatch][head, qkv_dim].
// We transpose it to [qkv_dim][qbatch, head, batch] in order to make the // We transpose it to [qkv_dim][qbatch, head, batch] in order to make the
// maximum possible number of consecutive columns have the same KV matrices. // maximum possible number of consecutive columns have the same KV matrices.
@ -416,26 +619,26 @@ void FlashAttention(const size_t num_tokens, const size_t layer_idx,
const auto func = [&](const size_t task, size_t worker) HWY_ATTR { const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
PROFILER_ZONE3(ctx.profiler, worker, zone); PROFILER_ZONE3(ctx.profiler, worker, zone);
// Offsets into original Q for each row in the tile. // Offsets into original Q for each row in the tile.
uint32_t q_offsets[kVTileSize]; uint32_t q_offsets[kMaxNF];
// Offsets into att_out for each row in the tile. // Offsets into att_out for each row in the tile.
uint32_t out_offsets[kVTileSize]; uint32_t out_offsets[kMaxNF];
// Start positions for each row in the tile. // Start positions for each row in the tile.
size_t start_positions[kVTileSize]; size_t start_positions[kMaxNF];
// Last positions for each row in the tile. Inclusive. // Last positions for each row in the tile. Inclusive.
uint32_t last_pos[kVTileSize]; uint32_t last_pos[kMaxNF];
// min and max last positions across all rows in the tile determines when // min and max last positions across all rows in the tile determines when
// TileFlashAttention switches to single vector mode to handle the // TileFlashAttention switches to single vector mode to handle the
// ragged sequence lengths. // ragged sequence lengths.
size_t min_last_pos = std::numeric_limits<size_t>::max(); size_t min_last_pos = std::numeric_limits<size_t>::max();
size_t max_last_pos = 0; size_t max_last_pos = 0;
// Indices into the qbatch.KV for each row in the tile. // Indices into the qbatch.KV for each row in the tile.
size_t qi_indices[kVTileSize]; size_t qi_indices[kMaxNF];
// Indices into the kv_cache for each row in the tile. // Indices into the kv_cache for each row in the tile.
size_t kv_offsets[kVTileSize]; size_t kv_offsets[kMaxNF];
// first_task is [qbatch, head, token]. // first_task is [qbatch, head, token].
const size_t first_task = task * kVTileSize; const size_t first_task = task * kVTileSize;
const size_t last_task = first_task + kVTileSize - 1; const size_t last_task = first_task + kVTileSize - 1;
bool use_tile_attention = last_task < total_tasks; bool use_tile_attention = kVTileSize > 1 && last_task < total_tasks;
for (size_t offset = 0; for (size_t offset = 0;
offset < kVTileSize && first_task + offset < total_tasks; ++offset) { offset < kVTileSize && first_task + offset < total_tasks; ++offset) {
const size_t batch_idx = div_tokens.Remainder(first_task + offset); const size_t batch_idx = div_tokens.Remainder(first_task + offset);
@ -486,15 +689,26 @@ void FlashAttention(const size_t num_tokens, const size_t layer_idx,
kv_cache.Stride()); kv_cache.Stride());
if (use_tile_attention) { if (use_tile_attention) {
// To avoid duplicating the code to setup K and V, the call to // To avoid duplicating the code to setup K and V, the call to
// TileFlashAttention is inside the loop over tasks, even thought it // TileFlashAttention is inside the loop over tasks, even though it
// handles all rows in the task at once. // handles all rows in the task at once.
StridedView<float> qT = StridedView<float> qT =
StridedView<float>(activations.q_T.Row(0) + first_task, kVTileSize, StridedView<float>(activations.q_T.Row(0) + first_task, kVTileSize,
activations.q_T.Stride()); activations.q_T.Stride());
TileFlashAttention( if (kVTileSize == kNF) {
activations.q, q_offsets, qT, k, start_positions[offset], last_pos, TileFlashAttention(activations.q, q_offsets, qT, k,
min_last_pos, max_last_pos, v, layer_idx, layer, activations, start_positions[offset], last_pos, min_last_pos,
activations.att_out, out_offsets, ctx.profiler, worker); max_last_pos, v, layer_idx, layer, activations,
activations.att_out, out_offsets, ctx.profiler,
worker);
} else if (kVTileSize == 4) {
TileFlashAttention4(activations.q, q_offsets, qT, k,
start_positions[offset], last_pos, min_last_pos,
max_last_pos, v, layer_idx, layer, activations,
activations.att_out, out_offsets, ctx.profiler,
worker);
} else {
HWY_UNREACHABLE;
}
break; break;
} else { } else {
SingleFlashAttention(start_positions[offset], last_pos[offset], SingleFlashAttention(start_positions[offset], last_pos[offset],

View File

@ -42,8 +42,8 @@ namespace gcpp {
float* HWY_RESTRICT att_out, hwy::Profiler& p, \ float* HWY_RESTRICT att_out, hwy::Profiler& p, \
size_t worker); \ size_t worker); \
\ \
void FlashAttention(size_t num_tokens, size_t layer_idx, \ void FlashAttention(size_t num_tokens, size_t target_parallelism, \
const LayerWeightsPtrs& layer, \ size_t layer_idx, const LayerWeightsPtrs& layer, \
AttentionActivations& activations, QBatch& qbatch, \ AttentionActivations& activations, QBatch& qbatch, \
ThreadingContext& ctx); \ ThreadingContext& ctx); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \

View File

@ -98,13 +98,14 @@ void AssertClose(const MatPtrT<float>& a, const MatPtrT<float>& b) {
} }
} }
void TestAttention() { void TestFlashAttention(size_t target_parallelism) {
ThreadingArgs threading_args; ThreadingArgs threading_args;
ThreadingContext ctx(threading_args); ThreadingContext ctx(threading_args);
// hwy::ThreadPool& pool = ctx.pools.Pool(); // hwy::ThreadPool& pool = ctx.pools.Pool();
constexpr size_t kOuter = 1024; constexpr size_t kOuter = 1024;
constexpr size_t kInner = 256; constexpr size_t kInner = 256;
ModelConfig config(Model::GEMMA2_2B, Type::kF32, PromptWrapping::GEMMA_PT); ModelConfig config(Model::GEMMA2_2B, Type::kF32, PromptWrapping::GEMMA_PT);
config.att_cap = 1024.0f;
TensorInfoRegistry tensor_info_registry(config); TensorInfoRegistry tensor_info_registry(config);
const LayerConfig& layer_config = config.layer_configs[0]; const LayerConfig& layer_config = config.layer_configs[0];
const LayerWeightsPtrs layers(0, layer_config, tensor_info_registry); const LayerWeightsPtrs layers(0, layer_config, tensor_info_registry);
@ -149,10 +150,17 @@ void TestAttention() {
// Copy the output to saved_att to allow for comparison. // Copy the output to saved_att to allow for comparison.
auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator); auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator);
SetMat(1, attention.q); SetMat(1, attention.q);
FlashAttention(tokens.size(), 0, layers, attention, qbatch, ctx); FlashAttention(tokens.size(), target_parallelism, 0, layers, attention,
qbatch, ctx);
AssertClose(attention.att_out, *saved_att); AssertClose(attention.att_out, *saved_att);
} }
void TestAttention() {
TestFlashAttention(8192);
TestFlashAttention(2048);
TestFlashAttention(256);
}
// NOLINTNEXTLINE(google-readability-namespace-comments) // NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE } // namespace HWY_NAMESPACE
} // namespace gcpp } // namespace gcpp

View File

@ -747,7 +747,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile(
static const auto zone = p.AddZone("Ops.MulByConstAndAdd"); static const auto zone = p.AddZone("Ops.MulByConstAndAdd");
PROFILER_ZONE3(p, worker, zone); PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::MaxLanes(df); HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
size_t i = 0; size_t i = 0;
while (i + NF <= size) { while (i + NF <= size) {
@ -882,8 +882,162 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile(
} }
i += NF; i += NF;
} }
const size_t remaining = size - i; HWY_DASSERT(size == i);
HWY_DASSERT(remaining == 0); }
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF c0,
const VF c1, const VF c2, const VF c3,
VF& sum0, VF& sum1, VF& sum2,
VF& sum3) {
sum0 = hn::MulAdd(common, c0, sum0);
sum1 = hn::MulAdd(common, c1, sum1);
sum2 = hn::MulAdd(common, c2, sum2);
sum3 = hn::MulAdd(common, c3, sum3);
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4Lanes(DF df, const MatPtrT<float>& v,
const size_t* HWY_RESTRICT pos,
const size_t offset, const VF c0,
const VF c1, const VF c2,
const VF c3, VF& sum0, VF& sum1,
VF& sum2, VF& sum3) {
// TODO(rays): Check whether a transpose of c0-c3 is applicable and faster.
VF x0 = hn::Load(df, v.Row(pos[0]) + offset);
MulAdd4(df, x0, hn::BroadcastLane<0>(c0), hn::BroadcastLane<0>(c1),
hn::BroadcastLane<0>(c2), hn::BroadcastLane<0>(c3), sum0, sum1, sum2,
sum3);
VF x1 = hn::Load(df, v.Row(pos[1]) + offset);
MulAdd4(df, x1, hn::BroadcastLane<1>(c0), hn::BroadcastLane<1>(c1),
hn::BroadcastLane<1>(c2), hn::BroadcastLane<1>(c3), sum0, sum1, sum2,
sum3);
VF x2 = hn::Load(df, v.Row(pos[2]) + offset);
MulAdd4(df, x2, hn::BroadcastLane<2>(c0), hn::BroadcastLane<2>(c1),
hn::BroadcastLane<2>(c2), hn::BroadcastLane<2>(c3), sum0, sum1, sum2,
sum3);
VF x3 = hn::Load(df, v.Row(pos[3]) + offset);
MulAdd4(df, x3, hn::BroadcastLane<3>(c0), hn::BroadcastLane<3>(c1),
hn::BroadcastLane<3>(c2), hn::BroadcastLane<3>(c3), sum0, sum1, sum2,
sum3);
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 31)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond4Lanes(
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {
VF x4 = hn::Load(df, v.Row(pos[4]) + offset);
MulAdd4(df, x4, hn::BroadcastLane<4>(c0), hn::BroadcastLane<4>(c1),
hn::BroadcastLane<4>(c2), hn::BroadcastLane<4>(c3), sum0, sum1, sum2,
sum3);
VF x5 = hn::Load(df, v.Row(pos[5]) + offset);
MulAdd4(df, x5, hn::BroadcastLane<5>(c0), hn::BroadcastLane<5>(c1),
hn::BroadcastLane<5>(c2), hn::BroadcastLane<5>(c3), sum0, sum1, sum2,
sum3);
VF x6 = hn::Load(df, v.Row(pos[6]) + offset);
MulAdd4(df, x6, hn::BroadcastLane<6>(c0), hn::BroadcastLane<6>(c1),
hn::BroadcastLane<6>(c2), hn::BroadcastLane<6>(c3), sum0, sum1, sum2,
sum3);
VF x7 = hn::Load(df, v.Row(pos[7]) + offset);
MulAdd4(df, x7, hn::BroadcastLane<7>(c0), hn::BroadcastLane<7>(c1),
hn::BroadcastLane<7>(c2), hn::BroadcastLane<7>(c3), sum0, sum1, sum2,
sum3);
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 31)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond4Lanes(
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 63)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond8Lanes(
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {
VF x8 = hn::Load(df, v.Row(pos[8]) + offset);
MulAdd4(df, x8, hn::BroadcastLane<8>(c0), hn::BroadcastLane<8>(c1),
hn::BroadcastLane<8>(c2), hn::BroadcastLane<8>(c3), sum0, sum1, sum2,
sum3);
VF x9 = hn::Load(df, v.Row(pos[9]) + offset);
MulAdd4(df, x9, hn::BroadcastLane<9>(c0), hn::BroadcastLane<9>(c1),
hn::BroadcastLane<9>(c2), hn::BroadcastLane<9>(c3), sum0, sum1, sum2,
sum3);
VF x10 = hn::Load(df, v.Row(pos[10]) + offset);
MulAdd4(df, x10, hn::BroadcastLane<10>(c0), hn::BroadcastLane<10>(c1),
hn::BroadcastLane<10>(c2), hn::BroadcastLane<10>(c3), sum0, sum1,
sum2, sum3);
VF x11 = hn::Load(df, v.Row(pos[11]) + offset);
MulAdd4(df, x11, hn::BroadcastLane<11>(c0), hn::BroadcastLane<11>(c1),
hn::BroadcastLane<11>(c2), hn::BroadcastLane<11>(c3), sum0, sum1,
sum2, sum3);
VF x12 = hn::Load(df, v.Row(pos[12]) + offset);
MulAdd4(df, x12, hn::BroadcastLane<12>(c0), hn::BroadcastLane<12>(c1),
hn::BroadcastLane<12>(c2), hn::BroadcastLane<12>(c3), sum0, sum1,
sum2, sum3);
VF x13 = hn::Load(df, v.Row(pos[13]) + offset);
MulAdd4(df, x13, hn::BroadcastLane<13>(c0), hn::BroadcastLane<13>(c1),
hn::BroadcastLane<13>(c2), hn::BroadcastLane<13>(c3), sum0, sum1,
sum2, sum3);
VF x14 = hn::Load(df, v.Row(pos[14]) + offset);
MulAdd4(df, x14, hn::BroadcastLane<14>(c0), hn::BroadcastLane<14>(c1),
hn::BroadcastLane<14>(c2), hn::BroadcastLane<14>(c3), sum0, sum1,
sum2, sum3);
VF x15 = hn::Load(df, v.Row(pos[15]) + offset);
MulAdd4(df, x15, hn::BroadcastLane<15>(c0), hn::BroadcastLane<15>(c1),
hn::BroadcastLane<15>(c2), hn::BroadcastLane<15>(c3), sum0, sum1,
sum2, sum3);
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 63)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond8Lanes(
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {}
// For an NFx4 tile of float values in 4xNF-lane registers, multiplies NF rows
// of V by the corresponding values in c0-c3 and adds them to NF rows of out,
// after first prescaling out by scale.
// The depth (size) must be a multiple of NF.
template <class DF, class VF = hn::Vec<DF>>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4(
DF df, const float* HWY_RESTRICT scales, const VF c0, const VF c1,
const VF c2, const VF c3, const MatPtrT<float>& v,
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
hwy::Profiler& p, const size_t worker) {
static const auto zone = p.AddZone("Ops.MulByConstAndAddTile4");
PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
size_t i = 0;
while (i + NF <= size) {
VF out0, out1, out2, out3;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out0 = hn::Mul(out0, hn::Set(df, scales[0]));
out1 = hn::Mul(out1, hn::Set(df, scales[1]));
out2 = hn::Mul(out2, hn::Set(df, scales[2]));
out3 = hn::Mul(out3, hn::Set(df, scales[3]));
MulAdd4Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, out3);
if HWY_LANES_CONSTEXPR (NF >= 8) {
MulAddSecond4Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, out3);
if HWY_LANES_CONSTEXPR (NF >= 16) {
MulAddSecond8Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2,
out3);
}
}
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
i += NF;
}
HWY_DASSERT(size == i);
} }
// Prescales NF rows of out by scale, then multiplies 1 row of V by the // Prescales NF rows of out by scale, then multiplies 1 row of V by the
@ -898,11 +1052,11 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
static const auto zone = p.AddZone("Ops.MulByConstAndAdd"); static const auto zone = p.AddZone("Ops.MulByConstAndAdd");
PROFILER_ZONE3(p, worker, zone); PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
const size_t NF = hn::MaxLanes(df); HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
size_t i = 0; size_t i = 0;
while (i + NF <= size) { while (i + NF <= size) {
if constexpr (NF == 16) { if HWY_LANES_CONSTEXPR (NF == 16) {
VF out0, out1, out2, out3, out4, out5, out6, out7; VF out0, out1, out2, out3, out4, out5, out6, out7;
VF out8, out9, out10, out11, out12, out13, out14, out15; VF out8, out9, out10, out11, out12, out13, out14, out15;
out0 = hn::Load(df, out + i + out_offsets[0]); out0 = hn::Load(df, out + i + out_offsets[0]);
@ -921,22 +1075,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
out13 = hn::Load(df, out + i + out_offsets[13]); out13 = hn::Load(df, out + i + out_offsets[13]);
out14 = hn::Load(df, out + i + out_offsets[14]); out14 = hn::Load(df, out + i + out_offsets[14]);
out15 = hn::Load(df, out + i + out_offsets[15]); out15 = hn::Load(df, out + i + out_offsets[15]);
out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale)); Mul16(df, scale, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale)); out9, out10, out11, out12, out13, out14, out15);
out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale));
out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale));
out4 = hn::Mul(out4, hn::BroadcastLane<4>(scale));
out5 = hn::Mul(out5, hn::BroadcastLane<5>(scale));
out6 = hn::Mul(out6, hn::BroadcastLane<6>(scale));
out7 = hn::Mul(out7, hn::BroadcastLane<7>(scale));
out8 = hn::Mul(out8, hn::BroadcastLane<8>(scale));
out9 = hn::Mul(out9, hn::BroadcastLane<9>(scale));
out10 = hn::Mul(out10, hn::BroadcastLane<10>(scale));
out11 = hn::Mul(out11, hn::BroadcastLane<11>(scale));
out12 = hn::Mul(out12, hn::BroadcastLane<12>(scale));
out13 = hn::Mul(out13, hn::BroadcastLane<13>(scale));
out14 = hn::Mul(out14, hn::BroadcastLane<14>(scale));
out15 = hn::Mul(out15, hn::BroadcastLane<15>(scale));
VF x0 = hn::Load(df, v.Row(pos) + i); VF x0 = hn::Load(df, v.Row(pos) + i);
MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8, MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15); out9, out10, out11, out12, out13, out14, out15);
@ -956,7 +1096,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
hn::Store(out13, df, out + i + out_offsets[13]); hn::Store(out13, df, out + i + out_offsets[13]);
hn::Store(out14, df, out + i + out_offsets[14]); hn::Store(out14, df, out + i + out_offsets[14]);
hn::Store(out15, df, out + i + out_offsets[15]); hn::Store(out15, df, out + i + out_offsets[15]);
} else if constexpr (NF == 8) { }
if HWY_LANES_CONSTEXPR (NF == 8) {
VF out0, out1, out2, out3, out4, out5, out6, out7; VF out0, out1, out2, out3, out4, out5, out6, out7;
out0 = hn::Load(df, out + i + out_offsets[0]); out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]); out1 = hn::Load(df, out + i + out_offsets[1]);
@ -966,14 +1107,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
out5 = hn::Load(df, out + i + out_offsets[5]); out5 = hn::Load(df, out + i + out_offsets[5]);
out6 = hn::Load(df, out + i + out_offsets[6]); out6 = hn::Load(df, out + i + out_offsets[6]);
out7 = hn::Load(df, out + i + out_offsets[7]); out7 = hn::Load(df, out + i + out_offsets[7]);
out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale)); Mul8(df, scale, out0, out1, out2, out3, out4, out5, out6, out7);
out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale));
out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale));
out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale));
out4 = hn::Mul(out4, hn::BroadcastLane<4>(scale));
out5 = hn::Mul(out5, hn::BroadcastLane<5>(scale));
out6 = hn::Mul(out6, hn::BroadcastLane<6>(scale));
out7 = hn::Mul(out7, hn::BroadcastLane<7>(scale));
VF x0 = hn::Load(df, v.Row(pos) + i); VF x0 = hn::Load(df, v.Row(pos) + i);
MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7); MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7);
hn::Store(out0, df, out + i + out_offsets[0]); hn::Store(out0, df, out + i + out_offsets[0]);
@ -984,7 +1118,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
hn::Store(out5, df, out + i + out_offsets[5]); hn::Store(out5, df, out + i + out_offsets[5]);
hn::Store(out6, df, out + i + out_offsets[6]); hn::Store(out6, df, out + i + out_offsets[6]);
hn::Store(out7, df, out + i + out_offsets[7]); hn::Store(out7, df, out + i + out_offsets[7]);
} else if constexpr (NF == 4) { }
if HWY_LANES_CONSTEXPR (NF == 4) {
VF out0, out1, out2, out3; VF out0, out1, out2, out3;
out0 = hn::Load(df, out + i + out_offsets[0]); out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]); out1 = hn::Load(df, out + i + out_offsets[1]);
@ -1000,13 +1135,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
hn::Store(out1, df, out + i + out_offsets[1]); hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]); hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]); hn::Store(out3, df, out + i + out_offsets[3]);
} else {
HWY_DASSERT(false);
} }
i += NF; i += NF;
} }
const size_t remaining = size - i; HWY_DASSERT(size == i);
HWY_DASSERT(remaining == 0);
} }
// See below for a specialized version for top-1 sampling. // See below for a specialized version for top-1 sampling.