This commit is contained in:
copybara-service[bot] 2026-03-02 11:00:10 +00:00 committed by GitHub
commit 43771bd811
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 29 additions and 43 deletions

View File

@ -627,6 +627,7 @@ cc_library(
"//compression:compress",
"//compression:types",
"@highway//:hwy",
"@highway//:math",
"@highway//:profiler",
],
)

View File

@ -53,6 +53,7 @@
#include "gemma/attention.h"
#include "ops/matmul-inl.h"
#include "ops/ops-inl.h"
#include "hwy/contrib/math/fast_math-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
@ -523,25 +524,22 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
new_max = hn::Max(new_max, old_max_vf);
auto changed_max = hn::Gt(new_max, hn::Set(df4, kNegInf));
hn::StoreU(new_max, df4, old_max);
auto apply_exp = [&](int i, VF& x_p0, VF& x_p1) HWY_ATTR {
const VF new_max_i = hn::Set(df, old_max[i]);
x_p0 = hn::FastExp(df, hn::Sub(x_p0, new_max_i));
x_p1 = hn::FastExp(df, hn::Sub(x_p1, new_max_i));
};
if constexpr (kNumQueries >= 1) {
const VF new_max_0 = hn::Set(df, old_max[0]);
x_0_p0 = hn::CallExp(df, hn::Sub(x_0_p0, new_max_0));
x_0_p1 = hn::CallExp(df, hn::Sub(x_0_p1, new_max_0));
apply_exp(0, x_0_p0, x_0_p1);
}
if constexpr (kNumQueries >= 2) {
const VF new_max_0 = hn::Set(df, old_max[1]);
x_1_p0 = hn::CallExp(df, hn::Sub(x_1_p0, new_max_0));
x_1_p1 = hn::CallExp(df, hn::Sub(x_1_p1, new_max_0));
apply_exp(1, x_1_p0, x_1_p1);
}
if constexpr (kNumQueries >= 3) {
const VF new_max_0 = hn::Set(df, old_max[2]);
x_2_p0 = hn::CallExp(df, hn::Sub(x_2_p0, new_max_0));
x_2_p1 = hn::CallExp(df, hn::Sub(x_2_p1, new_max_0));
apply_exp(2, x_2_p0, x_2_p1);
}
if constexpr (kNumQueries >= 4) {
const VF new_max_0 = hn::Set(df, old_max[3]);
x_3_p0 = hn::CallExp(df, hn::Sub(x_3_p0, new_max_0));
x_3_p1 = hn::CallExp(df, hn::Sub(x_3_p1, new_max_0));
apply_exp(3, x_3_p0, x_3_p1);
}
VF4 old_d_vf = hn::Set(df4, 0.0f);
old_d_vf = hn::LoadU(df4, old_d);
@ -592,10 +590,6 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
}
}
template <class DF, class VF = hn::Vec<DF>>
HWY_NOINLINE VF CallExp(DF df, VF x_p0) {
return hn::Exp(df, x_p0);
}
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1,
@ -649,45 +643,36 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
new_max = hn::Max(new_max, old_max_vf);
auto changed_max = hn::Gt(new_max, hn::Set(df8, kNegInf));
hn::StoreU(new_max, df8, old_max);
auto apply_exp = [&](int i, VF& x_p0, VF& x_p1) HWY_ATTR {
const VF new_max_i = hn::Set(df, old_max[i]);
x_p0 = hn::Exp(df, hn::Sub(x_p0, new_max_i));
x_p1 = hn::Exp(df, hn::Sub(x_p1, new_max_i));
};
if constexpr (kNumQueries >= 1) {
const VF new_max_0 = hn::Set(df, old_max[0]);
x_0_p0 = hn::CallExp(df, hn::Sub(x_0_p0, new_max_0));
x_0_p1 = hn::CallExp(df, hn::Sub(x_0_p1, new_max_0));
apply_exp(0, x_0_p0, x_0_p1);
}
if constexpr (kNumQueries >= 2) {
const VF new_max_0 = hn::Set(df, old_max[1]);
x_1_p0 = hn::CallExp(df, hn::Sub(x_1_p0, new_max_0));
x_1_p1 = hn::CallExp(df, hn::Sub(x_1_p1, new_max_0));
apply_exp(1, x_1_p0, x_1_p1);
}
if constexpr (kNumQueries >= 3) {
const VF new_max_0 = hn::Set(df, old_max[2]);
x_2_p0 = hn::CallExp(df, hn::Sub(x_2_p0, new_max_0));
x_2_p1 = hn::CallExp(df, hn::Sub(x_2_p1, new_max_0));
apply_exp(2, x_2_p0, x_2_p1);
}
if constexpr (kNumQueries >= 4) {
const VF new_max_0 = hn::Set(df, old_max[3]);
x_3_p0 = hn::CallExp(df, hn::Sub(x_3_p0, new_max_0));
x_3_p1 = hn::CallExp(df, hn::Sub(x_3_p1, new_max_0));
apply_exp(3, x_3_p0, x_3_p1);
}
if constexpr (kNumQueries >= 5) {
const VF new_max_0 = hn::Set(df, old_max[4]);
x_4_p0 = hn::CallExp(df, hn::Sub(x_4_p0, new_max_0));
x_4_p1 = hn::CallExp(df, hn::Sub(x_4_p1, new_max_0));
apply_exp(4, x_4_p0, x_4_p1);
}
if constexpr (kNumQueries >= 6) {
const VF new_max_0 = hn::Set(df, old_max[5]);
x_5_p0 = hn::CallExp(df, hn::Sub(x_5_p0, new_max_0));
x_5_p1 = hn::CallExp(df, hn::Sub(x_5_p1, new_max_0));
apply_exp(5, x_5_p0, x_5_p1);
}
if constexpr (kNumQueries >= 7) {
const VF new_max_0 = hn::Set(df, old_max[6]);
x_6_p0 = hn::CallExp(df, hn::Sub(x_6_p0, new_max_0));
x_6_p1 = hn::CallExp(df, hn::Sub(x_6_p1, new_max_0));
apply_exp(6, x_6_p0, x_6_p1);
}
if constexpr (kNumQueries >= 8) {
const VF new_max_0 = hn::Set(df, old_max[7]);
x_7_p0 = hn::CallExp(df, hn::Sub(x_7_p0, new_max_0));
x_7_p1 = hn::CallExp(df, hn::Sub(x_7_p1, new_max_0));
apply_exp(7, x_7_p0, x_7_p1);
}
VF8 old_d_vf = hn::Set(df8, 0.0f);
old_d_vf = hn::LoadU(df8, old_d);

View File

@ -349,11 +349,11 @@ void TestTiledFlashAttention() {
for (int i = 0; i < num_queries; ++i) {
std::cerr << "exp_d: " << exp_denominator_sums[i]
<< " max_logit: " << max_logits[i] << std::endl;
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 1e-4f)
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 4e-2f)
<< "i=" << i;
EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-6f) << "i=" << i;
for (int j = 0; j < qkv_dim; ++j) {
EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-6f);
EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-4f);
}
}
}
@ -464,7 +464,7 @@ void TestTiledFlashAttentionBF16() {
for (int i = 0; i < num_queries; ++i) {
std::cerr << "exp_d: " << exp_denominator_sums[i]
<< " max_logit: " << max_logits[i] << std::endl;
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 2e-2f)
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 4e-2f)
<< "i=" << i;
EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i;
for (int j = 0; j < qkv_dim; ++j) {