Merge pull request #194 from szabadka:softmax-fix

PiperOrigin-RevId: 636848144
This commit is contained in:
Copybara-Service 2024-05-24 02:48:17 -07:00
commit 93c0088646
2 changed files with 23 additions and 41 deletions

View File

@ -833,14 +833,16 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
vmax = hn::MaxOfLanes(d, vmax);
// Subtract max (avoid precision loss for large exponents) and exponentiate.
auto sum = hn::Zero(d);
hn::Transform(d, x, mask_pos,
[&sum, &vmax](const auto d, const auto value) HWY_ATTR {
const auto out = hn::Exp(d, hn::Sub(value, vmax));
sum = hn::Add(sum, out);
return out;
[&vmax](const auto d, const auto value) HWY_ATTR {
return hn::Exp(d, hn::Sub(value, vmax));
});
auto sum = hn::Zero(d);
Foreach(d, x, mask_pos, sum,
[&sum](const auto d, const auto value)
HWY_ATTR { sum = hn::Add(sum, value); });
// Normalize to probability distribution
const float mul = 1.0f / hn::ReduceSum(d, sum);
MulByConst(mul, x, size, mask_pos);

View File

@ -107,42 +107,16 @@ HWY_NOINLINE void SourceSoftmax(float* HWY_RESTRICT x, size_t size,
size_t mask_pos) {
HWY_DASSERT(size != 0);
HWY_DASSERT(mask_pos <= size);
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
const size_t N = hn::Lanes(d);
const hn::Vec<D> vmin = hn::Set(d, hwy::LowestValue<float>());
hn::Vec<D> vmax = vmin;
size_t idx = 0;
if (mask_pos >= N) {
for (; idx <= mask_pos - N; idx += N) {
vmax = hn::Max(vmax, LoadU(d, x + idx));
}
float sum = 0.0;
float maxval = *std::max_element(x, x + mask_pos);
for (size_t i = 0; i < mask_pos; ++i) {
x[i] = std::exp(x[i] - maxval);
sum += x[i];
}
vmax = hn::Max(vmax, LoadNOr(vmin, d, x + idx, mask_pos - idx));
vmax = hn::MaxOfLanes(d, vmax); // broadcast
hn::Vec<D> sum = hn::Zero(d);
idx = 0;
if (mask_pos >= N) {
for (; idx <= mask_pos - N; idx += N) {
const hn::Vec<D> out = hn::Exp(d, hn::Sub(hn::LoadU(d, x + idx), vmax));
sum = hn::Add(sum, out);
hn::StoreU(out, d, x + idx);
}
float scale = 1.0f / sum;
for (size_t i = 0; i < mask_pos; ++i) {
x[i] *= scale;
}
if (mask_pos > idx) {
const size_t remaining = mask_pos - idx;
const hn::Vec<D> out =
hn::Exp(d, hn::Sub(hn::LoadN(d, x + idx, remaining), vmax));
sum = hn::Add(sum, out);
hn::StoreN(out, d, x + idx, remaining);
}
const float mul = 1.0f / hn::ReduceSum(d, sum);
SourceMulByConst(mul, x, size, mask_pos);
}
template <size_t k>
@ -311,8 +285,14 @@ struct TestSoftmax {
SourceSoftmax(e, count, count);
Softmax(x, count, count);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__);
T sum = 0.0f;
for (size_t i = 0; i < count; ++i) {
sum += x[i];
double rel = std::abs(x[i] - e[i]) / e[i];
ASSERT_LT(rel, 1e-6)
<< "Mismatch on coordinate " << i << " out of " << count;
}
ASSERT_NEAR(sum, 1.0, 1e-6);
}
};