BF16 mixed-mode flash attention

PiperOrigin-RevId: 825433929
This commit is contained in:
Phil Culliton 2025-10-29 01:47:59 -07:00 committed by Copybara-Service
parent 4bd465ffd3
commit 116cd6eff6
4 changed files with 99 additions and 45 deletions

View File

@ -105,7 +105,7 @@ struct AttentionActivations {
}
MatStorageT<float> q; // query
MatStorageT<float> q_T; // Transposed to maximize attention speed.
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector
@ -151,7 +151,7 @@ struct AttentionActivationsPtrs {
const ModelConfig& config;
MatPtrT<float> q;
MatPtrT<float> q_T;
MatPtrT<BF16> q_T;
MatPtrT<float> pre_att_rms_out;
MatPtrT<float> att;
MatPtrT<float> att_out;

View File

@ -57,16 +57,27 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
ThreadingContext& ctx, const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kGenAttentionQDotK);
const hn::ScalableTag<BF16> dbf;
const size_t qkv_dim = k.Cols();
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
CompressPerThread tls;
const hn::ScalableTag<float> df;
CompressTraits<BF16>::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim),
0);
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const float score = Dot(q, k.Row(pos), k.Cols());
const float score =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos), qkv_dim);
att[pos] = score;
}
} else {
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t pos_modulo = div_seq_len.Remainder(pos);
const float score = Dot(q, k.Row(pos_modulo), k.Cols());
const float score =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_modulo), qkv_dim);
att[pos_modulo] = score;
}
}

View File

