From 44dfd69b9b94848e0d5627383bdd0af59aa0033a Mon Sep 17 00:00:00 2001 From: Krzysztof Rymski Date: Mon, 15 Dec 2025 07:14:04 -0800 Subject: [PATCH] Internal changes PiperOrigin-RevId: 844759322 --- BUILD.bazel | 1 + gemma/flash_attention.cc | 148 ++++++++++++++++++++++++++++++++++ gemma/flash_attention_test.cc | 3 + util/mat.h | 15 ++++ 4 files changed, 167 insertions(+) diff --git a/BUILD.bazel b/BUILD.bazel index 1dfaa8f..606a2fb 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -142,6 +142,7 @@ cc_test( ":mat", ":matmul", ":query", + ":test_util", ":threading_context", ":weights", "@googletest//:gtest_main", # buildcleaner: keep diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 7432f7b..671efb4 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -20,7 +20,9 @@ #include #include #include +#include #include +#include #include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "gemma/flash_structs.h" @@ -438,6 +440,152 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max, return scale; } +// Reduces each of x and stores in following lanes of max (tested with float32) +template , + class DF4 = hn::CappedTag, class VF4 = hn::Vec, + class VF = hn::Vec, typename F> +static VF4 HWY_INLINE Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3, + F reducer) { + const DF4 df4; + constexpr size_t kMaxLanes = hn::MaxLanes(df); + HWY_LANES_CONSTEXPR size_t kLanes = hn::Lanes(df); + HWY_ALIGN T x_transposed[4 * kMaxLanes]; + hn::StoreInterleaved4(x_0, x_1, x_2, x_3, df, x_transposed); + VF4 result = hn::Load(df4, x_transposed); + for (int i = 1; i < kLanes; ++i) { + result = reducer(result, hn::Load(df4, x_transposed + i * 4)); + } + return result; +} + +// Handles Up to 4 Q rows by NF*2 timesteps of flash attention. +template > +static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap( + DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1, + VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1, + float* HWY_RESTRICT old_max, float* HWY_RESTRICT old_d, + float* HWY_RESTRICT scales) { + using DF4 = hn::CappedTag; + const DF4 df4; + using VF4 = hn::Vec; + static_assert(kNumQueries >= 1 && kNumQueries <= 4); + VF4 new_max = hn::Set(df4, -std::numeric_limits::max() / 2.0f); + VF max_0, max_1, max_2, max_3 = hn::Zero(df); + max_0 = hn::Max(x_0_p0, x_0_p1); + if constexpr (kNumQueries >= 2) { + max_1 = hn::Max(x_1_p0, x_1_p1); + } + if constexpr (kNumQueries >= 3) { + max_2 = hn::Max(x_2_p0, x_2_p1); + } + if constexpr (kNumQueries >= 4) { + max_3 = hn::Max(x_3_p0, x_3_p1); + } + if constexpr (kNumQueries == 1) { + new_max = hn::InsertLane(new_max, 0, hn::ReduceMax(df, max_0)); + } else { + new_max = Reduce4(df, max_0, max_1, max_2, max_3, + [](auto a, auto b) { return hn::Max(a, b); }); + } + if (att_cap > 0.0f) { + VF4 cap = hn::Set(df4, att_cap); + VF4 one_over_cap = hn::Set(df4, one_over_att_cap); + new_max = hn::Mul(cap, hn::Tanh(df4, hn::Mul(new_max, one_over_cap))); + } + VF4 old_max_vf = hn::Set(df4, -std::numeric_limits::max() / 2.0f); + old_max_vf = hn::LoadU(df4, old_max); + new_max = hn::Max(new_max, old_max_vf); + // TODO figure out what was wrong with broadcasts and change to that. + HWY_ALIGN float tmp_max[4]; + hn::Store(new_max, df4, tmp_max); + if constexpr (kNumQueries >= 1) { + const VF new_max_0 = hn::Set(df, tmp_max[0]); + x_0_p0 = hn::Exp(df, hn::Sub(x_0_p0 , new_max_0)); + x_0_p1 = hn::Exp(df, hn::Sub(x_0_p1, new_max_0)); + } + if constexpr (kNumQueries >= 2) { + const VF new_max_0 = hn::Set(df, tmp_max[1]); + x_1_p0 = hn::Exp(df, hn::Sub(x_1_p0, new_max_0)); + x_1_p1 = hn::Exp(df, hn::Sub(x_1_p1, new_max_0)); + } + if constexpr (kNumQueries >= 3) { + const VF new_max_0 = hn::Set(df, tmp_max[2]); + x_2_p0 = hn::Exp(df, hn::Sub(x_2_p0, new_max_0)); + x_2_p1 = hn::Exp(df, hn::Sub(x_2_p1, new_max_0)); + } + if constexpr (kNumQueries >= 4) { + const VF new_max_0 = hn::Set(df, tmp_max[3]); + x_3_p0 = hn::Exp(df, hn::Sub(x_3_p0, new_max_0)); + x_3_p1 = hn::Exp(df, hn::Sub(x_3_p1, new_max_0)); + } + VF4 old_d_vf = hn::Set(df4, 0.0f); + old_d_vf = hn::LoadU(df4, old_d); + VF4 scale = hn::Mul(old_d_vf, hn::Exp(df4, hn::Sub(old_max_vf, new_max))); + + hn::StoreU(new_max, df4, old_max); + + VF4 x_sum = hn::Zero(df4); + if constexpr (kNumQueries == 1) { + x_sum = hn::Set(df4, hn::ReduceSum(df, x_0_p0) + hn::ReduceSum(df, x_0_p1)); + } else { + VF x_0_sum = hn::Add(x_0_p0, x_0_p1); + VF x_1_sum = hn::Add(x_1_p0, x_1_p1); + VF x_2_sum = hn::Add(x_2_p0, x_2_p1); + VF x_3_sum = hn::Add(x_3_p0, x_3_p1); + x_sum = Reduce4(df, x_0_sum, x_1_sum, x_2_sum, x_3_sum, + [](auto a, auto b) { return hn::Add(a, b); }); + } + old_d_vf = hn::Add(scale, x_sum); + auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df4, 0.0f)); + const VF zero = hn::Zero(df); + const VF4 zero4 = hn::Zero(df4); + const VF4 one_over_d = + hn::MaskedDivOr(zero4, non_zero_mask, hn::Set(df4, 1.0f), old_d_vf); + float tmp_one_over_d[4]; + hn::Store(one_over_d, df4, tmp_one_over_d); + hn::Store(old_d_vf, df4, old_d); + scale = hn::Mul(scale, one_over_d); + hn::Store(scale, df4, scales); + if (hn::ExtractLane(old_d_vf, 0) > 0.0f) { + const VF one_over_d_0 = hn::Set(df, tmp_one_over_d[0]); + x_0_p0 = hn::Mul(x_0_p0, one_over_d_0); + x_0_p1 = hn::Mul(x_0_p1, one_over_d_0); + } else { + x_0_p0 = zero; + x_0_p1 = zero; + } + if constexpr (kNumQueries >= 2) { + if (hn::ExtractLane(old_d_vf, 1) > 0.0f) { + const VF one_over_d_1 = hn::Set(df, tmp_one_over_d[1]); + x_1_p0 = hn::Mul(x_1_p0, one_over_d_1); + x_1_p1 = hn::Mul(x_1_p1, one_over_d_1); + } else { + x_1_p0 = zero; + x_1_p1 = zero; + } + } + if constexpr (kNumQueries >= 3) { + if (hn::ExtractLane(old_d_vf, 2) > 0.0f) { + const VF one_over_d_2 = hn::Set(df, tmp_one_over_d[2]); + x_2_p0 = hn::Mul(x_2_p0, one_over_d_2); + x_2_p1 = hn::Mul(x_2_p1, one_over_d_2); + } else { + x_2_p0 = zero; + x_2_p1 = zero; + } + } + if constexpr (kNumQueries >= 4) { + if (hn::ExtractLane(old_d_vf, 3) > 0.0f) { + const VF one_over_d_3 = hn::Set(df, tmp_one_over_d[3]); + x_3_p0 = hn::Mul(x_3_p0, one_over_d_3); + x_3_p1 = hn::Mul(x_3_p1, one_over_d_3); + } else { + x_3_p0 = zero; + x_3_p1 = zero; + } + } +} + // Implements flash attention for a strip of 4 query vectors. // It iterates through timesteps in K from `start_pos` up to `max_last_pos`. // Timesteps up to `min_last_pos` (*) are processed in tiles of shape 4 Q rows diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index 944277f..cecede0 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -14,6 +14,8 @@ // limitations under the License. #include +#include +#include #include #include @@ -24,6 +26,7 @@ #include "gemma/kv_cache.h" #include "gemma/weights.h" #include "ops/matmul.h" +#include "util/test_util.h" #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS diff --git a/util/mat.h b/util/mat.h index 753e194..83d03b1 100644 --- a/util/mat.h +++ b/util/mat.h @@ -454,6 +454,21 @@ decltype(auto) CallUpcastedActivation(const MatPtr* base, const Func& func, } } +// Like CallUpcasted, but only for kv_cache types: kBF16 and kF32. +template +decltype(auto) CallUpcastedKV(const MatPtr* base, const Func& func, + Args&&... args) { + if (base->GetType() == Type::kF32) { + const MatPtrT mat(*base); + return func(&mat, std::forward(args)...); + } else if (base->GetType() == Type::kBF16) { + const MatPtrT mat(*base); + return func(&mat, std::forward(args)...); + } else { + HWY_ABORT("Unhandled type %s.", TypeName(base->GetType())); + } +} + void CopyMat(const MatPtr& from, MatPtr& to); void ZeroInit(MatPtr& mat);