mirror of https://github.com/google/gemma.cpp.git
Merge branch 'dev' into rmsnorm
This commit is contained in:
commit
6712f07ee7
|
|
@ -1,4 +1,5 @@
|
||||||
FormatStyle: file
|
FormatStyle: file
|
||||||
|
WarningsAsErrors: "*"
|
||||||
Checks: "-*,\
|
Checks: "-*,\
|
||||||
abseil-*,\
|
abseil-*,\
|
||||||
-abseil-string-find-startswith,\
|
-abseil-string-find-startswith,\
|
||||||
|
|
@ -204,3 +205,6 @@ Checks: "-*,\
|
||||||
-readability-uppercase-literal-suffix,\
|
-readability-uppercase-literal-suffix,\
|
||||||
-readability-use-anyofallof
|
-readability-use-anyofallof
|
||||||
"
|
"
|
||||||
|
CheckOptions:
|
||||||
|
- { key: readability-identifier-naming.ConstexprVariableCase, value: CamelCase }
|
||||||
|
- { key: readability-identifier-naming.ConstexprVariablePrefix, value: k }
|
||||||
|
|
|
||||||
|
|
@ -341,7 +341,7 @@ BlobError BlobReader::Open(const char* filename) {
|
||||||
#endif
|
#endif
|
||||||
if (fd_ < 0) return __LINE__;
|
if (fd_ < 0) return __LINE__;
|
||||||
|
|
||||||
#if HWY_OS_LINUX
|
#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
|
||||||
// Doubles the readahead window, which seems slightly faster when cached.
|
// Doubles the readahead window, which seems slightly faster when cached.
|
||||||
(void)posix_fadvise(fd_, 0, 0, POSIX_FADV_SEQUENTIAL);
|
(void)posix_fadvise(fd_, 0, 0, POSIX_FADV_SEQUENTIAL);
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
30
ops.h
30
ops.h
|
|
@ -372,12 +372,30 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
|
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
|
||||||
float* HWY_RESTRICT out, size_t size) {
|
float* HWY_RESTRICT out, size_t size) {
|
||||||
constexpr float eps = 1e-6f;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
float ss = SquaredL2(x, size);
|
|
||||||
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
|
constexpr float kEps = 1e-6f;
|
||||||
for (size_t j = 0; j < size; j++) {
|
constexpr size_t kUnrollSize = 2;
|
||||||
// Note 1.0f centering here
|
|
||||||
out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]);
|
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
||||||
|
const hn::Repartition<float, decltype(dbf)> df32;
|
||||||
|
const size_t N32 = hn::Lanes(df32);
|
||||||
|
|
||||||
|
const float ss = SquaredL2(x, size);
|
||||||
|
const auto vss =
|
||||||
|
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
|
||||||
|
|
||||||
|
HWY_DASSERT(size % (kUnrollSize * MaxLanes(df32)) == 0);
|
||||||
|
for (size_t i = 0; i < size; i += kUnrollSize * N32) {
|
||||||
|
const hn::Vec<decltype(dbf)> w16 = hn::LoadU(dbf, weight + i);
|
||||||
|
const auto w0 = hn::PromoteLowerTo(df32, w16);
|
||||||
|
const auto w1 = hn::PromoteUpperTo(df32, w16);
|
||||||
|
const auto m0 = hn::Mul(vss, hn::LoadU(df32, x + i));
|
||||||
|
const auto m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32));
|
||||||
|
|
||||||
|
// (1+weight) * m = m + weight*m = one FMA.
|
||||||
|
hn::StoreU(hn::MulAdd(m0, w0, m0), df32, out + i);
|
||||||
|
hn::StoreU(hn::MulAdd(m1, w1, m1), df32, out + i + N32);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue