Minor refactor in Softmax

This commit is contained in:
enum-class 2024-03-20 00:20:14 +08:00
parent 858d5b08c2
commit 4400842337
1 changed files with 24 additions and 43 deletions

65
ops.h
View File

@ -619,49 +619,29 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
MulByConstAndAdd(c, x, out, size, size);
}
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, size_t size,
size_t mask_pos) {
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
const size_t mask_pos) {
HWY_DASSERT(size != 0);
HWY_DASSERT(mask_pos <= size);
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
const size_t N = hn::Lanes(d);
// Find max so we can subtract it below. Avoid hn::Foreach because SVE vectors
// cannot be lambda-captured.
// TODO(janwas): could be replaced with an hn::Accumulate algo.
const hn::Vec<D> vmin = hn::Set(d, hwy::LowestValue<float>());
hn::Vec<D> vmax = vmin;
size_t idx = 0;
if (mask_pos >= N) {
for (; idx <= mask_pos - N; idx += N) {
vmax = hn::Max(vmax, LoadU(d, x + idx));
}
}
vmax = hn::Max(vmax, LoadNOr(vmin, d, x + idx, mask_pos - idx));
vmax = hn::MaxOfLanes(d, vmax); // broadcast
const auto vmin = hn::Set(d, hwy::LowestValue<float>());
auto vmax = vmin;
Foreach(d, x, mask_pos, vmin,
[&vmax](const auto d, const auto value)
HWY_ATTR { vmax = hn::Max(vmax, value); });
// Subtract max (avoid precision loss for large exponents) and exponentiate.
// Also avoid hn::Transform because the additional `sum` output vector cannot
// be captured by a lambda.
hn::Vec<D> sum = hn::Zero(d);
idx = 0;
if (mask_pos >= N) {
for (; idx <= mask_pos - N; idx += N) {
const hn::Vec<D> out = hn::Exp(d, hn::Sub(hn::LoadU(d, x + idx), vmax));
auto sum = hn::Zero(d);
hn::Transform(d, x, mask_pos,
[&sum, &vmax](const auto d, const auto value) HWY_ATTR {
const auto out = hn::Exp(d, hn::Sub(value, vmax));
sum = hn::Add(sum, out);
hn::StoreU(out, d, x + idx);
}
}
if (mask_pos > idx) {
const size_t remaining = mask_pos - idx;
const hn::Vec<D> out =
hn::Exp(d, hn::Sub(hn::LoadN(d, x + idx, remaining), vmax));
sum = hn::Add(sum, out);
hn::StoreN(out, d, x + idx, remaining);
}
return out;
});
// Normalize to probability distribution
const float mul = 1.0f / hn::ReduceSum(d, sum);
@ -669,29 +649,30 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, size_t size,
}
static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x,
size_t size) {
const size_t size) {
Softmax(x, size, size);
}
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
size_t size, size_t max_pos) {
const size_t size,
const size_t max_pos) {
HWY_DASSERT(max_pos <= size);
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
const float inv_cap = 1.0f / cap;
const auto vcap = hn::Set(d, cap);
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
hn::Transform(d, x, size, [cap, inv_cap](D d, hn::Vec<D> v) HWY_ATTR {
return hn::Mul(hn::Set(d, cap),
hn::Tanh(d, hn::Mul(v, hn::Set(d, inv_cap))));
hn::Transform(d, x, size, [&vcap, &vinv_cap](D d, hn::Vec<D> v) HWY_ATTR {
return hn::Mul(vcap, hn::Tanh(d, hn::Mul(v, vinv_cap)));
});
}
static HWY_INLINE HWY_MAYBE_UNUSED void LogitsSoftCap(const float cap,
float* HWY_RESTRICT x,
size_t size) {
const size_t size) {
LogitsSoftCap(cap, x, size, size);
}
@ -716,8 +697,8 @@ create_distribution(std::array<float, k>& top_k, float temperature) {
using D = hn::ScalableTag<float>;
const D d;
const auto one = hn::Set(d, 1.0f);
const auto temperature_inv = hn::Div(one, hn::Set(d, temperature));
const auto temperature_inv =
hn::Div(hn::Set(d, 1.0f), hn::Set(d, temperature));
hn::Transform(d, top_k.data(), top_k.size(),
[&temperature_inv](D d, hn::Vec<D> v) HWY_ATTR {