mirror of https://github.com/google/gemma.cpp.git
Added zero-initialization to att_out.
Re-enabled flash attention when HWY_NATIVE_DOT_BF16 is not available. PiperOrigin-RevId: 806284756
This commit is contained in:
parent
2695aab5d2
commit
c9b8479f7d
|
|
@ -256,6 +256,11 @@ void TileFlashAttention(
|
||||||
using DI = hn::ScalableTag<uint32_t>;
|
using DI = hn::ScalableTag<uint32_t>;
|
||||||
const DI di;
|
const DI di;
|
||||||
using VI = hn::Vec<DI>;
|
using VI = hn::Vec<DI>;
|
||||||
|
const int kVTileSize = hn::MaxLanes(df);
|
||||||
|
for (int i = 0; i < kVTileSize; ++i) {
|
||||||
|
hwy::ZeroBytes(att_out.Row(0) + out_offsets[i],
|
||||||
|
v.Cols() * sizeof(att_out.Row(0)[0]));
|
||||||
|
}
|
||||||
VI lasts = hn::LoadU(di, last_pos);
|
VI lasts = hn::LoadU(di, last_pos);
|
||||||
VF old_m = hn::Set(df, -std::numeric_limits<float>::max() / 2.0f);
|
VF old_m = hn::Set(df, -std::numeric_limits<float>::max() / 2.0f);
|
||||||
VF old_d = hn::Zero(df);
|
VF old_d = hn::Zero(df);
|
||||||
|
|
|
||||||
|
|
@ -75,7 +75,7 @@ void Attention(LayerAttentionType type, const size_t num_tokens,
|
||||||
if (type == LayerAttentionType::kGemma) {
|
if (type == LayerAttentionType::kGemma) {
|
||||||
// TODO: remove flag to enable FlashAttention.
|
// TODO: remove flag to enable FlashAttention.
|
||||||
GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch,
|
GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch,
|
||||||
env, kAttentionUseOld);
|
env, HWY_NATIVE_DOT_BF16 ? kAttentionUseOld : 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue