mirror of https://github.com/google/gemma.cpp.git
375 lines
11 KiB
C++
375 lines
11 KiB
C++
// Copyright 2023 Google LLC
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#ifndef HWY_DISABLED_TARGETS
|
|
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
|
#endif
|
|
|
|
#include "hwy/aligned_allocator.h"
|
|
#include "hwy/base.h"
|
|
|
|
// clang-format off
|
|
#undef HWY_TARGET_INCLUDE
|
|
#define HWY_TARGET_INCLUDE "ops_test.cc" //NOLINT
|
|
// clang-format on
|
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
|
// copybara:import_next_line:gemma_cpp
|
|
#include "hwy/highway.h"
|
|
#include "hwy/tests/test_util-inl.h"
|
|
#include "ops.h"
|
|
|
|
HWY_BEFORE_NAMESPACE();
|
|
namespace gcpp {
|
|
namespace HWY_NAMESPACE {
|
|
|
|
namespace hn = hwy::HWY_NAMESPACE;
|
|
|
|
template <class Test>
|
|
struct ForeachCountAndMisalign {
|
|
template <typename T, class D>
|
|
HWY_NOINLINE void operator()(T /*unused*/, D d) const {
|
|
hwy::RandomState rng;
|
|
const size_t N = Lanes(d);
|
|
const size_t misalignments[3] = {0, N / 4, 3 * N / 5};
|
|
|
|
for (size_t count = 0; count < 2 * N; ++count) {
|
|
for (size_t ma : misalignments) {
|
|
for (size_t mb : misalignments) {
|
|
Test()(d, count, ma, mb, rng);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
T Random(hwy::RandomState& rng) {
|
|
const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023;
|
|
const double val = (bits - 512) / 64.0;
|
|
// Clamp negative to zero for unsigned types.
|
|
return hwy::ConvertScalarTo<T>(
|
|
HWY_MAX(hwy::ConvertScalarTo<double>(hwy::LowestValue<T>()), val));
|
|
}
|
|
|
|
HWY_NOINLINE void SourceAddFrom(const float* HWY_RESTRICT other,
|
|
float* HWY_RESTRICT x, size_t size) {
|
|
for (size_t i = 0; i < size; ++i) {
|
|
x[i] += other[i];
|
|
}
|
|
}
|
|
|
|
HWY_NOINLINE void SourceMulBy(const float* HWY_RESTRICT other,
|
|
float* HWY_RESTRICT x, size_t size,
|
|
size_t max_pos) {
|
|
HWY_DASSERT(max_pos <= size);
|
|
for (size_t i = 0; i < max_pos; ++i) {
|
|
x[i] *= other[i];
|
|
}
|
|
}
|
|
|
|
HWY_NOINLINE void SourceMulByConst(float c, float* HWY_RESTRICT x, size_t size,
|
|
size_t max_pos) {
|
|
for (size_t i = 0; i < max_pos; ++i) {
|
|
x[i] *= c;
|
|
}
|
|
}
|
|
|
|
HWY_NOINLINE void SourceMulByConstAndAdd(float c, const float* HWY_RESTRICT x,
|
|
float* HWY_RESTRICT out, size_t size,
|
|
size_t max_pos) {
|
|
for (size_t i = 0; i < max_pos; ++i) {
|
|
out[i] += x[i] * c;
|
|
}
|
|
}
|
|
|
|
HWY_NOINLINE void SourceSoftmax(float* HWY_RESTRICT x, size_t size,
|
|
size_t mask_pos) {
|
|
HWY_DASSERT(size != 0);
|
|
HWY_DASSERT(mask_pos <= size);
|
|
|
|
namespace hn = hwy::HWY_NAMESPACE;
|
|
using D = hn::ScalableTag<float>;
|
|
const D d;
|
|
const size_t N = hn::Lanes(d);
|
|
|
|
const hn::Vec<D> vmin = hn::Set(d, hwy::LowestValue<float>());
|
|
hn::Vec<D> vmax = vmin;
|
|
size_t idx = 0;
|
|
if (mask_pos >= N) {
|
|
for (; idx <= mask_pos - N; idx += N) {
|
|
vmax = hn::Max(vmax, LoadU(d, x + idx));
|
|
}
|
|
}
|
|
vmax = hn::Max(vmax, LoadNOr(vmin, d, x + idx, mask_pos - idx));
|
|
vmax = hn::MaxOfLanes(d, vmax); // broadcast
|
|
|
|
hn::Vec<D> sum = hn::Zero(d);
|
|
idx = 0;
|
|
if (mask_pos >= N) {
|
|
for (; idx <= mask_pos - N; idx += N) {
|
|
const hn::Vec<D> out = hn::Exp(d, hn::Sub(hn::LoadU(d, x + idx), vmax));
|
|
sum = hn::Add(sum, out);
|
|
hn::StoreU(out, d, x + idx);
|
|
}
|
|
}
|
|
if (mask_pos > idx) {
|
|
const size_t remaining = mask_pos - idx;
|
|
const hn::Vec<D> out =
|
|
hn::Exp(d, hn::Sub(hn::LoadN(d, x + idx, remaining), vmax));
|
|
sum = hn::Add(sum, out);
|
|
hn::StoreN(out, d, x + idx, remaining);
|
|
}
|
|
|
|
const float mul = 1.0f / hn::ReduceSum(d, sum);
|
|
SourceMulByConst(mul, x, size, mask_pos);
|
|
}
|
|
|
|
template <size_t k>
|
|
HWY_NOINLINE std::discrete_distribution<int> SourceCreateDistribution(
|
|
std::array<float, k>& top_k, float temperature) {
|
|
// re-normalize distribution
|
|
for (size_t i = 0; i < k; ++i) {
|
|
top_k[i] = exp(log(top_k[i]) / temperature);
|
|
}
|
|
float denominator = 0.0f;
|
|
for (size_t i = 0; i < k; ++i) {
|
|
denominator += top_k[i];
|
|
}
|
|
denominator = 1.0f / denominator;
|
|
MulByConst(denominator, top_k.data(), k);
|
|
return std::discrete_distribution<int>(std::begin(top_k), std::end(top_k));
|
|
}
|
|
|
|
struct TestAddFrom {
|
|
template <class D>
|
|
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
|
hwy::RandomState& rng) {
|
|
using T = hn::TFromD<D>;
|
|
|
|
hwy::AlignedFreeUniquePtr<T[]> px =
|
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
|
hwy::AlignedFreeUniquePtr<T[]> pe =
|
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
|
hwy::AlignedFreeUniquePtr<T[]> po =
|
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_b + count));
|
|
HWY_ASSERT(px && pe && po);
|
|
|
|
T* x = px.get() + misalign_a;
|
|
T* e = pe.get() + misalign_a;
|
|
T* o = po.get() + misalign_b;
|
|
|
|
for (size_t i = 0; i < count; ++i) {
|
|
x[i] = Random<T>(rng);
|
|
e[i] = x[i];
|
|
o[i] = Random<T>(rng);
|
|
}
|
|
|
|
SourceAddFrom(o, e, count);
|
|
AddFrom(o, x, count);
|
|
|
|
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
|
__LINE__);
|
|
}
|
|
};
|
|
|
|
struct TestMulBy {
|
|
template <class D>
|
|
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
|
hwy::RandomState& rng) {
|
|
using T = hn::TFromD<D>;
|
|
|
|
hwy::AlignedFreeUniquePtr<T[]> px =
|
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
|
hwy::AlignedFreeUniquePtr<T[]> pe =
|
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
|
hwy::AlignedFreeUniquePtr<T[]> po =
|
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_b + count));
|
|
HWY_ASSERT(px && pe && po);
|
|
|
|
T* x = px.get() + misalign_a;
|
|
T* e = pe.get() + misalign_a;
|
|
T* o = po.get() + misalign_b;
|
|
|
|
for (size_t i = 0; i < count; ++i) {
|
|
x[i] = Random<T>(rng);
|
|
e[i] = x[i];
|
|
o[i] = Random<T>(rng);
|
|
}
|
|
|
|
SourceMulBy(o, e, count, count);
|
|
MulBy(o, x, count, count);
|
|
|
|
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
|
__LINE__);
|
|
}
|
|
};
|
|
|
|
struct TestMulByConstAndAdd {
|
|
template <class D>
|
|
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
|
hwy::RandomState& rng) {
|
|
using T = hn::TFromD<D>;
|
|
|
|
hwy::AlignedFreeUniquePtr<T[]> px =
|
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
|
hwy::AlignedFreeUniquePtr<T[]> pe =
|
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
|
hwy::AlignedFreeUniquePtr<T[]> po =
|
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_b + count));
|
|
HWY_ASSERT(px && pe && po);
|
|
|
|
T* x = px.get() + misalign_a;
|
|
T* e = pe.get() + misalign_a;
|
|
T* o = po.get() + misalign_b;
|
|
|
|
for (size_t i = 0; i < count; ++i) {
|
|
x[i] = Random<T>(rng);
|
|
e[i] = x[i];
|
|
o[i] = Random<T>(rng);
|
|
}
|
|
T constant = Random<T>(rng);
|
|
|
|
SourceMulByConstAndAdd(constant, o, e, count, count);
|
|
MulByConstAndAdd(constant, o, x, count, count);
|
|
|
|
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
|
__LINE__);
|
|
}
|
|
};
|
|
|
|
struct TestMulByConst {
|
|
template <class D>
|
|
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
|
hwy::RandomState& rng) {
|
|
using T = hn::TFromD<D>;
|
|
|
|
hwy::AlignedFreeUniquePtr<T[]> px =
|
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
|
hwy::AlignedFreeUniquePtr<T[]> pe =
|
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
|
HWY_ASSERT(px && pe);
|
|
|
|
T* x = px.get() + misalign_a;
|
|
T* e = pe.get() + misalign_a;
|
|
|
|
for (size_t i = 0; i < count; ++i) {
|
|
x[i] = Random<T>(rng);
|
|
e[i] = x[i];
|
|
}
|
|
T constant = Random<T>(rng);
|
|
|
|
SourceMulByConst(constant, e, count, count);
|
|
MulByConst(constant, x, count, count);
|
|
|
|
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
|
__LINE__);
|
|
}
|
|
};
|
|
|
|
struct TestSoftmax {
|
|
template <class D>
|
|
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
|
hwy::RandomState& rng) {
|
|
using T = hn::TFromD<D>;
|
|
|
|
hwy::AlignedFreeUniquePtr<T[]> px =
|
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
|
hwy::AlignedFreeUniquePtr<T[]> pe =
|
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
|
HWY_ASSERT(px && pe);
|
|
|
|
T* x = px.get() + misalign_a;
|
|
T* e = pe.get() + misalign_a;
|
|
|
|
for (size_t i = 0; i < count; ++i) {
|
|
x[i] = Random<T>(rng);
|
|
e[i] = x[i];
|
|
}
|
|
|
|
SourceSoftmax(e, count, count);
|
|
Softmax(x, count, count);
|
|
|
|
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
|
__LINE__);
|
|
}
|
|
};
|
|
|
|
template <size_t k>
|
|
struct TestCreateDistribution {
|
|
void operator()(hwy::RandomState& rng) {
|
|
std::array<float, k> x;
|
|
std::array<float, k> e;
|
|
|
|
for (size_t i = 0; i < k; ++i) {
|
|
x[i] = Random<float>(rng);
|
|
e[i] = x[i];
|
|
}
|
|
const float constant = Random<float>(rng);
|
|
auto expected = SourceCreateDistribution(e, constant);
|
|
auto output = create_distribution(x, constant);
|
|
|
|
AssertEqual(expected, output, hwy::TargetName(HWY_TARGET), __FILE__,
|
|
__LINE__);
|
|
}
|
|
};
|
|
|
|
void TestAllAddFrom() {
|
|
hn::ForPartialVectors<ForeachCountAndMisalign<TestAddFrom>>()(float());
|
|
}
|
|
|
|
void TestAllMulBy() {
|
|
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulBy>>()(float());
|
|
}
|
|
|
|
void TestAllMulByConst() {
|
|
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulByConst>>()(float());
|
|
}
|
|
|
|
void TestAllMulByConstAndAdd() {
|
|
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulByConstAndAdd>>()(
|
|
float());
|
|
}
|
|
|
|
void TestAllSoftmax() {
|
|
hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmax>>()(float());
|
|
}
|
|
|
|
void TestAllCreateDistribution() {
|
|
TestCreateDistribution<2048>();
|
|
TestCreateDistribution<5000>();
|
|
}
|
|
|
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
|
} // namespace HWY_NAMESPACE
|
|
} // namespace gcpp
|
|
HWY_AFTER_NAMESPACE();
|
|
|
|
#if HWY_ONCE
|
|
|
|
namespace gcpp {
|
|
HWY_BEFORE_TEST(OpsTest);
|
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllAddFrom);
|
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulBy);
|
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
|
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
|
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
|
#ifdef HWY_AFTER_TEST
|
|
HWY_AFTER_TEST();
|
|
#endif
|
|
} // namespace gcpp
|
|
|
|
#endif
|