Use FastExpMinusOrZero in Softmax().

PiperOrigin-RevId: 895740071
This commit is contained in:
Nikhil Dev Goyal 2026-04-07 01:19:17 -07:00 committed by Copybara-Service
parent f01cc59218
commit 70513a1e0f
2 changed files with 8 additions and 6 deletions

View File

@ -55,6 +55,7 @@
#include "ops/matmul_static.h" // includes highway.h
#include "ops/sum-inl.h"
#include "hwy/contrib/algo/transform-inl.h"
#include "hwy/contrib/math/fast_math-inl.h"
#include "hwy/contrib/math/math-inl.h"
HWY_BEFORE_NAMESPACE();
@ -1442,10 +1443,11 @@ static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx,
hn::Transform(d, logits.data(), logits.size(),
[pmax](const auto d, const V value) HWY_ATTR {
if constexpr (HWY_TARGET & HWY_ALL_SVE) {
// Workaround for buggy SVE codegen: avoid inlined Exp().
return hn::CallExp(d, hn::Sub(value, *pmax));
// Workaround for buggy SVE codegen: avoid inlined
// FastExpMinusOrZero().
return hn::CallFastExpMinusOrZero(d, hn::Sub(value, *pmax));
} else {
return hn::Exp(d, hn::Sub(value, *pmax));
return hn::FastExpMinusOrZero(d, hn::Sub(value, *pmax));
}
});

View File

@ -321,10 +321,10 @@ class TestSoftmax {
for (size_t i = 0; i < count; ++i) {
sum += x[i];
double rel = std::abs(x[i] - e[i]) / e[i];
ASSERT_LT(rel, 1e-6) << "Mismatch on coordinate " << i << " out of "
ASSERT_LT(rel, 2e-5) << "Mismatch on coordinate " << i << " out of "
<< count;
}
ASSERT_NEAR(sum, 1.0, 1e-6);
ASSERT_NEAR(sum, 1.0, 2e-5);
}
private:
@ -384,7 +384,7 @@ class TestSoftmaxState {
}
ASSERT_NEAR(softmax_max, maxval, 1e-6);
ASSERT_NEAR(softmax_d, sum_exp, 1e-6);
ASSERT_NEAR(softmax_d, sum_exp, 2e-5);
}
};