Internal changes

PiperOrigin-RevId: 871792281
This commit is contained in:
The gemma.cpp Authors 2026-02-18 04:06:59 -08:00 committed by Copybara-Service
parent c6696342fa
commit 34739fd9f0
2 changed files with 42 additions and 243 deletions

View File

@ -19,7 +19,6 @@
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <iostream>
#include <limits>
#include <vector>
@ -465,28 +464,9 @@ static HWY_INLINE VF4 Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3,
return result;
}
// Returns vector with 8 lanes. Shouldn't be on architectures with less than 8
// lanes per vector.
template <class DF, typename T = hn::TFromD<DF>,
class DF8 = hn::CappedTag<T, 8>, class VF8 = hn::Vec<DF8>,
class VF = hn::Vec<DF>, typename F>
static HWY_INLINE VF8 Reduce8(DF df, VF x_0, VF x_1, VF x_2, VF x_3, VF x_4,
VF x_5, VF x_6, VF x_7, F reducer) {
auto res0123 = Reduce4(df, x_0, x_1, x_2, x_3, reducer);
auto res4567 = Reduce4(df, x_4, x_5, x_6, x_7, reducer);
using DF4 = hn::CappedTag<T, 4>;
const DF4 df4;
const DF8 df8;
HWY_ALIGN T buf[8];
hn::Store(res0123, df4, buf);
hn::Store(res4567, df4, buf + 4);
return hn::Load(df8, buf);
}
// Handles Up to 4 Q rows by NF*2 timesteps of flash attention.
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap(
DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1,
VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1,
float* HWY_RESTRICT old_max, float* HWY_RESTRICT old_d,
@ -522,29 +502,31 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
old_max_vf = hn::LoadU(df4, old_max);
new_max = hn::Max(new_max, old_max_vf);
auto changed_max = hn::Gt(new_max, hn::Set(df4, kNegInf));
// TODO figure out what was wrong with broadcasts and change to that.
hn::StoreU(new_max, df4, old_max);
if constexpr (kNumQueries >= 1) {
const VF new_max_0 = hn::Set(df, old_max[0]);
x_0_p0 = hn::CallExp(df, hn::Sub(x_0_p0, new_max_0));
x_0_p1 = hn::CallExp(df, hn::Sub(x_0_p1, new_max_0));
x_0_p0 = hn::Exp(df, hn::Sub(x_0_p0, new_max_0));
x_0_p1 = hn::Exp(df, hn::Sub(x_0_p1, new_max_0));
}
if constexpr (kNumQueries >= 2) {
const VF new_max_0 = hn::Set(df, old_max[1]);
x_1_p0 = hn::CallExp(df, hn::Sub(x_1_p0, new_max_0));
x_1_p1 = hn::CallExp(df, hn::Sub(x_1_p1, new_max_0));
x_1_p0 = hn::Exp(df, hn::Sub(x_1_p0, new_max_0));
x_1_p1 = hn::Exp(df, hn::Sub(x_1_p1, new_max_0));
}
if constexpr (kNumQueries >= 3) {
const VF new_max_0 = hn::Set(df, old_max[2]);
x_2_p0 = hn::CallExp(df, hn::Sub(x_2_p0, new_max_0));
x_2_p1 = hn::CallExp(df, hn::Sub(x_2_p1, new_max_0));
x_2_p0 = hn::Exp(df, hn::Sub(x_2_p0, new_max_0));
x_2_p1 = hn::Exp(df, hn::Sub(x_2_p1, new_max_0));
}
if constexpr (kNumQueries >= 4) {
const VF new_max_0 = hn::Set(df, old_max[3]);
x_3_p0 = hn::CallExp(df, hn::Sub(x_3_p0, new_max_0));
x_3_p1 = hn::CallExp(df, hn::Sub(x_3_p1, new_max_0));
x_3_p0 = hn::Exp(df, hn::Sub(x_3_p0, new_max_0));
x_3_p1 = hn::Exp(df, hn::Sub(x_3_p1, new_max_0));
}
VF4 old_d_vf = hn::Set(df4, 0.0f);
old_d_vf = hn::LoadU(df4, old_d);
VF4 scale = hn::Mul(old_d_vf, hn::Exp(df4, hn::Sub(old_max_vf, new_max)));
VF4 x_sum = hn::Zero(df4);
if constexpr (kNumQueries == 1) {
@ -557,7 +539,6 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
x_sum = Reduce4(df, x_0_sum, x_1_sum, x_2_sum, x_3_sum,
[](auto a, auto b) HWY_ATTR { return hn::Add(a, b); });
}
VF4 scale = hn::Mul(old_d_vf, hn::Exp(df4, hn::Sub(old_max_vf, new_max)));
old_d_vf = hn::Add(scale, x_sum);
auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df4, 0.0f));
const VF zero = hn::Zero(df);
@ -569,225 +550,43 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
hn::BlendedStore(old_d_vf, changed_max, df4, old_d);
scale = hn::Mul(scale, one_over_d);
hn::BlendedStore(scale, changed_max, df4, scales);
// same as lambda
auto mul_or_zero = [&](VF& x_p0, VF& x_p1, int i) HWY_ATTR {
if (HWY_LIKELY(old_d[i] > 0.0f && scales[i] != 1.0f)) {
const VF one_over_d_i = hn::Set(df, tmp_one_over_d[i]);
x_p0 = hn::Mul(x_p0, one_over_d_i);
x_p1 = hn::Mul(x_p1, one_over_d_i);
if (hn::ExtractLane(old_d_vf, 0) > 0.0f && scales[0] != 1.0f) {
const VF one_over_d_0 = hn::Set(df, tmp_one_over_d[0]);
x_0_p0 = hn::Mul(x_0_p0, one_over_d_0);
x_0_p1 = hn::Mul(x_0_p1, one_over_d_0);
} else {
x_0_p0 = zero;
x_0_p1 = zero;
}
if constexpr (kNumQueries >= 2) {
if (hn::ExtractLane(old_d_vf, 1) > 0.0f && scales[1] != 1.0f) {
const VF one_over_d_1 = hn::Set(df, tmp_one_over_d[1]);
x_1_p0 = hn::Mul(x_1_p0, one_over_d_1);
x_1_p1 = hn::Mul(x_1_p1, one_over_d_1);
} else {
x_p0 = zero;
x_p1 = zero;
x_1_p0 = zero;
x_1_p1 = zero;
}
};
mul_or_zero(x_0_p0, x_0_p1, 0);
if constexpr (kNumQueries >= 2) {
mul_or_zero(x_1_p0, x_1_p1, 1);
}
if constexpr (kNumQueries >= 3) {
mul_or_zero(x_2_p0, x_2_p1, 2);
}
if constexpr (kNumQueries >= 4) {
mul_or_zero(x_3_p0, x_3_p1, 3);
}
}
template <class DF, class VF = hn::Vec<DF>>
HWY_NOINLINE VF CallExp(DF df, VF x_p0) {
return hn::Exp(df, x_p0);
}
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1,
VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1,
VF& x_4_p0, VF& x_4_p1, VF& x_5_p0, VF& x_5_p1, VF& x_6_p0, VF& x_6_p1,
VF& x_7_p0, VF& x_7_p1, float* HWY_RESTRICT old_max,
float* HWY_RESTRICT old_d, float* HWY_RESTRICT scales) {
using DF8 = hn::CappedTag<float, 8>;
const DF8 df8;
using VF8 = hn::Vec<DF8>;
static_assert(kNumQueries >= 1 && kNumQueries <= 8);
VF8 new_max = hn::Set(df8, kNegInf);
VF max_0, max_1, max_2, max_3, max_4, max_5, max_6, max_7 = hn::Zero(df);
max_0 = hn::Max(x_0_p0, x_0_p1);
if constexpr (kNumQueries >= 2) {
max_1 = hn::Max(x_1_p0, x_1_p1);
}
if constexpr (kNumQueries >= 3) {
max_2 = hn::Max(x_2_p0, x_2_p1);
}
if constexpr (kNumQueries >= 4) {
max_3 = hn::Max(x_3_p0, x_3_p1);
}
if constexpr (kNumQueries >= 5) {
max_4 = hn::Max(x_4_p0, x_4_p1);
}
if constexpr (kNumQueries >= 6) {
max_5 = hn::Max(x_5_p0, x_5_p1);
}
if constexpr (kNumQueries >= 7) {
max_6 = hn::Max(x_6_p0, x_6_p1);
}
if constexpr (kNumQueries >= 8) {
max_7 = hn::Max(x_7_p0, x_7_p1);
}
if constexpr (kNumQueries == 1) {
new_max = hn::InsertLane(new_max, 0, hn::ReduceMax(df, max_0));
} else {
new_max =
Reduce8(df, max_0, max_1, max_2, max_3, max_4, max_5, max_6, max_7,
[](auto a, auto b) HWY_ATTR { return hn::Max(a, b); });
}
if (att_cap > 0.0f) {
VF8 cap = hn::Set(df8, att_cap);
VF8 one_over_cap = hn::Set(df8, one_over_att_cap);
new_max = hn::Mul(cap, hn::Tanh(df8, hn::Mul(new_max, one_over_cap)));
}
VF8 old_max_vf = hn::Set(df8, kNegInf);
old_max_vf = hn::LoadU(df8, old_max);
new_max = hn::Max(new_max, old_max_vf);
auto changed_max = hn::Gt(new_max, hn::Set(df8, kNegInf));
hn::StoreU(new_max, df8, old_max);
if constexpr (kNumQueries >= 1) {
const VF new_max_0 = hn::Set(df, old_max[0]);
x_0_p0 = hn::CallExp(df, hn::Sub(x_0_p0, new_max_0));
x_0_p1 = hn::CallExp(df, hn::Sub(x_0_p1, new_max_0));
}
if constexpr (kNumQueries >= 2) {
const VF new_max_0 = hn::Set(df, old_max[1]);
x_1_p0 = hn::CallExp(df, hn::Sub(x_1_p0, new_max_0));
x_1_p1 = hn::CallExp(df, hn::Sub(x_1_p1, new_max_0));
}
if constexpr (kNumQueries >= 3) {
const VF new_max_0 = hn::Set(df, old_max[2]);
x_2_p0 = hn::CallExp(df, hn::Sub(x_2_p0, new_max_0));
x_2_p1 = hn::CallExp(df, hn::Sub(x_2_p1, new_max_0));
}
if constexpr (kNumQueries >= 4) {
const VF new_max_0 = hn::Set(df, old_max[3]);
x_3_p0 = hn::CallExp(df, hn::Sub(x_3_p0, new_max_0));
x_3_p1 = hn::CallExp(df, hn::Sub(x_3_p1, new_max_0));
}
if constexpr (kNumQueries >= 5) {
const VF new_max_0 = hn::Set(df, old_max[4]);
x_4_p0 = hn::CallExp(df, hn::Sub(x_4_p0, new_max_0));
x_4_p1 = hn::CallExp(df, hn::Sub(x_4_p1, new_max_0));
}
if constexpr (kNumQueries >= 6) {
const VF new_max_0 = hn::Set(df, old_max[5]);
x_5_p0 = hn::CallExp(df, hn::Sub(x_5_p0, new_max_0));
x_5_p1 = hn::CallExp(df, hn::Sub(x_5_p1, new_max_0));
}
if constexpr (kNumQueries >= 7) {
const VF new_max_0 = hn::Set(df, old_max[6]);
x_6_p0 = hn::CallExp(df, hn::Sub(x_6_p0, new_max_0));
x_6_p1 = hn::CallExp(df, hn::Sub(x_6_p1, new_max_0));
}
if constexpr (kNumQueries >= 8) {
const VF new_max_0 = hn::Set(df, old_max[7]);
x_7_p0 = hn::CallExp(df, hn::Sub(x_7_p0, new_max_0));
x_7_p1 = hn::CallExp(df, hn::Sub(x_7_p1, new_max_0));
}
VF8 old_d_vf = hn::Set(df8, 0.0f);
old_d_vf = hn::LoadU(df8, old_d);
VF8 x_sum = hn::Zero(df8);
if constexpr (kNumQueries == 1) {
x_sum = hn::Set(df8, hn::ReduceSum(df, x_0_p0) + hn::ReduceSum(df, x_0_p1));
} else {
VF x_0_sum = hn::Add(x_0_p0, x_0_p1);
VF x_1_sum = hn::Add(x_1_p0, x_1_p1);
VF x_2_sum = hn::Add(x_2_p0, x_2_p1);
VF x_3_sum = hn::Add(x_3_p0, x_3_p1);
VF x_4_sum = hn::Add(x_4_p0, x_4_p1);
VF x_5_sum = hn::Add(x_5_p0, x_5_p1);
VF x_6_sum = hn::Add(x_6_p0, x_6_p1);
VF x_7_sum = hn::Add(x_7_p0, x_7_p1);
x_sum = Reduce8(df, x_0_sum, x_1_sum, x_2_sum, x_3_sum, x_4_sum, x_5_sum,
x_6_sum, x_7_sum,
[](auto a, auto b) HWY_ATTR { return hn::Add(a, b); });
}
VF8 scale = hn::Mul(old_d_vf, hn::Exp(df8, hn::Sub(old_max_vf, new_max)));
old_d_vf = hn::Add(scale, x_sum);
auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df8, 0.0f));
const VF zero = hn::Zero(df);
const VF8 zero8 = hn::Zero(df8);
const VF8 one_over_d =
hn::MaskedDivOr(zero8, non_zero_mask, hn::Set(df8, 1.0f), old_d_vf);
HWY_ALIGN float tmp_one_over_d[8];
hn::Store(one_over_d, df8, tmp_one_over_d);
hn::BlendedStore(old_d_vf, changed_max, df8, old_d);
scale = hn::Mul(scale, one_over_d);
hn::BlendedStore(scale, changed_max, df8, scales);
auto mul_or_zero = [&](VF& x_p0, VF& x_p1, int i) HWY_ATTR {
if (HWY_LIKELY(old_d[i] > 0.0f && scales[i] != 1.0f)) {
const VF one_over_d_i = hn::Set(df, tmp_one_over_d[i]);
x_p0 = hn::Mul(x_p0, one_over_d_i);
x_p1 = hn::Mul(x_p1, one_over_d_i);
if (hn::ExtractLane(old_d_vf, 2) > 0.0f && scales[2] != 1.0f) {
const VF one_over_d_2 = hn::Set(df, tmp_one_over_d[2]);
x_2_p0 = hn::Mul(x_2_p0, one_over_d_2);
x_2_p1 = hn::Mul(x_2_p1, one_over_d_2);
} else {
x_p0 = zero;
x_p1 = zero;
x_2_p0 = zero;
x_2_p1 = zero;
}
};
mul_or_zero(x_0_p0, x_0_p1, 0);
if constexpr (kNumQueries >= 2) {
mul_or_zero(x_1_p0, x_1_p1, 1);
}
if constexpr (kNumQueries >= 3) {
mul_or_zero(x_2_p0, x_2_p1, 2);
}
if constexpr (kNumQueries >= 4) {
mul_or_zero(x_3_p0, x_3_p1, 3);
}
if constexpr (kNumQueries >= 5) {
mul_or_zero(x_4_p0, x_4_p1, 4);
}
if constexpr (kNumQueries >= 6) {
mul_or_zero(x_5_p0, x_5_p1, 5);
}
if constexpr (kNumQueries >= 7) {
mul_or_zero(x_6_p0, x_6_p1, 6);
}
if constexpr (kNumQueries >= 8) {
mul_or_zero(x_7_p0, x_7_p1, 7);
}
}
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap(
DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1,
VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1,
VF& x_4_p0, VF& x_4_p1, VF& x_5_p0, VF& x_5_p1, VF& x_6_p0, VF& x_6_p1,
VF& x_7_p0, VF& x_7_p1, float* HWY_RESTRICT old_max,
float* HWY_RESTRICT old_d, float* HWY_RESTRICT scales, size_t q_group_idx,
size_t kNumQueriesPerGroup) {
constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4);
constexpr int kSecondHalfAmountOfQueries =
kNumQueries - kFirstHalfAmountOfQueries;
if constexpr (kNumQueries <= 4) {
FlashAttentionTileStepAndApplySoftCap4<kFirstHalfAmountOfQueries>(
df, att_cap, one_over_att_cap, x_0_p0, x_0_p1, x_1_p0, x_1_p1, x_2_p0,
x_2_p1, x_3_p0, x_3_p1, old_max + (q_group_idx)*kNumQueriesPerGroup,
old_d + (q_group_idx)*kNumQueriesPerGroup, scales);
} else {
#if HWY_MAX_BYTES <= 16
FlashAttentionTileStepAndApplySoftCap4<4>(
df, att_cap, one_over_att_cap, x_0_p0, x_0_p1, x_1_p0, x_1_p1, x_2_p0,
x_2_p1, x_3_p0, x_3_p1, old_max + (q_group_idx)*kNumQueriesPerGroup,
old_d + (q_group_idx)*kNumQueriesPerGroup, scales);
FlashAttentionTileStepAndApplySoftCap4<kSecondHalfAmountOfQueries>(
df, att_cap, one_over_att_cap, x_4_p0, x_4_p1, x_5_p0, x_5_p1, x_6_p0,
x_6_p1, x_7_p0, x_7_p1,
old_max + (q_group_idx + 1) * kNumQueriesPerGroup,
old_d + (q_group_idx + 1) * kNumQueriesPerGroup,
scales + kNumQueriesPerGroup);
#else
FlashAttentionTileStepAndApplySoftCap8<kNumQueries>(
df, att_cap, one_over_att_cap, x_0_p0, x_0_p1, x_1_p0, x_1_p1, x_2_p0,
x_2_p1, x_3_p0, x_3_p1, x_4_p0, x_4_p1, x_5_p0, x_5_p1, x_6_p0, x_6_p1,
x_7_p0, x_7_p1, old_max + (q_group_idx)*kNumQueriesPerGroup,
old_d + (q_group_idx)*kNumQueriesPerGroup, scales);
#endif
if (hn::ExtractLane(old_d_vf, 3) > 0.0f && scales[3] != 1.0f) {
const VF one_over_d_3 = hn::Set(df, tmp_one_over_d[3]);
x_3_p0 = hn::Mul(x_3_p0, one_over_d_3);
x_3_p1 = hn::Mul(x_3_p1, one_over_d_3);
} else {
x_3_p0 = zero;
x_3_p1 = zero;
}
}
}

View File

@ -68,7 +68,7 @@ void SetMat(const size_t offset, MatPtrT<float>& mat) {
const float i_scale = 1.0f / kInner;
const float j_scale = 1.0f / kOuter;
for (size_t i = 0; i < kOuter; ++i) {
float* HWY_RESTRICT row = mat.Row(i);
float* row = mat.Row(i);
for (size_t j = 0; j < kInner; ++j) {
row[j] =
static_cast<float>((i * kInner * i_scale + (j + offset) * j_scale));
@ -190,7 +190,7 @@ HWY_AFTER_NAMESPACE();
namespace gcpp {
HWY_BEFORE_TEST(FlashAttentionTest);
// HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestAttention);
HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestAttention);
HWY_AFTER_TEST();
} // namespace gcpp