mirror of https://github.com/google/gemma.cpp.git
Pre-compress query activations to BF16 before FlashAttention.
PiperOrigin-RevId: 826524997
This commit is contained in:
parent
8a100c1e8d
commit
ab87807a4c
|
|
@ -38,7 +38,10 @@ namespace gcpp {
|
||||||
|
|
||||||
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
const InferenceArgs& inference)
|
const InferenceArgs& inference)
|
||||||
: ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) {
|
: initializer_value_(gcpp::InternalInit()),
|
||||||
|
ctx_(threading),
|
||||||
|
env_(ctx_),
|
||||||
|
gemma_(loader, inference, ctx_) {
|
||||||
const ModelConfig& config = gemma_.Config();
|
const ModelConfig& config = gemma_.Config();
|
||||||
// Only allocate one for starters because GenerateBatch might not be called.
|
// Only allocate one for starters because GenerateBatch might not be called.
|
||||||
kv_caches_.push_back(KVCache(config, inference, ctx_.allocator));
|
kv_caches_.push_back(KVCache(config, inference, ctx_.allocator));
|
||||||
|
|
|
||||||
|
|
@ -125,6 +125,8 @@ class GemmaEnv {
|
||||||
MatMulEnv& MutableEnv() { return env_; }
|
MatMulEnv& MutableEnv() { return env_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
// This is used to ensure that InternalInit is called before anything else.
|
||||||
|
int initializer_value_ = 0;
|
||||||
ThreadingContext ctx_;
|
ThreadingContext ctx_;
|
||||||
MatMulEnv env_;
|
MatMulEnv env_;
|
||||||
Gemma gemma_;
|
Gemma gemma_;
|
||||||
|
|
|
||||||
|
|
@ -153,5 +153,3 @@ int main(int argc, char** argv) {
|
||||||
|
|
||||||
return RUN_ALL_TESTS();
|
return RUN_ALL_TESTS();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -181,7 +181,6 @@ TEST_F(GemmaTest, CrossEntropySmall) {
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
testing::InitGoogleTest(&argc, argv);
|
testing::InitGoogleTest(&argc, argv);
|
||||||
gcpp::InternalInit();
|
|
||||||
gcpp::GemmaTest::InitEnv(argc, argv);
|
gcpp::GemmaTest::InitEnv(argc, argv);
|
||||||
int ret = RUN_ALL_TESTS();
|
int ret = RUN_ALL_TESTS();
|
||||||
gcpp::GemmaTest::DeleteEnv();
|
gcpp::GemmaTest::DeleteEnv();
|
||||||
|
|
|
||||||
|
|
@ -54,6 +54,11 @@ struct AttentionActivations {
|
||||||
? layer_config.heads * 3 * layer_config.qkv_dim
|
? layer_config.heads * 3 * layer_config.qkv_dim
|
||||||
: layer_config.heads * layer_config.qkv_dim,
|
: layer_config.heads * layer_config.qkv_dim,
|
||||||
allocator)),
|
allocator)),
|
||||||
|
q_bf(MatFactory("q_bf", batch_size,
|
||||||
|
config.vocab_size == 0
|
||||||
|
? layer_config.heads * 3 * layer_config.qkv_dim
|
||||||
|
: layer_config.heads * layer_config.qkv_dim,
|
||||||
|
allocator)),
|
||||||
q_T(MatFactory("q_T", layer_config.qkv_dim,
|
q_T(MatFactory("q_T", layer_config.qkv_dim,
|
||||||
config.vocab_size == 0
|
config.vocab_size == 0
|
||||||
? batch_size * layer_config.heads * 3
|
? batch_size * layer_config.heads * 3
|
||||||
|
|
@ -88,12 +93,14 @@ struct AttentionActivations {
|
||||||
// If we forget any MatMul outputs here, debug builds print a warning but
|
// If we forget any MatMul outputs here, debug builds print a warning but
|
||||||
// fill them in each MatMul call.
|
// fill them in each MatMul call.
|
||||||
q.AllocateAndAttachRowPtrs(row_ptrs);
|
q.AllocateAndAttachRowPtrs(row_ptrs);
|
||||||
|
q_bf.AllocateAndAttachRowPtrs(row_ptrs);
|
||||||
q_T.AllocateAndAttachRowPtrs(row_ptrs);
|
q_T.AllocateAndAttachRowPtrs(row_ptrs);
|
||||||
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
|
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetBatchSize(size_t batch_size) {
|
void SetBatchSize(size_t batch_size) {
|
||||||
q.OverrideRows(batch_size);
|
q.OverrideRows(batch_size);
|
||||||
|
q_bf.OverrideRows(batch_size);
|
||||||
// q_T rows are always qkv_dim!
|
// q_T rows are always qkv_dim!
|
||||||
|
|
||||||
pre_att_rms_out.OverrideRows(batch_size);
|
pre_att_rms_out.OverrideRows(batch_size);
|
||||||
|
|
@ -105,6 +112,7 @@ struct AttentionActivations {
|
||||||
}
|
}
|
||||||
|
|
||||||
MatStorageT<float> q; // query
|
MatStorageT<float> q; // query
|
||||||
|
MatStorageT<BF16> q_bf;
|
||||||
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
|
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
|
||||||
|
|
||||||
MatStorageT<float> pre_att_rms_out;
|
MatStorageT<float> pre_att_rms_out;
|
||||||
|
|
@ -130,6 +138,7 @@ struct AttentionActivationsPtrs {
|
||||||
const AttentionActivations& activations)
|
const AttentionActivations& activations)
|
||||||
: AttentionActivationsPtrs(config, seq_len) {
|
: AttentionActivationsPtrs(config, seq_len) {
|
||||||
q = activations.q;
|
q = activations.q;
|
||||||
|
q_bf = activations.q_bf;
|
||||||
q_T = activations.q_T;
|
q_T = activations.q_T;
|
||||||
pre_att_rms_out = activations.pre_att_rms_out;
|
pre_att_rms_out = activations.pre_att_rms_out;
|
||||||
att = activations.att;
|
att = activations.att;
|
||||||
|
|
@ -141,6 +150,7 @@ struct AttentionActivationsPtrs {
|
||||||
|
|
||||||
void SetBatchSize(size_t batch_size) {
|
void SetBatchSize(size_t batch_size) {
|
||||||
q.OverrideRows(batch_size);
|
q.OverrideRows(batch_size);
|
||||||
|
q_bf.OverrideRows(batch_size);
|
||||||
// q_T rows are always qkv_dim!
|
// q_T rows are always qkv_dim!
|
||||||
pre_att_rms_out.OverrideRows(batch_size);
|
pre_att_rms_out.OverrideRows(batch_size);
|
||||||
att.OverrideRows(batch_size);
|
att.OverrideRows(batch_size);
|
||||||
|
|
@ -151,6 +161,7 @@ struct AttentionActivationsPtrs {
|
||||||
|
|
||||||
const ModelConfig& config;
|
const ModelConfig& config;
|
||||||
MatPtrT<float> q;
|
MatPtrT<float> q;
|
||||||
|
MatPtrT<BF16> q_bf;
|
||||||
MatPtrT<BF16> q_T;
|
MatPtrT<BF16> q_T;
|
||||||
MatPtrT<float> pre_att_rms_out;
|
MatPtrT<float> pre_att_rms_out;
|
||||||
MatPtrT<float> att;
|
MatPtrT<float> att;
|
||||||
|
|
|
||||||
|
|
@ -154,7 +154,7 @@ void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max,
|
||||||
|
|
||||||
// Calculates the complete attention outputs for a single row of q.
|
// Calculates the complete attention outputs for a single row of q.
|
||||||
void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
||||||
const float* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
|
const BF16* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
|
||||||
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||||
const AttentionActivationsPtrs& activations,
|
const AttentionActivationsPtrs& activations,
|
||||||
float* HWY_RESTRICT att_out, ThreadingContext& ctx,
|
float* HWY_RESTRICT att_out, ThreadingContext& ctx,
|
||||||
|
|
@ -162,17 +162,12 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
||||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention);
|
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention);
|
||||||
const hn::ScalableTag<BF16> dbf;
|
const hn::ScalableTag<BF16> dbf;
|
||||||
const size_t qkv_dim = k.Cols();
|
const size_t qkv_dim = k.Cols();
|
||||||
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
|
|
||||||
|
|
||||||
CompressPerThread tls;
|
|
||||||
const hn::ScalableTag<float> df;
|
|
||||||
CompressTraits<BF16>::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim),
|
|
||||||
0);
|
|
||||||
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
|
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
|
||||||
// TODO: Mixed-mode can be further improved for Turin: we can demote right
|
// TODO: Mixed-mode can be further improved for Turin: we can demote right
|
||||||
// before we do the dot product instruction, rather than promote both to f32.
|
// before we do the dot product instruction, rather than promote both to f32.
|
||||||
// But some potential accuracy loss there, needs evaluation first.
|
// But some potential accuracy loss there, needs evaluation first.
|
||||||
float m = Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
|
float m = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
|
||||||
if (float cap = activations.config.att_cap; cap > 0.0f) {
|
if (float cap = activations.config.att_cap; cap > 0.0f) {
|
||||||
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
|
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
|
||||||
m = cap * std::tanh(m / cap);
|
m = cap * std::tanh(m / cap);
|
||||||
|
|
@ -182,8 +177,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
||||||
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker);
|
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker);
|
||||||
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||||
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
|
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
|
||||||
float x =
|
float x = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
|
||||||
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
|
|
||||||
SingleFlashAttentionStep(x, activations.config.att_cap, m, d,
|
SingleFlashAttentionStep(x, activations.config.att_cap, m, d,
|
||||||
v.Row(pos_mod), v.Cols(), att_out);
|
v.Row(pos_mod), v.Cols(), att_out);
|
||||||
}
|
}
|
||||||
|
|
@ -193,19 +187,15 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
||||||
// the dot products of NF rows of Q for a single K timestep.
|
// the dot products of NF rows of Q for a single K timestep.
|
||||||
template <class DF, class VF = hn::Vec<DF>>
|
template <class DF, class VF = hn::Vec<DF>>
|
||||||
VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
|
VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
|
||||||
const size_t k_pos, const MatPtrT<float>& q,
|
const size_t k_pos, const MatPtrT<BF16>& q,
|
||||||
const MatPtrT<KV_t>& k) {
|
const MatPtrT<KV_t>& k) {
|
||||||
const hn::ScalableTag<BF16> dbf;
|
const hn::ScalableTag<BF16> dbf;
|
||||||
const size_t qkv_dim = k.Cols();
|
const size_t qkv_dim = k.Cols();
|
||||||
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
|
|
||||||
CompressPerThread tls;
|
|
||||||
|
|
||||||
hn::TFromD<DF> results[hn::MaxLanes(df)];
|
hn::TFromD<DF> results[hn::MaxLanes(df)];
|
||||||
for (size_t i = 0; i < hn::Lanes(df); ++i) {
|
for (size_t i = 0; i < hn::Lanes(df); ++i) {
|
||||||
CompressTraits<BF16>::Compress(df, q.Row(0) + q_offsets[i], qkv_dim, tls,
|
results[i] = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[i], qkv_dim), 0,
|
||||||
MakeSpan(q_bf, qkv_dim), 0);
|
k.Row(k_pos), qkv_dim);
|
||||||
results[i] =
|
|
||||||
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
|
|
||||||
}
|
}
|
||||||
return hn::LoadU(df, results);
|
return hn::LoadU(df, results);
|
||||||
}
|
}
|
||||||
|
|
@ -290,7 +280,7 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
|
||||||
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
|
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
|
||||||
// max_last_pos].
|
// max_last_pos].
|
||||||
void TileFlashAttention(
|
void TileFlashAttention(
|
||||||
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
||||||
const StridedView<BF16>& qT, const MatPtrT<KV_t>& k, const size_t start_pos,
|
const StridedView<BF16>& qT, const MatPtrT<KV_t>& k, const size_t start_pos,
|
||||||
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
|
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
|
||||||
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
|
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||||
|
|
@ -396,7 +386,7 @@ void TileFlashAttention(
|
||||||
// This is the result of 4 rows of Q against NF K timesteps, with positions
|
// This is the result of 4 rows of Q against NF K timesteps, with positions
|
||||||
// given by k_offsets[0..NF].
|
// given by k_offsets[0..NF].
|
||||||
template <class DF, class VF = hn::Vec<DF>>
|
template <class DF, class VF = hn::Vec<DF>>
|
||||||
void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
|
void QDotKTilex4(DF df, const BF16* HWY_RESTRICT q,
|
||||||
const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<KV_t>& k,
|
const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<KV_t>& k,
|
||||||
const int32_t* HWY_RESTRICT k_offsets, VF& sum0, VF& sum1,
|
const int32_t* HWY_RESTRICT k_offsets, VF& sum0, VF& sum1,
|
||||||
VF& sum2, VF& sum3) {
|
VF& sum2, VF& sum3) {
|
||||||
|
|
@ -411,17 +401,13 @@ void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
|
||||||
VI k_offsets_vec = hn::LoadU(di, k_offsets);
|
VI k_offsets_vec = hn::LoadU(di, k_offsets);
|
||||||
for (size_t i = 0; i < k.Cols(); ++i) {
|
for (size_t i = 0; i < k.Cols(); ++i) {
|
||||||
VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec);
|
VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec);
|
||||||
VF q_0 = hn::Set(df, hwy::ConvertScalarTo<float>(
|
VF q_0 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[0] + i]));
|
||||||
hwy::ConvertScalarTo<BF16>(q[q_offsets[0] + i])));
|
|
||||||
sum0 = hn::MulAdd(q_0, k_vec, sum0);
|
sum0 = hn::MulAdd(q_0, k_vec, sum0);
|
||||||
VF q_1 = hn::Set(df, hwy::ConvertScalarTo<float>(
|
VF q_1 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[1] + i]));
|
||||||
hwy::ConvertScalarTo<BF16>(q[q_offsets[1] + i])));
|
|
||||||
sum1 = hn::MulAdd(q_1, k_vec, sum1);
|
sum1 = hn::MulAdd(q_1, k_vec, sum1);
|
||||||
VF q_2 = hn::Set(df, hwy::ConvertScalarTo<float>(
|
VF q_2 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[2] + i]));
|
||||||
hwy::ConvertScalarTo<BF16>(q[q_offsets[2] + i])));
|
|
||||||
sum2 = hn::MulAdd(q_2, k_vec, sum2);
|
sum2 = hn::MulAdd(q_2, k_vec, sum2);
|
||||||
VF q_3 = hn::Set(df, hwy::ConvertScalarTo<float>(
|
VF q_3 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[3] + i]));
|
||||||
hwy::ConvertScalarTo<BF16>(q[q_offsets[3] + i])));
|
|
||||||
sum3 = hn::MulAdd(q_3, k_vec, sum3);
|
sum3 = hn::MulAdd(q_3, k_vec, sum3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -446,7 +432,7 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
|
||||||
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
|
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
|
||||||
// max_last_pos].
|
// max_last_pos].
|
||||||
Tile4FlashState TileFlashAttention4(
|
Tile4FlashState TileFlashAttention4(
|
||||||
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
||||||
const MatPtrT<KV_t>& k, const size_t start_pos,
|
const MatPtrT<KV_t>& k, const size_t start_pos,
|
||||||
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
|
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
|
||||||
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
|
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||||
|
|
@ -500,18 +486,13 @@ Tile4FlashState TileFlashAttention4(
|
||||||
}
|
}
|
||||||
const hn::ScalableTag<BF16> dbf;
|
const hn::ScalableTag<BF16> dbf;
|
||||||
const size_t qkv_dim = k.Cols();
|
const size_t qkv_dim = k.Cols();
|
||||||
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
|
|
||||||
CompressPerThread tls;
|
|
||||||
const hn::ScalableTag<float> df_compress;
|
|
||||||
|
|
||||||
while (position <= max_last_pos) {
|
while (position <= max_last_pos) {
|
||||||
size_t k_pos = activations.div_seq_len.Remainder(position);
|
size_t k_pos = activations.div_seq_len.Remainder(position);
|
||||||
if (position <= last_pos[0]) {
|
if (position <= last_pos[0]) {
|
||||||
// Past the last position, x0 doesn't count.
|
// Past the last position, x0 doesn't count.
|
||||||
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[0],
|
float x0 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[0], qkv_dim), 0,
|
||||||
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
|
k.Row(k_pos), qkv_dim);
|
||||||
float x0 =
|
|
||||||
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
|
|
||||||
SingleFlashAttentionStep(x0, activations.config.att_cap,
|
SingleFlashAttentionStep(x0, activations.config.att_cap,
|
||||||
state.row_states[0].max, state.row_states[0].d,
|
state.row_states[0].max, state.row_states[0].d,
|
||||||
v.Row(k_pos), v.Cols(),
|
v.Row(k_pos), v.Cols(),
|
||||||
|
|
@ -519,10 +500,8 @@ Tile4FlashState TileFlashAttention4(
|
||||||
}
|
}
|
||||||
if (position <= last_pos[1]) {
|
if (position <= last_pos[1]) {
|
||||||
// Past the last position, x1 doesn't count.
|
// Past the last position, x1 doesn't count.
|
||||||
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[1],
|
float x1 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[1], qkv_dim), 0,
|
||||||
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
|
k.Row(k_pos), qkv_dim);
|
||||||
float x1 =
|
|
||||||
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
|
|
||||||
SingleFlashAttentionStep(x1, activations.config.att_cap,
|
SingleFlashAttentionStep(x1, activations.config.att_cap,
|
||||||
state.row_states[1].max, state.row_states[1].d,
|
state.row_states[1].max, state.row_states[1].d,
|
||||||
v.Row(k_pos), v.Cols(),
|
v.Row(k_pos), v.Cols(),
|
||||||
|
|
@ -530,10 +509,8 @@ Tile4FlashState TileFlashAttention4(
|
||||||
}
|
}
|
||||||
if (position <= last_pos[2]) {
|
if (position <= last_pos[2]) {
|
||||||
// Past the last position, x2 doesn't count.
|
// Past the last position, x2 doesn't count.
|
||||||
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[2],
|
float x2 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[2], qkv_dim), 0,
|
||||||
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
|
k.Row(k_pos), qkv_dim);
|
||||||
float x2 =
|
|
||||||
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
|
|
||||||
SingleFlashAttentionStep(x2, activations.config.att_cap,
|
SingleFlashAttentionStep(x2, activations.config.att_cap,
|
||||||
state.row_states[2].max, state.row_states[2].d,
|
state.row_states[2].max, state.row_states[2].d,
|
||||||
v.Row(k_pos), v.Cols(),
|
v.Row(k_pos), v.Cols(),
|
||||||
|
|
@ -541,10 +518,8 @@ Tile4FlashState TileFlashAttention4(
|
||||||
}
|
}
|
||||||
if (position <= last_pos[3]) {
|
if (position <= last_pos[3]) {
|
||||||
// Past the last position, x3 doesn't count.
|
// Past the last position, x3 doesn't count.
|
||||||
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[3],
|
float x3 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[3], qkv_dim), 0,
|
||||||
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
|
k.Row(k_pos), qkv_dim);
|
||||||
float x3 =
|
|
||||||
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
|
|
||||||
SingleFlashAttentionStep(x3, activations.config.att_cap,
|
SingleFlashAttentionStep(x3, activations.config.att_cap,
|
||||||
state.row_states[3].max, state.row_states[3].d,
|
state.row_states[3].max, state.row_states[3].d,
|
||||||
v.Row(k_pos), v.Cols(),
|
v.Row(k_pos), v.Cols(),
|
||||||
|
|
@ -642,6 +617,17 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
||||||
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q,
|
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q,
|
||||||
query_norm_scale, layer_idx, activations, ctx);
|
query_norm_scale, layer_idx, activations, ctx);
|
||||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||||
|
// Compress q to q_bf.
|
||||||
|
ParallelFor(
|
||||||
|
ParallelismStrategy::kWithinCluster, activations.q.Rows(), ctx,
|
||||||
|
/*cluster_idx=*/0, Callers::kFlashAttention,
|
||||||
|
[&](size_t row, size_t worker) {
|
||||||
|
CompressPerThread tls;
|
||||||
|
const hn::ScalableTag<float> df;
|
||||||
|
CompressTraits<BF16>::Compress(
|
||||||
|
df, activations.q.Row(row), activations.q.Cols(), tls,
|
||||||
|
MakeSpan(activations.q_bf.Row(row), activations.q_bf.Cols()), 0);
|
||||||
|
});
|
||||||
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
||||||
const size_t qkv_dim = layer_config.qkv_dim;
|
const size_t qkv_dim = layer_config.qkv_dim;
|
||||||
|
|
||||||
|
|
@ -736,8 +722,8 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
||||||
last_pos[offset] = last;
|
last_pos[offset] = last;
|
||||||
min_last_pos = HWY_MIN(min_last_pos, last);
|
min_last_pos = HWY_MIN(min_last_pos, last);
|
||||||
max_last_pos = HWY_MAX(max_last_pos, last);
|
max_last_pos = HWY_MAX(max_last_pos, last);
|
||||||
q_offsets[offset] =
|
q_offsets[offset] = activations.q_bf.Row(tq_idx) + head * qkv_dim -
|
||||||
activations.q.Row(tq_idx) + head * qkv_dim - activations.q.Row(0);
|
activations.q_bf.Row(0);
|
||||||
out_offsets[offset] = activations.att_out.Row(tq_idx) + head * qkv_dim -
|
out_offsets[offset] = activations.att_out.Row(tq_idx) + head * qkv_dim -
|
||||||
activations.att_out.Row(0);
|
activations.att_out.Row(0);
|
||||||
const size_t kv_index = head / kHeadGroups;
|
const size_t kv_index = head / kHeadGroups;
|
||||||
|
|
@ -776,12 +762,12 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
||||||
// kNFx8HTileSize. In this case, qT is never used. Some tasks might
|
// kNFx8HTileSize. In this case, qT is never used. Some tasks might
|
||||||
// use qT and some might not, which is why the more general condition
|
// use qT and some might not, which is why the more general condition
|
||||||
// is used above to catch all cases where qT will be used.
|
// is used above to catch all cases where qT will be used.
|
||||||
TileFlashAttention(activations.q, q_offsets, qT, k,
|
TileFlashAttention(activations.q_bf, q_offsets, qT, k,
|
||||||
start_positions[offset], last_pos, min_last_pos,
|
start_positions[offset], last_pos, min_last_pos,
|
||||||
max_last_pos, v, layer_idx, activations,
|
max_last_pos, v, layer_idx, activations,
|
||||||
activations.att_out, out_offsets, ctx, worker);
|
activations.att_out, out_offsets, ctx, worker);
|
||||||
} else if (kVTileSize == 4) {
|
} else if (kVTileSize == 4) {
|
||||||
TileFlashAttention4(activations.q, q_offsets, k,
|
TileFlashAttention4(activations.q_bf, q_offsets, k,
|
||||||
start_positions[offset], last_pos, min_last_pos,
|
start_positions[offset], last_pos, min_last_pos,
|
||||||
max_last_pos, v, layer_idx, activations,
|
max_last_pos, v, layer_idx, activations,
|
||||||
activations.att_out, out_offsets, ctx, worker);
|
activations.att_out, out_offsets, ctx, worker);
|
||||||
|
|
@ -791,7 +777,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
||||||
break;
|
break;
|
||||||
} else {
|
} else {
|
||||||
SingleFlashAttention(start_positions[offset], last_pos[offset],
|
SingleFlashAttention(start_positions[offset], last_pos[offset],
|
||||||
activations.q.Row(0) + q_offsets[offset], k, v,
|
activations.q_bf.Row(0) + q_offsets[offset], k, v,
|
||||||
layer_idx, activations,
|
layer_idx, activations,
|
||||||
activations.att_out.Row(0) + out_offsets[offset],
|
activations.att_out.Row(0) + out_offsets[offset],
|
||||||
ctx, worker);
|
ctx, worker);
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ namespace gcpp {
|
||||||
ThreadingContext& ctx, size_t worker); \
|
ThreadingContext& ctx, size_t worker); \
|
||||||
\
|
\
|
||||||
Tile4FlashState TileFlashAttention4( \
|
Tile4FlashState TileFlashAttention4( \
|
||||||
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets, \
|
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \
|
||||||
const MatPtrT<KV_t>& k, size_t start_pos, \
|
const MatPtrT<KV_t>& k, size_t start_pos, \
|
||||||
const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
|
const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
|
||||||
size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \
|
size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \
|
||||||
|
|
|
||||||
4
io/io.cc
4
io/io.cc
|
|
@ -236,7 +236,9 @@ bool IOBatch::Add(void* mem, size_t bytes) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void InternalInit() {
|
int InternalInit() {
|
||||||
|
// currently unused, except for init list ordering in GemmaEnv.
|
||||||
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t IOBatch::Read(const File& file) const {
|
uint64_t IOBatch::Read(const File& file) const {
|
||||||
|
|
|
||||||
2
io/io.h
2
io/io.h
|
|
@ -150,7 +150,7 @@ std::string ReadFileToString(const Path& path);
|
||||||
|
|
||||||
// No-op in open-source. Must be called at the beginning of a binary, before
|
// No-op in open-source. Must be called at the beginning of a binary, before
|
||||||
// any I/O or flag usage.
|
// any I/O or flag usage.
|
||||||
void InternalInit();
|
int InternalInit();
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,6 @@ TEST_F(PaliGemmaTest, QueryObjects) {
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
testing::InitGoogleTest(&argc, argv);
|
testing::InitGoogleTest(&argc, argv);
|
||||||
gcpp::InternalInit();
|
|
||||||
|
|
||||||
gcpp::GemmaEnv env(argc, argv);
|
gcpp::GemmaEnv env(argc, argv);
|
||||||
gcpp::s_env = &env;
|
gcpp::s_env = &env;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue