// Copyright 2023 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS HWY_SCALAR #endif #include "hwy/aligned_allocator.h" #include "hwy/base.h" // clang-format off #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE "ops_test.cc" //NOLINT // clang-format on #include "hwy/foreach_target.h" // IWYU pragma: keep // copybara:import_next_line:gemma_cpp #include "hwy/highway.h" #include "hwy/tests/test_util-inl.h" #include "ops.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; template struct ForeachCountAndMisalign { template HWY_NOINLINE void operator()(T /*unused*/, D d) const { hwy::RandomState rng; const size_t N = Lanes(d); const size_t misalignments[3] = {0, N / 4, 3 * N / 5}; for (size_t count = 0; count < 2 * N; ++count) { for (size_t ma : misalignments) { for (size_t mb : misalignments) { Test()(d, count, ma, mb, rng); } } } } }; template T Random(hwy::RandomState& rng) { const int32_t bits = static_cast(Random32(&rng)) & 1023; const double val = (bits - 512) / 64.0; // Clamp negative to zero for unsigned types. return hwy::ConvertScalarTo( HWY_MAX(hwy::ConvertScalarTo(hwy::LowestValue()), val)); } HWY_NOINLINE void SourceAddFrom(const float* HWY_RESTRICT other, float* HWY_RESTRICT x, size_t size) { for (size_t i = 0; i < size; ++i) { x[i] += other[i]; } } HWY_NOINLINE void SourceMulBy(const float* HWY_RESTRICT other, float* HWY_RESTRICT x, size_t size, size_t max_pos) { HWY_DASSERT(max_pos <= size); for (size_t i = 0; i < max_pos; ++i) { x[i] *= other[i]; } } HWY_NOINLINE void SourceMulByConst(float c, float* HWY_RESTRICT x, size_t size, size_t max_pos) { for (size_t i = 0; i < max_pos; ++i) { x[i] *= c; } } HWY_NOINLINE void SourceMulByConstAndAdd(float c, const float* HWY_RESTRICT x, float* HWY_RESTRICT out, size_t size, size_t max_pos) { for (size_t i = 0; i < max_pos; ++i) { out[i] += x[i] * c; } } HWY_NOINLINE void SourceSoftmax(float* HWY_RESTRICT x, size_t size, size_t mask_pos) { HWY_DASSERT(size != 0); HWY_DASSERT(mask_pos <= size); namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; const D d; const size_t N = hn::Lanes(d); const hn::Vec vmin = hn::Set(d, hwy::LowestValue()); hn::Vec vmax = vmin; size_t idx = 0; if (mask_pos >= N) { for (; idx <= mask_pos - N; idx += N) { vmax = hn::Max(vmax, LoadU(d, x + idx)); } } vmax = hn::Max(vmax, LoadNOr(vmin, d, x + idx, mask_pos - idx)); vmax = hn::MaxOfLanes(d, vmax); // broadcast hn::Vec sum = hn::Zero(d); idx = 0; if (mask_pos >= N) { for (; idx <= mask_pos - N; idx += N) { const hn::Vec out = hn::Exp(d, hn::Sub(hn::LoadU(d, x + idx), vmax)); sum = hn::Add(sum, out); hn::StoreU(out, d, x + idx); } } if (mask_pos > idx) { const size_t remaining = mask_pos - idx; const hn::Vec out = hn::Exp(d, hn::Sub(hn::LoadN(d, x + idx, remaining), vmax)); sum = hn::Add(sum, out); hn::StoreN(out, d, x + idx, remaining); } const float mul = 1.0f / hn::ReduceSum(d, sum); SourceMulByConst(mul, x, size, mask_pos); } template HWY_NOINLINE std::discrete_distribution SourceCreateDistribution( std::array& top_k, float temperature) { // re-normalize distribution for (size_t i = 0; i < k; ++i) { top_k[i] = exp(log(top_k[i]) / temperature); } float denominator = 0.0f; for (size_t i = 0; i < k; ++i) { denominator += top_k[i]; } denominator = 1.0f / denominator; MulByConst(denominator, top_k.data(), k); return std::discrete_distribution(std::begin(top_k), std::end(top_k)); } struct TestAddFrom { template void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, hwy::RandomState& rng) { using T = hn::TFromD; hwy::AlignedFreeUniquePtr px = hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); hwy::AlignedFreeUniquePtr pe = hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); hwy::AlignedFreeUniquePtr po = hwy::AllocateAligned(HWY_MAX(1, misalign_b + count)); HWY_ASSERT(px && pe && po); T* x = px.get() + misalign_a; T* e = pe.get() + misalign_a; T* o = po.get() + misalign_b; for (size_t i = 0; i < count; ++i) { x[i] = Random(rng); e[i] = x[i]; o[i] = Random(rng); } SourceAddFrom(o, e, count); AddFrom(o, x, count); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__); } }; struct TestMulBy { template void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, hwy::RandomState& rng) { using T = hn::TFromD; hwy::AlignedFreeUniquePtr px = hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); hwy::AlignedFreeUniquePtr pe = hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); hwy::AlignedFreeUniquePtr po = hwy::AllocateAligned(HWY_MAX(1, misalign_b + count)); HWY_ASSERT(px && pe && po); T* x = px.get() + misalign_a; T* e = pe.get() + misalign_a; T* o = po.get() + misalign_b; for (size_t i = 0; i < count; ++i) { x[i] = Random(rng); e[i] = x[i]; o[i] = Random(rng); } SourceMulBy(o, e, count, count); MulBy(o, x, count, count); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__); } }; struct TestMulByConstAndAdd { template void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, hwy::RandomState& rng) { using T = hn::TFromD; hwy::AlignedFreeUniquePtr px = hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); hwy::AlignedFreeUniquePtr pe = hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); hwy::AlignedFreeUniquePtr po = hwy::AllocateAligned(HWY_MAX(1, misalign_b + count)); HWY_ASSERT(px && pe && po); T* x = px.get() + misalign_a; T* e = pe.get() + misalign_a; T* o = po.get() + misalign_b; for (size_t i = 0; i < count; ++i) { x[i] = Random(rng); e[i] = x[i]; o[i] = Random(rng); } T constant = Random(rng); SourceMulByConstAndAdd(constant, o, e, count, count); MulByConstAndAdd(constant, o, x, count, count); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__); } }; struct TestMulByConst { template void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, hwy::RandomState& rng) { using T = hn::TFromD; hwy::AlignedFreeUniquePtr px = hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); hwy::AlignedFreeUniquePtr pe = hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); HWY_ASSERT(px && pe); T* x = px.get() + misalign_a; T* e = pe.get() + misalign_a; for (size_t i = 0; i < count; ++i) { x[i] = Random(rng); e[i] = x[i]; } T constant = Random(rng); SourceMulByConst(constant, e, count, count); MulByConst(constant, x, count, count); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__); } }; struct TestSoftmax { template void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, hwy::RandomState& rng) { using T = hn::TFromD; hwy::AlignedFreeUniquePtr px = hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); hwy::AlignedFreeUniquePtr pe = hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); HWY_ASSERT(px && pe); T* x = px.get() + misalign_a; T* e = pe.get() + misalign_a; for (size_t i = 0; i < count; ++i) { x[i] = Random(rng); e[i] = x[i]; } SourceSoftmax(e, count, count); Softmax(x, count, count); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__); } }; template struct TestCreateDistribution { void operator()(hwy::RandomState& rng) { std::array x; std::array e; for (size_t i = 0; i < k; ++i) { x[i] = Random(rng); e[i] = x[i]; } const float constant = Random(rng); auto expected = SourceCreateDistribution(e, constant); auto output = create_distribution(x, constant); AssertEqual(expected, output, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__); } }; void TestAllAddFrom() { hn::ForPartialVectors>()(float()); } void TestAllMulBy() { hn::ForPartialVectors>()(float()); } void TestAllMulByConst() { hn::ForPartialVectors>()(float()); } void TestAllMulByConstAndAdd() { hn::ForPartialVectors>()( float()); } void TestAllSoftmax() { hn::ForPartialVectors>()(float()); } void TestAllCreateDistribution() { TestCreateDistribution<2048>(); TestCreateDistribution<5000>(); } // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp HWY_AFTER_NAMESPACE(); #if HWY_ONCE namespace gcpp { HWY_BEFORE_TEST(OpsTest); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllAddFrom); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulBy); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); #ifdef HWY_AFTER_TEST HWY_AFTER_TEST(); #endif } // namespace gcpp #endif