Merge pull request #105 from enum-class:improve_ops_utility

PiperOrigin-RevId: 618827910
This commit is contained in:
Copybara-Service 2024-03-25 07:00:45 -07:00
commit 9f1595c110
3 changed files with 483 additions and 73 deletions

View File

@ -83,3 +83,35 @@ target_link_libraries(libgemma hwy hwy_contrib sentencepiece)
target_include_directories(libgemma PRIVATE ${sentencepiece_SOURCE_DIR})
target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests")
if (GEMMA_ENABLE_TESTS)
set(GEMMA_TEST_FILES
ops_test.cc
)
include(FetchContent)
FetchContent_Declare(
googletest
URL https://github.com/google/googletest/archive/refs/tags/v1.14.0.zip
)
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
FetchContent_MakeAvailable(googletest)
include(GoogleTest)
foreach (TESTFILE IN LISTS GEMMA_TEST_FILES)
# The TESTNAME is the name without the extension or directory.
get_filename_component(TESTNAME ${TESTFILE} NAME_WE)
add_executable(${TESTNAME} ${TESTFILE})
# Test all targets, not just the best/baseline. This changes the default
# policy to all-attainable; note that setting -DHWY_COMPILE_* directly can
# cause compile errors because only one may be set, and other CMakeLists.txt
# that include us may set them.
target_compile_options(${TESTNAME} PRIVATE -DHWY_IS_TEST=1)
target_link_libraries(${TESTNAME} PRIVATE libgemma GTest::gtest_main hwy hwy_contrib hwy_test)
gtest_discover_tests(${TESTNAME})
endforeach ()
endif() # GEMMA_ENABLE_TESTS

150
ops.h
View File

