mirror of https://github.com/google/gemma.cpp.git
parent
af715d2436
commit
4a0d23f47e
29
ops.h
29
ops.h
|
|
@ -198,16 +198,16 @@ HWY_INLINE void MatVec(const CompressedArray<MatT, kCapacity>& mat,
|
|||
|
||||
template <class D, HWY_IF_F32_D(D)>
|
||||
static HWY_INLINE hn::Vec<D> Gelu(D d, hn::Vec<D> v) {
|
||||
const hn::Vec<D> kMul = Set(d, 0.044715f);
|
||||
const hn::Vec<D> kMul = hn::Set(d, 0.044715f);
|
||||
const hn::Vec<D> kSqrt2OverPi = hn::Set(d, 0.797884560804236f);
|
||||
const hn::Vec<D> kHalf = Set(d, 0.5f);
|
||||
const hn::Vec<D> kHalf = hn::Set(d, 0.5f);
|
||||
|
||||
// tanh approximation matches training.
|
||||
const hn::Vec<D> v3 = hn::Mul(hn::Mul(v, v), v);
|
||||
const hn::Vec<D> arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v));
|
||||
// 0.5 * (1 + tan) = MulAdd(0.5, tan, 0.5).
|
||||
const hn::Vec<D> cdf = hn::MulAdd(kHalf, hn::Tanh(d, arg), kHalf);
|
||||
return Mul(v, cdf);
|
||||
return hn::Mul(v, cdf);
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x,
|
||||
|
|
@ -230,21 +230,22 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void GeluMulToBF16(
|
|||
size_t i = 0;
|
||||
if (size >= 2 * NF) {
|
||||
for (; i < size - 2 * NF; i += 2 * NF) {
|
||||
const VF mul0 = LoadU(df, mul + i);
|
||||
const VF mul1 = LoadU(df, mul + i + NF);
|
||||
const VF g0 = Mul(mul0, Gelu(df, LoadU(df, gelu_in + i)));
|
||||
const VF g1 = Mul(mul1, Gelu(df, LoadU(df, gelu_in + i + NF)));
|
||||
const VF mul0 = hn::LoadU(df, mul + i);
|
||||
const VF mul1 = hn::LoadU(df, mul + i + NF);
|
||||
const VF g0 = hn::Mul(mul0, Gelu(df, hn::LoadU(df, gelu_in + i)));
|
||||
const VF g1 = hn::Mul(mul1, Gelu(df, hn::LoadU(df, gelu_in + i + NF)));
|
||||
const hn::Vec<decltype(dbf)> bf = hn::OrderedDemote2To(dbf, g0, g1);
|
||||
StoreU(bf, dbf, out + i);
|
||||
hn::StoreU(bf, dbf, out + i);
|
||||
}
|
||||
}
|
||||
if (i != size) {
|
||||
const size_t remaining = size - i;
|
||||
const VF mul0 = LoadN(df, mul + i, remaining);
|
||||
const VF g0 = Mul(mul0, Gelu(df, LoadN(df, gelu_in + i, remaining)));
|
||||
const VF mul0 = hn::LoadN(df, mul + i, remaining);
|
||||
const VF g0 =
|
||||
hn::Mul(mul0, Gelu(df, hn::LoadN(df, gelu_in + i, remaining)));
|
||||
const hn::Half<decltype(dbf)> dbfh;
|
||||
const hn::Vec<decltype(dbfh)> bfh = hn::DemoteTo(dbfh, g0);
|
||||
StoreN(bfh, dbfh, out + i, remaining);
|
||||
hn::StoreN(bfh, dbfh, out + i, remaining);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -381,7 +382,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
|||
|
||||
constexpr float eps = 1e-6f;
|
||||
const float ss = SquaredL2(inout, size);
|
||||
const VF vss = Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
|
||||
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
|
||||
|
||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
||||
|
|
@ -409,7 +410,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
|||
|
||||
constexpr float eps = 1e-6f;
|
||||
const float ss = SquaredL2(x, size);
|
||||
const VF vss = Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
|
||||
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
|
||||
|
||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
||||
|
|
@ -436,7 +437,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
|||
|
||||
constexpr float eps = 1e-6f;
|
||||
const float ss = SquaredL2(x, size);
|
||||
const VF vss = Set(df32, 1.0f / sqrtf(ss / size + eps));
|
||||
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / size + eps));
|
||||
|
||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue