Pre-compress query activations to BF16 before FlashAttention.

PiperOrigin-RevId: 826524997
This commit is contained in:
Phil Culliton 2025-10-31 09:49:07 -07:00 committed by Copybara-Service
parent 8a100c1e8d
commit ab87807a4c
10 changed files with 59 additions and 59 deletions

View File

@ -38,7 +38,10 @@ namespace gcpp {
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference) const InferenceArgs& inference)
: ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) { : initializer_value_(gcpp::InternalInit()),
ctx_(threading),
env_(ctx_),
gemma_(loader, inference, ctx_) {
const ModelConfig& config = gemma_.Config(); const ModelConfig& config = gemma_.Config();
// Only allocate one for starters because GenerateBatch might not be called. // Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.push_back(KVCache(config, inference, ctx_.allocator)); kv_caches_.push_back(KVCache(config, inference, ctx_.allocator));

View File

@ -125,6 +125,8 @@ class GemmaEnv {
MatMulEnv& MutableEnv() { return env_; } MatMulEnv& MutableEnv() { return env_; }
private: private:
// This is used to ensure that InternalInit is called before anything else.
int initializer_value_ = 0;
ThreadingContext ctx_; ThreadingContext ctx_;
MatMulEnv env_; MatMulEnv env_;
Gemma gemma_; Gemma gemma_;

View File

@ -153,5 +153,3 @@ int main(int argc, char** argv) {
return RUN_ALL_TESTS(); return RUN_ALL_TESTS();
} }

View File

@ -181,7 +181,6 @@ TEST_F(GemmaTest, CrossEntropySmall) {
int main(int argc, char** argv) { int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
gcpp::InternalInit();
gcpp::GemmaTest::InitEnv(argc, argv); gcpp::GemmaTest::InitEnv(argc, argv);
int ret = RUN_ALL_TESTS(); int ret = RUN_ALL_TESTS();
gcpp::GemmaTest::DeleteEnv(); gcpp::GemmaTest::DeleteEnv();

View File

@ -54,6 +54,11 @@ struct AttentionActivations {
? layer_config.heads * 3 * layer_config.qkv_dim ? layer_config.heads * 3 * layer_config.qkv_dim
: layer_config.heads * layer_config.qkv_dim, : layer_config.heads * layer_config.qkv_dim,
allocator)), allocator)),
q_bf(MatFactory("q_bf", batch_size,
config.vocab_size == 0
? layer_config.heads * 3 * layer_config.qkv_dim
: layer_config.heads * layer_config.qkv_dim,
allocator)),
q_T(MatFactory("q_T", layer_config.qkv_dim, q_T(MatFactory("q_T", layer_config.qkv_dim,
config.vocab_size == 0 config.vocab_size == 0
? batch_size * layer_config.heads * 3 ? batch_size * layer_config.heads * 3
@ -88,12 +93,14 @@ struct AttentionActivations {
// If we forget any MatMul outputs here, debug builds print a warning but // If we forget any MatMul outputs here, debug builds print a warning but
// fill them in each MatMul call. // fill them in each MatMul call.
q.AllocateAndAttachRowPtrs(row_ptrs); q.AllocateAndAttachRowPtrs(row_ptrs);
q_bf.AllocateAndAttachRowPtrs(row_ptrs);
q_T.AllocateAndAttachRowPtrs(row_ptrs); q_T.AllocateAndAttachRowPtrs(row_ptrs);
att_sums.AllocateAndAttachRowPtrs(row_ptrs); att_sums.AllocateAndAttachRowPtrs(row_ptrs);
} }
void SetBatchSize(size_t batch_size) { void SetBatchSize(size_t batch_size) {
q.OverrideRows(batch_size); q.OverrideRows(batch_size);
q_bf.OverrideRows(batch_size);
// q_T rows are always qkv_dim! // q_T rows are always qkv_dim!
pre_att_rms_out.OverrideRows(batch_size); pre_att_rms_out.OverrideRows(batch_size);
@ -105,6 +112,7 @@ struct AttentionActivations {
} }
MatStorageT<float> q; // query MatStorageT<float> q; // query
MatStorageT<BF16> q_bf;
MatStorageT<BF16> q_T; // Transposed to maximize attention speed. MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
MatStorageT<float> pre_att_rms_out; MatStorageT<float> pre_att_rms_out;
@ -130,6 +138,7 @@ struct AttentionActivationsPtrs {
const AttentionActivations& activations) const AttentionActivations& activations)
: AttentionActivationsPtrs(config, seq_len) { : AttentionActivationsPtrs(config, seq_len) {
q = activations.q; q = activations.q;
q_bf = activations.q_bf;
q_T = activations.q_T; q_T = activations.q_T;
pre_att_rms_out = activations.pre_att_rms_out; pre_att_rms_out = activations.pre_att_rms_out;
att = activations.att; att = activations.att;
@ -141,6 +150,7 @@ struct AttentionActivationsPtrs {
void SetBatchSize(size_t batch_size) { void SetBatchSize(size_t batch_size) {
q.OverrideRows(batch_size); q.OverrideRows(batch_size);
q_bf.OverrideRows(batch_size);
// q_T rows are always qkv_dim! // q_T rows are always qkv_dim!
pre_att_rms_out.OverrideRows(batch_size); pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size); att.OverrideRows(batch_size);
@ -151,6 +161,7 @@ struct AttentionActivationsPtrs {
const ModelConfig& config; const ModelConfig& config;
MatPtrT<float> q; MatPtrT<float> q;
MatPtrT<BF16> q_bf;
MatPtrT<BF16> q_T; MatPtrT<BF16> q_T;
MatPtrT<float> pre_att_rms_out; MatPtrT<float> pre_att_rms_out;
MatPtrT<float> att; MatPtrT<float> att;

View File

@ -154,7 +154,7 @@ void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max,
// Calculates the complete attention outputs for a single row of q. // Calculates the complete attention outputs for a single row of q.
void SingleFlashAttention(const size_t start_pos, const size_t last_pos, void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
const float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const BF16* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
const MatPtrT<KV_t>& v, const size_t layer_idx, const MatPtrT<KV_t>& v, const size_t layer_idx,
const AttentionActivationsPtrs& activations, const AttentionActivationsPtrs& activations,
float* HWY_RESTRICT att_out, ThreadingContext& ctx, float* HWY_RESTRICT att_out, ThreadingContext& ctx,
@ -162,17 +162,12 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention); GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention);
const hn::ScalableTag<BF16> dbf; const hn::ScalableTag<BF16> dbf;
const size_t qkv_dim = k.Cols(); const size_t qkv_dim = k.Cols();
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
CompressPerThread tls;
const hn::ScalableTag<float> df;
CompressTraits<BF16>::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim),
0);
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos); const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
// TODO: Mixed-mode can be further improved for Turin: we can demote right // TODO: Mixed-mode can be further improved for Turin: we can demote right
// before we do the dot product instruction, rather than promote both to f32. // before we do the dot product instruction, rather than promote both to f32.
// But some potential accuracy loss there, needs evaluation first. // But some potential accuracy loss there, needs evaluation first.
float m = Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim); float m = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
if (float cap = activations.config.att_cap; cap > 0.0f) { if (float cap = activations.config.att_cap; cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x. // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
m = cap * std::tanh(m / cap); m = cap * std::tanh(m / cap);
@ -182,8 +177,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker); MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker);
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = activations.div_seq_len.Remainder(pos); const size_t pos_mod = activations.div_seq_len.Remainder(pos);
float x = float x = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
SingleFlashAttentionStep(x, activations.config.att_cap, m, d, SingleFlashAttentionStep(x, activations.config.att_cap, m, d,
v.Row(pos_mod), v.Cols(), att_out); v.Row(pos_mod), v.Cols(), att_out);
} }
@ -193,19 +187,15 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
// the dot products of NF rows of Q for a single K timestep. // the dot products of NF rows of Q for a single K timestep.
template <class DF, class VF = hn::Vec<DF>> template <class DF, class VF = hn::Vec<DF>>
VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets, VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
const size_t k_pos, const MatPtrT<float>& q, const size_t k_pos, const MatPtrT<BF16>& q,
const MatPtrT<KV_t>& k) { const MatPtrT<KV_t>& k) {
const hn::ScalableTag<BF16> dbf; const hn::ScalableTag<BF16> dbf;
const size_t qkv_dim = k.Cols(); const size_t qkv_dim = k.Cols();
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
CompressPerThread tls;
hn::TFromD<DF> results[hn::MaxLanes(df)]; hn::TFromD<DF> results[hn::MaxLanes(df)];
for (size_t i = 0; i < hn::Lanes(df); ++i) { for (size_t i = 0; i < hn::Lanes(df); ++i) {
CompressTraits<BF16>::Compress(df, q.Row(0) + q_offsets[i], qkv_dim, tls, results[i] = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[i], qkv_dim), 0,
MakeSpan(q_bf, qkv_dim), 0); k.Row(k_pos), qkv_dim);
results[i] =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
} }
return hn::LoadU(df, results); return hn::LoadU(df, results);
} }
@ -290,7 +280,7 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, // min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
// max_last_pos]. // max_last_pos].
void TileFlashAttention( void TileFlashAttention(
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets,
const StridedView<BF16>& qT, const MatPtrT<KV_t>& k, const size_t start_pos, const StridedView<BF16>& qT, const MatPtrT<KV_t>& k, const size_t start_pos,
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos, const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx, const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
@ -396,7 +386,7 @@ void TileFlashAttention(
// This is the result of 4 rows of Q against NF K timesteps, with positions // This is the result of 4 rows of Q against NF K timesteps, with positions
// given by k_offsets[0..NF]. // given by k_offsets[0..NF].
template <class DF, class VF = hn::Vec<DF>> template <class DF, class VF = hn::Vec<DF>>
void QDotKTilex4(DF df, const float* HWY_RESTRICT q, void QDotKTilex4(DF df, const BF16* HWY_RESTRICT q,
const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<KV_t>& k, const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<KV_t>& k,
const int32_t* HWY_RESTRICT k_offsets, VF& sum0, VF& sum1, const int32_t* HWY_RESTRICT k_offsets, VF& sum0, VF& sum1,
VF& sum2, VF& sum3) { VF& sum2, VF& sum3) {
@ -411,17 +401,13 @@ void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
VI k_offsets_vec = hn::LoadU(di, k_offsets); VI k_offsets_vec = hn::LoadU(di, k_offsets);
for (size_t i = 0; i < k.Cols(); ++i) { for (size_t i = 0; i < k.Cols(); ++i) {
VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec); VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec);
VF q_0 = hn::Set(df, hwy::ConvertScalarTo<float>( VF q_0 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[0] + i]));
hwy::ConvertScalarTo<BF16>(q[q_offsets[0] + i])));
sum0 = hn::MulAdd(q_0, k_vec, sum0); sum0 = hn::MulAdd(q_0, k_vec, sum0);
VF q_1 = hn::Set(df, hwy::ConvertScalarTo<float>( VF q_1 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[1] + i]));
hwy::ConvertScalarTo<BF16>(q[q_offsets[1] + i])));
sum1 = hn::MulAdd(q_1, k_vec, sum1); sum1 = hn::MulAdd(q_1, k_vec, sum1);
VF q_2 = hn::Set(df, hwy::ConvertScalarTo<float>( VF q_2 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[2] + i]));
hwy::ConvertScalarTo<BF16>(q[q_offsets[2] + i])));
sum2 = hn::MulAdd(q_2, k_vec, sum2); sum2 = hn::MulAdd(q_2, k_vec, sum2);
VF q_3 = hn::Set(df, hwy::ConvertScalarTo<float>( VF q_3 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[3] + i]));
hwy::ConvertScalarTo<BF16>(q[q_offsets[3] + i])));
sum3 = hn::MulAdd(q_3, k_vec, sum3); sum3 = hn::MulAdd(q_3, k_vec, sum3);
} }
} }
@ -446,7 +432,7 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, // min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
// max_last_pos]. // max_last_pos].
Tile4FlashState TileFlashAttention4( Tile4FlashState TileFlashAttention4(
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets,
const MatPtrT<KV_t>& k, const size_t start_pos, const MatPtrT<KV_t>& k, const size_t start_pos,
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos, const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx, const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
@ -500,18 +486,13 @@ Tile4FlashState TileFlashAttention4(
} }
const hn::ScalableTag<BF16> dbf; const hn::ScalableTag<BF16> dbf;
const size_t qkv_dim = k.Cols(); const size_t qkv_dim = k.Cols();
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
CompressPerThread tls;
const hn::ScalableTag<float> df_compress;
while (position <= max_last_pos) { while (position <= max_last_pos) {
size_t k_pos = activations.div_seq_len.Remainder(position); size_t k_pos = activations.div_seq_len.Remainder(position);
if (position <= last_pos[0]) { if (position <= last_pos[0]) {
// Past the last position, x0 doesn't count. // Past the last position, x0 doesn't count.
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[0], float x0 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[0], qkv_dim), 0,
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0); k.Row(k_pos), qkv_dim);
float x0 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x0, activations.config.att_cap, SingleFlashAttentionStep(x0, activations.config.att_cap,
state.row_states[0].max, state.row_states[0].d, state.row_states[0].max, state.row_states[0].d,
v.Row(k_pos), v.Cols(), v.Row(k_pos), v.Cols(),
@ -519,10 +500,8 @@ Tile4FlashState TileFlashAttention4(
} }
if (position <= last_pos[1]) { if (position <= last_pos[1]) {
// Past the last position, x1 doesn't count. // Past the last position, x1 doesn't count.
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[1], float x1 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[1], qkv_dim), 0,
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0); k.Row(k_pos), qkv_dim);
float x1 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x1, activations.config.att_cap, SingleFlashAttentionStep(x1, activations.config.att_cap,
state.row_states[1].max, state.row_states[1].d, state.row_states[1].max, state.row_states[1].d,
v.Row(k_pos), v.Cols(), v.Row(k_pos), v.Cols(),
@ -530,10 +509,8 @@ Tile4FlashState TileFlashAttention4(
} }
if (position <= last_pos[2]) { if (position <= last_pos[2]) {
// Past the last position, x2 doesn't count. // Past the last position, x2 doesn't count.
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[2], float x2 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[2], qkv_dim), 0,
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0); k.Row(k_pos), qkv_dim);
float x2 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x2, activations.config.att_cap, SingleFlashAttentionStep(x2, activations.config.att_cap,
state.row_states[2].max, state.row_states[2].d, state.row_states[2].max, state.row_states[2].d,
v.Row(k_pos), v.Cols(), v.Row(k_pos), v.Cols(),
@ -541,10 +518,8 @@ Tile4FlashState TileFlashAttention4(
} }
if (position <= last_pos[3]) { if (position <= last_pos[3]) {
// Past the last position, x3 doesn't count. // Past the last position, x3 doesn't count.
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[3], float x3 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[3], qkv_dim), 0,
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0); k.Row(k_pos), qkv_dim);
float x3 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x3, activations.config.att_cap, SingleFlashAttentionStep(x3, activations.config.att_cap,
state.row_states[3].max, state.row_states[3].d, state.row_states[3].max, state.row_states[3].d,
v.Row(k_pos), v.Cols(), v.Row(k_pos), v.Cols(),
@ -642,6 +617,17 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q,
query_norm_scale, layer_idx, activations, ctx); query_norm_scale, layer_idx, activations, ctx);
const hwy::Divisor div_qbatch(qbatch.Size()); const hwy::Divisor div_qbatch(qbatch.Size());
// Compress q to q_bf.
ParallelFor(
ParallelismStrategy::kWithinCluster, activations.q.Rows(), ctx,
/*cluster_idx=*/0, Callers::kFlashAttention,
[&](size_t row, size_t worker) {
CompressPerThread tls;
const hn::ScalableTag<float> df;
CompressTraits<BF16>::Compress(
df, activations.q.Row(row), activations.q.Cols(), tls,
MakeSpan(activations.q_bf.Row(row), activations.q_bf.Cols()), 0);
});
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
const size_t qkv_dim = layer_config.qkv_dim; const size_t qkv_dim = layer_config.qkv_dim;
@ -736,8 +722,8 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
last_pos[offset] = last; last_pos[offset] = last;
min_last_pos = HWY_MIN(min_last_pos, last); min_last_pos = HWY_MIN(min_last_pos, last);
max_last_pos = HWY_MAX(max_last_pos, last); max_last_pos = HWY_MAX(max_last_pos, last);
q_offsets[offset] = q_offsets[offset] = activations.q_bf.Row(tq_idx) + head * qkv_dim -
activations.q.Row(tq_idx) + head * qkv_dim - activations.q.Row(0); activations.q_bf.Row(0);
out_offsets[offset] = activations.att_out.Row(tq_idx) + head * qkv_dim - out_offsets[offset] = activations.att_out.Row(tq_idx) + head * qkv_dim -
activations.att_out.Row(0); activations.att_out.Row(0);
const size_t kv_index = head / kHeadGroups; const size_t kv_index = head / kHeadGroups;
@ -776,12 +762,12 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
// kNFx8HTileSize. In this case, qT is never used. Some tasks might // kNFx8HTileSize. In this case, qT is never used. Some tasks might
// use qT and some might not, which is why the more general condition // use qT and some might not, which is why the more general condition
// is used above to catch all cases where qT will be used. // is used above to catch all cases where qT will be used.
TileFlashAttention(activations.q, q_offsets, qT, k, TileFlashAttention(activations.q_bf, q_offsets, qT, k,
start_positions[offset], last_pos, min_last_pos, start_positions[offset], last_pos, min_last_pos,
max_last_pos, v, layer_idx, activations, max_last_pos, v, layer_idx, activations,
activations.att_out, out_offsets, ctx, worker); activations.att_out, out_offsets, ctx, worker);
} else if (kVTileSize == 4) { } else if (kVTileSize == 4) {
TileFlashAttention4(activations.q, q_offsets, k, TileFlashAttention4(activations.q_bf, q_offsets, k,
start_positions[offset], last_pos, min_last_pos, start_positions[offset], last_pos, min_last_pos,
max_last_pos, v, layer_idx, activations, max_last_pos, v, layer_idx, activations,
activations.att_out, out_offsets, ctx, worker); activations.att_out, out_offsets, ctx, worker);
@ -791,7 +777,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
break; break;
} else { } else {
SingleFlashAttention(start_positions[offset], last_pos[offset], SingleFlashAttention(start_positions[offset], last_pos[offset],
activations.q.Row(0) + q_offsets[offset], k, v, activations.q_bf.Row(0) + q_offsets[offset], k, v,
layer_idx, activations, layer_idx, activations,
activations.att_out.Row(0) + out_offsets[offset], activations.att_out.Row(0) + out_offsets[offset],
ctx, worker); ctx, worker);

View File

@ -45,7 +45,7 @@ namespace gcpp {
ThreadingContext& ctx, size_t worker); \ ThreadingContext& ctx, size_t worker); \
\ \
Tile4FlashState TileFlashAttention4( \ Tile4FlashState TileFlashAttention4( \
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets, \ const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \
const MatPtrT<KV_t>& k, size_t start_pos, \ const MatPtrT<KV_t>& k, size_t start_pos, \
const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \ const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \ size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \

View File

@ -236,7 +236,9 @@ bool IOBatch::Add(void* mem, size_t bytes) {
return true; return true;
} }
void InternalInit() { int InternalInit() {
// currently unused, except for init list ordering in GemmaEnv.
return 0;
} }
uint64_t IOBatch::Read(const File& file) const { uint64_t IOBatch::Read(const File& file) const {

View File

@ -150,7 +150,7 @@ std::string ReadFileToString(const Path& path);
// No-op in open-source. Must be called at the beginning of a binary, before // No-op in open-source. Must be called at the beginning of a binary, before
// any I/O or flag usage. // any I/O or flag usage.
void InternalInit(); int InternalInit();
} // namespace gcpp } // namespace gcpp

View File

@ -72,7 +72,6 @@ TEST_F(PaliGemmaTest, QueryObjects) {
int main(int argc, char** argv) { int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
gcpp::InternalInit();
gcpp::GemmaEnv env(argc, argv); gcpp::GemmaEnv env(argc, argv);
gcpp::s_env = &env; gcpp::s_env = &env;