mirror of https://github.com/google/gemma.cpp.git
Use FastExpMinusOrZero in Softmax().
PiperOrigin-RevId: 895740071
This commit is contained in:
parent
f01cc59218
commit
70513a1e0f
|
|
@ -55,6 +55,7 @@
|
|||
#include "ops/matmul_static.h" // includes highway.h
|
||||
#include "ops/sum-inl.h"
|
||||
#include "hwy/contrib/algo/transform-inl.h"
|
||||
#include "hwy/contrib/math/fast_math-inl.h"
|
||||
#include "hwy/contrib/math/math-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
|
|
@ -1442,10 +1443,11 @@ static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx,
|
|||
hn::Transform(d, logits.data(), logits.size(),
|
||||
[pmax](const auto d, const V value) HWY_ATTR {
|
||||
if constexpr (HWY_TARGET & HWY_ALL_SVE) {
|
||||
// Workaround for buggy SVE codegen: avoid inlined Exp().
|
||||
return hn::CallExp(d, hn::Sub(value, *pmax));
|
||||
// Workaround for buggy SVE codegen: avoid inlined
|
||||
// FastExpMinusOrZero().
|
||||
return hn::CallFastExpMinusOrZero(d, hn::Sub(value, *pmax));
|
||||
} else {
|
||||
return hn::Exp(d, hn::Sub(value, *pmax));
|
||||
return hn::FastExpMinusOrZero(d, hn::Sub(value, *pmax));
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -321,10 +321,10 @@ class TestSoftmax {
|
|||
for (size_t i = 0; i < count; ++i) {
|
||||
sum += x[i];
|
||||
double rel = std::abs(x[i] - e[i]) / e[i];
|
||||
ASSERT_LT(rel, 1e-6) << "Mismatch on coordinate " << i << " out of "
|
||||
ASSERT_LT(rel, 2e-5) << "Mismatch on coordinate " << i << " out of "
|
||||
<< count;
|
||||
}
|
||||
ASSERT_NEAR(sum, 1.0, 1e-6);
|
||||
ASSERT_NEAR(sum, 1.0, 2e-5);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -384,7 +384,7 @@ class TestSoftmaxState {
|
|||
}
|
||||
|
||||
ASSERT_NEAR(softmax_max, maxval, 1e-6);
|
||||
ASSERT_NEAR(softmax_d, sum_exp, 1e-6);
|
||||
ASSERT_NEAR(softmax_d, sum_exp, 2e-5);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue