mirror of https://github.com/google/gemma.cpp.git
Pre-compress query activations to BF16 before FlashAttention.
PiperOrigin-RevId: 826524997
This commit is contained in:
parent
8a100c1e8d
commit
ab87807a4c
|
|
@ -38,7 +38,10 @@ namespace gcpp {
|
|||
|
||||
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference)
|
||||
: ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) {
|
||||
: initializer_value_(gcpp::InternalInit()),
|
||||
ctx_(threading),
|
||||
env_(ctx_),
|
||||
gemma_(loader, inference, ctx_) {
|
||||
const ModelConfig& config = gemma_.Config();
|
||||
// Only allocate one for starters because GenerateBatch might not be called.
|
||||
kv_caches_.push_back(KVCache(config, inference, ctx_.allocator));
|
||||
|
|
|
|||
|
|
@ -125,6 +125,8 @@ class GemmaEnv {
|
|||
MatMulEnv& MutableEnv() { return env_; }
|
||||
|
||||
private:
|
||||
// This is used to ensure that InternalInit is called before anything else.
|
||||
int initializer_value_ = 0;
|
||||
ThreadingContext ctx_;
|
||||
MatMulEnv env_;
|
||||
Gemma gemma_;
|
||||
|
|
|
|||
|
|
@ -153,5 +153,3 @@ int main(int argc, char** argv) {
|
|||
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -181,7 +181,6 @@ TEST_F(GemmaTest, CrossEntropySmall) {
|
|||
|
||||
int main(int argc, char** argv) {
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
gcpp::InternalInit();
|
||||
gcpp::GemmaTest::InitEnv(argc, argv);
|
||||
int ret = RUN_ALL_TESTS();
|
||||
gcpp::GemmaTest::DeleteEnv();
|
||||
|
|
|
|||
|
|
@ -54,6 +54,11 @@ struct AttentionActivations {
|
|||
? layer_config.heads * 3 * layer_config.qkv_dim
|
||||
: layer_config.heads * layer_config.qkv_dim,
|
||||
allocator)),
|
||||
q_bf(MatFactory("q_bf", batch_size,
|
||||
config.vocab_size == 0
|
||||
? layer_config.heads * 3 * layer_config.qkv_dim
|
||||
: layer_config.heads * layer_config.qkv_dim,
|
||||
allocator)),
|
||||
q_T(MatFactory("q_T", layer_config.qkv_dim,
|
||||
config.vocab_size == 0
|
||||
? batch_size * layer_config.heads * 3
|
||||
|
|
@ -88,12 +93,14 @@ struct AttentionActivations {
|
|||
// If we forget any MatMul outputs here, debug builds print a warning but
|
||||
// fill them in each MatMul call.
|
||||
q.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
q_bf.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
q_T.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
}
|
||||
|
||||
void SetBatchSize(size_t batch_size) {
|
||||
q.OverrideRows(batch_size);
|
||||
q_bf.OverrideRows(batch_size);
|
||||
// q_T rows are always qkv_dim!
|
||||
|
||||
pre_att_rms_out.OverrideRows(batch_size);
|
||||
|
|
@ -105,6 +112,7 @@ struct AttentionActivations {
|
|||
}
|
||||
|
||||
MatStorageT<float> q; // query
|
||||
MatStorageT<BF16> q_bf;
|
||||
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
|
||||
|
||||
MatStorageT<float> pre_att_rms_out;
|
||||
|
|
@ -130,6 +138,7 @@ struct AttentionActivationsPtrs {
|
|||
const AttentionActivations& activations)
|
||||
: AttentionActivationsPtrs(config, seq_len) {
|
||||
q = activations.q;
|
||||
q_bf = activations.q_bf;
|
||||
q_T = activations.q_T;
|
||||
pre_att_rms_out = activations.pre_att_rms_out;
|
||||
att = activations.att;
|
||||
|
|
@ -141,6 +150,7 @@ struct AttentionActivationsPtrs {
|
|||
|
||||
void SetBatchSize(size_t batch_size) {
|
||||
q.OverrideRows(batch_size);
|
||||
q_bf.OverrideRows(batch_size);
|
||||
// q_T rows are always qkv_dim!
|
||||
pre_att_rms_out.OverrideRows(batch_size);
|
||||
att.OverrideRows(batch_size);
|
||||
|
|
@ -151,6 +161,7 @@ struct AttentionActivationsPtrs {
|
|||
|
||||
const ModelConfig& config;
|
||||
MatPtrT<float> q;
|
||||
MatPtrT<BF16> q_bf;
|
||||
MatPtrT<BF16> q_T;
|
||||
MatPtrT<float> pre_att_rms_out;
|
||||
MatPtrT<float> att;
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max,
|
|||
|
||||
// Calculates the complete attention outputs for a single row of q.
|
||||
void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
||||
const float* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
|
||||
const BF16* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
|
||||
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||
const AttentionActivationsPtrs& activations,
|
||||
float* HWY_RESTRICT att_out, ThreadingContext& ctx,
|
||||
|
|
@ -162,17 +162,12 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
|||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention);
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
const size_t qkv_dim = k.Cols();
|
||||
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
|
||||
|
||||
CompressPerThread tls;
|
||||
const hn::ScalableTag<float> df;
|
||||
CompressTraits<BF16>::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim),
|
||||
0);
|
||||
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
|
||||
// TODO: Mixed-mode can be further improved for Turin: we can demote right
|
||||
// before we do the dot product instruction, rather than promote both to f32.
|
||||
// But some potential accuracy loss there, needs evaluation first.
|
||||
float m = Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
|
||||
float m = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
|
||||
if (float cap = activations.config.att_cap; cap > 0.0f) {
|
||||
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
|
||||
m = cap * std::tanh(m / cap);
|
||||
|
|
@ -182,8 +177,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
|||
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker);
|
||||
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
|
||||
float x =
|
||||
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
|
||||
float x = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
|
||||
SingleFlashAttentionStep(x, activations.config.att_cap, m, d,
|
||||
v.Row(pos_mod), v.Cols(), att_out);
|
||||
}
|
||||
|
|
@ -193,19 +187,15 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
|||
// the dot products of NF rows of Q for a single K timestep.
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
|
||||
const size_t k_pos, const MatPtrT<float>& q,
|
||||
const size_t k_pos, const MatPtrT<BF16>& q,
|
||||
const MatPtrT<KV_t>& k) {
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
const size_t qkv_dim = k.Cols();
|
||||
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
|
||||
CompressPerThread tls;
|
||||
|
||||
hn::TFromD<DF> results[hn::MaxLanes(df)];
|
||||
for (size_t i = 0; i < hn::Lanes(df); ++i) {
|
||||
CompressTraits<BF16>::Compress(df, q.Row(0) + q_offsets[i], qkv_dim, tls,
|
||||
MakeSpan(q_bf, qkv_dim), 0);
|
||||
results[i] =
|
||||
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
|
||||
results[i] = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[i], qkv_dim), 0,
|
||||
k.Row(k_pos), qkv_dim);
|
||||
}
|
||||
return hn::LoadU(df, results);
|
||||
}
|
||||
|
|
@ -290,7 +280,7 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
|
|||
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
|
||||
// max_last_pos].
|
||||
void TileFlashAttention(
|
||||
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
||||
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
||||
const StridedView<BF16>& qT, const MatPtrT<KV_t>& k, const size_t start_pos,
|
||||
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
|
||||
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||
|
|
@ -396,7 +386,7 @@ void TileFlashAttention(
|
|||
// This is the result of 4 rows of Q against NF K timesteps, with positions
|
||||
// given by k_offsets[0..NF].
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
|
||||
void QDotKTilex4(DF df, const BF16* HWY_RESTRICT q,
|
||||
const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<KV_t>& k,
|
||||
const int32_t* HWY_RESTRICT k_offsets, VF& sum0, VF& sum1,
|
||||
VF& sum2, VF& sum3) {
|
||||
|
|
@ -411,17 +401,13 @@ void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
|
|||
VI k_offsets_vec = hn::LoadU(di, k_offsets);
|
||||
for (size_t i = 0; i < k.Cols(); ++i) {
|
||||
VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec);
|
||||
VF q_0 = hn::Set(df, hwy::ConvertScalarTo<float>(
|
||||
hwy::ConvertScalarTo<BF16>(q[q_offsets[0] + i])));
|
||||
VF q_0 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[0] + i]));
|
||||
sum0 = hn::MulAdd(q_0, k_vec, sum0);
|
||||
VF q_1 = hn::Set(df, hwy::ConvertScalarTo<float>(
|
||||
hwy::ConvertScalarTo<BF16>(q[q_offsets[1] + i])));
|
||||
VF q_1 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[1] + i]));
|
||||
sum1 = hn::MulAdd(q_1, k_vec, sum1);
|
||||
VF q_2 = hn::Set(df, hwy::ConvertScalarTo<float>(
|
||||
hwy::ConvertScalarTo<BF16>(q[q_offsets[2] + i])));
|
||||
VF q_2 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[2] + i]));
|
||||
sum2 = hn::MulAdd(q_2, k_vec, sum2);
|
||||
VF q_3 = hn::Set(df, hwy::ConvertScalarTo<float>(
|
||||
hwy::ConvertScalarTo<BF16>(q[q_offsets[3] + i])));
|
||||
VF q_3 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[3] + i]));
|
||||
sum3 = hn::MulAdd(q_3, k_vec, sum3);
|
||||
}
|
||||
}
|
||||
|
|
@ -446,7 +432,7 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
|
|||
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
|
||||
// max_last_pos].
|
||||
Tile4FlashState TileFlashAttention4(
|
||||
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
||||
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
||||
const MatPtrT<KV_t>& k, const size_t start_pos,
|
||||
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
|
||||
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||
|
|
@ -500,18 +486,13 @@ Tile4FlashState TileFlashAttention4(
|
|||
}
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
const size_t qkv_dim = k.Cols();
|
||||
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
|
||||
CompressPerThread tls;
|
||||
const hn::ScalableTag<float> df_compress;
|
||||
|
||||
while (position <= max_last_pos) {
|
||||
size_t k_pos = activations.div_seq_len.Remainder(position);
|
||||
if (position <= last_pos[0]) {
|
||||
// Past the last position, x0 doesn't count.
|
||||
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[0],
|
||||
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
|
||||
float x0 =
|
||||
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
|
||||
float x0 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[0], qkv_dim), 0,
|
||||
k.Row(k_pos), qkv_dim);
|
||||
SingleFlashAttentionStep(x0, activations.config.att_cap,
|
||||
state.row_states[0].max, state.row_states[0].d,
|
||||
v.Row(k_pos), v.Cols(),
|
||||
|
|
@ -519,10 +500,8 @@ Tile4FlashState TileFlashAttention4(
|
|||
}
|
||||
if (position <= last_pos[1]) {
|
||||
// Past the last position, x1 doesn't count.
|
||||
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[1],
|
||||
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
|
||||
float x1 =
|
||||
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
|
||||
float x1 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[1], qkv_dim), 0,
|
||||
k.Row(k_pos), qkv_dim);
|
||||
SingleFlashAttentionStep(x1, activations.config.att_cap,
|
||||
state.row_states[1].max, state.row_states[1].d,
|
||||
v.Row(k_pos), v.Cols(),
|
||||
|
|
@ -530,10 +509,8 @@ Tile4FlashState TileFlashAttention4(
|
|||
}
|
||||
if (position <= last_pos[2]) {
|
||||
// Past the last position, x2 doesn't count.
|
||||
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[2],
|
||||
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
|
||||
float x2 =
|
||||
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
|
||||
float x2 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[2], qkv_dim), 0,
|
||||
k.Row(k_pos), qkv_dim);
|
||||
SingleFlashAttentionStep(x2, activations.config.att_cap,
|
||||
state.row_states[2].max, state.row_states[2].d,
|
||||
v.Row(k_pos), v.Cols(),
|
||||
|
|
@ -541,10 +518,8 @@ Tile4FlashState TileFlashAttention4(
|
|||
}
|
||||
if (position <= last_pos[3]) {
|
||||
// Past the last position, x3 doesn't count.
|
||||
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[3],
|
||||
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
|
||||
float x3 =
|
||||
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
|
||||
float x3 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[3], qkv_dim), 0,
|
||||
k.Row(k_pos), qkv_dim);
|
||||
SingleFlashAttentionStep(x3, activations.config.att_cap,
|
||||
state.row_states[3].max, state.row_states[3].d,
|
||||
v.Row(k_pos), v.Cols(),
|
||||
|
|
@ -642,6 +617,17 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
|||
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q,
|
||||
query_norm_scale, layer_idx, activations, ctx);
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
// Compress q to q_bf.
|
||||
ParallelFor(
|
||||
ParallelismStrategy::kWithinCluster, activations.q.Rows(), ctx,
|
||||
/*cluster_idx=*/0, Callers::kFlashAttention,
|
||||
[&](size_t row, size_t worker) {
|
||||
CompressPerThread tls;
|
||||
const hn::ScalableTag<float> df;
|
||||
CompressTraits<BF16>::Compress(
|
||||
df, activations.q.Row(row), activations.q.Cols(), tls,
|
||||
MakeSpan(activations.q_bf.Row(row), activations.q_bf.Cols()), 0);
|
||||
});
|
||||
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
|
||||
|
|
@ -736,8 +722,8 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
|||
last_pos[offset] = last;
|
||||
min_last_pos = HWY_MIN(min_last_pos, last);
|
||||
max_last_pos = HWY_MAX(max_last_pos, last);
|
||||
q_offsets[offset] =
|
||||
activations.q.Row(tq_idx) + head * qkv_dim - activations.q.Row(0);
|
||||
q_offsets[offset] = activations.q_bf.Row(tq_idx) + head * qkv_dim -
|
||||
activations.q_bf.Row(0);
|
||||
out_offsets[offset] = activations.att_out.Row(tq_idx) + head * qkv_dim -
|
||||
activations.att_out.Row(0);
|
||||
const size_t kv_index = head / kHeadGroups;
|
||||
|
|
@ -776,12 +762,12 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
|||
// kNFx8HTileSize. In this case, qT is never used. Some tasks might
|
||||
// use qT and some might not, which is why the more general condition
|
||||
// is used above to catch all cases where qT will be used.
|
||||
TileFlashAttention(activations.q, q_offsets, qT, k,
|
||||
TileFlashAttention(activations.q_bf, q_offsets, qT, k,
|
||||
start_positions[offset], last_pos, min_last_pos,
|
||||
max_last_pos, v, layer_idx, activations,
|
||||
activations.att_out, out_offsets, ctx, worker);
|
||||
} else if (kVTileSize == 4) {
|
||||
TileFlashAttention4(activations.q, q_offsets, k,
|
||||
TileFlashAttention4(activations.q_bf, q_offsets, k,
|
||||
start_positions[offset], last_pos, min_last_pos,
|
||||
max_last_pos, v, layer_idx, activations,
|
||||
activations.att_out, out_offsets, ctx, worker);
|
||||
|
|
@ -791,7 +777,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
|||
break;
|
||||
} else {
|
||||
SingleFlashAttention(start_positions[offset], last_pos[offset],
|
||||
activations.q.Row(0) + q_offsets[offset], k, v,
|
||||
activations.q_bf.Row(0) + q_offsets[offset], k, v,
|
||||
layer_idx, activations,
|
||||
activations.att_out.Row(0) + out_offsets[offset],
|
||||
ctx, worker);
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ namespace gcpp {
|
|||
ThreadingContext& ctx, size_t worker); \
|
||||
\
|
||||
Tile4FlashState TileFlashAttention4( \
|
||||
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets, \
|
||||
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \
|
||||
const MatPtrT<KV_t>& k, size_t start_pos, \
|
||||
const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
|
||||
size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \
|
||||
|
|
|
|||
4
io/io.cc
4
io/io.cc
|
|
@ -236,7 +236,9 @@ bool IOBatch::Add(void* mem, size_t bytes) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void InternalInit() {
|
||||
int InternalInit() {
|
||||
// currently unused, except for init list ordering in GemmaEnv.
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint64_t IOBatch::Read(const File& file) const {
|
||||
|
|
|
|||
2
io/io.h
2
io/io.h
|
|
@ -150,7 +150,7 @@ std::string ReadFileToString(const Path& path);
|
|||
|
||||
// No-op in open-source. Must be called at the beginning of a binary, before
|
||||
// any I/O or flag usage.
|
||||
void InternalInit();
|
||||
int InternalInit();
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
|
|
@ -72,7 +72,6 @@ TEST_F(PaliGemmaTest, QueryObjects) {
|
|||
|
||||
int main(int argc, char** argv) {
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
gcpp::InternalInit();
|
||||
|
||||
gcpp::GemmaEnv env(argc, argv);
|
||||
gcpp::s_env = &env;
|
||||
|
|
|
|||
Loading…
Reference in New Issue