mirror of https://github.com/google/gemma.cpp.git
Minor refactor in Softmax
This commit is contained in:
parent
858d5b08c2
commit
4400842337
65
ops.h
65
ops.h
|
|
@ -619,49 +619,29 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
|||
MulByConstAndAdd(c, x, out, size, size);
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, size_t size,
|
||||
size_t mask_pos) {
|
||||
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||
const size_t mask_pos) {
|
||||
HWY_DASSERT(size != 0);
|
||||
HWY_DASSERT(mask_pos <= size);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
const D d;
|
||||
const size_t N = hn::Lanes(d);
|
||||
|
||||
// Find max so we can subtract it below. Avoid hn::Foreach because SVE vectors
|
||||
// cannot be lambda-captured.
|
||||
// TODO(janwas): could be replaced with an hn::Accumulate algo.
|
||||
const hn::Vec<D> vmin = hn::Set(d, hwy::LowestValue<float>());
|
||||
hn::Vec<D> vmax = vmin;
|
||||
size_t idx = 0;
|
||||
if (mask_pos >= N) {
|
||||
for (; idx <= mask_pos - N; idx += N) {
|
||||
vmax = hn::Max(vmax, LoadU(d, x + idx));
|
||||
}
|
||||
}
|
||||
vmax = hn::Max(vmax, LoadNOr(vmin, d, x + idx, mask_pos - idx));
|
||||
vmax = hn::MaxOfLanes(d, vmax); // broadcast
|
||||
const auto vmin = hn::Set(d, hwy::LowestValue<float>());
|
||||
auto vmax = vmin;
|
||||
Foreach(d, x, mask_pos, vmin,
|
||||
[&vmax](const auto d, const auto value)
|
||||
HWY_ATTR { vmax = hn::Max(vmax, value); });
|
||||
|
||||
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
||||
// Also avoid hn::Transform because the additional `sum` output vector cannot
|
||||
// be captured by a lambda.
|
||||
hn::Vec<D> sum = hn::Zero(d);
|
||||
idx = 0;
|
||||
if (mask_pos >= N) {
|
||||
for (; idx <= mask_pos - N; idx += N) {
|
||||
const hn::Vec<D> out = hn::Exp(d, hn::Sub(hn::LoadU(d, x + idx), vmax));
|
||||
auto sum = hn::Zero(d);
|
||||
hn::Transform(d, x, mask_pos,
|
||||
[&sum, &vmax](const auto d, const auto value) HWY_ATTR {
|
||||
const auto out = hn::Exp(d, hn::Sub(value, vmax));
|
||||
sum = hn::Add(sum, out);
|
||||
hn::StoreU(out, d, x + idx);
|
||||
}
|
||||
}
|
||||
if (mask_pos > idx) {
|
||||
const size_t remaining = mask_pos - idx;
|
||||
const hn::Vec<D> out =
|
||||
hn::Exp(d, hn::Sub(hn::LoadN(d, x + idx, remaining), vmax));
|
||||
sum = hn::Add(sum, out);
|
||||
hn::StoreN(out, d, x + idx, remaining);
|
||||
}
|
||||
return out;
|
||||
});
|
||||
|
||||
// Normalize to probability distribution
|
||||
const float mul = 1.0f / hn::ReduceSum(d, sum);
|
||||
|
|
@ -669,29 +649,30 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, size_t size,
|
|||
}
|
||||
|
||||
static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x,
|
||||
size_t size) {
|
||||
const size_t size) {
|
||||
Softmax(x, size, size);
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
||||
size_t size, size_t max_pos) {
|
||||
const size_t size,
|
||||
const size_t max_pos) {
|
||||
HWY_DASSERT(max_pos <= size);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
const D d;
|
||||
|
||||
const float inv_cap = 1.0f / cap;
|
||||
const auto vcap = hn::Set(d, cap);
|
||||
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
|
||||
|
||||
hn::Transform(d, x, size, [cap, inv_cap](D d, hn::Vec<D> v) HWY_ATTR {
|
||||
return hn::Mul(hn::Set(d, cap),
|
||||
hn::Tanh(d, hn::Mul(v, hn::Set(d, inv_cap))));
|
||||
hn::Transform(d, x, size, [&vcap, &vinv_cap](D d, hn::Vec<D> v) HWY_ATTR {
|
||||
return hn::Mul(vcap, hn::Tanh(d, hn::Mul(v, vinv_cap)));
|
||||
});
|
||||
}
|
||||
|
||||
static HWY_INLINE HWY_MAYBE_UNUSED void LogitsSoftCap(const float cap,
|
||||
float* HWY_RESTRICT x,
|
||||
size_t size) {
|
||||
const size_t size) {
|
||||
LogitsSoftCap(cap, x, size, size);
|
||||
}
|
||||
|
||||
|
|
@ -716,8 +697,8 @@ create_distribution(std::array<float, k>& top_k, float temperature) {
|
|||
using D = hn::ScalableTag<float>;
|
||||
const D d;
|
||||
|
||||
const auto one = hn::Set(d, 1.0f);
|
||||
const auto temperature_inv = hn::Div(one, hn::Set(d, temperature));
|
||||
const auto temperature_inv =
|
||||
hn::Div(hn::Set(d, 1.0f), hn::Set(d, temperature));
|
||||
|
||||
hn::Transform(d, top_k.data(), top_k.size(),
|
||||
[&temperature_inv](D d, hn::Vec<D> v) HWY_ATTR {
|
||||
|
|
|
|||
Loading…
Reference in New Issue