Pre-compress query activations to BF16 before FlashAttention.

PiperOrigin-RevId: 826524997
This commit is contained in:
Phil Culliton 2025-10-31 09:49:07 -07:00 committed by Copybara-Service
parent 8a100c1e8d
commit ab87807a4c
10 changed files with 59 additions and 59 deletions

View File

@ -38,7 +38,10 @@ namespace gcpp {
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference)
: ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) {
: initializer_value_(gcpp::InternalInit()),
ctx_(threading),
env_(ctx_),
gemma_(loader, inference, ctx_) {
const ModelConfig& config = gemma_.Config();
// Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.push_back(KVCache(config, inference, ctx_.allocator));

View File

@ -125,6 +125,8 @@ class GemmaEnv {
MatMulEnv& MutableEnv() { return env_; }
private:
// This is used to ensure that InternalInit is called before anything else.
int initializer_value_ = 0;
ThreadingContext ctx_;
MatMulEnv env_;
Gemma gemma_;

View File

@ -153,5 +153,3 @@ int main(int argc, char** argv) {
return RUN_ALL_TESTS();
}

View File

@ -181,7 +181,6 @@ TEST_F(GemmaTest, CrossEntropySmall) {
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
gcpp::InternalInit();
gcpp::GemmaTest::InitEnv(argc, argv);
int ret = RUN_ALL_TESTS();
gcpp::GemmaTest::DeleteEnv();

View File

@ -54,6 +54,11 @@ struct AttentionActivations {
? layer_config.heads * 3 * layer_config.qkv_dim
: layer_config.heads * layer_config.qkv_dim,
allocator)),
q_bf(MatFactory("q_bf", batch_size,
config.vocab_size == 0
? layer_config.heads * 3 * layer_config.qkv_dim
: layer_config.heads * layer_config.qkv_dim,
allocator)),
q_T(MatFactory("q_T", layer_config.qkv_dim,
config.vocab_size == 0
? batch_size * layer_config.heads * 3
@ -88,12 +93,14 @@ struct AttentionActivations {
// If we forget any MatMul outputs here, debug builds print a warning but
// fill them in each MatMul call.
q.AllocateAndAttachRowPtrs(row_ptrs);
q_bf.AllocateAndAttachRowPtrs(row_ptrs);
q_T.AllocateAndAttachRowPtrs(row_ptrs);
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
}
void SetBatchSize(size_t batch_size) {
q.OverrideRows(batch_size);
q_bf.OverrideRows(batch_size);
// q_T rows are always qkv_dim!
pre_att_rms_out.OverrideRows(batch_size);
@ -105,6 +112,7 @@ struct AttentionActivations {
}
MatStorageT<float> q; // query
MatStorageT<BF16> q_bf;
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
MatStorageT<float> pre_att_rms_out;
@ -130,6 +138,7 @@ struct AttentionActivationsPtrs {
const AttentionActivations& activations)
: AttentionActivationsPtrs(config, seq_len) {
q = activations.q;
q_bf = activations.q_bf;
q_T = activations.q_T;
pre_att_rms_out = activations.pre_att_rms_out;
att = activations.att;
@ -141,6 +150,7 @@ struct AttentionActivationsPtrs {
void SetBatchSize(size_t batch_size) {
q.OverrideRows(batch_size);
q_bf.OverrideRows(batch_size);
// q_T rows are always qkv_dim!
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
@ -151,6 +161,7 @@ struct AttentionActivationsPtrs {
const ModelConfig& config;
MatPtrT<float> q;
MatPtrT<BF16> q_bf;
MatPtrT<BF16> q_T;
MatPtrT<float> pre_att_rms_out;
MatPtrT<float> att;

View File

@ -154,7 +154,7 @@ void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max,
// Calculates the complete attention outputs for a single row of q.
void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
const float* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
const BF16* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const AttentionActivationsPtrs& activations,
float* HWY_RESTRICT att_out, ThreadingContext& ctx,
@ -162,17 +162,12 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention);
const hn::ScalableTag<BF16> dbf;
const size_t qkv_dim = k.Cols();
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
CompressPerThread tls;
const hn::ScalableTag<float> df;
CompressTraits<BF16>::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim),
0);
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
// TODO: Mixed-mode can be further improved for Turin: we can demote right
// before we do the dot product instruction, rather than promote both to f32.
// But some potential accuracy loss there, needs evaluation first.
float m = Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
float m = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
if (float cap = activations.config.att_cap; cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
m = cap * std::tanh(m / cap);
@ -182,8 +177,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker);
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
float x =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
float x = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
SingleFlashAttentionStep(x, activations.config.att_cap, m, d,
v.Row(pos_mod), v.Cols(), att_out);
}
@ -193,19 +187,15 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
// the dot products of NF rows of Q for a single K timestep.
template <class DF, class VF = hn::Vec<DF>>
VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
const size_t k_pos, const MatPtrT<float>& q,
const size_t k_pos, const MatPtrT<BF16>& q,
const MatPtrT<KV_t>& k) {
const hn::ScalableTag<BF16> dbf;
const size_t qkv_dim = k.Cols();
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
CompressPerThread tls;
hn::TFromD<DF> results[hn::MaxLanes(df)];
for (size_t i = 0; i < hn::Lanes(df); ++i) {
CompressTraits<BF16>::Compress(df, q.Row(0) + q_offsets[i], qkv_dim, tls,
MakeSpan(q_bf, qkv_dim), 0);
results[i] =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
results[i] = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[i], qkv_dim), 0,
k.Row(k_pos), qkv_dim);
}
return hn::LoadU(df, results);
}
@ -290,7 +280,7 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
// max_last_pos].
void TileFlashAttention(
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets,
const StridedView<BF16>& qT, const MatPtrT<KV_t>& k, const size_t start_pos,
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
@ -396,7 +386,7 @@ void TileFlashAttention(
// This is the result of 4 rows of Q against NF K timesteps, with positions
// given by k_offsets[0..NF].
template <class DF, class VF = hn::Vec<DF>>
void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
void QDotKTilex4(DF df, const BF16* HWY_RESTRICT q,
const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<KV_t>& k,
const int32_t* HWY_RESTRICT k_offsets, VF& sum0, VF& sum1,
VF& sum2, VF& sum3) {
@ -411,17 +401,13 @@ void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
VI k_offsets_vec = hn::LoadU(di, k_offsets);
for (size_t i = 0; i < k.Cols(); ++i) {
VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec);
VF q_0 = hn::Set(df, hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(q[q_offsets[0] + i])));
VF q_0 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[0] + i]));
sum0 = hn::MulAdd(q_0, k_vec, sum0);
VF q_1 = hn::Set(df, hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(q[q_offsets[1] + i])));
VF q_1 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[1] + i]));
sum1 = hn::MulAdd(q_1, k_vec, sum1);
VF q_2 = hn::Set(df, hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(q[q_offsets[2] + i])));
VF q_2 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[2] + i]));
sum2 = hn::MulAdd(q_2, k_vec, sum2);
VF q_3 = hn::Set(df, hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(q[q_offsets[3] + i])));
VF q_3 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[3] + i]));
sum3 = hn::MulAdd(q_3, k_vec, sum3);
}
}
@ -446,7 +432,7 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
// max_last_pos].
Tile4FlashState TileFlashAttention4(
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets,
const MatPtrT<KV_t>& k, const size_t start_pos,
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
@ -500,18 +486,13 @@ Tile4FlashState TileFlashAttention4(
}
const hn::ScalableTag<BF16> dbf;
const size_t qkv_dim = k.Cols();
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
CompressPerThread tls;
const hn::ScalableTag<float> df_compress;
while (position <= max_last_pos) {
size_t k_pos = activations.div_seq_len.Remainder(position);
if (position <= last_pos[0]) {
// Past the last position, x0 doesn't count.
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[0],
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x0 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
float x0 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[0], qkv_dim), 0,
k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x0, activations.config.att_cap,
state.row_states[0].max, state.row_states[0].d,
v.Row(k_pos), v.Cols(),
@ -519,10 +500,8 @@ Tile4FlashState TileFlashAttention4(
}
if (position <= last_pos[1]) {
// Past the last position, x1 doesn't count.
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[1],
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x1 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
float x1 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[1], qkv_dim), 0,
k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x1, activations.config.att_cap,
state.row_states[1].max, state.row_states[1].d,
v.Row(k_pos), v.Cols(),
@ -530,10 +509,8 @@ Tile4FlashState TileFlashAttention4(
}
if (position <= last_pos[2]) {
// Past the last position, x2 doesn't count.
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[2],
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x2 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
float x2 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[2], qkv_dim), 0,
k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x2, activations.config.att_cap,
state.row_states[2].max, state.row_states[2].d,
v.Row(k_pos), v.Cols(),
@ -541,10 +518,8 @@ Tile4FlashState TileFlashAttention4(
}
if (position <= last_pos[3]) {
// Past the last position, x3 doesn't count.
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[3],
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x3 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
float x3 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[3], qkv_dim), 0,
k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x3, activations.config.att_cap,
state.row_states[3].max, state.row_states[3].d,
v.Row(k_pos), v.Cols(),
@ -642,6 +617,17 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q,
query_norm_scale, layer_idx, activations, ctx);
const hwy::Divisor div_qbatch(qbatch.Size());
// Compress q to q_bf.
ParallelFor(
ParallelismStrategy::kWithinCluster, activations.q.Rows(), ctx,
/*cluster_idx=*/0, Callers::kFlashAttention,
[&](size_t row, size_t worker) {
CompressPerThread tls;
const hn::ScalableTag<float> df;
CompressTraits<BF16>::Compress(
df, activations.q.Row(row), activations.q.Cols(), tls,
MakeSpan(activations.q_bf.Row(row), activations.q_bf.Cols()), 0);
});
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
const size_t qkv_dim = layer_config.qkv_dim;
@ -736,8 +722,8 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
last_pos[offset] = last;
min_last_pos = HWY_MIN(min_last_pos, last);
max_last_pos = HWY_MAX(max_last_pos, last);
q_offsets[offset] =
activations.q.Row(tq_idx) + head * qkv_dim - activations.q.Row(0);
q_offsets[offset] = activations.q_bf.Row(tq_idx) + head * qkv_dim -
activations.q_bf.Row(0);
out_offsets[offset] = activations.att_out.Row(tq_idx) + head * qkv_dim -
activations.att_out.Row(0);
const size_t kv_index = head / kHeadGroups;
@ -776,12 +762,12 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
// kNFx8HTileSize. In this case, qT is never used. Some tasks might
// use qT and some might not, which is why the more general condition
// is used above to catch all cases where qT will be used.
TileFlashAttention(activations.q, q_offsets, qT, k,
TileFlashAttention(activations.q_bf, q_offsets, qT, k,
start_positions[offset], last_pos, min_last_pos,
max_last_pos, v, layer_idx, activations,
activations.att_out, out_offsets, ctx, worker);
} else if (kVTileSize == 4) {
TileFlashAttention4(activations.q, q_offsets, k,
TileFlashAttention4(activations.q_bf, q_offsets, k,
start_positions[offset], last_pos, min_last_pos,
max_last_pos, v, layer_idx, activations,
activations.att_out, out_offsets, ctx, worker);
@ -791,7 +777,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
break;
} else {
SingleFlashAttention(start_positions[offset], last_pos[offset],
activations.q.Row(0) + q_offsets[offset], k, v,
activations.q_bf.Row(0) + q_offsets[offset], k, v,
layer_idx, activations,
activations.att_out.Row(0) + out_offsets[offset],
ctx, worker);

View File

@ -45,7 +45,7 @@ namespace gcpp {
ThreadingContext& ctx, size_t worker); \
\
Tile4FlashState TileFlashAttention4( \
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets, \
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \
const MatPtrT<KV_t>& k, size_t start_pos, \
const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \

View File

@ -236,7 +236,9 @@ bool IOBatch::Add(void* mem, size_t bytes) {
return true;
}
void InternalInit() {
int InternalInit() {
// currently unused, except for init list ordering in GemmaEnv.
return 0;
}
uint64_t IOBatch::Read(const File& file) const {

View File

@ -150,7 +150,7 @@ std::string ReadFileToString(const Path& path);
// No-op in open-source. Must be called at the beginning of a binary, before
// any I/O or flag usage.
void InternalInit();
int InternalInit();
} // namespace gcpp

View File

@ -72,7 +72,6 @@ TEST_F(PaliGemmaTest, QueryObjects) {
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
gcpp::InternalInit();
gcpp::GemmaEnv env(argc, argv);
gcpp::s_env = &env;