mirror of https://github.com/google/gemma.cpp.git
Fix build for RPi, missing hn::. Refs #112, thanks long568
PiperOrigin-RevId: 617704418
This commit is contained in:
parent
ba86c8d590
commit
a135bc1e47
15
ops.h
15
ops.h
|
|
@ -341,20 +341,21 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a,
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2(
|
||||||
const float* HWY_RESTRICT a, size_t size) {
|
const float* HWY_RESTRICT a, size_t size) {
|
||||||
const hn::ScalableTag<float> d;
|
const hn::ScalableTag<float> d;
|
||||||
|
using V = hn::Vec<decltype(d)>;
|
||||||
const size_t N = hn::Lanes(d);
|
const size_t N = hn::Lanes(d);
|
||||||
HWY_DASSERT(size >= 2 * N);
|
HWY_DASSERT(size >= 2 * N);
|
||||||
HWY_DASSERT(size % (2 * N) == 0);
|
HWY_DASSERT(size % (2 * N) == 0);
|
||||||
|
|
||||||
auto sum0 = hn::Zero(d);
|
V sum0 = hn::Zero(d);
|
||||||
auto sum1 = hn::Zero(d);
|
V sum1 = hn::Zero(d);
|
||||||
for (size_t i = 0; i <= size - 2 * N; i += 2 * N) {
|
for (size_t i = 0; i <= size - 2 * N; i += 2 * N) {
|
||||||
const auto a0 = LoadU(d, a + i);
|
const V a0 = hn::LoadU(d, a + i);
|
||||||
sum0 = MulAdd(a0, a0, sum0);
|
sum0 = hn::MulAdd(a0, a0, sum0);
|
||||||
const auto a1 = LoadU(d, a + i + N);
|
const V a1 = hn::LoadU(d, a + i + N);
|
||||||
sum1 = MulAdd(a1, a1, sum1);
|
sum1 = hn::MulAdd(a1, a1, sum1);
|
||||||
}
|
}
|
||||||
|
|
||||||
return ReduceSum(d, Add(sum0, sum1));
|
return hn::ReduceSum(d, hn::Add(sum0, sum1));
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue