diff --git a/BUILD.bazel b/BUILD.bazel index 319421f..152657e 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -21,7 +21,7 @@ licenses(["notice"]) exports_files(["LICENSE"]) cc_library( - name = "transformer_ops", + name = "ops", hdrs = [ "ops.h", ], @@ -38,6 +38,21 @@ cc_library( ], ) +cc_test( + name = "ops_test", + size = "small", + srcs = ["ops_test.cc"], + local_defines = ["HWY_IS_TEST"], + # for test_suite. + tags = ["hwy_ops_test"], + deps = [ + ":ops", + "@googletest//:gtest_main", + "@hwy//:hwy", + "@hwy//:hwy_test_util", + ], +) + cc_library( name = "args", hdrs = [ @@ -59,7 +74,7 @@ cc_library( ], deps = [ ":args", - ":transformer_ops", + ":ops", # "//base", "//compression:compress", "@hwy//:hwy", diff --git a/ops_test.cc b/ops_test.cc index e19106b..f908c39 100644 --- a/ops_test.cc +++ b/ops_test.cc @@ -17,6 +17,9 @@ #define HWY_DISABLED_TARGETS HWY_SCALAR #endif +#include +#include + #include "hwy/aligned_allocator.h" #include "hwy/base.h" @@ -25,9 +28,10 @@ #define HWY_TARGET_INCLUDE "ops_test.cc" //NOLINT // clang-format on #include "hwy/foreach_target.h" // IWYU pragma: keep -// copybara:import_next_line:gemma_cpp #include "hwy/highway.h" #include "hwy/tests/test_util-inl.h" +// After highway.h +// copybara:import_next_line:gemma_cpp #include "ops.h" HWY_BEFORE_NAMESPACE(); @@ -282,6 +286,7 @@ struct TestSoftmax { template void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, hwy::RandomState& rng) { + if (count == 0) return; // *Softmax would assert using T = hn::TFromD; hwy::AlignedFreeUniquePtr px =