From 4a0d23f47ee36370e8db429648f58af6fdb9f953 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 23 Feb 2024 21:01:26 -0800 Subject: [PATCH] Add missing hn::, fixes #25 PiperOrigin-RevId: 609914890 --- ops.h | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/ops.h b/ops.h index 5539892..db2ae4f 100644 --- a/ops.h +++ b/ops.h @@ -198,16 +198,16 @@ HWY_INLINE void MatVec(const CompressedArray& mat, template static HWY_INLINE hn::Vec Gelu(D d, hn::Vec v) { - const hn::Vec kMul = Set(d, 0.044715f); + const hn::Vec kMul = hn::Set(d, 0.044715f); const hn::Vec kSqrt2OverPi = hn::Set(d, 0.797884560804236f); - const hn::Vec kHalf = Set(d, 0.5f); + const hn::Vec kHalf = hn::Set(d, 0.5f); // tanh approximation matches training. const hn::Vec v3 = hn::Mul(hn::Mul(v, v), v); const hn::Vec arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v)); // 0.5 * (1 + tan) = MulAdd(0.5, tan, 0.5). const hn::Vec cdf = hn::MulAdd(kHalf, hn::Tanh(d, arg), kHalf); - return Mul(v, cdf); + return hn::Mul(v, cdf); } static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x, @@ -230,21 +230,22 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void GeluMulToBF16( size_t i = 0; if (size >= 2 * NF) { for (; i < size - 2 * NF; i += 2 * NF) { - const VF mul0 = LoadU(df, mul + i); - const VF mul1 = LoadU(df, mul + i + NF); - const VF g0 = Mul(mul0, Gelu(df, LoadU(df, gelu_in + i))); - const VF g1 = Mul(mul1, Gelu(df, LoadU(df, gelu_in + i + NF))); + const VF mul0 = hn::LoadU(df, mul + i); + const VF mul1 = hn::LoadU(df, mul + i + NF); + const VF g0 = hn::Mul(mul0, Gelu(df, hn::LoadU(df, gelu_in + i))); + const VF g1 = hn::Mul(mul1, Gelu(df, hn::LoadU(df, gelu_in + i + NF))); const hn::Vec bf = hn::OrderedDemote2To(dbf, g0, g1); - StoreU(bf, dbf, out + i); + hn::StoreU(bf, dbf, out + i); } } if (i != size) { const size_t remaining = size - i; - const VF mul0 = LoadN(df, mul + i, remaining); - const VF g0 = Mul(mul0, Gelu(df, LoadN(df, gelu_in + i, remaining))); + const VF mul0 = hn::LoadN(df, mul + i, remaining); + const VF g0 = + hn::Mul(mul0, Gelu(df, hn::LoadN(df, gelu_in + i, remaining))); const hn::Half dbfh; const hn::Vec bfh = hn::DemoteTo(dbfh, g0); - StoreN(bfh, dbfh, out + i, remaining); + hn::StoreN(bfh, dbfh, out + i, remaining); } } @@ -381,7 +382,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( constexpr float eps = 1e-6f; const float ss = SquaredL2(inout, size); - const VF vss = Set(df32, 1.0f / sqrtf(ss / static_cast(size) + eps)); + const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast(size) + eps)); HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); for (size_t i = 0; i < size; i += 2 * N32) { @@ -409,7 +410,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( constexpr float eps = 1e-6f; const float ss = SquaredL2(x, size); - const VF vss = Set(df32, 1.0f / sqrtf(ss / static_cast(size) + eps)); + const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast(size) + eps)); HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); for (size_t i = 0; i < size; i += 2 * N32) { @@ -436,7 +437,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( constexpr float eps = 1e-6f; const float ss = SquaredL2(x, size); - const VF vss = Set(df32, 1.0f / sqrtf(ss / size + eps)); + const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / size + eps)); HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); for (size_t i = 0; i < size; i += 2 * N32) {