Add ops_test to BUILD, rename transformer_ops->ops, fix includes.

Also fix copybara. Refs #105

PiperOrigin-RevId: 619157071
This commit is contained in:
Jan Wassenberg 2024-03-26 05:36:58 -07:00 committed by Copybara-Service
parent 9f1595c110
commit c1d3c3284c
2 changed files with 23 additions and 3 deletions

View File

@ -21,7 +21,7 @@ licenses(["notice"])
exports_files(["LICENSE"]) exports_files(["LICENSE"])
cc_library( cc_library(
name = "transformer_ops", name = "ops",
hdrs = [ hdrs = [
"ops.h", "ops.h",
], ],
@ -38,6 +38,21 @@ cc_library(
], ],
) )
cc_test(
name = "ops_test",
size = "small",
srcs = ["ops_test.cc"],
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
deps = [
":ops",
"@googletest//:gtest_main",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
],
)
cc_library( cc_library(
name = "args", name = "args",
hdrs = [ hdrs = [
@ -59,7 +74,7 @@ cc_library(
], ],
deps = [ deps = [
":args", ":args",
":transformer_ops", ":ops",
# "//base", # "//base",
"//compression:compress", "//compression:compress",
"@hwy//:hwy", "@hwy//:hwy",

View File

@ -17,6 +17,9 @@
#define HWY_DISABLED_TARGETS HWY_SCALAR #define HWY_DISABLED_TARGETS HWY_SCALAR
#endif #endif
#include <array>
#include <random>
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
@ -25,9 +28,10 @@
#define HWY_TARGET_INCLUDE "ops_test.cc" //NOLINT #define HWY_TARGET_INCLUDE "ops_test.cc" //NOLINT
// clang-format on // clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
// copybara:import_next_line:gemma_cpp
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/tests/test_util-inl.h" #include "hwy/tests/test_util-inl.h"
// After highway.h
// copybara:import_next_line:gemma_cpp
#include "ops.h" #include "ops.h"
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();
@ -282,6 +286,7 @@ struct TestSoftmax {
template <class D> template <class D>
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
hwy::RandomState& rng) { hwy::RandomState& rng) {
if (count == 0) return; // *Softmax would assert
using T = hn::TFromD<D>; using T = hn::TFromD<D>;
hwy::AlignedFreeUniquePtr<T[]> px = hwy::AlignedFreeUniquePtr<T[]> px =