mirror of https://github.com/google/gemma.cpp.git
Support Bazel builds. Fixes #16
Also fix nuq/sfp-inl: warning, cast, and disable SCALAR PiperOrigin-RevId: 612704056
This commit is contained in:
parent
cd7468199c
commit
bb9b023502
|
|
@ -12,6 +12,7 @@ jobs:
|
|||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
# When adding another, also add to copybara's github_check_runs.
|
||||
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
|
||||
build_type: ['Release']
|
||||
preset: ['make', 'windows']
|
||||
|
|
@ -54,3 +55,21 @@ jobs:
|
|||
${{ github.workspace }}/build/${{ matrix.build_type }}/libgemma.lib
|
||||
${{ github.workspace }}/build/gemma
|
||||
${{ github.workspace }}/build/libgemma.a
|
||||
|
||||
bazel:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@63c24ba6bd7ba022e95695ff85de572c04a18142 # v2.7.0
|
||||
with:
|
||||
egress-policy: audit # cannot be block - runner does git checkout
|
||||
|
||||
- uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 # v4.0.0
|
||||
|
||||
- uses: bazelbuild/setup-bazelisk@b39c379c82683a5f25d34f0d062761f62693e0b2 # v3.0.0
|
||||
|
||||
- uses: actions/cache@ab5e6d0c87105b4c9c2047343972218f562e4319 # v4.0.1
|
||||
with:
|
||||
path: ~/.cache/bazel
|
||||
key: bazel-${{ runner.os }}
|
||||
- run: bazel build -c opt --cxxopt=-std=c++20 //...
|
||||
59
BUILD.bazel
59
BUILD.bazel
|
|
@ -25,21 +25,14 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
"//compression:compress",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:algo",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:dot",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:math",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:matvec",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:profiler",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:thread_pool",
|
||||
"//hwy/contrib/sort:vqsort",
|
||||
"@hwy//:algo",
|
||||
"@hwy//:dot",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:math",
|
||||
"@hwy//:matvec",
|
||||
"@hwy//:profiler",
|
||||
"@hwy//:thread_pool",
|
||||
"@hwy//hwy/contrib/sort:vqsort",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -49,8 +42,7 @@ cc_library(
|
|||
"util/args.h",
|
||||
],
|
||||
deps = [
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -61,8 +53,7 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
":args",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -78,19 +69,13 @@ cc_library(
|
|||
deps = [
|
||||
":args",
|
||||
":transformer_ops",
|
||||
"//base",
|
||||
"//compression:compress",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:matvec",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:nanobenchmark", # timer
|
||||
# copybara:import_next_line:hwy
|
||||
"//:profiler",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:thread_pool",
|
||||
":sentencepiece_processor",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:matvec",
|
||||
"@hwy//:nanobenchmark", # timer
|
||||
"@hwy//:profiler",
|
||||
"@hwy//:thread_pool",
|
||||
"@com_google_sentencepiece//:sentencepiece_processor",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -104,13 +89,9 @@ cc_binary(
|
|||
":args",
|
||||
":gemma_lib",
|
||||
"//compression:compress",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:nanobenchmark",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:profiler",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:thread_pool",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:nanobenchmark",
|
||||
"@hwy//:profiler",
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -127,13 +127,13 @@ working with weights, kv cache and activations (e.g. you might have multiple kv
|
|||
caches and activations for a single set of weights) more directly rather than
|
||||
only using a Gemma object.
|
||||
|
||||
## Use the tokenizer in the Gemma object (or interact with the Tokenizer object directly)
|
||||
### Use the tokenizer in the Gemma object (or interact with the Tokenizer object directly)
|
||||
|
||||
You pretty much only do things with the tokenizer, call `Encode()` to go from
|
||||
string prompts to token id vectors, or `Decode()` to go from token id vector
|
||||
outputs from the model back to strings.
|
||||
|
||||
## The main entrypoint for generation is `GenerateGemma()`
|
||||
### The main entrypoint for generation is `GenerateGemma()`
|
||||
|
||||
Calling into `GenerateGemma` with a tokenized prompt will 1) mutate the
|
||||
activation values in `model` and 2) invoke StreamFunc - a lambda callback for
|
||||
|
|
@ -150,7 +150,7 @@ constrained decoding type of use cases where you want to force the generation
|
|||
to fit a grammar. If you're not doing this, you can send an empty lambda as a
|
||||
no-op which is what `run.cc` does.
|
||||
|
||||
## If you want to invoke the neural network forward function directly call the `Transformer()` function
|
||||
### If you want to invoke the neural network forward function directly call the `Transformer()` function
|
||||
|
||||
For high-level applications, you might only call `GenerateGemma()` and never
|
||||
interact directly with the neural network, but if you're doing something a bit
|
||||
|
|
@ -158,11 +158,20 @@ more custom you can call transformer which performs a single inference
|
|||
operation on a single token and mutates the Activations and the KVCache through
|
||||
the neural network computation.
|
||||
|
||||
## For low level operations, defining new architectures, call `ops.h` functions directly
|
||||
### For low level operations, defining new architectures, call `ops.h` functions directly
|
||||
|
||||
You use `ops.h` if you're writing other NN architectures or modifying the
|
||||
inference path of the Gemma model.
|
||||
|
||||
## Building with Bazel
|
||||
|
||||
The sentencepiece library we depend on requires some additional work to build
|
||||
with the Bazel build system. First, it does not export its BUILD file, so we
|
||||
provide `bazel/sentencepiece.bazel`. Second, it ships with a vendored subset of
|
||||
the Abseil library. `bazel/com_google_sentencepiece.patch` changes the code to
|
||||
support Abseil as a standalone dependency without third_party/ prefixes, similar
|
||||
to the transforms we apply to Gemma via Copybara.
|
||||
|
||||
## Discord
|
||||
|
||||
We're also trying out a discord server for discussion here -
|
||||
|
|
|
|||
55
MODULE.bazel
55
MODULE.bazel
|
|
@ -3,12 +3,57 @@ module(
|
|||
version = "0.1.0",
|
||||
)
|
||||
|
||||
bazel_dep(
|
||||
name = "rules_license",
|
||||
version = "0.0.7",
|
||||
bazel_dep(name = "rules_license", version = "0.0.7")
|
||||
bazel_dep(name = "googletest", version = "1.14.0")
|
||||
|
||||
# Copied from Highway because Bazel does not load them transitively
|
||||
bazel_dep(name = "bazel_skylib", version = "1.4.1")
|
||||
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||
bazel_dep(name = "platforms", version = "0.0.7")
|
||||
|
||||
http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
|
||||
http_archive(
|
||||
name = "hwy",
|
||||
urls = ["https://github.com/google/highway/archive/refs/tags/1.1.0.zip"],
|
||||
integrity = "sha256-zkJX2SwL4wQ0nHMsURW7MDLEf43vFSnqhSUsUM6eQmY=",
|
||||
strip_prefix = "highway-1.1.0",
|
||||
)
|
||||
|
||||
bazel_dep(
|
||||
http_archive(
|
||||
name = "com_google_sentencepiece",
|
||||
version = "0.1.96",
|
||||
sha256 = "8409b0126ebd62b256c685d5757150cf7fcb2b92a2f2b98efb3f38fc36719754",
|
||||
strip_prefix = "sentencepiece-0.1.96",
|
||||
urls = ["https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"],
|
||||
build_file = "@//bazel:sentencepiece.bazel",
|
||||
patches = ["@//bazel:com_google_sentencepiece.patch"],
|
||||
patch_args = ["-p1"],
|
||||
)
|
||||
|
||||
# For sentencepiece
|
||||
http_archive(
|
||||
name = "darts_clone",
|
||||
build_file_content = """
|
||||
licenses(["notice"])
|
||||
exports_files(["LICENSE"])
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
cc_library(
|
||||
name = "darts_clone",
|
||||
hdrs = [
|
||||
"include/darts.h",
|
||||
],
|
||||
)
|
||||
""",
|
||||
sha256 = "c97f55d05c98da6fcaf7f9ecc6a6dc6bc5b18b8564465f77abff8879d446491c",
|
||||
strip_prefix = "darts-clone-e40ce4627526985a7767444b6ed6893ab6ff8983",
|
||||
urls = [
|
||||
"https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip",
|
||||
],
|
||||
)
|
||||
# ABSL on 2023-10-18
|
||||
http_archive(
|
||||
name = "com_google_absl",
|
||||
sha256 = "f841f78243f179326f2a80b719f2887c38fe226d288ecdc46e2aa091e6aa43bc",
|
||||
strip_prefix = "abseil-cpp-9687a8ea750bfcddf790372093245a1d041b21a3",
|
||||
urls = ["https://github.com/abseil/abseil-cpp/archive//9687a8ea750bfcddf790372093245a1d041b21a3.tar.gz"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -175,6 +175,14 @@ cmake --build --preset windows -j [number of parallel threads to use]
|
|||
|
||||
If the build is successful, you should now have a `gemma.exe` executable in the `build/` directory.
|
||||
|
||||
#### Bazel
|
||||
|
||||
```sh
|
||||
bazel build -c opt --cxxopt=-std=c++20 :gemma
|
||||
```
|
||||
|
||||
If the build is successful, you should now have a `gemma` executable in the `bazel-bin/` directory.
|
||||
|
||||
### Step 4: Run
|
||||
|
||||
You can now run `gemma` from inside the `build/` directory.
|
||||
|
|
|
|||
24
WORKSPACE
24
WORKSPACE
|
|
@ -1,24 +1,4 @@
|
|||
workspace(name = "gemma")
|
||||
|
||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
|
||||
|
||||
maybe(
|
||||
http_archive,
|
||||
name = "rules_license",
|
||||
sha256 = "4531deccb913639c30e5c7512a054d5d875698daeb75d8cf90f284375fe7c360",
|
||||
urls = [
|
||||
"https://github.com/bazelbuild/rules_license/releases/download/0.0.7/rules_license-0.0.7.tar.gz",
|
||||
],
|
||||
)
|
||||
|
||||
maybe(
|
||||
http_archive,
|
||||
name = "com_google_sentencepiece",
|
||||
sha256 = "8409b0126ebd62b256c685d5757150cf7fcb2b92a2f2b98efb3f38fc36719754",
|
||||
strip_prefix = "sentencepiece-0.1.96",
|
||||
urls = ["https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"],
|
||||
build_file = "@//third_party:sentencepiece.bazel",
|
||||
patches = ["@//third_party:com_google_sentencepiece.patch"],
|
||||
patch_args = ["-p1"],
|
||||
)
|
||||
# This file marks the root of the Bazel workspace.
|
||||
# See MODULE.bazel for external dependencies setup.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,4 @@
|
|||
package(
|
||||
default_applicable_licenses = ["//:license"],
|
||||
default_visibility = ["//:__subpackages__"],
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,97 @@
|
|||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
features = [
|
||||
"layering_check",
|
||||
"parse_headers",
|
||||
],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2, BSD, MIT
|
||||
|
||||
proto_library(
|
||||
name = "sentencepiece_proto",
|
||||
srcs = ["src/sentencepiece.proto"],
|
||||
)
|
||||
|
||||
cc_proto_library(
|
||||
name = "sentencepiece_cc_proto",
|
||||
deps = [":sentencepiece_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "sentencepiece_model_proto",
|
||||
srcs = ["src/sentencepiece_model.proto"],
|
||||
)
|
||||
|
||||
cc_proto_library(
|
||||
name = "sentencepiece_model_cc_proto",
|
||||
deps = [":sentencepiece_model_proto"],
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "config_h",
|
||||
srcs = ["config.h.in"],
|
||||
outs = ["config.h"],
|
||||
cmd = "cp $< $@",
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "common",
|
||||
hdrs = [
|
||||
"config.h",
|
||||
"src/common.h",
|
||||
],
|
||||
deps = [
|
||||
"@com_google_absl//absl/base",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sentencepiece_processor",
|
||||
srcs = [
|
||||
"src/bpe_model.cc",
|
||||
"src/char_model.cc",
|
||||
"src/error.cc",
|
||||
"src/filesystem.cc",
|
||||
"src/model_factory.cc",
|
||||
"src/model_interface.cc",
|
||||
"src/normalizer.cc",
|
||||
"src/sentencepiece_processor.cc",
|
||||
"src/unigram_model.cc",
|
||||
"src/util.cc",
|
||||
"src/word_model.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"src/bpe_model.h",
|
||||
"src/char_model.h",
|
||||
"src/filesystem.h",
|
||||
"src/freelist.h",
|
||||
"src/model_factory.h",
|
||||
"src/model_interface.h",
|
||||
"src/normalizer.h",
|
||||
"src/sentencepiece_processor.h",
|
||||
"src/trainer_interface.h",
|
||||
"src/unigram_model.h",
|
||||
"src/util.h",
|
||||
"src/word_model.h",
|
||||
],
|
||||
defines = ["_USE_TF_STRING_VIEW"],
|
||||
includes = [
|
||||
".",
|
||||
"src",
|
||||
],
|
||||
linkstatic = 1,
|
||||
deps =
|
||||
[
|
||||
":common",
|
||||
":sentencepiece_cc_proto",
|
||||
":sentencepiece_model_cc_proto",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@darts_clone",
|
||||
],
|
||||
)
|
||||
|
|
@ -1,10 +1,10 @@
|
|||
# Weight compression, I/O and analysis
|
||||
|
||||
package(
|
||||
default_applicable_licenses = ["//third_party/gemma_cpp:license"],
|
||||
default_applicable_licenses = ["//:license"],
|
||||
default_visibility = [
|
||||
"//learning/gemini/prod/contrib/gemini_cpp:__subpackages__",
|
||||
"//third_party/gemma_cpp:__subpackages__",
|
||||
"//:__subpackages__",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -17,10 +17,8 @@ cc_library(
|
|||
"blob_store.h",
|
||||
],
|
||||
deps = [
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:thread_pool",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -34,8 +32,7 @@ cc_library(
|
|||
"stats.h",
|
||||
],
|
||||
deps = [
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -48,8 +45,7 @@ cc_library(
|
|||
"sfp-inl.h",
|
||||
],
|
||||
deps = [
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -65,15 +61,11 @@ cc_test(
|
|||
deps = [
|
||||
":sfp",
|
||||
":stats",
|
||||
"//testing/base/public:gunit_main_no_google3",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy_test_util",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:nanobenchmark",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:thread_pool",
|
||||
"@googletest//:gtest_main",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:hwy_test_util",
|
||||
"@hwy//:nanobenchmark",
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -87,9 +79,8 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
":sfp",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
"//third_party/highway/hwy/contrib/sort:vqsort",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//hwy/contrib/sort:vqsort",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -106,13 +97,10 @@ cc_test(
|
|||
":nuq",
|
||||
":sfp",
|
||||
":stats",
|
||||
"//testing/base/public:gunit_main_no_google3",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy_test_util",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:nanobenchmark",
|
||||
"@googletest//:gtest_main",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:hwy_test_util",
|
||||
"@hwy//:nanobenchmark",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -131,12 +119,9 @@ cc_library(
|
|||
":nuq",
|
||||
":sfp",
|
||||
":stats",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:dot",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:thread_pool",
|
||||
"@hwy//:dot",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -150,12 +135,9 @@ cc_library(
|
|||
":nuq",
|
||||
":sfp",
|
||||
":stats",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:nanobenchmark", # timer
|
||||
# copybara:import_next_line:hwy
|
||||
"//:thread_pool",
|
||||
"//third_party/highway/hwy/contrib/sort:vqsort",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:nanobenchmark", # timer
|
||||
"@hwy//:thread_pool",
|
||||
"@hwy//hwy/contrib/sort:vqsort",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -470,9 +470,7 @@ class NuqCodec {
|
|||
static HWY_INLINE size_t Enc(DF df, const float* const in, const size_t num,
|
||||
ClusterBuf& buf, const size_t out_capacity,
|
||||
NuqStream* const out, const size_t out_ofs) {
|
||||
const hn::Repartition<uint8_t, DF> d8;
|
||||
const hn::Repartition<uint16_t, DF> d16;
|
||||
using V8 = hn::Vec<decltype(d8)>;
|
||||
using V16 = hn::Vec<decltype(d16)>;
|
||||
|
||||
const size_t N16 = hn::Lanes(d16);
|
||||
|
|
|
|||
|
|
@ -13,6 +13,11 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// SFP uses ConcatEven/Odd which are not supported. Use HWY_EMU128 instead.
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||
#endif
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
|
@ -23,9 +28,10 @@
|
|||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"third_party/gemma_cpp/compression/nuq_test.cc" // NOLINT
|
||||
#define HWY_TARGET_INCLUDE "compression/nuq_test.cc" // NOLINT
|
||||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
// Other headers that include Highway must come after foreach_target.h
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
|
|
|
|||
|
|
@ -13,6 +13,11 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// We use ConcatEven/Odd which are not supported. Use HWY_EMU128 instead.
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||
#endif
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/sfp.h"
|
||||
|
||||
|
|
@ -27,9 +32,10 @@
|
|||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"third_party/gemma_cpp/compression/sfp_test.cc" // NOLINT
|
||||
#define HWY_TARGET_INCLUDE "compression/sfp_test.cc" // NOLINT
|
||||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
// Any highway.h must come after foreach_target.h
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
|
|
@ -301,7 +307,7 @@ struct TestEncDec {
|
|||
for (size_t i = 0; i < num; ++i) {
|
||||
const float out = hwy::F32FromBF16(dec[i]);
|
||||
sum += hwy::ConvertScalarTo<double>(hwy::ScalarAbs(in[i]));
|
||||
stats.Notify(in[i], out);
|
||||
stats.Notify(hwy::ConvertScalarTo<float>(in[i]), out);
|
||||
}
|
||||
const double avg = sum / num;
|
||||
fprintf(stderr, "Avg magnitude %.3E, p-norm %.3E snr %.2f @%zu = %.4E\n",
|
||||
|
|
|
|||
Loading…
Reference in New Issue