Add missing hn::, fixes #25

PiperOrigin-RevId: 609914890
This commit is contained in:
Jan Wassenberg 2024-02-23 21:01:26 -08:00 committed by Dan Zheng
parent af715d2436
commit 4a0d23f47e
1 changed files with 15 additions and 14 deletions

29
ops.h
View File

@ -198,16 +198,16 @@ HWY_INLINE void MatVec(const CompressedArray<MatT, kCapacity>& mat,
template <class D, HWY_IF_F32_D(D)>
static HWY_INLINE hn::Vec<D> Gelu(D d, hn::Vec<D> v) {
const hn::Vec<D> kMul = Set(d, 0.044715f);
const hn::Vec<D> kMul = hn::Set(d, 0.044715f);
const hn::Vec<D> kSqrt2OverPi = hn::Set(d, 0.797884560804236f);
const hn::Vec<D> kHalf = Set(d, 0.5f);
const hn::Vec<D> kHalf = hn::Set(d, 0.5f);
// tanh approximation matches training.
const hn::Vec<D> v3 = hn::Mul(hn::Mul(v, v), v);
const hn::Vec<D> arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v));
// 0.5 * (1 + tan) = MulAdd(0.5, tan, 0.5).
const hn::Vec<D> cdf = hn::MulAdd(kHalf, hn::Tanh(d, arg), kHalf);
return Mul(v, cdf);
return hn::Mul(v, cdf);
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x,
@ -230,21 +230,22 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void GeluMulToBF16(
size_t i = 0;
if (size >= 2 * NF) {
for (; i < size - 2 * NF; i += 2 * NF) {
const VF mul0 = LoadU(df, mul + i);
const VF mul1 = LoadU(df, mul + i + NF);
const VF g0 = Mul(mul0, Gelu(df, LoadU(df, gelu_in + i)));
const VF g1 = Mul(mul1, Gelu(df, LoadU(df, gelu_in + i + NF)));
const VF mul0 = hn::LoadU(df, mul + i);
const VF mul1 = hn::LoadU(df, mul + i + NF);
const VF g0 = hn::Mul(mul0, Gelu(df, hn::LoadU(df, gelu_in + i)));
const VF g1 = hn::Mul(mul1, Gelu(df, hn::LoadU(df, gelu_in + i + NF)));
const hn::Vec<decltype(dbf)> bf = hn::OrderedDemote2To(dbf, g0, g1);
StoreU(bf, dbf, out + i);
hn::StoreU(bf, dbf, out + i);
}
}
if (i != size) {
const size_t remaining = size - i;
const VF mul0 = LoadN(df, mul + i, remaining);
const VF g0 = Mul(mul0, Gelu(df, LoadN(df, gelu_in + i, remaining)));
const VF mul0 = hn::LoadN(df, mul + i, remaining);
const VF g0 =
hn::Mul(mul0, Gelu(df, hn::LoadN(df, gelu_in + i, remaining)));
const hn::Half<decltype(dbf)> dbfh;
const hn::Vec<decltype(dbfh)> bfh = hn::DemoteTo(dbfh, g0);
StoreN(bfh, dbfh, out + i, remaining);
hn::StoreN(bfh, dbfh, out + i, remaining);
}
}
@ -381,7 +382,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
constexpr float eps = 1e-6f;
const float ss = SquaredL2(inout, size);
const VF vss = Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) {
@ -409,7 +410,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
constexpr float eps = 1e-6f;
const float ss = SquaredL2(x, size);
const VF vss = Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) {
@ -436,7 +437,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
constexpr float eps = 1e-6f;
const float ss = SquaredL2(x, size);
const VF vss = Set(df32, 1.0f / sqrtf(ss / size + eps));
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / size + eps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) {