@ -552,47 +552,66 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul,
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
const float* HWY_RESTRICT other, float* HWY_RESTRICT x, size_t size) {
for (size_t i = 0; i < size; ++i) {
x[i] += other[i];
}
const float* HWY_RESTRICT other, float* HWY_RESTRICT x, const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
hn::Transform1(d, x, size, other,
[](const auto d, const auto x, const auto other)
HWY_ATTR { return hn::Add(x, other); });
}
static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other,
float* HWY_RESTRICT x, size_t size,
size_t max_pos) {
float* HWY_RESTRICT x, const size_t size,
const size_t max_pos) {
HWY_DASSERT(max_pos <= size);
for (size_t i = 0; i < max_pos; ++i) {
x[i] *= other[i];
}
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
hn::Transform1(d, x, max_pos, other,
[](const auto d, const auto x, const auto other)
HWY_ATTR { return hn::Mul(x, other); });
}
static HWY_INLINE HWY_MAYBE_UNUSED void MulBy(const float* HWY_RESTRICT other,
float* HWY_RESTRICT x,
size_t size) {
const size_t size) {
return MulBy(other, x, size, size);
}
static HWY_NOINLINE void MulByConst(float c, float* HWY_RESTRICT x, size_t size,
size_t max_pos) {
static HWY_NOINLINE void MulByConst(const float c, float* HWY_RESTRICT x,
const size_t size, const size_t max_pos) {
HWY_DASSERT(max_pos <= size);
for (size_t i = 0; i < max_pos; ++i) {
x[i] *= c;
}
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
const auto constant = hn::Set(d, c);
hn::Transform(d, x, max_pos,
[&constant](const auto d, const auto x)
HWY_ATTR { return hn::Mul(x, constant); });
}
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(float c,
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(const float c,
float* HWY_RESTRICT x,
size_t size) {
const size_t size) {
MulByConst(c, x, size, size);
}
static HWY_NOINLINE void MulByConstAndAdd(float c, const float* HWY_RESTRICT x,
float* HWY_RESTRICT out, size_t size,
size_t max_pos) {
for (size_t i = 0; i < max_pos; ++i) {
out[i] += x[i] * c;
}
static HWY_NOINLINE void MulByConstAndAdd(const float c,
const float* HWY_RESTRICT x,
float* HWY_RESTRICT out,
const size_t size,
const size_t max_pos) {
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
const auto constant = hn::Set(d, c);
hn::Transform1(
d, out, max_pos, x,
[&constant](const auto d, const auto out_element, const auto x_element)
HWY_ATTR { return hn::MulAdd(x_element, constant, out_element); });
}
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
@ -601,49 +620,30 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
MulByConstAndAdd(c, x, out, size, size);
}
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, size_t size,
size_t mask_pos) {
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
const size_t mask_pos) {
HWY_DASSERT(size != 0);
HWY_DASSERT(mask_pos <= size);
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
const size_t N = hn::Lanes(d);
// Find max so we can subtract it below. Avoid hn::Foreach because SVE vectors
// cannot be lambda-captured.
// TODO(janwas): could be replaced with an hn::Accumulate algo.
const hn::Vec<D> vmin = hn::Set(d, hwy::LowestValue<float>());
hn::Vec<D> vmax = vmin;
size_t idx = 0;
if (mask_pos >= N) {
for (; idx <= mask_pos - N; idx += N) {
vmax = hn::Max(vmax, LoadU(d, x + idx));
}
}
vmax = hn::Max(vmax, LoadNOr(vmin, d, x + idx, mask_pos - idx));
vmax = hn::MaxOfLanes(d, vmax); // broadcast
const auto vmin = hn::Set(d, hwy::LowestValue<float>());
auto vmax = vmin;
Foreach(d, x, mask_pos, vmin,
[&vmax](const auto d, const auto value)
HWY_ATTR { vmax = hn::Max(vmax, value); });
vmax = hn::MaxOfLanes(d, vmax);
// Subtract max (avoid precision loss for large exponents) and exponentiate.
// Also avoid hn::Transform because the additional `sum` output vector cannot
// be captured by a lambda.
hn::Vec<D> sum = hn::Zero(d);
idx = 0;
if (mask_pos >= N) {
for (; idx <= mask_pos - N; idx += N) {
const hn::Vec<D> out = hn::Exp(d, hn::Sub(hn::LoadU(d, x + idx), vmax));
sum = hn::Add(sum, out);
hn::StoreU(out, d, x + idx);
}
}
if (mask_pos > idx) {
const size_t remaining = mask_pos - idx;
const hn::Vec<D> out =
hn::Exp(d, hn::Sub(hn::LoadN(d, x + idx, remaining), vmax));
sum = hn::Add(sum, out);
hn::StoreN(out, d, x + idx, remaining);
}
auto sum = hn::Zero(d);
hn::Transform(d, x, mask_pos,
[&sum, &vmax](const auto d, const auto value) HWY_ATTR {
const auto out = hn::Exp(d, hn::Sub(value, vmax));
sum = hn::Add(sum, out);
return out;
});
// Normalize to probability distribution
const float mul = 1.0f / hn::ReduceSum(d, sum);
@ -651,29 +651,30 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, size_t size,
}
static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x,
size_t size) {
const size_t size) {
Softmax(x, size, size);
}
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
size_t size, size_t max_pos) {
const size_t size,
const size_t max_pos) {
HWY_DASSERT(max_pos <= size);
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
const float inv_cap = 1.0f / cap;
const auto vcap = hn::Set(d, cap);
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
hn::Transform(d, x, size, [cap, inv_cap](D d, hn::Vec<D> v) HWY_ATTR {
return hn::Mul(hn::Set(d, cap),
hn::Tanh(d, hn::Mul(v, hn::Set(d, inv_cap))));
hn::Transform(d, x, size, [&vcap, &vinv_cap](D d, hn::Vec<D> v) HWY_ATTR {
return hn::Mul(vcap, hn::Tanh(d, hn::Mul(v, vinv_cap)));
});
}
static HWY_INLINE HWY_MAYBE_UNUSED void LogitsSoftCap(const float cap,
float* HWY_RESTRICT x,
size_t size) {
const size_t size) {
LogitsSoftCap(cap, x, size, size);
}
@ -694,15 +695,18 @@ template <size_t k>
static HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution<int>
create_distribution(std::array<float, k>& top_k, float temperature) {
// re-normalize distribution
for (size_t i = 0; i < k; ++i) {
top_k[i] = exp(log(top_k[i]) / temperature);
}
float denominator = 0.0f;
for (size_t i = 0; i < k; ++i) {
denominator += top_k[i];
}
denominator = 1.0f / denominator;
MulByConst(denominator, top_k.data(), k);
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
const auto temperature_inv =
hn::Div(hn::Set(d, 1.0f), hn::Set(d, temperature));
hn::Transform(d, top_k.data(), top_k.size(),
[&temperature_inv](D d, hn::Vec<D> v) HWY_ATTR {
return hn::Exp(d, hn::Mul(hn::Log(d, v), temperature_inv));
});
return std::discrete_distribution<int>(std::begin(top_k), std::end(top_k));
}

374
ops_test.cc Normal file
View File

@ -0,0 +1,374 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "ops_test.cc" //NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
// copybara:import_next_line:gemma_cpp
#include "hwy/highway.h"
#include "hwy/tests/test_util-inl.h"
#include "ops.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
template <class Test>
struct ForeachCountAndMisalign {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) const {
hwy::RandomState rng;
const size_t N = Lanes(d);
const size_t misalignments[3] = {0, N / 4, 3 * N / 5};
for (size_t count = 0; count < 2 * N; ++count) {
for (size_t ma : misalignments) {
for (size_t mb : misalignments) {
Test()(d, count, ma, mb, rng);
}
}
}
}
};
template <typename T>
T Random(hwy::RandomState& rng) {
const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023;
const double val = (bits - 512) / 64.0;
// Clamp negative to zero for unsigned types.
return hwy::ConvertScalarTo<T>(
HWY_MAX(hwy::ConvertScalarTo<double>(hwy::LowestValue<T>()), val));
}
HWY_NOINLINE void SourceAddFrom(const float* HWY_RESTRICT other,
float* HWY_RESTRICT x, size_t size) {
for (size_t i = 0; i < size; ++i) {
x[i] += other[i];
}
}
HWY_NOINLINE void SourceMulBy(const float* HWY_RESTRICT other,
float* HWY_RESTRICT x, size_t size,
size_t max_pos) {
HWY_DASSERT(max_pos <= size);
for (size_t i = 0; i < max_pos; ++i) {
x[i] *= other[i];
}
}
HWY_NOINLINE void SourceMulByConst(float c, float* HWY_RESTRICT x, size_t size,
size_t max_pos) {
for (size_t i = 0; i < max_pos; ++i) {
x[i] *= c;
}
}
HWY_NOINLINE void SourceMulByConstAndAdd(float c, const float* HWY_RESTRICT x,
float* HWY_RESTRICT out, size_t size,
size_t max_pos) {
for (size_t i = 0; i < max_pos; ++i) {
out[i] += x[i] * c;
}
}
HWY_NOINLINE void SourceSoftmax(float* HWY_RESTRICT x, size_t size,
size_t mask_pos) {
HWY_DASSERT(size != 0);
HWY_DASSERT(mask_pos <= size);
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
const size_t N = hn::Lanes(d);
const hn::Vec<D> vmin = hn::Set(d, hwy::LowestValue<float>());
hn::Vec<D> vmax = vmin;
size_t idx = 0;
if (mask_pos >= N) {
for (; idx <= mask_pos - N; idx += N) {
vmax = hn::Max(vmax, LoadU(d, x + idx));
}
}
vmax = hn::Max(vmax, LoadNOr(vmin, d, x + idx, mask_pos - idx));
vmax = hn::MaxOfLanes(d, vmax); // broadcast
hn::Vec<D> sum = hn::Zero(d);
idx = 0;
if (mask_pos >= N) {
for (; idx <= mask_pos - N; idx += N) {
const hn::Vec<D> out = hn::Exp(d, hn::Sub(hn::LoadU(d, x + idx), vmax));
sum = hn::Add(sum, out);
hn::StoreU(out, d, x + idx);
}
}
if (mask_pos > idx) {
const size_t remaining = mask_pos - idx;
const hn::Vec<D> out =
hn::Exp(d, hn::Sub(hn::LoadN(d, x + idx, remaining), vmax));
sum = hn::Add(sum, out);
hn::StoreN(out, d, x + idx, remaining);
}
const float mul = 1.0f / hn::ReduceSum(d, sum);
SourceMulByConst(mul, x, size, mask_pos);
}
template <size_t k>
HWY_NOINLINE std::discrete_distribution<int> SourceCreateDistribution(
std::array<float, k>& top_k, float temperature) {
// re-normalize distribution
for (size_t i = 0; i < k; ++i) {
top_k[i] = exp(log(top_k[i]) / temperature);
}
float denominator = 0.0f;
for (size_t i = 0; i < k; ++i) {
denominator += top_k[i];
}
denominator = 1.0f / denominator;
MulByConst(denominator, top_k.data(), k);
return std::discrete_distribution<int>(std::begin(top_k), std::end(top_k));
}
struct TestAddFrom {
template <class D>
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
hwy::RandomState& rng) {
using T = hn::TFromD<D>;
hwy::AlignedFreeUniquePtr<T[]> px =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
hwy::AlignedFreeUniquePtr<T[]> pe =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
hwy::AlignedFreeUniquePtr<T[]> po =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_b + count));
HWY_ASSERT(px && pe && po);
T* x = px.get() + misalign_a;
T* e = pe.get() + misalign_a;
T* o = po.get() + misalign_b;
for (size_t i = 0; i < count; ++i) {
x[i] = Random<T>(rng);
e[i] = x[i];
o[i] = Random<T>(rng);
}
SourceAddFrom(o, e, count);
AddFrom(o, x, count);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__);
}
};
struct TestMulBy {
template <class D>
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
hwy::RandomState& rng) {
using T = hn::TFromD<D>;
hwy::AlignedFreeUniquePtr<T[]> px =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
hwy::AlignedFreeUniquePtr<T[]> pe =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
hwy::AlignedFreeUniquePtr<T[]> po =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_b + count));
HWY_ASSERT(px && pe && po);
T* x = px.get() + misalign_a;
T* e = pe.get() + misalign_a;
T* o = po.get() + misalign_b;
for (size_t i = 0; i < count; ++i) {
x[i] = Random<T>(rng);
e[i] = x[i];
o[i] = Random<T>(rng);
}
SourceMulBy(o, e, count, count);
MulBy(o, x, count, count);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__);
}
};
struct TestMulByConstAndAdd {
template <class D>
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
hwy::RandomState& rng) {
using T = hn::TFromD<D>;
hwy::AlignedFreeUniquePtr<T[]> px =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
hwy::AlignedFreeUniquePtr<T[]> pe =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
hwy::AlignedFreeUniquePtr<T[]> po =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_b + count));
HWY_ASSERT(px && pe && po);
T* x = px.get() + misalign_a;
T* e = pe.get() + misalign_a;
T* o = po.get() + misalign_b;
for (size_t i = 0; i < count; ++i) {
x[i] = Random<T>(rng);
e[i] = x[i];
o[i] = Random<T>(rng);
}
T constant = Random<T>(rng);
SourceMulByConstAndAdd(constant, o, e, count, count);
MulByConstAndAdd(constant, o, x, count, count);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__);
}
};
struct TestMulByConst {
template <class D>
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
hwy::RandomState& rng) {
using T = hn::TFromD<D>;
hwy::AlignedFreeUniquePtr<T[]> px =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
hwy::AlignedFreeUniquePtr<T[]> pe =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
HWY_ASSERT(px && pe);
T* x = px.get() + misalign_a;
T* e = pe.get() + misalign_a;
for (size_t i = 0; i < count; ++i) {
x[i] = Random<T>(rng);
e[i] = x[i];
}
T constant = Random<T>(rng);
SourceMulByConst(constant, e, count, count);
MulByConst(constant, x, count, count);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__);
}
};
struct TestSoftmax {
template <class D>
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
hwy::RandomState& rng) {
using T = hn::TFromD<D>;
hwy::AlignedFreeUniquePtr<T[]> px =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
hwy::AlignedFreeUniquePtr<T[]> pe =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
HWY_ASSERT(px && pe);
T* x = px.get() + misalign_a;
T* e = pe.get() + misalign_a;
for (size_t i = 0; i < count; ++i) {
x[i] = Random<T>(rng);
e[i] = x[i];
}
SourceSoftmax(e, count, count);
Softmax(x, count, count);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__);
}
};
template <size_t k>
struct TestCreateDistribution {
void operator()(hwy::RandomState& rng) {
std::array<float, k> x;
std::array<float, k> e;
for (size_t i = 0; i < k; ++i) {
x[i] = Random<float>(rng);
e[i] = x[i];
}
const float constant = Random<float>(rng);
auto expected = SourceCreateDistribution(e, constant);
auto output = create_distribution(x, constant);
AssertEqual(expected, output, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__);
}
};
void TestAllAddFrom() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestAddFrom>>()(float());
}
void TestAllMulBy() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulBy>>()(float());
}
void TestAllMulByConst() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulByConst>>()(float());
}
void TestAllMulByConstAndAdd() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulByConstAndAdd>>()(
float());
}
void TestAllSoftmax() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmax>>()(float());
}
void TestAllCreateDistribution() {
TestCreateDistribution<2048>();
TestCreateDistribution<5000>();
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_BEFORE_TEST(OpsTest);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllAddFrom);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulBy);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
#ifdef HWY_AFTER_TEST
HWY_AFTER_TEST();
#endif
} // namespace gcpp
#endif