diff --git a/ops.h b/ops.h index 5539892..db2ae4f 100644 --- a/ops.h +++ b/ops.h @@ -198,16 +198,16 @@ HWY_INLINE void MatVec(const CompressedArray& mat, template static HWY_INLINE hn::Vec Gelu(D d, hn::Vec v) { - const hn::Vec kMul = Set(d, 0.044715f); + const hn::Vec kMul = hn::Set(d, 0.044715f); const hn::Vec kSqrt2OverPi = hn::Set(d, 0.797884560804236f); - const hn::Vec kHalf = Set(d, 0.5f); + const hn::Vec kHalf = hn::Set(d, 0.5f); // tanh approximation matches training. const hn::Vec v3 = hn::Mul(hn::Mul(v, v), v); const hn::Vec arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v)); // 0.5 * (1 + tan) = MulAdd(0.5, tan, 0.5). const hn::Vec cdf = hn::MulAdd(kHalf, hn::Tanh(d, arg), kHalf); - return Mul(v, cdf); + return hn::Mul(v, cdf); } static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x, @@ -230,21 +230,22 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void GeluMulToBF16( size_t i = 0; if (size >= 2 * NF) { for (; i < size - 2 * NF; i += 2 * NF) { - const VF mul0 = LoadU(df, mul + i); - const VF mul1 = LoadU(df, mul + i + NF); - const VF g0 = Mul(mul0, Gelu(df, LoadU(df, gelu_in + i))); - const VF g1 = Mul(mul1, Gelu(df, LoadU(df, gelu_in + i + NF))); + const VF mul0 = hn::LoadU(df, mul + i); + const VF mul1 = hn::LoadU(df, mul + i + NF); + const VF g0 = hn::Mul(mul0, Gelu(df, hn::LoadU(df, gelu_in + i))); + const VF g1 = hn::Mul(mul1, Gelu(df, hn::LoadU(df, gelu_in + i + NF))); const hn::Vec bf = hn::OrderedDemote2To(dbf, g0, g1); - StoreU(bf, dbf, out + i); + hn::StoreU(bf, dbf, out + i); } } if (i != size) { const size_t remaining = size - i; - const VF mul0 = LoadN(df, mul + i, remaining); - const VF g0 = Mul(mul0, Gelu(df, LoadN(df, gelu_in + i, remaining))); + const VF mul0 = hn::LoadN(df, mul + i, remaining); + const VF g0 = + hn::Mul(mul0, Gelu(df, hn::LoadN(df, gelu_in + i, remaining))); const hn::Half dbfh; const hn::Vec bfh = hn::DemoteTo(dbfh, g0); - StoreN(bfh, dbfh, out + i, remaining); + hn::StoreN(bfh, dbfh, out + i, remaining); } } @@ -381,7 +382,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( constexpr float eps = 1e-6f; const float ss = SquaredL2(inout, size); - const VF vss = Set(df32, 1.0f / sqrtf(ss / static_cast(size) + eps)); + const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast(size) + eps)); HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); for (size_t i = 0; i < size; i += 2 * N32) { @@ -409,7 +410,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( constexpr float eps = 1e-6f; const float ss = SquaredL2(x, size); - const VF vss = Set(df32, 1.0f / sqrtf(ss / static_cast(size) + eps)); + const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast(size) + eps)); HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); for (size_t i = 0; i < size; i += 2 * N32) { @@ -436,7 +437,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( constexpr float eps = 1e-6f; const float ss = SquaredL2(x, size); - const VF vss = Set(df32, 1.0f / sqrtf(ss / size + eps)); + const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / size + eps)); HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); for (size_t i = 0; i < size; i += 2 * N32) {