mirror of https://github.com/google/gemma.cpp.git
Fixes #37, lambda issue: missing HWY_ATTR, and cannot capture SVE in/out vectors.
PiperOrigin-RevId: 610260610
This commit is contained in:
parent
1243be71c4
commit
6a3085828f
55
ops.h
55
ops.h
|
|
@ -214,7 +214,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x,
|
|||
size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
hn::Transform(D(), x, size, [](D d, hn::Vec<D> v) { return Gelu(d, v); });
|
||||
hn::Transform(D(), x, size,
|
||||
[](D d, hn::Vec<D> v) HWY_ATTR { return Gelu(d, v); });
|
||||
}
|
||||
|
||||
// out[i] = BF(mul[i] * Gelu(gelu_in[i]))
|
||||
|
|
@ -567,22 +568,41 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, size_t size,
|
|||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
const D d;
|
||||
using V = hn::Vec<D>;
|
||||
const size_t N = hn::Lanes(d);
|
||||
|
||||
// Find max so we can subtract it below.
|
||||
const V vmin = hn::Set(d, hwy::LowestValue<float>());
|
||||
V max = vmin;
|
||||
hn::Foreach(d, x, mask_pos, vmin,
|
||||
[&max](D d, V v) { max = hn::Max(max, v); });
|
||||
max = hn::MaxOfLanes(d, max); // broadcast
|
||||
// Find max so we can subtract it below. Avoid hn::Foreach because SVE vectors
|
||||
// cannot be lambda-captured.
|
||||
// TODO(janwas): could be replaced with an hn::Accumulate algo.
|
||||
const hn::Vec<D> vmin = hn::Set(d, hwy::LowestValue<float>());
|
||||
hn::Vec<D> vmax = vmin;
|
||||
size_t idx = 0;
|
||||
if (mask_pos >= N) {
|
||||
for (; idx <= mask_pos - N; idx += N) {
|
||||
vmax = hn::Max(vmax, LoadU(d, x + idx));
|
||||
}
|
||||
}
|
||||
vmax = hn::Max(vmax, LoadNOr(vmin, d, x + idx, mask_pos - idx));
|
||||
vmax = hn::MaxOfLanes(d, vmax); // broadcast
|
||||
|
||||
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
||||
V sum = hn::Zero(d);
|
||||
hn::Transform(d, x, mask_pos, [&sum, max](D d, V v) {
|
||||
const V out = hn::Exp(d, hn::Sub(v, max));
|
||||
// Also avoid hn::Transform because the additional `sum` output vector cannot
|
||||
// be captured by a lambda.
|
||||
hn::Vec<D> sum = hn::Zero(d);
|
||||
idx = 0;
|
||||
if (mask_pos >= N) {
|
||||
for (; idx <= mask_pos - N; idx += N) {
|
||||
const hn::Vec<D> out = hn::Exp(d, hn::Sub(hn::LoadU(d, x + idx), vmax));
|
||||
sum = hn::Add(sum, out);
|
||||
hn::StoreU(out, d, x + idx);
|
||||
}
|
||||
}
|
||||
if (mask_pos > idx) {
|
||||
const size_t remaining = mask_pos - idx;
|
||||
const hn::Vec<D> out =
|
||||
hn::Exp(d, hn::Sub(hn::LoadN(d, x + idx, remaining), vmax));
|
||||
sum = hn::Add(sum, out);
|
||||
return out;
|
||||
});
|
||||
hn::StoreN(out, d, x + idx, remaining);
|
||||
}
|
||||
|
||||
// Normalize to probability distribution
|
||||
const float mul = 1.0f / hn::ReduceSum(d, sum);
|
||||
|
|
@ -601,13 +621,12 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
|||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
const D d;
|
||||
using V = hn::Vec<D>;
|
||||
|
||||
const V inv_cap = hn::Set(d, 1.0f / cap);
|
||||
const V vcap = hn::Set(d, cap);
|
||||
const float inv_cap = 1.0f / cap;
|
||||
|
||||
hn::Transform(d, x, size, [vcap, inv_cap](D d, hn::Vec<D> v) {
|
||||
return hn::Mul(vcap, hn::Tanh(d, hn::Mul(inv_cap, v)));
|
||||
hn::Transform(d, x, size, [cap, inv_cap](D d, hn::Vec<D> v) HWY_ATTR {
|
||||
return hn::Mul(hn::Set(d, cap),
|
||||
hn::Tanh(d, hn::Mul(v, hn::Set(d, inv_cap))));
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue