mirror of https://github.com/google/gemma.cpp.git
Merge pull request #194 from szabadka:softmax-fix
PiperOrigin-RevId: 636848144
This commit is contained in:
commit
93c0088646
12
gemma/ops.h
12
gemma/ops.h
|
|
@ -833,14 +833,16 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
|||
vmax = hn::MaxOfLanes(d, vmax);
|
||||
|
||||
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
||||
auto sum = hn::Zero(d);
|
||||
hn::Transform(d, x, mask_pos,
|
||||
[&sum, &vmax](const auto d, const auto value) HWY_ATTR {
|
||||
const auto out = hn::Exp(d, hn::Sub(value, vmax));
|
||||
sum = hn::Add(sum, out);
|
||||
return out;
|
||||
[&vmax](const auto d, const auto value) HWY_ATTR {
|
||||
return hn::Exp(d, hn::Sub(value, vmax));
|
||||
});
|
||||
|
||||
auto sum = hn::Zero(d);
|
||||
Foreach(d, x, mask_pos, sum,
|
||||
[&sum](const auto d, const auto value)
|
||||
HWY_ATTR { sum = hn::Add(sum, value); });
|
||||
|
||||
// Normalize to probability distribution
|
||||
const float mul = 1.0f / hn::ReduceSum(d, sum);
|
||||
MulByConst(mul, x, size, mask_pos);
|
||||
|
|
|
|||
|
|
@ -107,42 +107,16 @@ HWY_NOINLINE void SourceSoftmax(float* HWY_RESTRICT x, size_t size,
|
|||
size_t mask_pos) {
|
||||
HWY_DASSERT(size != 0);
|
||||
HWY_DASSERT(mask_pos <= size);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
const D d;
|
||||
const size_t N = hn::Lanes(d);
|
||||
|
||||
const hn::Vec<D> vmin = hn::Set(d, hwy::LowestValue<float>());
|
||||
hn::Vec<D> vmax = vmin;
|
||||
size_t idx = 0;
|
||||
if (mask_pos >= N) {
|
||||
for (; idx <= mask_pos - N; idx += N) {
|
||||
vmax = hn::Max(vmax, LoadU(d, x + idx));
|
||||
}
|
||||
float sum = 0.0;
|
||||
float maxval = *std::max_element(x, x + mask_pos);
|
||||
for (size_t i = 0; i < mask_pos; ++i) {
|
||||
x[i] = std::exp(x[i] - maxval);
|
||||
sum += x[i];
|
||||
}
|
||||
vmax = hn::Max(vmax, LoadNOr(vmin, d, x + idx, mask_pos - idx));
|
||||
vmax = hn::MaxOfLanes(d, vmax); // broadcast
|
||||
|
||||
hn::Vec<D> sum = hn::Zero(d);
|
||||
idx = 0;
|
||||
if (mask_pos >= N) {
|
||||
for (; idx <= mask_pos - N; idx += N) {
|
||||
const hn::Vec<D> out = hn::Exp(d, hn::Sub(hn::LoadU(d, x + idx), vmax));
|
||||
sum = hn::Add(sum, out);
|
||||
hn::StoreU(out, d, x + idx);
|
||||
}
|
||||
float scale = 1.0f / sum;
|
||||
for (size_t i = 0; i < mask_pos; ++i) {
|
||||
x[i] *= scale;
|
||||
}
|
||||
if (mask_pos > idx) {
|
||||
const size_t remaining = mask_pos - idx;
|
||||
const hn::Vec<D> out =
|
||||
hn::Exp(d, hn::Sub(hn::LoadN(d, x + idx, remaining), vmax));
|
||||
sum = hn::Add(sum, out);
|
||||
hn::StoreN(out, d, x + idx, remaining);
|
||||
}
|
||||
|
||||
const float mul = 1.0f / hn::ReduceSum(d, sum);
|
||||
SourceMulByConst(mul, x, size, mask_pos);
|
||||
}
|
||||
|
||||
template <size_t k>
|
||||
|
|
@ -311,8 +285,14 @@ struct TestSoftmax {
|
|||
SourceSoftmax(e, count, count);
|
||||
Softmax(x, count, count);
|
||||
|
||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||
__LINE__);
|
||||
T sum = 0.0f;
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
sum += x[i];
|
||||
double rel = std::abs(x[i] - e[i]) / e[i];
|
||||
ASSERT_LT(rel, 1e-6)
|
||||
<< "Mismatch on coordinate " << i << " out of " << count;
|
||||
}
|
||||
ASSERT_NEAR(sum, 1.0, 1e-6);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue