diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index e018ab8..b8c105a 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -453,8 +453,15 @@ static VF4 HWY_INLINE Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3, HWY_LANES_CONSTEXPR size_t kLanes = hn::Lanes(df); HWY_ALIGN T x_transposed[4 * kMaxLanes]; hn::StoreInterleaved4(x_0, x_1, x_2, x_3, df, x_transposed); + VF x01 = + reducer(hn::Load(df, x_transposed), hn::Load(df, x_transposed + kLanes)); + VF x23 = reducer(hn::Load(df, x_transposed + 2 * kLanes), + hn::Load(df, x_transposed + 3 * kLanes)); + VF x0123 = reducer(x01, x23); + hn::Store(x0123, df, x_transposed); + VF4 result = hn::Load(df4, x_transposed); - for (int i = 1; i < kLanes; ++i) { + for (int i = 1; i < kLanes / 4; ++i) { result = reducer(result, hn::Load(df4, x_transposed + i * 4)); } return result;