Replace remaining occurrences of Exp with FastExpMinusOrZero in flash attention.

PiperOrigin-RevId: 882817324
This commit is contained in:
Nikhil Dev Goyal 2026-03-12 15:47:51 -07:00 committed by Copybara-Service
parent 529c201eb6
commit 139a8e0964
2 changed files with 35 additions and 16 deletions

View File

@ -46,11 +46,18 @@ static const char* kQuestions =
"Which people first proposed the quark model of hadrons, and when?";
// All phrases in kAnswers must appear in the response in the order given for
// the test to pass.
static const char* kAnswers[] = {
"a ship's anchor", "a dark forest", "an hour",
"enormous sand", "castles", "limpet shells",
"Murray Gell-Mann", "George Zweig", "1964"};
// the test to pass. Multiple acceptable answers can be provided for each
// expected phrase.
static const std::vector<std::vector<const char*>> kAnswers = {
{"rusty metal", "ship's anchor"},
{"dark forest"},
{"an hour"},
{"enormous sand"},
{"castles"},
{"limpet shells"},
{"Murray Gell-Mann"},
{"George Zweig"},
{"1964"}};
std::string LoadPromptFile(const std::string& filename) {
// If the filename is empty, return an empty string.
@ -108,12 +115,22 @@ class GemmaTest : public ::testing::Test {
void TestExpectations(const std::string& response) {
fprintf(stderr, "Response: '%s'\n", response.c_str());
size_t pos = 0;
for (const char* answer : kAnswers) {
auto found = response.find(answer, pos);
EXPECT_NE(found, std::string::npos)
<< "Response does not contain " << answer;
if (found != std::string::npos) {
pos = found + strlen(answer);
for (const auto& answer_group : kAnswers) {
size_t earliest_pos = std::string::npos;
const char* matched_answer = nullptr;
for (const char* answer : answer_group) {
auto found = response.find(answer, pos);
if (found != std::string::npos &&
(earliest_pos == std::string::npos || found < earliest_pos)) {
earliest_pos = found;
matched_answer = answer;
}
}
EXPECT_NE(earliest_pos, std::string::npos)
<< "Response does not contain acceptable answers, e.g., "
<< answer_group[0];
if (earliest_pos != std::string::npos) {
pos = earliest_pos + strlen(matched_answer);
}
}
s_env->PrintProfileResults();

View File

@ -489,7 +489,7 @@ HWY_INLINE float SingleFlashAttentionRowVector(DF df, size_t start_pos,
}
float m = hn::ReduceMax(df, x);
m = std::max(m, old_max);
x = hn::Exp(df, hn::Sub(x, hn::Set(df, m)));
x = hn::FastExpMinusOrZero(df, hn::Sub(x, hn::Set(df, m)));
float scale = old_d * std::exp(old_max - m);
old_d = hn::ReduceSum(df, x) + scale;
old_max = m;
@ -538,8 +538,8 @@ HWY_INLINE float DoubleFlashAttentionRowVector(DF df, size_t start_pos,
float m = hn::ReduceMax(df, x_max);
m = std::max(m, old_max);
VF m_vec = hn::Set(df, m);
x0 = hn::Exp(df, hn::Sub(x0, m_vec));
x1 = hn::Exp(df, hn::Sub(x1, m_vec));
x0 = hn::FastExpMinusOrZero(df, hn::Sub(x0, m_vec));
x1 = hn::FastExpMinusOrZero(df, hn::Sub(x1, m_vec));
float scale = old_d * std::exp(old_max - m);
VF x_sum = hn::Add(x0, x1);
old_d = hn::ReduceSum(df, x_sum) + scale;
@ -672,7 +672,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
x_sum = Reduce4(df, x_0_sum, x_1_sum, x_2_sum, x_3_sum,
[](auto a, auto b) HWY_ATTR { return hn::Add(a, b); });
}
VF4 scale = hn::Mul(old_d_vf, hn::Exp(df4, hn::Sub(old_max_vf, new_max)));
VF4 scale = hn::Mul(
old_d_vf, hn::FastExpMinusOrZero(df4, hn::Sub(old_max_vf, new_max)));
old_d_vf = hn::Add(scale, x_sum);
auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df4, 0.0f));
const VF zero = hn::Zero(df);
@ -810,7 +811,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
x_6_sum, x_7_sum,
[](auto a, auto b) HWY_ATTR { return hn::Add(a, b); });
}
VF8 scale = hn::Mul(old_d_vf, hn::Exp(df8, hn::Sub(old_max_vf, new_max)));
VF8 scale = hn::Mul(
old_d_vf, hn::FastExpMinusOrZero(df8, hn::Sub(old_max_vf, new_max)));
old_d_vf = hn::Add(scale, x_sum);
auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df8, 0.0f));
const VF zero = hn::Zero(df);