mirror of https://github.com/google/gemma.cpp.git
Merge pull request #194 from szabadka:softmax-fix
PiperOrigin-RevId: 636848144
This commit is contained in:
commit
93c0088646
12
gemma/ops.h
12
gemma/ops.h
|
|
@ -833,14 +833,16 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
vmax = hn::MaxOfLanes(d, vmax);
|
vmax = hn::MaxOfLanes(d, vmax);
|
||||||
|
|
||||||
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
||||||
auto sum = hn::Zero(d);
|
|
||||||
hn::Transform(d, x, mask_pos,
|
hn::Transform(d, x, mask_pos,
|
||||||
[&sum, &vmax](const auto d, const auto value) HWY_ATTR {
|
[&vmax](const auto d, const auto value) HWY_ATTR {
|
||||||
const auto out = hn::Exp(d, hn::Sub(value, vmax));
|
return hn::Exp(d, hn::Sub(value, vmax));
|
||||||
sum = hn::Add(sum, out);
|
|
||||||
return out;
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
auto sum = hn::Zero(d);
|
||||||
|
Foreach(d, x, mask_pos, sum,
|
||||||
|
[&sum](const auto d, const auto value)
|
||||||
|
HWY_ATTR { sum = hn::Add(sum, value); });
|
||||||
|
|
||||||
// Normalize to probability distribution
|
// Normalize to probability distribution
|
||||||
const float mul = 1.0f / hn::ReduceSum(d, sum);
|
const float mul = 1.0f / hn::ReduceSum(d, sum);
|
||||||
MulByConst(mul, x, size, mask_pos);
|
MulByConst(mul, x, size, mask_pos);
|
||||||
|
|
|
||||||
|
|
@ -107,42 +107,16 @@ HWY_NOINLINE void SourceSoftmax(float* HWY_RESTRICT x, size_t size,
|
||||||
size_t mask_pos) {
|
size_t mask_pos) {
|
||||||
HWY_DASSERT(size != 0);
|
HWY_DASSERT(size != 0);
|
||||||
HWY_DASSERT(mask_pos <= size);
|
HWY_DASSERT(mask_pos <= size);
|
||||||
|
float sum = 0.0;
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
float maxval = *std::max_element(x, x + mask_pos);
|
||||||
using D = hn::ScalableTag<float>;
|
for (size_t i = 0; i < mask_pos; ++i) {
|
||||||
const D d;
|
x[i] = std::exp(x[i] - maxval);
|
||||||
const size_t N = hn::Lanes(d);
|
sum += x[i];
|
||||||
|
|
||||||
const hn::Vec<D> vmin = hn::Set(d, hwy::LowestValue<float>());
|
|
||||||
hn::Vec<D> vmax = vmin;
|
|
||||||
size_t idx = 0;
|
|
||||||
if (mask_pos >= N) {
|
|
||||||
for (; idx <= mask_pos - N; idx += N) {
|
|
||||||
vmax = hn::Max(vmax, LoadU(d, x + idx));
|
|
||||||
}
|
}
|
||||||
|
float scale = 1.0f / sum;
|
||||||
|
for (size_t i = 0; i < mask_pos; ++i) {
|
||||||
|
x[i] *= scale;
|
||||||
}
|
}
|
||||||
vmax = hn::Max(vmax, LoadNOr(vmin, d, x + idx, mask_pos - idx));
|
|
||||||
vmax = hn::MaxOfLanes(d, vmax); // broadcast
|
|
||||||
|
|
||||||
hn::Vec<D> sum = hn::Zero(d);
|
|
||||||
idx = 0;
|
|
||||||
if (mask_pos >= N) {
|
|
||||||
for (; idx <= mask_pos - N; idx += N) {
|
|
||||||
const hn::Vec<D> out = hn::Exp(d, hn::Sub(hn::LoadU(d, x + idx), vmax));
|
|
||||||
sum = hn::Add(sum, out);
|
|
||||||
hn::StoreU(out, d, x + idx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (mask_pos > idx) {
|
|
||||||
const size_t remaining = mask_pos - idx;
|
|
||||||
const hn::Vec<D> out =
|
|
||||||
hn::Exp(d, hn::Sub(hn::LoadN(d, x + idx, remaining), vmax));
|
|
||||||
sum = hn::Add(sum, out);
|
|
||||||
hn::StoreN(out, d, x + idx, remaining);
|
|
||||||
}
|
|
||||||
|
|
||||||
const float mul = 1.0f / hn::ReduceSum(d, sum);
|
|
||||||
SourceMulByConst(mul, x, size, mask_pos);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t k>
|
template <size_t k>
|
||||||
|
|
@ -311,8 +285,14 @@ struct TestSoftmax {
|
||||||
SourceSoftmax(e, count, count);
|
SourceSoftmax(e, count, count);
|
||||||
Softmax(x, count, count);
|
Softmax(x, count, count);
|
||||||
|
|
||||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
T sum = 0.0f;
|
||||||
__LINE__);
|
for (size_t i = 0; i < count; ++i) {
|
||||||
|
sum += x[i];
|
||||||
|
double rel = std::abs(x[i] - e[i]) / e[i];
|
||||||
|
ASSERT_LT(rel, 1e-6)
|
||||||
|
<< "Mismatch on coordinate " << i << " out of " << count;
|
||||||
|
}
|
||||||
|
ASSERT_NEAR(sum, 1.0, 1e-6);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue