mirror of https://github.com/google/gemma.cpp.git
BF16 mixed-mode flash attention
PiperOrigin-RevId: 825433929
This commit is contained in:
parent
4bd465ffd3
commit
116cd6eff6
|
|
@ -104,8 +104,8 @@ struct AttentionActivations {
|
||||||
// `inv_timescale*` are not batched.
|
// `inv_timescale*` are not batched.
|
||||||
}
|
}
|
||||||
|
|
||||||
MatStorageT<float> q; // query
|
MatStorageT<float> q; // query
|
||||||
MatStorageT<float> q_T; // Transposed to maximize attention speed.
|
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
|
||||||
|
|
||||||
MatStorageT<float> pre_att_rms_out;
|
MatStorageT<float> pre_att_rms_out;
|
||||||
MatStorageT<float> att; // attention vector
|
MatStorageT<float> att; // attention vector
|
||||||
|
|
@ -151,7 +151,7 @@ struct AttentionActivationsPtrs {
|
||||||
|
|
||||||
const ModelConfig& config;
|
const ModelConfig& config;
|
||||||
MatPtrT<float> q;
|
MatPtrT<float> q;
|
||||||
MatPtrT<float> q_T;
|
MatPtrT<BF16> q_T;
|
||||||
MatPtrT<float> pre_att_rms_out;
|
MatPtrT<float> pre_att_rms_out;
|
||||||
MatPtrT<float> att;
|
MatPtrT<float> att;
|
||||||
MatPtrT<float> att_out;
|
MatPtrT<float> att_out;
|
||||||
|
|
|
||||||
|
|
@ -57,16 +57,27 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||||
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
|
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
|
||||||
ThreadingContext& ctx, const size_t worker) {
|
ThreadingContext& ctx, const size_t worker) {
|
||||||
GCPP_ZONE(ctx, worker, Zones::kGenAttentionQDotK);
|
GCPP_ZONE(ctx, worker, Zones::kGenAttentionQDotK);
|
||||||
|
const hn::ScalableTag<BF16> dbf;
|
||||||
|
const size_t qkv_dim = k.Cols();
|
||||||
|
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
|
||||||
|
|
||||||
|
CompressPerThread tls;
|
||||||
|
const hn::ScalableTag<float> df;
|
||||||
|
CompressTraits<BF16>::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim),
|
||||||
|
0);
|
||||||
|
|
||||||
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
||||||
// Slightly faster: no wraparound.
|
// Slightly faster: no wraparound.
|
||||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||||
const float score = Dot(q, k.Row(pos), k.Cols());
|
const float score =
|
||||||
|
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos), qkv_dim);
|
||||||
att[pos] = score;
|
att[pos] = score;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||||
const size_t pos_modulo = div_seq_len.Remainder(pos);
|
const size_t pos_modulo = div_seq_len.Remainder(pos);
|
||||||
const float score = Dot(q, k.Row(pos_modulo), k.Cols());
|
const float score =
|
||||||
|
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_modulo), qkv_dim);
|
||||||
att[pos_modulo] = score;
|
att[pos_modulo] = score;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,7 @@ static constexpr size_t kNFx8HTileSize = 8;
|
||||||
// q has shape [batch, qbatch][head, qkv_dim].
|
// q has shape [batch, qbatch][head, qkv_dim].
|
||||||
// q_t has shape [qkv_dim][qbatch, head, batch] in order to make the maximum
|
// q_t has shape [qkv_dim][qbatch, head, batch] in order to make the maximum
|
||||||
// possible consecutive elements have the same KV.
|
// possible consecutive elements have the same KV.
|
||||||
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
|
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<BF16>& q_t,
|
||||||
const size_t qbatch_size, ThreadingContext& ctx) {
|
const size_t qbatch_size, ThreadingContext& ctx) {
|
||||||
// Group floats by the number of floats in a cache line.
|
// Group floats by the number of floats in a cache line.
|
||||||
const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float);
|
const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float);
|
||||||
|
|
@ -69,12 +69,13 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
|
||||||
for (size_t lane = 0; lane < kNF; ++lane) {
|
for (size_t lane = 0; lane < kNF; ++lane) {
|
||||||
size_t q_row = task * kNF + lane;
|
size_t q_row = task * kNF + lane;
|
||||||
if (q_row >= q_t.Rows()) break;
|
if (q_row >= q_t.Rows()) break;
|
||||||
float* HWY_RESTRICT qt_row = q_t.Row(q_row);
|
BF16* HWY_RESTRICT qt_row = q_t.Row(q_row);
|
||||||
for (size_t qi = 0; qi < qbatch_size; ++qi) {
|
for (size_t qi = 0; qi < qbatch_size; ++qi) {
|
||||||
for (size_t h = 0; h < num_heads; ++h) {
|
for (size_t h = 0; h < num_heads; ++h) {
|
||||||
for (size_t b = 0; b < batch_size; ++b) {
|
for (size_t b = 0; b < batch_size; ++b) {
|
||||||
qt_row[(qi * num_heads + h) * batch_size + b] =
|
qt_row[(qi * num_heads + h) * batch_size + b] =
|
||||||
q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row];
|
hwy::ConvertScalarTo<BF16>(
|
||||||
|
q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -158,8 +159,19 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
||||||
float* HWY_RESTRICT att_out, ThreadingContext& ctx,
|
float* HWY_RESTRICT att_out, ThreadingContext& ctx,
|
||||||
const size_t worker) {
|
const size_t worker) {
|
||||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention);
|
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention);
|
||||||
|
const hn::ScalableTag<BF16> dbf;
|
||||||
|
const size_t qkv_dim = k.Cols();
|
||||||
|
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
|
||||||
|
|
||||||
|
CompressPerThread tls;
|
||||||
|
const hn::ScalableTag<float> df;
|
||||||
|
CompressTraits<BF16>::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim),
|
||||||
|
0);
|
||||||
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
|
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
|
||||||
float m = Dot(q, k.Row(pos_mod), k.Cols());
|
// TODO: Mixed-mode can be further improved for Turin: we can demote right
|
||||||
|
// before we do the dot product instruction, rather than promote both to f32.
|
||||||
|
// But some potential accuracy loss there, needs evaluation first.
|
||||||
|
float m = Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
|
||||||
if (float cap = activations.config.att_cap; cap > 0.0f) {
|
if (float cap = activations.config.att_cap; cap > 0.0f) {
|
||||||
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
|
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
|
||||||
m = cap * std::tanh(m / cap);
|
m = cap * std::tanh(m / cap);
|
||||||
|
|
@ -169,7 +181,8 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
||||||
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker);
|
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker);
|
||||||
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||||
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
|
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
|
||||||
float x = Dot(q, k.Row(pos_mod), k.Cols());
|
float x =
|
||||||
|
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
|
||||||
SingleFlashAttentionStep(x, activations.config.att_cap, m, d,
|
SingleFlashAttentionStep(x, activations.config.att_cap, m, d,
|
||||||
v.Row(pos_mod), v.Cols(), att_out);
|
v.Row(pos_mod), v.Cols(), att_out);
|
||||||
}
|
}
|
||||||
|
|
@ -179,25 +192,31 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
||||||
// the dot products of NF rows of Q for a single K timestep.
|
// the dot products of NF rows of Q for a single K timestep.
|
||||||
template <class DF, class VF = hn::Vec<DF>>
|
template <class DF, class VF = hn::Vec<DF>>
|
||||||
VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
|
VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
|
||||||
const size_t k_pos, const MatPtrT<KV_t>& q,
|
const size_t k_pos, const MatPtrT<float>& q,
|
||||||
const MatPtrT<KV_t>& k) {
|
const MatPtrT<KV_t>& k) {
|
||||||
|
const hn::ScalableTag<BF16> dbf;
|
||||||
|
const size_t qkv_dim = k.Cols();
|
||||||
|
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
|
||||||
|
CompressPerThread tls;
|
||||||
|
|
||||||
hn::TFromD<DF> results[hn::MaxLanes(df)];
|
hn::TFromD<DF> results[hn::MaxLanes(df)];
|
||||||
for (size_t i = 0; i < hn::Lanes(df); ++i) {
|
for (size_t i = 0; i < hn::Lanes(df); ++i) {
|
||||||
results[i] = Dot(q.Row(0) + q_offsets[i], k.Row(k_pos), k.Cols());
|
CompressTraits<BF16>::Compress(df, q.Row(0) + q_offsets[i], qkv_dim, tls,
|
||||||
|
MakeSpan(q_bf, qkv_dim), 0);
|
||||||
|
results[i] =
|
||||||
|
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
|
||||||
}
|
}
|
||||||
return hn::LoadU(df, results);
|
return hn::LoadU(df, results);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns an NF Q rows by 8 K rows tile of Q.K dot products, in single
|
// Returns an NF Q rows by 8 K rows tile of Q.K dot products.
|
||||||
// precision.
|
|
||||||
// This is the result of NF rows of Q against 8 K timesteps, with positions
|
// This is the result of NF rows of Q against 8 K timesteps, with positions
|
||||||
// given by k_pos[0..7]. Q has been transposed so that the NF rows are read in
|
// given by k_pos[0..7]. Q has been transposed so that the NF rows are read in
|
||||||
// consecutive elements, and other columns by adding q_stride.
|
// consecutive elements, and other columns by adding q_stride.
|
||||||
template <class DF, class VF = hn::Vec<DF>>
|
template <class DF, class VF = hn::Vec<DF>>
|
||||||
void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
|
void QDotKTile(DF df, const BF16* HWY_RESTRICT q, const size_t q_stride,
|
||||||
const MatPtrT<KV_t>& k, const size_t* k_pos, VF& sum0,
|
const MatPtrT<KV_t>& k, const size_t* k_pos, VF& sum0, VF& sum1,
|
||||||
VF& sum1, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6,
|
VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7) {
|
||||||
VF& sum7) {
|
|
||||||
constexpr size_t kHTileSize = kNFx8HTileSize;
|
constexpr size_t kHTileSize = kNFx8HTileSize;
|
||||||
sum0 = hn::Zero(df);
|
sum0 = hn::Zero(df);
|
||||||
sum1 = hn::Zero(df);
|
sum1 = hn::Zero(df);
|
||||||
|
|
@ -211,8 +230,13 @@ void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
|
||||||
for (int i = 0; i < kHTileSize; ++i) {
|
for (int i = 0; i < kHTileSize; ++i) {
|
||||||
k_row[i] = k.Row(k_pos[i]);
|
k_row[i] = k.Row(k_pos[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const hn::Rebind<BF16, DF> dbfh;
|
||||||
|
using VBF = hn::Vec<decltype(dbfh)>;
|
||||||
|
|
||||||
for (size_t i = 0; i < k.Cols(); ++i) {
|
for (size_t i = 0; i < k.Cols(); ++i) {
|
||||||
VF q_vec = hn::Load(df, q);
|
const VBF q_vec_bf = hn::Load(dbfh, q);
|
||||||
|
const VF q_vec = hn::PromoteTo(df, q_vec_bf);
|
||||||
VF k_0 = hn::Set(df, k_row[0][i]);
|
VF k_0 = hn::Set(df, k_row[0][i]);
|
||||||
sum0 = hn::MulAdd(q_vec, k_0, sum0);
|
sum0 = hn::MulAdd(q_vec, k_0, sum0);
|
||||||
VF k_1 = hn::Set(df, k_row[1][i]);
|
VF k_1 = hn::Set(df, k_row[1][i]);
|
||||||
|
|
@ -264,17 +288,14 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
|
||||||
// Sweeps a tile of NF Q rows by 8 K timesteps accumulators from start_pos to
|
// Sweeps a tile of NF Q rows by 8 K timesteps accumulators from start_pos to
|
||||||
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
|
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
|
||||||
// max_last_pos].
|
// max_last_pos].
|
||||||
void TileFlashAttention(const MatPtrT<float>& q,
|
void TileFlashAttention(
|
||||||
const uint32_t* HWY_RESTRICT q_offsets,
|
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
||||||
const StridedView<float>& qT, const MatPtrT<KV_t>& k,
|
const StridedView<BF16>& qT, const MatPtrT<KV_t>& k, const size_t start_pos,
|
||||||
const size_t start_pos,
|
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
|
||||||
const uint32_t* HWY_RESTRICT last_pos,
|
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||||
const size_t min_last_pos, const size_t max_last_pos,
|
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out,
|
||||||
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx,
|
||||||
const AttentionActivationsPtrs& activations,
|
const size_t worker) {
|
||||||
MatPtrT<float>& att_out,
|
|
||||||
const uint32_t* HWY_RESTRICT out_offsets,
|
|
||||||
ThreadingContext& ctx, const size_t worker) {
|
|
||||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention);
|
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention);
|
||||||
constexpr int kHTileSize = kNFx8HTileSize;
|
constexpr int kHTileSize = kNFx8HTileSize;
|
||||||
using DF = hn::ScalableTag<float>;
|
using DF = hn::ScalableTag<float>;
|
||||||
|
|
@ -291,7 +312,7 @@ void TileFlashAttention(const MatPtrT<float>& q,
|
||||||
VI lasts = hn::LoadU(di, last_pos);
|
VI lasts = hn::LoadU(di, last_pos);
|
||||||
VF old_m = hn::Set(df, -std::numeric_limits<float>::max() / 2.0f);
|
VF old_m = hn::Set(df, -std::numeric_limits<float>::max() / 2.0f);
|
||||||
VF old_d = hn::Zero(df);
|
VF old_d = hn::Zero(df);
|
||||||
const float* HWY_RESTRICT qT_row = qT.Row(0);
|
const BF16* HWY_RESTRICT qT_row = qT.Row(0);
|
||||||
const size_t qT_stride = qT.Stride();
|
const size_t qT_stride = qT.Stride();
|
||||||
size_t position = start_pos;
|
size_t position = start_pos;
|
||||||
while (position + kHTileSize - 1 <= min_last_pos) {
|
while (position + kHTileSize - 1 <= min_last_pos) {
|
||||||
|
|
@ -300,8 +321,7 @@ void TileFlashAttention(const MatPtrT<float>& q,
|
||||||
k_pos[i] = activations.div_seq_len.Remainder(position + i);
|
k_pos[i] = activations.div_seq_len.Remainder(position + i);
|
||||||
}
|
}
|
||||||
VF x0, x1, x2, x3, x4, x5, x6, x7;
|
VF x0, x1, x2, x3, x4, x5, x6, x7;
|
||||||
QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6,
|
QDotKTile(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6, x7);
|
||||||
x7);
|
|
||||||
if (activations.config.att_cap > 0.0f) {
|
if (activations.config.att_cap > 0.0f) {
|
||||||
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
|
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
|
||||||
VF cap = hn::Set(df, activations.config.att_cap);
|
VF cap = hn::Set(df, activations.config.att_cap);
|
||||||
|
|
@ -390,13 +410,17 @@ void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
|
||||||
VI k_offsets_vec = hn::LoadU(di, k_offsets);
|
VI k_offsets_vec = hn::LoadU(di, k_offsets);
|
||||||
for (size_t i = 0; i < k.Cols(); ++i) {
|
for (size_t i = 0; i < k.Cols(); ++i) {
|
||||||
VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec);
|
VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec);
|
||||||
VF q_0 = hn::Set(df, q[q_offsets[0] + i]);
|
VF q_0 = hn::Set(df, hwy::ConvertScalarTo<float>(
|
||||||
|
hwy::ConvertScalarTo<BF16>(q[q_offsets[0] + i])));
|
||||||
sum0 = hn::MulAdd(q_0, k_vec, sum0);
|
sum0 = hn::MulAdd(q_0, k_vec, sum0);
|
||||||
VF q_1 = hn::Set(df, q[q_offsets[1] + i]);
|
VF q_1 = hn::Set(df, hwy::ConvertScalarTo<float>(
|
||||||
|
hwy::ConvertScalarTo<BF16>(q[q_offsets[1] + i])));
|
||||||
sum1 = hn::MulAdd(q_1, k_vec, sum1);
|
sum1 = hn::MulAdd(q_1, k_vec, sum1);
|
||||||
VF q_2 = hn::Set(df, q[q_offsets[2] + i]);
|
VF q_2 = hn::Set(df, hwy::ConvertScalarTo<float>(
|
||||||
|
hwy::ConvertScalarTo<BF16>(q[q_offsets[2] + i])));
|
||||||
sum2 = hn::MulAdd(q_2, k_vec, sum2);
|
sum2 = hn::MulAdd(q_2, k_vec, sum2);
|
||||||
VF q_3 = hn::Set(df, q[q_offsets[3] + i]);
|
VF q_3 = hn::Set(df, hwy::ConvertScalarTo<float>(
|
||||||
|
hwy::ConvertScalarTo<BF16>(q[q_offsets[3] + i])));
|
||||||
sum3 = hn::MulAdd(q_3, k_vec, sum3);
|
sum3 = hn::MulAdd(q_3, k_vec, sum3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -478,32 +502,50 @@ void TileFlashAttention4(const MatPtrT<float>& q,
|
||||||
out_offsets, v.Cols());
|
out_offsets, v.Cols());
|
||||||
position += kHTileSize;
|
position += kHTileSize;
|
||||||
}
|
}
|
||||||
|
const hn::ScalableTag<BF16> dbf;
|
||||||
|
const size_t qkv_dim = k.Cols();
|
||||||
|
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
|
||||||
|
CompressPerThread tls;
|
||||||
|
const hn::ScalableTag<float> df_compress;
|
||||||
|
|
||||||
while (position <= max_last_pos) {
|
while (position <= max_last_pos) {
|
||||||
size_t k_pos = activations.div_seq_len.Remainder(position);
|
size_t k_pos = activations.div_seq_len.Remainder(position);
|
||||||
if (position <= last_pos[0]) {
|
if (position <= last_pos[0]) {
|
||||||
// Past the last position, x0 doesn't count.
|
// Past the last position, x0 doesn't count.
|
||||||
float x0 = Dot(q.Row(0) + q_offsets[0], k.Row(k_pos), k.Cols());
|
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[0],
|
||||||
|
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
|
||||||
|
float x0 =
|
||||||
|
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
|
||||||
SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0,
|
SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0,
|
||||||
v.Row(k_pos), v.Cols(),
|
v.Row(k_pos), v.Cols(),
|
||||||
att_out.Row(0) + out_offsets[0]);
|
att_out.Row(0) + out_offsets[0]);
|
||||||
}
|
}
|
||||||
if (position <= last_pos[1]) {
|
if (position <= last_pos[1]) {
|
||||||
// Past the last position, x1 doesn't count.
|
// Past the last position, x1 doesn't count.
|
||||||
float x1 = Dot(q.Row(0) + q_offsets[1], k.Row(k_pos), k.Cols());
|
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[1],
|
||||||
|
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
|
||||||
|
float x1 =
|
||||||
|
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
|
||||||
SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1,
|
SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1,
|
||||||
v.Row(k_pos), v.Cols(),
|
v.Row(k_pos), v.Cols(),
|
||||||
att_out.Row(0) + out_offsets[1]);
|
att_out.Row(0) + out_offsets[1]);
|
||||||
}
|
}
|
||||||
if (position <= last_pos[2]) {
|
if (position <= last_pos[2]) {
|
||||||
// Past the last position, x2 doesn't count.
|
// Past the last position, x2 doesn't count.
|
||||||
float x2 = Dot(q.Row(0) + q_offsets[2], k.Row(k_pos), k.Cols());
|
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[2],
|
||||||
|
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
|
||||||
|
float x2 =
|
||||||
|
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
|
||||||
SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2,
|
SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2,
|
||||||
v.Row(k_pos), v.Cols(),
|
v.Row(k_pos), v.Cols(),
|
||||||
att_out.Row(0) + out_offsets[2]);
|
att_out.Row(0) + out_offsets[2]);
|
||||||
}
|
}
|
||||||
if (position <= last_pos[3]) {
|
if (position <= last_pos[3]) {
|
||||||
// Past the last position, x3 doesn't count.
|
// Past the last position, x3 doesn't count.
|
||||||
float x3 = Dot(q.Row(0) + q_offsets[3], k.Row(k_pos), k.Cols());
|
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[3],
|
||||||
|
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
|
||||||
|
float x3 =
|
||||||
|
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
|
||||||
SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3,
|
SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3,
|
||||||
v.Row(k_pos), v.Cols(),
|
v.Row(k_pos), v.Cols(),
|
||||||
att_out.Row(0) + out_offsets[3]);
|
att_out.Row(0) + out_offsets[3]);
|
||||||
|
|
@ -722,9 +764,9 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
||||||
// To avoid duplicating the code to setup K and V, the call to
|
// To avoid duplicating the code to setup K and V, the call to
|
||||||
// TileFlashAttention is inside the loop over tasks, even though it
|
// TileFlashAttention is inside the loop over tasks, even though it
|
||||||
// handles all rows in the task at once.
|
// handles all rows in the task at once.
|
||||||
StridedView<float> qT =
|
StridedView<BF16> qT =
|
||||||
StridedView<float>(activations.q_T.Row(0) + first_task, kVTileSize,
|
StridedView<BF16>(activations.q_T.Row(0) + first_task, kVTileSize,
|
||||||
activations.q_T.Stride());
|
activations.q_T.Stride());
|
||||||
if (kVTileSize == kNF) {
|
if (kVTileSize == kNF) {
|
||||||
// We can still use TileFlashAttention even if we didn't transpose Q
|
// We can still use TileFlashAttention even if we didn't transpose Q
|
||||||
// above. The condition used for transposing Q above is more general
|
// above. The condition used for transposing Q above is more general
|
||||||
|
|
|
||||||
|
|
@ -413,7 +413,8 @@ using DotKernelDefault =
|
||||||
template <class D, typename WT, typename VT>
|
template <class D, typename WT, typename VT>
|
||||||
HWY_INLINE float Dot(D d, const PackedSpan<const WT>& w, size_t w_ofs,
|
HWY_INLINE float Dot(D d, const PackedSpan<const WT>& w, size_t w_ofs,
|
||||||
const VT* HWY_RESTRICT vec, size_t num) {
|
const VT* HWY_RESTRICT vec, size_t num) {
|
||||||
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelDefault());
|
return DecompressAndCall(d, w, w_ofs, MakeConstSpan(vec, num),
|
||||||
|
DotKernelDefault());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adapter for two pointers, no bounds checking.
|
// Adapter for two pointers, no bounds checking.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue