diff --git a/CMakeLists.txt b/CMakeLists.txt index 9efce80..844ddf2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -83,3 +83,35 @@ target_link_libraries(libgemma hwy hwy_contrib sentencepiece) target_include_directories(libgemma PRIVATE ${sentencepiece_SOURCE_DIR}) target_compile_definitions(libgemma PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) target_compile_options(libgemma PRIVATE $<$:-Wno-deprecated-declarations>) + +set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests") +if (GEMMA_ENABLE_TESTS) + +set(GEMMA_TEST_FILES + ops_test.cc +) + +include(FetchContent) +FetchContent_Declare( + googletest + URL https://github.com/google/googletest/archive/refs/tags/v1.14.0.zip +) +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) +FetchContent_MakeAvailable(googletest) +include(GoogleTest) + +foreach (TESTFILE IN LISTS GEMMA_TEST_FILES) + # The TESTNAME is the name without the extension or directory. + get_filename_component(TESTNAME ${TESTFILE} NAME_WE) + add_executable(${TESTNAME} ${TESTFILE}) + # Test all targets, not just the best/baseline. This changes the default + # policy to all-attainable; note that setting -DHWY_COMPILE_* directly can + # cause compile errors because only one may be set, and other CMakeLists.txt + # that include us may set them. + target_compile_options(${TESTNAME} PRIVATE -DHWY_IS_TEST=1) + + target_link_libraries(${TESTNAME} PRIVATE libgemma GTest::gtest_main hwy hwy_contrib hwy_test) + + gtest_discover_tests(${TESTNAME}) +endforeach () +endif() # GEMMA_ENABLE_TESTS diff --git a/ops.h b/ops.h index c7410d3..e8a58b7 100644 --- a/ops.h +++ b/ops.h @@ -634,6 +634,7 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, Foreach(d, x, mask_pos, vmin, [&vmax](const auto d, const auto value) HWY_ATTR { vmax = hn::Max(vmax, value); }); + vmax = hn::MaxOfLanes(d, vmax); // Subtract max (avoid precision loss for large exponents) and exponentiate. auto sum = hn::Zero(d); @@ -703,7 +704,7 @@ create_distribution(std::array& top_k, float temperature) { hn::Transform(d, top_k.data(), top_k.size(), [&temperature_inv](D d, hn::Vec v) HWY_ATTR { - return hn::Mul(hn::Exp(d, hn::Log(d, v)), temperature_inv); + return hn::Exp(d, hn::Mul(hn::Log(d, v), temperature_inv)); }); return std::discrete_distribution(std::begin(top_k), std::end(top_k)); diff --git a/ops_test.cc b/ops_test.cc new file mode 100644 index 0000000..e19106b --- /dev/null +++ b/ops_test.cc @@ -0,0 +1,374 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS HWY_SCALAR +#endif + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "ops_test.cc" //NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +// copybara:import_next_line:gemma_cpp +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" +#include "ops.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +namespace hn = hwy::HWY_NAMESPACE; + +template +struct ForeachCountAndMisalign { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) const { + hwy::RandomState rng; + const size_t N = Lanes(d); + const size_t misalignments[3] = {0, N / 4, 3 * N / 5}; + + for (size_t count = 0; count < 2 * N; ++count) { + for (size_t ma : misalignments) { + for (size_t mb : misalignments) { + Test()(d, count, ma, mb, rng); + } + } + } + } +}; + +template +T Random(hwy::RandomState& rng) { + const int32_t bits = static_cast(Random32(&rng)) & 1023; + const double val = (bits - 512) / 64.0; + // Clamp negative to zero for unsigned types. + return hwy::ConvertScalarTo( + HWY_MAX(hwy::ConvertScalarTo(hwy::LowestValue()), val)); +} + +HWY_NOINLINE void SourceAddFrom(const float* HWY_RESTRICT other, + float* HWY_RESTRICT x, size_t size) { + for (size_t i = 0; i < size; ++i) { + x[i] += other[i]; + } +} + +HWY_NOINLINE void SourceMulBy(const float* HWY_RESTRICT other, + float* HWY_RESTRICT x, size_t size, + size_t max_pos) { + HWY_DASSERT(max_pos <= size); + for (size_t i = 0; i < max_pos; ++i) { + x[i] *= other[i]; + } +} + +HWY_NOINLINE void SourceMulByConst(float c, float* HWY_RESTRICT x, size_t size, + size_t max_pos) { + for (size_t i = 0; i < max_pos; ++i) { + x[i] *= c; + } +} + +HWY_NOINLINE void SourceMulByConstAndAdd(float c, const float* HWY_RESTRICT x, + float* HWY_RESTRICT out, size_t size, + size_t max_pos) { + for (size_t i = 0; i < max_pos; ++i) { + out[i] += x[i] * c; + } +} + +HWY_NOINLINE void SourceSoftmax(float* HWY_RESTRICT x, size_t size, + size_t mask_pos) { + HWY_DASSERT(size != 0); + HWY_DASSERT(mask_pos <= size); + + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + const D d; + const size_t N = hn::Lanes(d); + + const hn::Vec vmin = hn::Set(d, hwy::LowestValue()); + hn::Vec vmax = vmin; + size_t idx = 0; + if (mask_pos >= N) { + for (; idx <= mask_pos - N; idx += N) { + vmax = hn::Max(vmax, LoadU(d, x + idx)); + } + } + vmax = hn::Max(vmax, LoadNOr(vmin, d, x + idx, mask_pos - idx)); + vmax = hn::MaxOfLanes(d, vmax); // broadcast + + hn::Vec sum = hn::Zero(d); + idx = 0; + if (mask_pos >= N) { + for (; idx <= mask_pos - N; idx += N) { + const hn::Vec out = hn::Exp(d, hn::Sub(hn::LoadU(d, x + idx), vmax)); + sum = hn::Add(sum, out); + hn::StoreU(out, d, x + idx); + } + } + if (mask_pos > idx) { + const size_t remaining = mask_pos - idx; + const hn::Vec out = + hn::Exp(d, hn::Sub(hn::LoadN(d, x + idx, remaining), vmax)); + sum = hn::Add(sum, out); + hn::StoreN(out, d, x + idx, remaining); + } + + const float mul = 1.0f / hn::ReduceSum(d, sum); + SourceMulByConst(mul, x, size, mask_pos); +} + +template +HWY_NOINLINE std::discrete_distribution SourceCreateDistribution( + std::array& top_k, float temperature) { + // re-normalize distribution + for (size_t i = 0; i < k; ++i) { + top_k[i] = exp(log(top_k[i]) / temperature); + } + float denominator = 0.0f; + for (size_t i = 0; i < k; ++i) { + denominator += top_k[i]; + } + denominator = 1.0f / denominator; + MulByConst(denominator, top_k.data(), k); + return std::discrete_distribution(std::begin(top_k), std::end(top_k)); +} + +struct TestAddFrom { + template + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + hwy::RandomState& rng) { + using T = hn::TFromD; + + hwy::AlignedFreeUniquePtr px = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + hwy::AlignedFreeUniquePtr pe = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + hwy::AlignedFreeUniquePtr po = + hwy::AllocateAligned(HWY_MAX(1, misalign_b + count)); + HWY_ASSERT(px && pe && po); + + T* x = px.get() + misalign_a; + T* e = pe.get() + misalign_a; + T* o = po.get() + misalign_b; + + for (size_t i = 0; i < count; ++i) { + x[i] = Random(rng); + e[i] = x[i]; + o[i] = Random(rng); + } + + SourceAddFrom(o, e, count); + AddFrom(o, x, count); + + hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, + __LINE__); + } +}; + +struct TestMulBy { + template + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + hwy::RandomState& rng) { + using T = hn::TFromD; + + hwy::AlignedFreeUniquePtr px = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + hwy::AlignedFreeUniquePtr pe = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + hwy::AlignedFreeUniquePtr po = + hwy::AllocateAligned(HWY_MAX(1, misalign_b + count)); + HWY_ASSERT(px && pe && po); + + T* x = px.get() + misalign_a; + T* e = pe.get() + misalign_a; + T* o = po.get() + misalign_b; + + for (size_t i = 0; i < count; ++i) { + x[i] = Random(rng); + e[i] = x[i]; + o[i] = Random(rng); + } + + SourceMulBy(o, e, count, count); + MulBy(o, x, count, count); + + hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, + __LINE__); + } +}; + +struct TestMulByConstAndAdd { + template + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + hwy::RandomState& rng) { + using T = hn::TFromD; + + hwy::AlignedFreeUniquePtr px = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + hwy::AlignedFreeUniquePtr pe = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + hwy::AlignedFreeUniquePtr po = + hwy::AllocateAligned(HWY_MAX(1, misalign_b + count)); + HWY_ASSERT(px && pe && po); + + T* x = px.get() + misalign_a; + T* e = pe.get() + misalign_a; + T* o = po.get() + misalign_b; + + for (size_t i = 0; i < count; ++i) { + x[i] = Random(rng); + e[i] = x[i]; + o[i] = Random(rng); + } + T constant = Random(rng); + + SourceMulByConstAndAdd(constant, o, e, count, count); + MulByConstAndAdd(constant, o, x, count, count); + + hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, + __LINE__); + } +}; + +struct TestMulByConst { + template + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + hwy::RandomState& rng) { + using T = hn::TFromD; + + hwy::AlignedFreeUniquePtr px = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + hwy::AlignedFreeUniquePtr pe = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + HWY_ASSERT(px && pe); + + T* x = px.get() + misalign_a; + T* e = pe.get() + misalign_a; + + for (size_t i = 0; i < count; ++i) { + x[i] = Random(rng); + e[i] = x[i]; + } + T constant = Random(rng); + + SourceMulByConst(constant, e, count, count); + MulByConst(constant, x, count, count); + + hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, + __LINE__); + } +}; + +struct TestSoftmax { + template + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + hwy::RandomState& rng) { + using T = hn::TFromD; + + hwy::AlignedFreeUniquePtr px = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + hwy::AlignedFreeUniquePtr pe = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + HWY_ASSERT(px && pe); + + T* x = px.get() + misalign_a; + T* e = pe.get() + misalign_a; + + for (size_t i = 0; i < count; ++i) { + x[i] = Random(rng); + e[i] = x[i]; + } + + SourceSoftmax(e, count, count); + Softmax(x, count, count); + + hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, + __LINE__); + } +}; + +template +struct TestCreateDistribution { + void operator()(hwy::RandomState& rng) { + std::array x; + std::array e; + + for (size_t i = 0; i < k; ++i) { + x[i] = Random(rng); + e[i] = x[i]; + } + const float constant = Random(rng); + auto expected = SourceCreateDistribution(e, constant); + auto output = create_distribution(x, constant); + + AssertEqual(expected, output, hwy::TargetName(HWY_TARGET), __FILE__, + __LINE__); + } +}; + +void TestAllAddFrom() { + hn::ForPartialVectors>()(float()); +} + +void TestAllMulBy() { + hn::ForPartialVectors>()(float()); +} + +void TestAllMulByConst() { + hn::ForPartialVectors>()(float()); +} + +void TestAllMulByConstAndAdd() { + hn::ForPartialVectors>()( + float()); +} + +void TestAllSoftmax() { + hn::ForPartialVectors>()(float()); +} + +void TestAllCreateDistribution() { + TestCreateDistribution<2048>(); + TestCreateDistribution<5000>(); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace gcpp { +HWY_BEFORE_TEST(OpsTest); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllAddFrom); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulBy); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); +#ifdef HWY_AFTER_TEST +HWY_AFTER_TEST(); +#endif +} // namespace gcpp + +#endif