mirror of https://github.com/google/gemma.cpp.git
Add ops_test to BUILD, rename transformer_ops->ops, fix includes.
Also fix copybara. Refs #105 PiperOrigin-RevId: 619157071
This commit is contained in:
parent
9f1595c110
commit
c1d3c3284c
19
BUILD.bazel
19
BUILD.bazel
|
|
@ -21,7 +21,7 @@ licenses(["notice"])
|
||||||
exports_files(["LICENSE"])
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "transformer_ops",
|
name = "ops",
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"ops.h",
|
"ops.h",
|
||||||
],
|
],
|
||||||
|
|
@ -38,6 +38,21 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "ops_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["ops_test.cc"],
|
||||||
|
local_defines = ["HWY_IS_TEST"],
|
||||||
|
# for test_suite.
|
||||||
|
tags = ["hwy_ops_test"],
|
||||||
|
deps = [
|
||||||
|
":ops",
|
||||||
|
"@googletest//:gtest_main",
|
||||||
|
"@hwy//:hwy",
|
||||||
|
"@hwy//:hwy_test_util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "args",
|
name = "args",
|
||||||
hdrs = [
|
hdrs = [
|
||||||
|
|
@ -59,7 +74,7 @@ cc_library(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":args",
|
":args",
|
||||||
":transformer_ops",
|
":ops",
|
||||||
# "//base",
|
# "//base",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,9 @@
|
||||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include <array>
|
||||||
|
#include <random>
|
||||||
|
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
||||||
|
|
@ -25,9 +28,10 @@
|
||||||
#define HWY_TARGET_INCLUDE "ops_test.cc" //NOLINT
|
#define HWY_TARGET_INCLUDE "ops_test.cc" //NOLINT
|
||||||
// clang-format on
|
// clang-format on
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/tests/test_util-inl.h"
|
#include "hwy/tests/test_util-inl.h"
|
||||||
|
// After highway.h
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "ops.h"
|
#include "ops.h"
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
|
|
@ -282,6 +286,7 @@ struct TestSoftmax {
|
||||||
template <class D>
|
template <class D>
|
||||||
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
||||||
hwy::RandomState& rng) {
|
hwy::RandomState& rng) {
|
||||||
|
if (count == 0) return; // *Softmax would assert
|
||||||
using T = hn::TFromD<D>;
|
using T = hn::TFromD<D>;
|
||||||
|
|
||||||
hwy::AlignedFreeUniquePtr<T[]> px =
|
hwy::AlignedFreeUniquePtr<T[]> px =
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue