Add ops_test to BUILD, rename transformer_ops->ops, fix includes.

Also fix copybara. Refs #105

PiperOrigin-RevId: 619157071
This commit is contained in:
Jan Wassenberg 2024-03-26 05:36:58 -07:00 committed by Copybara-Service
parent 9f1595c110
commit c1d3c3284c
2 changed files with 23 additions and 3 deletions

View File

@ -21,7 +21,7 @@ licenses(["notice"])
exports_files(["LICENSE"])
cc_library(
name = "transformer_ops",
name = "ops",
hdrs = [
"ops.h",
],
@ -38,6 +38,21 @@ cc_library(
],
)
cc_test(
name = "ops_test",
size = "small",
srcs = ["ops_test.cc"],
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
deps = [
":ops",
"@googletest//:gtest_main",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
],
)
cc_library(
name = "args",
hdrs = [
@ -59,7 +74,7 @@ cc_library(
],
deps = [
":args",
":transformer_ops",
":ops",
# "//base",
"//compression:compress",
"@hwy//:hwy",

View File

@ -17,6 +17,9 @@
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
#include <array>
#include <random>
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
@ -25,9 +28,10 @@
#define HWY_TARGET_INCLUDE "ops_test.cc" //NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
// copybara:import_next_line:gemma_cpp
#include "hwy/highway.h"
#include "hwy/tests/test_util-inl.h"
// After highway.h
// copybara:import_next_line:gemma_cpp
#include "ops.h"
HWY_BEFORE_NAMESPACE();
@ -282,6 +286,7 @@ struct TestSoftmax {
template <class D>
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
hwy::RandomState& rng) {
if (count == 0) return; // *Softmax would assert
using T = hn::TFromD<D>;
hwy::AlignedFreeUniquePtr<T[]> px =