@ -58,7 +58,7 @@ static constexpr size_t kNFx8HTileSize = 8;
// q has shape [batch, qbatch][head, qkv_dim].
// q_t has shape [qkv_dim][qbatch, head, batch] in order to make the maximum
// possible consecutive elements have the same KV.
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<BF16>& q_t,
const size_t qbatch_size, ThreadingContext& ctx) {
// Group floats by the number of floats in a cache line.
const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float);
@ -69,12 +69,13 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
for (size_t lane = 0; lane < kNF; ++lane) {
size_t q_row = task * kNF + lane;
if (q_row >= q_t.Rows()) break;
float* HWY_RESTRICT qt_row = q_t.Row(q_row);
BF16* HWY_RESTRICT qt_row = q_t.Row(q_row);
for (size_t qi = 0; qi < qbatch_size; ++qi) {
for (size_t h = 0; h < num_heads; ++h) {
for (size_t b = 0; b < batch_size; ++b) {
qt_row[(qi * num_heads + h) * batch_size + b] =
q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row];
hwy::ConvertScalarTo<BF16>(
q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row]);
}
}
}
@ -158,8 +159,19 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
float* HWY_RESTRICT att_out, ThreadingContext& ctx,
const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention);
const hn::ScalableTag<BF16> dbf;
const size_t qkv_dim = k.Cols();
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
CompressPerThread tls;
const hn::ScalableTag<float> df;
CompressTraits<BF16>::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim),
0);
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
float m = Dot(q, k.Row(pos_mod), k.Cols());
// TODO: Mixed-mode can be further improved for Turin: we can demote right
// before we do the dot product instruction, rather than promote both to f32.
// But some potential accuracy loss there, needs evaluation first.
float m = Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
if (float cap = activations.config.att_cap; cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
m = cap * std::tanh(m / cap);
@ -169,7 +181,8 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker);
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
float x = Dot(q, k.Row(pos_mod), k.Cols());
float x =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
SingleFlashAttentionStep(x, activations.config.att_cap, m, d,
v.Row(pos_mod), v.Cols(), att_out);
}
@ -179,25 +192,31 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
// the dot products of NF rows of Q for a single K timestep.
template <class DF, class VF = hn::Vec<DF>>
VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
const size_t k_pos, const MatPtrT<KV_t>& q,
const size_t k_pos, const MatPtrT<float>& q,
const MatPtrT<KV_t>& k) {
const hn::ScalableTag<BF16> dbf;
const size_t qkv_dim = k.Cols();
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
CompressPerThread tls;
hn::TFromD<DF> results[hn::MaxLanes(df)];
for (size_t i = 0; i < hn::Lanes(df); ++i) {
results[i] = Dot(q.Row(0) + q_offsets[i], k.Row(k_pos), k.Cols());
CompressTraits<BF16>::Compress(df, q.Row(0) + q_offsets[i], qkv_dim, tls,
MakeSpan(q_bf, qkv_dim), 0);
results[i] =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
}
return hn::LoadU(df, results);
}
// Returns an NF Q rows by 8 K rows tile of Q.K dot products, in single
// precision.
// Returns an NF Q rows by 8 K rows tile of Q.K dot products.
// This is the result of NF rows of Q against 8 K timesteps, with positions
// given by k_pos[0..7]. Q has been transposed so that the NF rows are read in
// consecutive elements, and other columns by adding q_stride.
template <class DF, class VF = hn::Vec<DF>>
void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
const MatPtrT<KV_t>& k, const size_t* k_pos, VF& sum0,
VF& sum1, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6,
VF& sum7) {
void QDotKTile(DF df, const BF16* HWY_RESTRICT q, const size_t q_stride,
const MatPtrT<KV_t>& k, const size_t* k_pos, VF& sum0, VF& sum1,
VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7) {
constexpr size_t kHTileSize = kNFx8HTileSize;
sum0 = hn::Zero(df);
sum1 = hn::Zero(df);
@ -211,8 +230,13 @@ void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
for (int i = 0; i < kHTileSize; ++i) {
k_row[i] = k.Row(k_pos[i]);
}
const hn::Rebind<BF16, DF> dbfh;
using VBF = hn::Vec<decltype(dbfh)>;
for (size_t i = 0; i < k.Cols(); ++i) {
VF q_vec = hn::Load(df, q);
const VBF q_vec_bf = hn::Load(dbfh, q);
const VF q_vec = hn::PromoteTo(df, q_vec_bf);
VF k_0 = hn::Set(df, k_row[0][i]);
sum0 = hn::MulAdd(q_vec, k_0, sum0);
VF k_1 = hn::Set(df, k_row[1][i]);
@ -264,17 +288,14 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
// Sweeps a tile of NF Q rows by 8 K timesteps accumulators from start_pos to
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
// max_last_pos].
void TileFlashAttention(const MatPtrT<float>& q,
const uint32_t* HWY_RESTRICT q_offsets,
const StridedView<float>& qT, const MatPtrT<KV_t>& k,
const size_t start_pos,
const uint32_t* HWY_RESTRICT last_pos,
const size_t min_last_pos, const size_t max_last_pos,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const AttentionActivationsPtrs& activations,
MatPtrT<float>& att_out,
const uint32_t* HWY_RESTRICT out_offsets,
ThreadingContext& ctx, const size_t worker) {
void TileFlashAttention(
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
const StridedView<BF16>& qT, const MatPtrT<KV_t>& k, const size_t start_pos,
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out,
const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx,
const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention);
constexpr int kHTileSize = kNFx8HTileSize;
using DF = hn::ScalableTag<float>;
@ -291,7 +312,7 @@ void TileFlashAttention(const MatPtrT<float>& q,
VI lasts = hn::LoadU(di, last_pos);
VF old_m = hn::Set(df, -std::numeric_limits<float>::max() / 2.0f);
VF old_d = hn::Zero(df);
const float* HWY_RESTRICT qT_row = qT.Row(0);
const BF16* HWY_RESTRICT qT_row = qT.Row(0);
const size_t qT_stride = qT.Stride();
size_t position = start_pos;
while (position + kHTileSize - 1 <= min_last_pos) {
@ -300,8 +321,7 @@ void TileFlashAttention(const MatPtrT<float>& q,
k_pos[i] = activations.div_seq_len.Remainder(position + i);
}
VF x0, x1, x2, x3, x4, x5, x6, x7;
QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6,
x7);
QDotKTile(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6, x7);
if (activations.config.att_cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
VF cap = hn::Set(df, activations.config.att_cap);
@ -390,13 +410,17 @@ void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
VI k_offsets_vec = hn::LoadU(di, k_offsets);
for (size_t i = 0; i < k.Cols(); ++i) {
VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec);
VF q_0 = hn::Set(df, q[q_offsets[0] + i]);
VF q_0 = hn::Set(df, hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(q[q_offsets[0] + i])));
sum0 = hn::MulAdd(q_0, k_vec, sum0);
VF q_1 = hn::Set(df, q[q_offsets[1] + i]);
VF q_1 = hn::Set(df, hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(q[q_offsets[1] + i])));
sum1 = hn::MulAdd(q_1, k_vec, sum1);
VF q_2 = hn::Set(df, q[q_offsets[2] + i]);
VF q_2 = hn::Set(df, hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(q[q_offsets[2] + i])));
sum2 = hn::MulAdd(q_2, k_vec, sum2);
VF q_3 = hn::Set(df, q[q_offsets[3] + i]);
VF q_3 = hn::Set(df, hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(q[q_offsets[3] + i])));
sum3 = hn::MulAdd(q_3, k_vec, sum3);
}
}
@ -478,32 +502,50 @@ void TileFlashAttention4(const MatPtrT<float>& q,
out_offsets, v.Cols());
position += kHTileSize;
}
const hn::ScalableTag<BF16> dbf;
const size_t qkv_dim = k.Cols();
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
CompressPerThread tls;
const hn::ScalableTag<float> df_compress;
while (position <= max_last_pos) {
size_t k_pos = activations.div_seq_len.Remainder(position);
if (position <= last_pos[0]) {
// Past the last position, x0 doesn't count.
float x0 = Dot(q.Row(0) + q_offsets[0], k.Row(k_pos), k.Cols());
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[0],
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x0 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[0]);
}
if (position <= last_pos[1]) {
// Past the last position, x1 doesn't count.
float x1 = Dot(q.Row(0) + q_offsets[1], k.Row(k_pos), k.Cols());
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[1],
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x1 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[1]);
}
if (position <= last_pos[2]) {
// Past the last position, x2 doesn't count.
float x2 = Dot(q.Row(0) + q_offsets[2], k.Row(k_pos), k.Cols());
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[2],
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x2 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[2]);
}
if (position <= last_pos[3]) {
// Past the last position, x3 doesn't count.
float x3 = Dot(q.Row(0) + q_offsets[3], k.Row(k_pos), k.Cols());
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[3],
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x3 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[3]);
@ -722,8 +764,8 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
// To avoid duplicating the code to setup K and V, the call to
// TileFlashAttention is inside the loop over tasks, even though it
// handles all rows in the task at once.
StridedView<float> qT =
StridedView<float>(activations.q_T.Row(0) + first_task, kVTileSize,
StridedView<BF16> qT =
StridedView<BF16>(activations.q_T.Row(0) + first_task, kVTileSize,
activations.q_T.Stride());
if (kVTileSize == kNF) {
// We can still use TileFlashAttention even if we didn't transpose Q

View File

@ -413,7 +413,8 @@ using DotKernelDefault =
template <class D, typename WT, typename VT>
HWY_INLINE float Dot(D d, const PackedSpan<const WT>& w, size_t w_ofs,
const VT* HWY_RESTRICT vec, size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelDefault());
return DecompressAndCall(d, w, w_ofs, MakeConstSpan(vec, num),
DotKernelDefault());
}
// Adapter for two pointers, no bounds checking.