Fix Softmax on SVE

PiperOrigin-RevId: 640947138
This commit is contained in:
Paul Chang 2024-06-06 10:39:00 -07:00 committed by Copybara-Service
parent 39d4115717
commit 6c0be20fa6
2 changed files with 14 additions and 4 deletions

View File

@ -29,6 +29,7 @@
#include "compression/sfp.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/detect_targets.h"
#include "hwy/profiler.h"
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
@ -1247,8 +1248,15 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
// Subtract max (avoid precision loss for large exponents) and exponentiate.
hn::Transform(d, x, mask_pos,
[&vmax](const auto d, const auto value)
HWY_ATTR { return hn::Exp(d, hn::Sub(value, vmax)); });
[&vmax](const auto d, const auto value) HWY_ATTR {
#if HWY_TARGET & HWY_ALL_SVE
// Temporary workaround for buggy SVE codegen: avoid inlined
// Exp().
return hn::CallExp(d, hn::Sub(value, vmax));
#else
return hn::Exp(d, hn::Sub(value, vmax));
#endif
});
auto sum = hn::Zero(d);
Foreach(d, x, mask_pos, sum, [&sum](const auto d, const auto value) HWY_ATTR {

View File

@ -109,12 +109,12 @@ HWY_NOINLINE void SourceSoftmax(float* HWY_RESTRICT x, size_t size,
HWY_DASSERT(size != 0);
HWY_DASSERT(mask_pos <= size);
float sum = 0.0;
float maxval = *std::max_element(x, x + mask_pos);
const float maxval = *std::max_element(x, x + mask_pos);
for (size_t i = 0; i < mask_pos; ++i) {
x[i] = std::exp(x[i] - maxval);
sum += x[i];
}
float scale = 1.0f / sum;
const float scale = 1.0f / sum;
for (size_t i = 0; i < mask_pos; ++i) {
x[i] *= scale;
}
@ -237,6 +237,7 @@ struct TestMulByConst {
template <class D>
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
hwy::RandomState& rng) {
if (misalign_b == 0) return;
using T = hn::TFromD<D>;
hwy::AlignedFreeUniquePtr<T[]> px =
@ -267,6 +268,7 @@ struct TestSoftmax {
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
hwy::RandomState& rng) {
if (count == 0) return; // *Softmax would assert
if (misalign_b == 0) return;
using T = hn::TFromD<D>;
hwy::AlignedFreeUniquePtr<T[]> px =