mirror of https://github.com/google/gemma.cpp.git
Compare commits
57 Commits
dfd4e16ae8
...
2c038e1285
| Author | SHA1 | Date |
|---|---|---|
|
|
2c038e1285 | |
|
|
8ed22e52bf | |
|
|
19017fdb6d | |
|
|
28ca001d5e | |
|
|
429eb78512 | |
|
|
3d72f17261 | |
|
|
6eeef2e2d9 | |
|
|
2a71333c8a | |
|
|
9a2682d544 | |
|
|
bafb8382f8 | |
|
|
0afa480d90 | |
|
|
4a6173d929 | |
|
|
564937ede6 | |
|
|
2829ef17ad | |
|
|
59ebecce22 | |
|
|
12fb2f05cf | |
|
|
8f04a8346d | |
|
|
f8ccb8e37c | |
|
|
374fd7478a | |
|
|
afaca4efa8 | |
|
|
befe9fb07e | |
|
|
6a78a23f4c | |
|
|
f608337fef | |
|
|
aa0b113214 | |
|
|
5cb63346aa | |
|
|
27117cc39f | |
|
|
1d18c5a129 | |
|
|
0816a1070d | |
|
|
7a12e29027 | |
|
|
e8f59bb411 | |
|
|
9e0ac5de34 | |
|
|
2d4de6b08b | |
|
|
75eca87039 | |
|
|
b27d8d6b92 | |
|
|
ea45d7c4d7 | |
|
|
e8d29792ac | |
|
|
3bf22abb22 | |
|
|
ca971ef50f | |
|
|
e9a0caed87 | |
|
|
38f1ea9b80 | |
|
|
a8ceb75f43 | |
|
|
a939b5fc9f | |
|
|
05e7e2b2bb | |
|
|
4ef3da733a | |
|
|
2c5706f159 | |
|
|
03284d752e | |
|
|
342e998cb6 | |
|
|
e541707caa | |
|
|
4e960d67f6 | |
|
|
809bd0709d | |
|
|
54120a5571 | |
|
|
881eeffe0a | |
|
|
da91f4c4be | |
|
|
827fec1904 | |
|
|
2099b37732 | |
|
|
a982ec1287 | |
|
|
9ca662dc14 |
|
|
@ -0,0 +1,6 @@
|
||||||
|
{
|
||||||
|
"cmake.configureOnOpen": false,
|
||||||
|
"files.associations": {
|
||||||
|
"array": "cpp"
|
||||||
|
}
|
||||||
|
}
|
||||||
87
BUILD.bazel
87
BUILD.bazel
|
|
@ -22,9 +22,7 @@ exports_files(["LICENSE"])
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "ops",
|
name = "ops",
|
||||||
hdrs = [
|
hdrs = ["gemma/ops.h"],
|
||||||
"ops.h",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@hwy//:algo",
|
"@hwy//:algo",
|
||||||
|
|
@ -41,42 +39,34 @@ cc_library(
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "ops_test",
|
name = "ops_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["ops_test.cc"],
|
srcs = ["gemma/ops_test.cc"],
|
||||||
local_defines = ["HWY_IS_TEST"],
|
local_defines = ["HWY_IS_TEST"],
|
||||||
# for test_suite.
|
# for test_suite.
|
||||||
tags = ["hwy_ops_test"],
|
tags = ["hwy_ops_test"],
|
||||||
deps = [
|
deps = [
|
||||||
":ops",
|
":ops",
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main",
|
||||||
|
"//compression:compress",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:hwy_test_util",
|
"@hwy//:hwy_test_util",
|
||||||
],
|
"@hwy//:thread_pool",
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "args",
|
|
||||||
hdrs = [
|
|
||||||
"util/args.h",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
"@hwy//:hwy",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "gemma_lib",
|
name = "gemma_lib",
|
||||||
srcs = [
|
srcs = [
|
||||||
"gemma.cc",
|
"gemma/gemma.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"configs.h",
|
"gemma/configs.h",
|
||||||
"gemma.h",
|
"gemma/gemma.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":args",
|
|
||||||
":ops",
|
":ops",
|
||||||
# "//base",
|
# "//base",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
|
"//compression:io",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:matvec",
|
"@hwy//:matvec",
|
||||||
"@hwy//:nanobenchmark", # timer
|
"@hwy//:nanobenchmark", # timer
|
||||||
|
|
@ -86,9 +76,29 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "args",
|
||||||
|
hdrs = ["util/args.h"],
|
||||||
|
deps = [
|
||||||
|
"//compression:io",
|
||||||
|
"@hwy//:hwy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "app",
|
||||||
|
hdrs = ["util/app.h"],
|
||||||
|
deps = [
|
||||||
|
":args",
|
||||||
|
":gemma_lib",
|
||||||
|
"//compression:io",
|
||||||
|
"@hwy//:hwy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "gemma_test",
|
name = "gemma_test",
|
||||||
srcs = ["gemma_test.cc"],
|
srcs = ["gemma/gemma_test.cc"],
|
||||||
# Requires model files
|
# Requires model files
|
||||||
tags = [
|
tags = [
|
||||||
"local",
|
"local",
|
||||||
|
|
@ -105,23 +115,9 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "app",
|
|
||||||
hdrs = [
|
|
||||||
"util/app.h",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
":args",
|
|
||||||
":gemma_lib",
|
|
||||||
"@hwy//:hwy",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "gemma",
|
name = "gemma",
|
||||||
srcs = [
|
srcs = ["gemma/run.cc"],
|
||||||
"run.cc",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
":app",
|
":app",
|
||||||
":args",
|
":args",
|
||||||
|
|
@ -137,9 +133,7 @@ cc_binary(
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "compress_weights",
|
name = "compress_weights",
|
||||||
srcs = [
|
srcs = ["gemma/compress_weights.cc"],
|
||||||
"compress_weights.cc",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
":args",
|
":args",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
|
|
@ -154,8 +148,25 @@ cc_binary(
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "benchmark",
|
name = "benchmark",
|
||||||
|
srcs = ["gemma/benchmark.cc"],
|
||||||
|
deps = [
|
||||||
|
":app",
|
||||||
|
":args",
|
||||||
|
":gemma_lib",
|
||||||
|
# "//base",
|
||||||
|
"//compression:compress",
|
||||||
|
"@hwy//:hwy",
|
||||||
|
"@hwy//:nanobenchmark",
|
||||||
|
"@hwy//:profiler",
|
||||||
|
"@hwy//:thread_pool",
|
||||||
|
"@nlohmann_json//:json",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "debug_prompt",
|
||||||
srcs = [
|
srcs = [
|
||||||
"benchmark.cc",
|
"debug_prompt.cc",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":app",
|
":app",
|
||||||
|
|
|
||||||
|
|
@ -34,16 +34,22 @@ FetchContent_Declare(json GIT_REPOSITORY https://github.com/nlohmann/json.git GI
|
||||||
FetchContent_MakeAvailable(json)
|
FetchContent_MakeAvailable(json)
|
||||||
|
|
||||||
set(SOURCES
|
set(SOURCES
|
||||||
gemma.cc
|
|
||||||
compression/blob_store.cc
|
compression/blob_store.cc
|
||||||
compression/blob_store.h
|
compression/blob_store.h
|
||||||
compression/compress.h
|
compression/compress.h
|
||||||
compression/compress-inl.h
|
compression/compress-inl.h
|
||||||
|
compression/io_win.cc
|
||||||
|
compression/io.cc
|
||||||
|
compression/io.h
|
||||||
compression/nuq.h
|
compression/nuq.h
|
||||||
compression/nuq-inl.h
|
compression/nuq-inl.h
|
||||||
compression/sfp.h
|
compression/sfp.h
|
||||||
compression/sfp-inl.h
|
compression/sfp-inl.h
|
||||||
compression/test_util.h
|
compression/test_util.h
|
||||||
|
gemma/configs.h
|
||||||
|
gemma/gemma.cc
|
||||||
|
gemma/gemma.h
|
||||||
|
gemma/ops.h
|
||||||
util/app.h
|
util/app.h
|
||||||
util/args.h
|
util/args.h
|
||||||
)
|
)
|
||||||
|
|
@ -72,26 +78,31 @@ set_property(TARGET libgemma PROPERTY CXX_STANDARD 17)
|
||||||
set_target_properties(libgemma PROPERTIES PREFIX "")
|
set_target_properties(libgemma PROPERTIES PREFIX "")
|
||||||
set_property(TARGET libgemma PROPERTY POSITION_INDEPENDENT_CODE ON)
|
set_property(TARGET libgemma PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||||
target_include_directories(libgemma PUBLIC ./)
|
target_include_directories(libgemma PUBLIC ./)
|
||||||
target_link_libraries(libgemma hwy hwy_contrib sentencepiece)
|
target_link_libraries(libgemma hwy hwy_contrib sentencepiece-static)
|
||||||
target_include_directories(libgemma PUBLIC ${sentencepiece_SOURCE_DIR})
|
target_include_directories(libgemma PUBLIC ${sentencepiece_SOURCE_DIR})
|
||||||
target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
|
target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
|
||||||
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
|
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
|
||||||
|
install(TARGETS libgemma DESTINATION lib)
|
||||||
|
|
||||||
# Executable Target
|
# Executable Target
|
||||||
|
|
||||||
add_executable(gemma run.cc)
|
add_executable(gemma gemma/run.cc)
|
||||||
target_link_libraries(gemma libgemma hwy hwy_contrib)
|
target_link_libraries(gemma libgemma hwy hwy_contrib)
|
||||||
|
install(TARGETS gemma DESTINATION bin)
|
||||||
|
|
||||||
add_executable(benchmark benchmark.cc)
|
add_executable(benchmark gemma/benchmark.cc)
|
||||||
target_link_libraries(benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
|
target_link_libraries(benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
|
||||||
|
|
||||||
|
add_executable(debug_prompt debug_prompt.cc)
|
||||||
|
target_link_libraries(debug_prompt libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
|
||||||
|
|
||||||
## Tests
|
## Tests
|
||||||
set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests")
|
set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests")
|
||||||
if (GEMMA_ENABLE_TESTS)
|
if (GEMMA_ENABLE_TESTS)
|
||||||
|
|
||||||
set(GEMMA_TEST_FILES
|
set(GEMMA_TEST_FILES
|
||||||
ops_test.cc
|
gemma/ops_test.cc
|
||||||
gemma_test.cc
|
gemma/gemma_test.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
foreach (TESTFILE IN LISTS GEMMA_TEST_FILES)
|
foreach (TESTFILE IN LISTS GEMMA_TEST_FILES)
|
||||||
|
|
@ -112,5 +123,5 @@ endif() # GEMMA_ENABLE_TESTS
|
||||||
|
|
||||||
## Tools
|
## Tools
|
||||||
|
|
||||||
add_executable(compress_weights compress_weights.cc)
|
add_executable(compress_weights gemma/compress_weights.cc)
|
||||||
target_link_libraries(compress_weights libgemma hwy hwy_contrib)
|
target_link_libraries(compress_weights libgemma hwy hwy_contrib)
|
||||||
|
|
|
||||||
|
|
@ -83,6 +83,23 @@ A `.clang-format` configuration is provided with our defaults, please run source
|
||||||
files through `clang-format` (or a formatter that produces equivalent behavior)
|
files through `clang-format` (or a formatter that produces equivalent behavior)
|
||||||
before finalizing PR for submission.
|
before finalizing PR for submission.
|
||||||
|
|
||||||
|
## Converting weights
|
||||||
|
|
||||||
|
We use a stripped down binary blob (.sbs) artifact to accelerate weight loading
|
||||||
|
in C++. These files can be downloaded directly from Kaggle and HuggingFace. You
|
||||||
|
can also convert Pytorch or Keras checkpoints to .sbs, but most end users should
|
||||||
|
not have to do this.
|
||||||
|
|
||||||
|
If starting with Keras, first run this script to convert to Pytorch:
|
||||||
|
https://github.com/keras-team/keras-nlp/blob/master/tools/gemma/export_gemma_to_torch_xla.py
|
||||||
|
|
||||||
|
From Pytorch, use the following script to generate uncompressed weights:
|
||||||
|
https://github.com/google/gemma.cpp/blob/dev/util/convert_weights.py
|
||||||
|
|
||||||
|
Then run gemma/compress_weights.cc (Bazel target :compress_weights), specifying
|
||||||
|
the resulting file as `--weights` and the desired .sbs name as the
|
||||||
|
`--compressed_weights`.
|
||||||
|
|
||||||
## Compile-Time Flags (Advanced)
|
## Compile-Time Flags (Advanced)
|
||||||
|
|
||||||
There are several compile-time flags to be aware of (note these may or may not
|
There are several compile-time flags to be aware of (note these may or may not
|
||||||
|
|
@ -169,9 +186,9 @@ inference path of the Gemma model.
|
||||||
The sentencepiece library we depend on requires some additional work to build
|
The sentencepiece library we depend on requires some additional work to build
|
||||||
with the Bazel build system. First, it does not export its BUILD file, so we
|
with the Bazel build system. First, it does not export its BUILD file, so we
|
||||||
provide `bazel/sentencepiece.bazel`. Second, it ships with a vendored subset of
|
provide `bazel/sentencepiece.bazel`. Second, it ships with a vendored subset of
|
||||||
the Abseil library. `bazel/com_google_sentencepiece.patch` changes the code to
|
the Abseil library. `bazel/sentencepiece.patch` changes the code to support
|
||||||
support Abseil as a standalone dependency without third_party/ prefixes, similar
|
Abseil as a standalone dependency without third_party/ prefixes, similar to the
|
||||||
to the transforms we apply to Gemma via Copybara.
|
transforms we apply to Gemma via Copybara.
|
||||||
|
|
||||||
## Discord
|
## Discord
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ http_archive(
|
||||||
strip_prefix = "sentencepiece-0.1.96",
|
strip_prefix = "sentencepiece-0.1.96",
|
||||||
urls = ["https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"],
|
urls = ["https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"],
|
||||||
build_file = "@//bazel:sentencepiece.bazel",
|
build_file = "@//bazel:sentencepiece.bazel",
|
||||||
patches = ["@//bazel:com_google_sentencepiece.patch"],
|
patches = ["@//bazel:sentencepiece.patch"],
|
||||||
patch_args = ["-p1"],
|
patch_args = ["-p1"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -206,6 +206,12 @@ bazel build -c opt --cxxopt=-std=c++20 :gemma
|
||||||
|
|
||||||
If the build is successful, you should now have a `gemma` executable in the `bazel-bin/` directory.
|
If the build is successful, you should now have a `gemma` executable in the `bazel-bin/` directory.
|
||||||
|
|
||||||
|
#### Make
|
||||||
|
|
||||||
|
If you prefer Makefiles, @jart has made one available here:
|
||||||
|
|
||||||
|
https://github.com/jart/gemma3/blob/main/Makefile
|
||||||
|
|
||||||
### Step 4: Run
|
### Step 4: Run
|
||||||
|
|
||||||
You can now run `gemma` from inside the `build/` directory.
|
You can now run `gemma` from inside the `build/` directory.
|
||||||
|
|
@ -252,7 +258,7 @@ here provide a C++ implementation of this model based on the paper.
|
||||||
|
|
||||||
To use the recurrent version of Gemma included in this repository, build the
|
To use the recurrent version of Gemma included in this repository, build the
|
||||||
gemma binary as noted above in Step 3. Download the compressed weights and
|
gemma binary as noted above in Step 3. Download the compressed weights and
|
||||||
tokenizer from
|
tokenizer from the RecurrentGemma
|
||||||
[Kaggle](https://www.kaggle.com/models/google/recurrentgemma/gemmaCpp) as in
|
[Kaggle](https://www.kaggle.com/models/google/recurrentgemma/gemmaCpp) as in
|
||||||
Step 1, and run the binary as follows:
|
Step 1, and run the binary as follows:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
# Required for referencing bazel:com_google_sentencepiece.patch
|
# Required for referencing bazel:sentencepiece.patch
|
||||||
package(
|
package(
|
||||||
default_applicable_licenses = ["//:license"],
|
default_applicable_licenses = ["//:license"],
|
||||||
default_visibility = ["//:__subpackages__"],
|
default_visibility = ["//:__subpackages__"],
|
||||||
|
|
|
||||||
|
|
@ -11,14 +11,24 @@ package(
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "blob_store",
|
name = "io",
|
||||||
srcs = [
|
srcs = [
|
||||||
"blob_store.cc",
|
"io.cc",
|
||||||
],
|
# Placeholder for io backend, do not remove
|
||||||
hdrs = [
|
|
||||||
"blob_store.h",
|
|
||||||
],
|
],
|
||||||
|
hdrs = ["io.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
# Placeholder for io deps, do not remove
|
||||||
|
"@hwy//:hwy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "blob_store",
|
||||||
|
srcs = ["blob_store.cc"],
|
||||||
|
hdrs = ["blob_store.h"],
|
||||||
|
deps = [
|
||||||
|
":io",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:thread_pool",
|
"@hwy//:thread_pool",
|
||||||
],
|
],
|
||||||
|
|
@ -39,7 +49,23 @@ cc_library(
|
||||||
name = "distortion",
|
name = "distortion",
|
||||||
hdrs = ["distortion.h"],
|
hdrs = ["distortion.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":stats",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
|
"@hwy//hwy/contrib/sort:vqsort",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "distortion_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["distortion_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":distortion",
|
||||||
|
":test_util",
|
||||||
|
"@googletest//:gtest_main",
|
||||||
|
"@hwy//:hwy",
|
||||||
|
"@hwy//:hwy_test_util",
|
||||||
|
"@hwy//:nanobenchmark", # Unpredictable1
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -56,12 +82,8 @@ cc_library(
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "sfp",
|
name = "sfp",
|
||||||
hdrs = [
|
hdrs = ["sfp.h"],
|
||||||
"sfp.h",
|
textual_hdrs = ["sfp-inl.h"],
|
||||||
],
|
|
||||||
textual_hdrs = [
|
|
||||||
"sfp-inl.h",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
],
|
],
|
||||||
|
|
@ -88,12 +110,8 @@ cc_test(
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "nuq",
|
name = "nuq",
|
||||||
hdrs = [
|
hdrs = ["nuq.h"],
|
||||||
"nuq.h",
|
textual_hdrs = ["nuq-inl.h"],
|
||||||
],
|
|
||||||
textual_hdrs = [
|
|
||||||
"nuq-inl.h",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
":sfp",
|
":sfp",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
|
|
@ -134,6 +152,7 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":blob_store",
|
":blob_store",
|
||||||
":distortion",
|
":distortion",
|
||||||
|
":io",
|
||||||
":nuq",
|
":nuq",
|
||||||
":sfp",
|
":sfp",
|
||||||
":stats",
|
":stats",
|
||||||
|
|
@ -146,9 +165,7 @@ cc_library(
|
||||||
# For internal experimentation
|
# For internal experimentation
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "analyze",
|
name = "analyze",
|
||||||
textual_hdrs = [
|
textual_hdrs = ["analyze.h"],
|
||||||
"analyze.h",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
":distortion",
|
":distortion",
|
||||||
":nuq",
|
":nuq",
|
||||||
|
|
|
||||||
|
|
@ -26,11 +26,8 @@
|
||||||
#include <cstdlib> // std::abs
|
#include <cstdlib> // std::abs
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/distortion.h"
|
#include "compression/distortion.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/nuq.h"
|
#include "compression/nuq.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/stats.h"
|
#include "compression/stats.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
@ -46,9 +43,7 @@
|
||||||
#define THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
|
#define THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/nuq-inl.h"
|
#include "compression/nuq-inl.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/sfp-inl.h"
|
#include "compression/sfp-inl.h"
|
||||||
#include "hwy/contrib/sort/vqsort-inl.h"
|
#include "hwy/contrib/sort/vqsort-inl.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
|
|
||||||
|
|
@ -13,89 +13,21 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
// Request POSIX 2008, including `pread()` and `posix_fadvise()`.
|
|
||||||
#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700
|
|
||||||
#undef _XOPEN_SOURCE
|
|
||||||
#define _XOPEN_SOURCE 700
|
|
||||||
#endif
|
|
||||||
#if !defined(_POSIX_C_SOURCE) || _POSIX_C_SOURCE < 200809
|
|
||||||
#define _POSIX_C_SOURCE 200809
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Make `off_t` 64-bit even on 32-bit systems. Works for Android >= r15c.
|
|
||||||
#undef _FILE_OFFSET_BITS
|
|
||||||
#define _FILE_OFFSET_BITS 64
|
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/blob_store.h"
|
#include "compression/blob_store.h"
|
||||||
|
|
||||||
#include <fcntl.h> // open
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
|
|
||||||
#include <sys/stat.h> // O_RDONLY
|
|
||||||
#if HWY_OS_WIN
|
|
||||||
#include <fileapi.h>
|
|
||||||
#include <io.h> // read, write, close
|
|
||||||
#else
|
|
||||||
#include <unistd.h> // read, write, close
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "compression/io.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/detect_compiler_arch.h"
|
#include "hwy/detect_compiler_arch.h"
|
||||||
|
|
||||||
namespace {
|
|
||||||
#if HWY_OS_WIN
|
|
||||||
|
|
||||||
// pread is not supported on Windows
|
|
||||||
static int64_t pread(int fd, void* buf, uint64_t size, uint64_t offset) {
|
|
||||||
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
|
|
||||||
if (file == INVALID_HANDLE_VALUE) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
OVERLAPPED overlapped = {0};
|
|
||||||
overlapped.Offset = offset & 0xFFFFFFFF;
|
|
||||||
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
|
|
||||||
|
|
||||||
DWORD bytes_read;
|
|
||||||
if (!ReadFile(file, buf, size, &bytes_read, &overlapped)) {
|
|
||||||
if (GetLastError() != ERROR_HANDLE_EOF) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return bytes_read;
|
|
||||||
}
|
|
||||||
|
|
||||||
// pwrite is not supported on Windows
|
|
||||||
static int64_t pwrite(int fd, const void* buf, uint64_t size, uint64_t offset) {
|
|
||||||
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
|
|
||||||
if (file == INVALID_HANDLE_VALUE) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
OVERLAPPED overlapped = {0};
|
|
||||||
overlapped.Offset = offset & 0xFFFFFFFF;
|
|
||||||
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
|
|
||||||
|
|
||||||
DWORD bytes_written;
|
|
||||||
if (!WriteFile(file, buf, size, &bytes_written, &overlapped)) {
|
|
||||||
if (GetLastError() != ERROR_HANDLE_EOF) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return bytes_written;
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
hwy::uint128_t MakeKey(const char* string) {
|
hwy::uint128_t MakeKey(const char* string) {
|
||||||
|
|
@ -132,61 +64,6 @@ void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data,
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
struct IO {
|
|
||||||
// Returns size in bytes or 0.
|
|
||||||
static uint64_t FileSize(const char* filename) {
|
|
||||||
int fd = open(filename, O_RDONLY);
|
|
||||||
if (fd < 0) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
#if HWY_OS_WIN
|
|
||||||
const int64_t size = _lseeki64(fd, 0, SEEK_END);
|
|
||||||
HWY_ASSERT(close(fd) != -1);
|
|
||||||
if (size < 0) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
static_assert(sizeof(off_t) == 8, "64-bit off_t required");
|
|
||||||
const off_t size = lseek(fd, 0, SEEK_END);
|
|
||||||
HWY_ASSERT(close(fd) != -1);
|
|
||||||
if (size == static_cast<off_t>(-1)) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
return static_cast<uint64_t>(size);
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool Read(int fd, uint64_t offset, uint64_t size, void* to) {
|
|
||||||
uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
|
|
||||||
uint64_t pos = 0;
|
|
||||||
for (;;) {
|
|
||||||
// pread seems to be faster than lseek + read when parallelized.
|
|
||||||
const auto bytes_read = pread(fd, bytes + pos, size - pos, offset + pos);
|
|
||||||
if (bytes_read <= 0) break;
|
|
||||||
pos += bytes_read;
|
|
||||||
HWY_ASSERT(pos <= size);
|
|
||||||
if (pos == size) break;
|
|
||||||
}
|
|
||||||
return pos == size; // success if managed to read desired size
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool Write(const void* from, uint64_t size, uint64_t offset, int fd) {
|
|
||||||
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(from);
|
|
||||||
uint64_t pos = 0;
|
|
||||||
for (;;) {
|
|
||||||
const auto bytes_written =
|
|
||||||
pwrite(fd, bytes + pos, size - pos, offset + pos);
|
|
||||||
if (bytes_written <= 0) break;
|
|
||||||
pos += bytes_written;
|
|
||||||
HWY_ASSERT(pos <= size);
|
|
||||||
if (pos == size) break;
|
|
||||||
}
|
|
||||||
return pos == size; // success if managed to write desired size
|
|
||||||
}
|
|
||||||
}; // IO
|
|
||||||
|
|
||||||
static_assert(HWY_IS_LITTLE_ENDIAN, "Assumes little endian");
|
static_assert(HWY_IS_LITTLE_ENDIAN, "Assumes little endian");
|
||||||
|
|
||||||
// On-disk representation (little-endian).
|
// On-disk representation (little-endian).
|
||||||
|
|
@ -323,26 +200,13 @@ class BlobStore {
|
||||||
};
|
};
|
||||||
#pragma pack(pop)
|
#pragma pack(pop)
|
||||||
|
|
||||||
BlobError BlobReader::Open(const char* filename) {
|
BlobError BlobReader::Open(const Path& filename) {
|
||||||
#if HWY_OS_WIN
|
file_ = OpenFileOrNull(filename, "r");
|
||||||
DWORD flags = FILE_ATTRIBUTE_NORMAL | FILE_FLAG_SEQUENTIAL_SCAN;
|
if (!file_) return __LINE__;
|
||||||
HANDLE file = CreateFileA(filename, GENERIC_READ, FILE_SHARE_READ, nullptr,
|
|
||||||
OPEN_EXISTING, flags, nullptr);
|
|
||||||
if (file == INVALID_HANDLE_VALUE) return __LINE__;
|
|
||||||
fd_ = _open_osfhandle(reinterpret_cast<intptr_t>(file), _O_RDONLY);
|
|
||||||
#else
|
|
||||||
fd_ = open(filename, O_RDONLY);
|
|
||||||
#endif
|
|
||||||
if (fd_ < 0) return __LINE__;
|
|
||||||
|
|
||||||
#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
|
|
||||||
// Doubles the readahead window, which seems slightly faster when cached.
|
|
||||||
(void)posix_fadvise(fd_, 0, 0, POSIX_FADV_SEQUENTIAL);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Read first part of header to get actual size.
|
// Read first part of header to get actual size.
|
||||||
BlobStore bs;
|
BlobStore bs;
|
||||||
if (!IO::Read(fd_, 0, sizeof(bs), &bs)) return __LINE__;
|
if (!file_->Read(0, sizeof(bs), &bs)) return __LINE__;
|
||||||
const size_t padded_size = bs.PaddedHeaderSize();
|
const size_t padded_size = bs.PaddedHeaderSize();
|
||||||
HWY_ASSERT(padded_size >= sizeof(bs));
|
HWY_ASSERT(padded_size >= sizeof(bs));
|
||||||
|
|
||||||
|
|
@ -354,18 +218,11 @@ BlobError BlobReader::Open(const char* filename) {
|
||||||
hwy::CopySameSize(&bs, blob_store_.get());
|
hwy::CopySameSize(&bs, blob_store_.get());
|
||||||
// Read the rest of the header, but not the full file.
|
// Read the rest of the header, but not the full file.
|
||||||
uint8_t* bytes = reinterpret_cast<uint8_t*>(blob_store_.get());
|
uint8_t* bytes = reinterpret_cast<uint8_t*>(blob_store_.get());
|
||||||
if (!IO::Read(fd_, sizeof(bs), padded_size - sizeof(bs),
|
if (!file_->Read(sizeof(bs), padded_size - sizeof(bs), bytes + sizeof(bs))) {
|
||||||
bytes + sizeof(bs))) {
|
|
||||||
return __LINE__;
|
return __LINE__;
|
||||||
}
|
}
|
||||||
|
|
||||||
return blob_store_->CheckValidity(IO::FileSize(filename));
|
return blob_store_->CheckValidity(file_->FileSize());
|
||||||
}
|
|
||||||
|
|
||||||
BlobReader::~BlobReader() {
|
|
||||||
if (fd_ >= 0) {
|
|
||||||
HWY_ASSERT(close(fd_) != -1);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) {
|
BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) {
|
||||||
|
|
@ -392,13 +249,13 @@ BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) {
|
||||||
// between consecutive runs.
|
// between consecutive runs.
|
||||||
// - memory-mapped I/O is less predictable and adds noise to measurements.
|
// - memory-mapped I/O is less predictable and adds noise to measurements.
|
||||||
BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
|
BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
|
||||||
const int fd = fd_;
|
File* pfile = file_.get(); // not owned
|
||||||
const auto& requests = requests_;
|
const auto& requests = requests_;
|
||||||
std::atomic_flag err = ATOMIC_FLAG_INIT;
|
std::atomic_flag err = ATOMIC_FLAG_INIT;
|
||||||
// >5x speedup from parallel reads when cached.
|
// >5x speedup from parallel reads when cached.
|
||||||
pool.Run(0, requests.size(),
|
pool.Run(0, requests.size(),
|
||||||
[fd, &requests, &err](uint64_t i, size_t /*thread*/) {
|
[pfile, &requests, &err](uint64_t i, size_t /*thread*/) {
|
||||||
if (!IO::Read(fd, requests[i].offset, requests[i].size,
|
if (!pfile->Read(requests[i].offset, requests[i].size,
|
||||||
requests[i].data)) {
|
requests[i].data)) {
|
||||||
err.test_and_set();
|
err.test_and_set();
|
||||||
}
|
}
|
||||||
|
|
@ -407,8 +264,7 @@ BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool,
|
BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
|
||||||
const char* filename) const {
|
|
||||||
HWY_ASSERT(keys_.size() == blobs_.size());
|
HWY_ASSERT(keys_.size() == blobs_.size());
|
||||||
|
|
||||||
// Concatenate blobs in memory.
|
// Concatenate blobs in memory.
|
||||||
|
|
@ -419,26 +275,18 @@ BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool,
|
||||||
keys_.data(), blobs_.data(), keys_.size(), bs.get());
|
keys_.data(), blobs_.data(), keys_.size(), bs.get());
|
||||||
|
|
||||||
// Create/replace existing file.
|
// Create/replace existing file.
|
||||||
#if HWY_OS_WIN
|
std::unique_ptr<File> file = OpenFileOrNull(filename, "w+");
|
||||||
DWORD flags = FILE_ATTRIBUTE_NORMAL;
|
if (!file) return __LINE__;
|
||||||
HANDLE file = CreateFileA(filename, GENERIC_WRITE, 0, nullptr, CREATE_ALWAYS,
|
File* pfile = file.get(); // not owned
|
||||||
flags, nullptr);
|
|
||||||
if (file == INVALID_HANDLE_VALUE) return __LINE__;
|
|
||||||
const int fd = _open_osfhandle(reinterpret_cast<intptr_t>(file), _O_WRONLY);
|
|
||||||
#else
|
|
||||||
const int fd = open(filename, O_CREAT | O_RDWR | O_TRUNC, 0644);
|
|
||||||
#endif
|
|
||||||
if (fd < 0) return __LINE__;
|
|
||||||
|
|
||||||
std::atomic_flag err = ATOMIC_FLAG_INIT;
|
std::atomic_flag err = ATOMIC_FLAG_INIT;
|
||||||
pool.Run(0, requests.size(),
|
pool.Run(0, requests.size(),
|
||||||
[fd, &requests, &err](uint64_t i, size_t /*thread*/) {
|
[pfile, &requests, &err](uint64_t i, size_t /*thread*/) {
|
||||||
if (!IO::Write(requests[i].data, requests[i].size,
|
if (!pfile->Write(requests[i].data, requests[i].size,
|
||||||
requests[i].offset, fd)) {
|
requests[i].offset)) {
|
||||||
err.test_and_set();
|
err.test_and_set();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
HWY_ASSERT(close(fd) != -1);
|
|
||||||
if (err.test_and_set()) return __LINE__;
|
if (err.test_and_set()) return __LINE__;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,10 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "compression/io.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h" // hwy::uint128_t
|
#include "hwy/base.h" // hwy::uint128_t
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
@ -59,10 +61,10 @@ struct BlobIO {
|
||||||
class BlobReader {
|
class BlobReader {
|
||||||
public:
|
public:
|
||||||
BlobReader() { requests_.reserve(500); }
|
BlobReader() { requests_.reserve(500); }
|
||||||
~BlobReader();
|
~BlobReader() = default;
|
||||||
|
|
||||||
// Opens `filename` and reads its header.
|
// Opens `filename` and reads its header.
|
||||||
BlobError Open(const char* filename);
|
BlobError Open(const Path& filename);
|
||||||
|
|
||||||
// Enqueues read requests if `key` is found and its size matches `size`.
|
// Enqueues read requests if `key` is found and its size matches `size`.
|
||||||
BlobError Enqueue(hwy::uint128_t key, void* data, size_t size);
|
BlobError Enqueue(hwy::uint128_t key, void* data, size_t size);
|
||||||
|
|
@ -73,7 +75,7 @@ class BlobReader {
|
||||||
private:
|
private:
|
||||||
BlobStorePtr blob_store_; // holds header, not the entire file
|
BlobStorePtr blob_store_; // holds header, not the entire file
|
||||||
std::vector<BlobIO> requests_;
|
std::vector<BlobIO> requests_;
|
||||||
int fd_ = 0;
|
std::unique_ptr<File> file_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class BlobWriter {
|
class BlobWriter {
|
||||||
|
|
@ -84,7 +86,7 @@ class BlobWriter {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stores all blobs to disk in the given order with padding for alignment.
|
// Stores all blobs to disk in the given order with padding for alignment.
|
||||||
BlobError WriteAll(hwy::ThreadPool& pool, const char* filename) const;
|
BlobError WriteAll(hwy::ThreadPool& pool, const Path& filename);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<hwy::uint128_t> keys_;
|
std::vector<hwy::uint128_t> keys_;
|
||||||
|
|
|
||||||
|
|
@ -23,11 +23,8 @@
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/blob_store.h"
|
#include "compression/blob_store.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/distortion.h"
|
#include "compression/distortion.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
@ -44,9 +41,7 @@
|
||||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE
|
#define THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/nuq-inl.h"
|
#include "compression/nuq-inl.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/sfp-inl.h"
|
#include "compression/sfp-inl.h"
|
||||||
#include "hwy/contrib/dot/dot-inl.h"
|
#include "hwy/contrib/dot/dot-inl.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
|
@ -63,6 +58,7 @@ struct CompressTraits {};
|
||||||
template <>
|
template <>
|
||||||
struct CompressTraits<float> {
|
struct CompressTraits<float> {
|
||||||
using MatT = float;
|
using MatT = float;
|
||||||
|
static constexpr bool kSupportsEvenOdd = false;
|
||||||
|
|
||||||
template <class DF, HWY_IF_F32_D(DF)>
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
|
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
|
||||||
|
|
@ -116,6 +112,7 @@ struct CompressTraits<float> {
|
||||||
template <>
|
template <>
|
||||||
struct CompressTraits<hwy::bfloat16_t> {
|
struct CompressTraits<hwy::bfloat16_t> {
|
||||||
using MatT = hwy::bfloat16_t;
|
using MatT = hwy::bfloat16_t;
|
||||||
|
static constexpr bool kSupportsEvenOdd = true;
|
||||||
|
|
||||||
template <class DF, HWY_IF_F32_D(DF)>
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
|
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
|
||||||
|
|
@ -224,11 +221,59 @@ struct CompressTraits<hwy::bfloat16_t> {
|
||||||
// bf16*bf16.
|
// bf16*bf16.
|
||||||
return hn::Dot::Compute<kAssumptions>(d_vec, vec_aligned, in + in_ofs, num);
|
return hn::Dot::Compute<kAssumptions>(d_vec, vec_aligned, in + in_ofs, num);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Computes the dot product of an even-odd deinterleaved, f32 `vec_aligned`
|
||||||
|
// and a column- major matrix `in`. `vec_aligned` should be aligned and
|
||||||
|
// alternate even-indexed `hn::Lanes(df32)` elements followed by odd-indexed
|
||||||
|
// `hn::Lanes(df32)` elements.
|
||||||
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
|
static HWY_INLINE float DotEO(
|
||||||
|
const DF df32, const hwy::bfloat16_t* HWY_RESTRICT in, size_t in_ofs,
|
||||||
|
const float* HWY_RESTRICT vec_aligned, size_t num) {
|
||||||
|
HWY_DASSERT(num >= (hn::Lanes(df32) * 2) && (num % (hn::Lanes(df32) * 2)) == 0);
|
||||||
|
HWY_DASSERT((in_ofs % (hn::Lanes(df32) * 2)) == 0);
|
||||||
|
HWY_DASSERT(hn::IsAligned(df32, vec_aligned));
|
||||||
|
|
||||||
|
const hn::Repartition<hwy::bfloat16_t, DF> dbf16;
|
||||||
|
using VF32 = decltype(Zero(df32));
|
||||||
|
const size_t N = Lanes(dbf16);
|
||||||
|
|
||||||
|
VF32 sum0 = Zero(df32);
|
||||||
|
VF32 sum1 = Zero(df32);
|
||||||
|
VF32 sum2 = Zero(df32);
|
||||||
|
VF32 sum3 = Zero(df32);
|
||||||
|
|
||||||
|
const hn::RebindToUnsigned<decltype(df32)> du32;
|
||||||
|
using VU32 = hn::VFromD<decltype(du32)>;
|
||||||
|
const VU32 odd = Set(du32, 0xFFFF0000u);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < num; /* i += 2 * N */) {
|
||||||
|
const auto interleaved0 = hn::LoadU(dbf16, in + in_ofs + i);
|
||||||
|
const VF32 ae0 = Load(df32, vec_aligned + i);
|
||||||
|
const VF32 ao0 = Load(df32, vec_aligned + i + (N / 2));
|
||||||
|
sum0 = hn::MulAdd(ae0, hn::PromoteEvenTo(df32, interleaved0), sum0);
|
||||||
|
sum1 = hn::MulAdd(ao0, hn::PromoteOddTo(df32, interleaved0), sum1);
|
||||||
|
i += N;
|
||||||
|
|
||||||
|
const auto interleaved1 = hn::LoadU(dbf16, in + in_ofs + i);
|
||||||
|
const VF32 ae1 = Load(df32, vec_aligned + i);
|
||||||
|
const VF32 ao1 = Load(df32, vec_aligned + i + (N / 2));
|
||||||
|
sum2 = hn::MulAdd(ae1, hn::PromoteEvenTo(df32, interleaved1), sum2);
|
||||||
|
sum3 = hn::MulAdd(ao1, hn::PromoteOddTo(df32, interleaved1), sum3);
|
||||||
|
i += N;
|
||||||
|
}
|
||||||
|
|
||||||
|
sum0 = Add(sum0, sum1);
|
||||||
|
sum2 = Add(sum2, sum3);
|
||||||
|
sum0 = Add(sum0, sum2);
|
||||||
|
return ReduceSum(df32, sum0);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct CompressTraits<SfpStream> {
|
struct CompressTraits<SfpStream> {
|
||||||
using MatT = SfpStream;
|
using MatT = SfpStream;
|
||||||
|
static constexpr bool kSupportsEvenOdd = false;
|
||||||
|
|
||||||
template <class DF, HWY_IF_F32_D(DF)>
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
static HWY_INLINE void Compress(DF df, const float* in, size_t num,
|
static HWY_INLINE void Compress(DF df, const float* in, size_t num,
|
||||||
|
|
@ -278,6 +323,7 @@ struct CompressTraits<SfpStream> {
|
||||||
template <>
|
template <>
|
||||||
struct CompressTraits<NuqStream> {
|
struct CompressTraits<NuqStream> {
|
||||||
using MatT = NuqStream;
|
using MatT = NuqStream;
|
||||||
|
static constexpr bool kSupportsEvenOdd = false;
|
||||||
|
|
||||||
template <class DF, HWY_IF_F32_D(DF)>
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
static HWY_INLINE void Compress(DF df, const float* in, size_t num,
|
static HWY_INLINE void Compress(DF df, const float* in, size_t num,
|
||||||
|
|
@ -430,16 +476,22 @@ HWY_INLINE float Dot(DF df, const ArrayT& compressed, size_t compressed_ofs,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns dot product with `vec_aligned` of length `num`.
|
// Returns dot product with `vec_aligned` of length `num`.
|
||||||
template <class DF, typename MatT, size_t kCapacity, typename VecT>
|
template <bool kVecEO, class DF, typename MatT, size_t kCapacity, typename VecT>
|
||||||
HWY_INLINE float Dot(DF df, const CompressedArray<MatT, kCapacity>& compressed,
|
HWY_INLINE float Dot(DF df, const CompressedArray<MatT, kCapacity>& compressed,
|
||||||
size_t compressed_ofs, const VecT* vec_aligned,
|
size_t compressed_ofs, const VecT* vec_aligned,
|
||||||
size_t num) {
|
size_t num) {
|
||||||
HWY_DASSERT(compressed_ofs + num <= compressed.size());
|
HWY_DASSERT(compressed_ofs + num <= compressed.size());
|
||||||
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
|
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
|
||||||
using Traits = CompressTraits<MatT>;
|
using Traits = CompressTraits<MatT>;
|
||||||
return (compressed.scale() * Traits::Dot(df, compressed.size(),
|
float dot_result;
|
||||||
compressed.data(), compressed_ofs,
|
if constexpr (kVecEO) {
|
||||||
vec_aligned, num));
|
dot_result = Traits::DotEO(df, compressed.data(), compressed_ofs,
|
||||||
|
vec_aligned, num);
|
||||||
|
} else {
|
||||||
|
dot_result = Traits::Dot(df, compressed.size(), compressed.data(),
|
||||||
|
compressed_ofs, vec_aligned, num);
|
||||||
|
}
|
||||||
|
return compressed.scale() * dot_result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Callback used by ForeachTensor.
|
// Callback used by ForeachTensor.
|
||||||
|
|
@ -464,11 +516,11 @@ class Compressor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void WriteAll(hwy::ThreadPool& pool, const char* blob_filename) {
|
void WriteAll(hwy::ThreadPool& pool, const Path& blob_filename) {
|
||||||
const BlobError err = writer_.WriteAll(pool, blob_filename);
|
const BlobError err = writer_.WriteAll(pool, blob_filename);
|
||||||
if (err != 0) {
|
if (err != 0) {
|
||||||
fprintf(stderr, "Failed to write blobs to %s (error %d)\n", blob_filename,
|
fprintf(stderr, "Failed to write blobs to %s (error %d)\n",
|
||||||
err);
|
blob_filename.path.c_str(), err);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,19 +27,15 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// IWYU pragma: begin_exports
|
// IWYU pragma: begin_exports
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/blob_store.h"
|
#include "compression/blob_store.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
#include "compression/io.h"
|
||||||
#include "compression/nuq.h"
|
#include "compression/nuq.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/sfp.h"
|
#include "compression/sfp.h"
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/distortion.h"
|
#include "compression/distortion.h"
|
||||||
#include "hwy/base.h" // hwy::bfloat16_t
|
#include "hwy/base.h" // hwy::bfloat16_t
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#if COMPRESS_STATS
|
#if COMPRESS_STATS
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/stats.h"
|
#include "compression/stats.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
@ -171,13 +167,13 @@ hwy::uint128_t CacheKey(const char* name) {
|
||||||
|
|
||||||
class CacheLoader {
|
class CacheLoader {
|
||||||
public:
|
public:
|
||||||
explicit CacheLoader(const char* blob_filename) {
|
explicit CacheLoader(const Path& blob_filename) {
|
||||||
err_ = reader_.Open(blob_filename);
|
err_ = reader_.Open(blob_filename);
|
||||||
if (err_ != 0) {
|
if (err_ != 0) {
|
||||||
fprintf(stderr,
|
fprintf(stderr,
|
||||||
"Cached compressed weights does not exist yet (code %d), "
|
"Cached compressed weights does not exist yet (code %d), "
|
||||||
"compressing weights and creating file: %s.\n",
|
"compressing weights and creating file: %s.\n",
|
||||||
err_, blob_filename);
|
err_, blob_filename.path.c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,85 +15,214 @@
|
||||||
|
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_
|
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_
|
||||||
|
|
||||||
#include <math.h> // pow
|
#include <math.h> // pow
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "compression/stats.h"
|
||||||
|
#include "hwy/aligned_allocator.h" // HWY_ALIGNMENT
|
||||||
#include "hwy/base.h" // ScalarAbs
|
#include "hwy/base.h" // ScalarAbs
|
||||||
|
#include "hwy/contrib/sort/vqsort.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
// Returns `sum` and `err` such that `sum + err` is exactly equal to `a + b`,
|
||||||
|
// despite floating-point rounding. `sum` is already the best estimate, so do
|
||||||
|
// not actually add `err` to it. Knuth98/Moller65. Unlike Fast2Sum [Dekker71],
|
||||||
|
// this does not require any relative ordering of the exponents of a and b.
|
||||||
|
template <typename T>
|
||||||
|
static inline T TwoSum(T a, T b, T& err) {
|
||||||
|
const T sum = a + b;
|
||||||
|
const T a2 = sum - b;
|
||||||
|
const T b2 = sum - a2;
|
||||||
|
const T err_a = a - a2;
|
||||||
|
const T err_b = b - b2;
|
||||||
|
err = err_a + err_b;
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accumulates numbers with about twice the precision of T using 7 * n FLOPS.
|
||||||
|
// Rump/Ogita/Oishi08, Algorithm 6.11 in Handbook of Floating-Point Arithmetic.
|
||||||
|
template <typename T>
|
||||||
|
class CascadedSummation {
|
||||||
|
public:
|
||||||
|
void Notify(T t) {
|
||||||
|
T err;
|
||||||
|
sum_ = TwoSum(sum_, t, err);
|
||||||
|
sum_err_ += err;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Assimilate(const CascadedSummation& other) {
|
||||||
|
Notify(other.sum_);
|
||||||
|
sum_err_ += other.sum_err_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allows users to observe how much difference the extra precision made.
|
||||||
|
T Err() const { return sum_err_; }
|
||||||
|
|
||||||
|
// Returns the sum of all `t` passed to `Notify`.
|
||||||
|
T Total() const { return sum_ + sum_err_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
T sum_ = T{0};
|
||||||
|
T sum_err_ = T{0};
|
||||||
|
};
|
||||||
|
|
||||||
|
// Summarizes the error of a distortion (e.g. quantization) applied to a series
|
||||||
|
// of numbers.
|
||||||
|
// Users should check all four resulting metrics (NumExact, NumRoundedToZero,
|
||||||
|
// GeomeanValueDivL1, WeightedAverageL1) because each covers different aspects.
|
||||||
class DistortionStats {
|
class DistortionStats {
|
||||||
public:
|
public:
|
||||||
void Notify(float original, float distorted) {
|
void Notify(float original, float distorted) {
|
||||||
(void)padding_; // prevent unused member warning
|
(void)padding_; // prevent unused member warning
|
||||||
|
|
||||||
const double l1 = hwy::ScalarAbs(original - distorted);
|
const bool rounded_to_zero = (original != 0.0f) && (distorted == 0.0f);
|
||||||
|
// We expect original == 0 is not distorted (can be exactly represented).
|
||||||
|
HWY_ASSERT(original != 0.0f || distorted == 0.0f);
|
||||||
|
|
||||||
if (l1 > max_l1_) {
|
s_original_.Notify(original);
|
||||||
max_l1_ = l1;
|
const float l1f = hwy::ScalarAbs(original - distorted);
|
||||||
max_idx_ = n_;
|
const double l1 = static_cast<double>(l1f);
|
||||||
|
s_l1_.Notify(l1f);
|
||||||
|
b_l1_.Notify(HWY_MIN(99, static_cast<int>(l1f * 1E4)));
|
||||||
|
if (l1f != 0.0f) {
|
||||||
|
l1_.push_back(l1f);
|
||||||
|
}
|
||||||
|
sum_l1_.Notify(l1f);
|
||||||
|
if (rounded_to_zero) sum_l1_rounded_.Notify(l1f);
|
||||||
|
|
||||||
|
// Event counts
|
||||||
|
{
|
||||||
|
n_ += 1;
|
||||||
|
// Rounding (small) negative numbers to 0 does not influence dot products
|
||||||
|
// as much as an actual sign flip, so do not count them.
|
||||||
|
n_sign_flip_ +=
|
||||||
|
((original < 0.0f) != (distorted < 0.0f)) && !rounded_to_zero;
|
||||||
|
n_exact_ += (l1f == 0.0f);
|
||||||
|
n_rounded_to_zero += rounded_to_zero;
|
||||||
}
|
}
|
||||||
|
|
||||||
const double pow3 = l1 * l1 * l1;
|
// Signal to noise ratio (Shannon's channel capacity, NOT the L2-based and
|
||||||
sum_pow3_ += pow3;
|
// logarithmic PSNR) to estimate the ratios of original to the L1 norm.
|
||||||
sum_pow6_ += pow3 * pow3;
|
if (l1f != 0.0) { // prevent division by zero
|
||||||
n_ += 1;
|
const double snr =
|
||||||
|
1.0 + static_cast<double>(hwy::ScalarAbs(original)) / l1;
|
||||||
// Avoid division by zero, which happens when there is no error. NumExact()
|
// For numerical purposes (prevents overflow). A hierarchical geomean
|
||||||
// reports the number of times this happens.
|
|
||||||
if (l1 != 0.0) {
|
|
||||||
const double rel = 1.0 + hwy::ScalarAbs(original) / l1;
|
|
||||||
// Logarithm is required to prevent overflow. A hierarchical geomean
|
|
||||||
// could also work, but that is more complex and not necessarily better.
|
// could also work, but that is more complex and not necessarily better.
|
||||||
sum_log_rel_ += log(rel);
|
// We will return exp() of the arithmetic mean.
|
||||||
num_rel_ += 1;
|
sum_log_snr_ += log(snr);
|
||||||
|
num_snr_ += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Assimilate(const DistortionStats& other) {
|
void Assimilate(const DistortionStats& other) {
|
||||||
if (other.max_l1_ > max_l1_) {
|
s_original_.Assimilate(other.s_original_);
|
||||||
max_l1_ = other.max_l1_;
|
s_l1_.Assimilate(other.s_l1_);
|
||||||
max_idx_ = other.max_idx_;
|
b_l1_.Assimilate(other.b_l1_);
|
||||||
}
|
sum_l1_.Assimilate(other.sum_l1_);
|
||||||
|
sum_l1_rounded_.Assimilate(other.sum_l1_rounded_);
|
||||||
|
l1_.insert(l1_.end(), other.l1_.begin(), other.l1_.end());
|
||||||
|
|
||||||
sum_pow3_ += other.sum_pow3_;
|
|
||||||
sum_pow6_ += other.sum_pow6_;
|
|
||||||
n_ += other.n_;
|
n_ += other.n_;
|
||||||
|
n_sign_flip_ += other.n_sign_flip_;
|
||||||
|
n_exact_ += other.n_exact_;
|
||||||
|
n_rounded_to_zero += other.n_rounded_to_zero;
|
||||||
|
|
||||||
sum_log_rel_ += other.sum_log_rel_;
|
sum_log_snr_ += other.sum_log_snr_;
|
||||||
num_rel_ += other.num_rel_;
|
num_snr_ += other.num_snr_;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t NumExact() const { return n_ - num_rel_; }
|
size_t NumExact() const { return n_exact_; }
|
||||||
|
size_t NumSignFlip() const { return n_sign_flip_; }
|
||||||
|
size_t NumRoundedToZero() const { return n_rounded_to_zero; }
|
||||||
|
// Total absolute error.
|
||||||
|
double SumL1() const { return sum_l1_.Total(); }
|
||||||
|
// Total absolute error for numbers that were rounded to zero.
|
||||||
|
double SumL1Rounded() const { return sum_l1_rounded_.Total(); }
|
||||||
|
|
||||||
|
// Returns geomean of 1 + S/N (Shannon channel capacity). This is computed via
|
||||||
|
// the ratio of input magnitude to nonzero L1 norms. Higher is better.
|
||||||
double GeomeanValueDivL1() const {
|
double GeomeanValueDivL1() const {
|
||||||
if (num_rel_ == 0) return 0.0;
|
if (num_snr_ == 0) return 0.0;
|
||||||
return exp(sum_log_rel_ / static_cast<double>(num_rel_));
|
return exp(sum_log_snr_ / static_cast<double>(num_snr_));
|
||||||
}
|
}
|
||||||
|
|
||||||
double PNorm() const {
|
// Returns weighted average of nonzero L1 norms. Those further from the median
|
||||||
// p-norms are a compromise between max-norm (penalizes the largest error
|
// L1 norm are much more heavily weighted, such that this behaves more like
|
||||||
// without dilution, but does not notice any other errors) and L1 (all
|
// the L-infinity norm, but still includes all differences, not just the max.
|
||||||
// errors contribute, but large errors are diluted by smaller ones).
|
// Lower is better, magnitude depends on the input magnitude.
|
||||||
const double norm3 = pow(sum_pow3_ / static_cast<double>(n_), 1.0 / 3);
|
double WeightedAverageL1() const {
|
||||||
const double norm6 = pow(sum_pow6_ / static_cast<double>(n_), 1.0 / 6);
|
if (l1_.empty()) return 0.0f; // all exact
|
||||||
return 0.5 * (norm3 + norm6);
|
|
||||||
|
std::vector<float> weights(l1_); // copy so we can modify
|
||||||
|
const float median = [&weights]() {
|
||||||
|
const size_t mid = weights.size() / 2;
|
||||||
|
// We just want the median; partial sort is faster if available (v1.2).
|
||||||
|
#if HWY_MAJOR > 1 || HWY_MINOR >= 2
|
||||||
|
hwy::VQSelect(weights.data(), weights.size(), mid, hwy::SortAscending());
|
||||||
|
#else
|
||||||
|
hwy::VQSort(weights.data(), weights.size(), hwy::SortAscending());
|
||||||
|
#endif
|
||||||
|
return weights[mid];
|
||||||
|
}();
|
||||||
|
weights = l1_; // restore original order
|
||||||
|
|
||||||
|
// Replace with distance from median (might have too few samples for mode).
|
||||||
|
float max_abs = -1.0f;
|
||||||
|
for (float& d : weights) {
|
||||||
|
d = hwy::ScalarAbs(d - median);
|
||||||
|
max_abs = HWY_MAX(max_abs, d);
|
||||||
|
}
|
||||||
|
HWY_ASSERT(max_abs >= 0.0f);
|
||||||
|
// All equal - return the distance value to prevent division by zero.
|
||||||
|
if (max_abs == 0.0f) return median;
|
||||||
|
|
||||||
|
// Normalize to max difference and exponentiate.
|
||||||
|
const double inv_max = 1.0 / static_cast<double>(max_abs);
|
||||||
|
double sum_weights = 0.0;
|
||||||
|
for (float& w : weights) {
|
||||||
|
const double normalized = static_cast<double>(w) * inv_max;
|
||||||
|
const double amplified = exp(4.0 * normalized * normalized);
|
||||||
|
sum_weights += amplified;
|
||||||
|
w = static_cast<float>(amplified);
|
||||||
|
}
|
||||||
|
// At least 1.0 per weight, plus more for at least one weight because we
|
||||||
|
// verified via max_abs that not all are equal.
|
||||||
|
HWY_ASSERT(sum_weights > static_cast<double>(weights.size()));
|
||||||
|
|
||||||
|
// Return weighted average.
|
||||||
|
double weighted_sum = 0.0;
|
||||||
|
for (size_t i = 0; i < weights.size(); ++i) {
|
||||||
|
weighted_sum += l1_[i] * weights[i];
|
||||||
|
}
|
||||||
|
return weighted_sum / sum_weights;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t MaxIndex() const { return max_idx_; }
|
Stats& L1() { return s_l1_; }
|
||||||
double MaxL1() const { return max_l1_; }
|
Stats& Original() { return s_original_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
Stats s_original_;
|
||||||
|
Stats s_l1_;
|
||||||
|
Bins<100> b_l1_;
|
||||||
|
CascadedSummation<double> sum_l1_; // all
|
||||||
|
CascadedSummation<double> sum_l1_rounded_; // only if rounded_to_zero
|
||||||
|
std::vector<float> l1_;
|
||||||
|
|
||||||
|
// Event counts
|
||||||
size_t n_ = 0;
|
size_t n_ = 0;
|
||||||
size_t max_idx_ = 0; // index that had l1 = max_l1_.
|
size_t n_sign_flip_ = 0;
|
||||||
double max_l1_ = -1.0;
|
size_t n_exact_ = 0;
|
||||||
|
size_t n_rounded_to_zero = 0;
|
||||||
|
|
||||||
double sum_pow3_ = 0.0;
|
double sum_log_snr_ = 0.0;
|
||||||
double sum_pow6_ = 0.0;
|
size_t num_snr_ = 0;
|
||||||
|
|
||||||
double sum_log_rel_ = 0.0;
|
uint8_t padding_[HWY_ALIGNMENT]; // prevents false sharing
|
||||||
size_t num_rel_ = 0;
|
|
||||||
double padding_; // prevents false sharing
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,99 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "compression/distortion.h"
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include "compression/test_util.h"
|
||||||
|
#include "hwy/nanobenchmark.h"
|
||||||
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
|
#include "hwy/tests/test_util-inl.h" // HWY_ASSERT_EQ
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
#if !HWY_TEST_STANDALONE
|
||||||
|
class DistortionTest : public testing::Test {};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
TEST(DistortionTest, TestCascadedSummation) {
|
||||||
|
CascadedSummation<double> cs;
|
||||||
|
// Example from Priest92. Exact sum is 2.
|
||||||
|
const double kHuge = 9007199254740992.0 * hwy::Unpredictable1(); // 2^53
|
||||||
|
const double kNeg = -4503599627370495.0 * hwy::Unpredictable1(); // -(2^52-1)
|
||||||
|
const double kIn[6] = {kHuge, kHuge - 2.0, kNeg, kNeg, kNeg, kNeg};
|
||||||
|
for (double in : kIn) {
|
||||||
|
cs.Notify(in);
|
||||||
|
}
|
||||||
|
HWY_ASSERT_EQ(2.0, cs.Total());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Number of exact and rounded-to-zero matches expectations.
|
||||||
|
TEST(DistortionTest, TestCounts) {
|
||||||
|
// Arbitrary positive/negative original, zero distorted.
|
||||||
|
DistortionStats stats;
|
||||||
|
for (size_t i = 1; i < 10; ++i) {
|
||||||
|
stats.Notify(i / 100.0f, 0.0f);
|
||||||
|
stats.Notify(i / -100.0f, 0.0f);
|
||||||
|
}
|
||||||
|
HWY_ASSERT(stats.NumExact() == 0);
|
||||||
|
HWY_ASSERT(stats.NumRoundedToZero() == 18);
|
||||||
|
|
||||||
|
// Add some exact (same):
|
||||||
|
size_t num_exact = 0;
|
||||||
|
for (float x = 0.0f; x <= 1.5f; x += 0.25f) {
|
||||||
|
stats.Notify(x, x);
|
||||||
|
stats.Notify(-x, -x);
|
||||||
|
num_exact += 2;
|
||||||
|
}
|
||||||
|
HWY_ASSERT_EQ(num_exact, stats.NumExact());
|
||||||
|
HWY_ASSERT(stats.NumRoundedToZero() == 18); // unchanged
|
||||||
|
}
|
||||||
|
|
||||||
|
// Few large differences are diluted in SNR but not WeightedAverageL1.
|
||||||
|
TEST(DistortionTest, TestDilution) {
|
||||||
|
DistortionStats stats;
|
||||||
|
for (size_t i = 0; i < 100; ++i) {
|
||||||
|
stats.Notify(0.998f, 0.999f); // small
|
||||||
|
}
|
||||||
|
HWY_ASSERT(IsInside(900.0, 1000.0, stats.GeomeanValueDivL1()));
|
||||||
|
// All-equal WeightedSum is exact.
|
||||||
|
HWY_ASSERT(IsNear(0.001, stats.WeightedAverageL1()));
|
||||||
|
|
||||||
|
// Now add a large difference:
|
||||||
|
stats.Notify(1.875f - 0.0625f, 1.875f); // max magnitude, 3-bit mantissa
|
||||||
|
// .. WeightedAverageL1 is closer to it.
|
||||||
|
HWY_ASSERT(IsInside(0.020, 0.025, stats.WeightedAverageL1()));
|
||||||
|
|
||||||
|
// Add a small and large difference:
|
||||||
|
stats.Notify((1.75f - 0.125f) / 1024, 1.75f / 1024); // small, 2-bit mantissa
|
||||||
|
stats.Notify(-1.875f + 0.0625f, -1.875f); // larger negative
|
||||||
|
// .. SNR is still barely affected.
|
||||||
|
HWY_ASSERT(IsInside(890.0, 900.0, stats.GeomeanValueDivL1()));
|
||||||
|
// .. WeightedAverageL1 is higher after another large error.
|
||||||
|
HWY_ASSERT(IsInside(0.030, 0.035, stats.WeightedAverageL1()));
|
||||||
|
|
||||||
|
// With these inputs, none are exact nor round to zero.
|
||||||
|
HWY_ASSERT(stats.NumExact() == 0);
|
||||||
|
HWY_ASSERT(stats.NumRoundedToZero() == 0);
|
||||||
|
HWY_ASSERT_EQ(0.0, stats.SumL1Rounded());
|
||||||
|
HWY_ASSERT(IsInside(0.220, 0.23, stats.SumL1()));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
HWY_TEST_MAIN();
|
||||||
|
|
@ -0,0 +1,121 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Safe to be first, does not include POSIX headers.
|
||||||
|
#include "hwy/detect_compiler_arch.h"
|
||||||
|
// Only compile this file on non-Windows; it replaces io_win.cc. It is easier to
|
||||||
|
// check this in source code because we support multiple build systems.
|
||||||
|
#if !HWY_OS_WIN
|
||||||
|
|
||||||
|
// Request POSIX 2008, including `pread()` and `posix_fadvise()`.
|
||||||
|
#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700
|
||||||
|
#undef _XOPEN_SOURCE
|
||||||
|
#define _XOPEN_SOURCE 700
|
||||||
|
#endif
|
||||||
|
#if !defined(_POSIX_C_SOURCE) || _POSIX_C_SOURCE < 200809
|
||||||
|
#define _POSIX_C_SOURCE 200809
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Make `off_t` 64-bit even on 32-bit systems. Works for Android >= r15c.
|
||||||
|
#undef _FILE_OFFSET_BITS
|
||||||
|
#define _FILE_OFFSET_BITS 64
|
||||||
|
|
||||||
|
#include <fcntl.h> // open
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
|
||||||
|
#include <sys/stat.h> // O_RDONLY
|
||||||
|
#include <unistd.h> // read, write, close
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "compression/io.h"
|
||||||
|
#include "hwy/base.h" // HWY_ASSERT
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
class FilePosix : public File {
|
||||||
|
int fd_ = 0;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit FilePosix(int fd) : fd_(fd) { HWY_ASSERT(fd > 0); }
|
||||||
|
~FilePosix() override {
|
||||||
|
if (fd_ != 0) {
|
||||||
|
HWY_ASSERT(close(fd_) != -1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t FileSize() const override {
|
||||||
|
static_assert(sizeof(off_t) == 8, "64-bit off_t required");
|
||||||
|
const off_t size = lseek(fd_, 0, SEEK_END);
|
||||||
|
if (size < 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return static_cast<uint64_t>(size);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Read(uint64_t offset, uint64_t size, void* to) const override {
|
||||||
|
uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
|
||||||
|
uint64_t pos = 0;
|
||||||
|
for (;;) {
|
||||||
|
// pread seems to be faster than lseek + read when parallelized.
|
||||||
|
const auto bytes_read = pread(fd_, bytes + pos, size - pos, offset + pos);
|
||||||
|
if (bytes_read <= 0) break;
|
||||||
|
pos += bytes_read;
|
||||||
|
HWY_ASSERT(pos <= size);
|
||||||
|
if (pos == size) break;
|
||||||
|
}
|
||||||
|
return pos == size; // success if managed to read desired size
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Write(const void* from, uint64_t size, uint64_t offset) override {
|
||||||
|
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(from);
|
||||||
|
uint64_t pos = 0;
|
||||||
|
for (;;) {
|
||||||
|
const auto bytes_written =
|
||||||
|
pwrite(fd_, bytes + pos, size - pos, offset + pos);
|
||||||
|
if (bytes_written <= 0) break;
|
||||||
|
pos += bytes_written;
|
||||||
|
HWY_ASSERT(pos <= size);
|
||||||
|
if (pos == size) break;
|
||||||
|
}
|
||||||
|
return pos == size; // success if managed to write desired size
|
||||||
|
}
|
||||||
|
}; // FilePosix
|
||||||
|
|
||||||
|
HWY_MAYBE_UNUSED extern std::unique_ptr<File> OpenFileGoogle(
|
||||||
|
const Path& filename, const char* mode);
|
||||||
|
|
||||||
|
std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode) {
|
||||||
|
std::unique_ptr<File> file; // OpenFileGoogle omitted
|
||||||
|
if (file) return file;
|
||||||
|
|
||||||
|
const bool is_read = mode[0] != 'w';
|
||||||
|
const int flags = is_read ? O_RDONLY : O_CREAT | O_RDWR | O_TRUNC;
|
||||||
|
const int fd = open(filename.path.c_str(), flags, 0644);
|
||||||
|
if (fd < 0) return file;
|
||||||
|
|
||||||
|
#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
|
||||||
|
if (is_read) {
|
||||||
|
// Doubles the readahead window, which seems slightly faster when cached.
|
||||||
|
(void)posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return std::make_unique<FilePosix>(fd);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
#endif // !HWY_OS_WIN
|
||||||
|
|
@ -0,0 +1,88 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility> // std::move
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
// Forward-declare to break the circular dependency: OpenFileOrNull returns
|
||||||
|
// File and has a Path argument, and Path::Exists calls OpenFileOrNull. We
|
||||||
|
// prefer to define Exists inline because there are multiple io*.cc files.
|
||||||
|
struct Path;
|
||||||
|
|
||||||
|
// Abstract base class enables multiple I/O backends in the same binary.
|
||||||
|
class File {
|
||||||
|
public:
|
||||||
|
File() = default;
|
||||||
|
virtual ~File() = default;
|
||||||
|
|
||||||
|
// Noncopyable.
|
||||||
|
File(const File& other) = delete;
|
||||||
|
const File& operator=(const File& other) = delete;
|
||||||
|
|
||||||
|
// Returns size in bytes or 0.
|
||||||
|
virtual uint64_t FileSize() const = 0;
|
||||||
|
|
||||||
|
// Returns true if all the requested bytes were read.
|
||||||
|
virtual bool Read(uint64_t offset, uint64_t size, void* to) const = 0;
|
||||||
|
|
||||||
|
// Returns true if all the requested bytes were written.
|
||||||
|
virtual bool Write(const void* from, uint64_t size, uint64_t offset) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Returns nullptr on failure. `mode` is either "r" or "w+". This is not just
|
||||||
|
// named 'OpenFile' to avoid a conflict with Windows.h #define.
|
||||||
|
std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode);
|
||||||
|
|
||||||
|
// Wrapper for strings representing a path name. Differentiates vs. arbitrary
|
||||||
|
// strings and supports shortening for display purposes.
|
||||||
|
struct Path {
|
||||||
|
Path() {}
|
||||||
|
explicit Path(const char* p) : path(p) {}
|
||||||
|
explicit Path(std::string p) : path(std::move(p)) {}
|
||||||
|
|
||||||
|
Path& operator=(const char* other) {
|
||||||
|
path = other;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Shortened() const {
|
||||||
|
constexpr size_t kMaxLen = 48;
|
||||||
|
constexpr size_t kCutPoint = kMaxLen / 2 - 5;
|
||||||
|
if (path.size() > kMaxLen) {
|
||||||
|
return std::string(begin(path), begin(path) + kCutPoint) + " ... " +
|
||||||
|
std::string(end(path) - kCutPoint, end(path));
|
||||||
|
}
|
||||||
|
if (path.empty()) return "[no path specified]";
|
||||||
|
return path;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns whether the file existed when this was called.
|
||||||
|
bool Exists() const { return !!OpenFileOrNull(*this, "r"); }
|
||||||
|
|
||||||
|
std::string path;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_
|
||||||
|
|
@ -0,0 +1,115 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "hwy/detect_compiler_arch.h"
|
||||||
|
// Only compile this file on Windows; it replaces io.cc. It is easier to check
|
||||||
|
// this in source code because we support multiple build systems.
|
||||||
|
#if HWY_OS_WIN
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "compression/io.h"
|
||||||
|
#include "hwy/base.h" // HWY_ASSERT
|
||||||
|
#ifndef WIN32_LEAN_AND_MEAN
|
||||||
|
#define WIN32_LEAN_AND_MEAN
|
||||||
|
#endif
|
||||||
|
#ifndef VC_EXTRALEAN
|
||||||
|
#define VC_EXTRALEAN
|
||||||
|
#endif
|
||||||
|
#include <Windows.h>
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
class FileWin : public File {
|
||||||
|
HANDLE hFile_ = INVALID_HANDLE_VALUE;
|
||||||
|
|
||||||
|
public:
|
||||||
|
FileWin(HANDLE hFile) : hFile_(hFile) {
|
||||||
|
HWY_ASSERT(hFile != INVALID_HANDLE_VALUE);
|
||||||
|
}
|
||||||
|
~FileWin() override {
|
||||||
|
if (hFile_ != INVALID_HANDLE_VALUE) {
|
||||||
|
HWY_ASSERT(CloseHandle(hFile_) != 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t FileSize() const override {
|
||||||
|
DWORD hi;
|
||||||
|
const DWORD lo = GetFileSize(hFile_, &hi);
|
||||||
|
if (lo == INVALID_FILE_SIZE) return 0;
|
||||||
|
return (static_cast<uint64_t>(hi) << 32) | lo;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Read(uint64_t offset, uint64_t size, void* to) const override {
|
||||||
|
uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
|
||||||
|
OVERLAPPED overlapped = {0};
|
||||||
|
// Loop is required because ReadFile[Ex] size argument is 32-bit.
|
||||||
|
while (size != 0) {
|
||||||
|
overlapped.Offset = offset & 0xFFFFFFFF;
|
||||||
|
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
|
||||||
|
const DWORD want =
|
||||||
|
static_cast<DWORD>(HWY_MIN(size, uint64_t{0xFFFFFFFF}));
|
||||||
|
DWORD got;
|
||||||
|
if (!ReadFile(hFile_, bytes, want, &got, &overlapped)) {
|
||||||
|
if (GetLastError() != ERROR_HANDLE_EOF) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
offset += got;
|
||||||
|
bytes += got;
|
||||||
|
size -= got;
|
||||||
|
}
|
||||||
|
return true; // read everything => success
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Write(const void* from, uint64_t size, uint64_t offset) override {
|
||||||
|
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(from);
|
||||||
|
OVERLAPPED overlapped = {0};
|
||||||
|
// Loop is required because WriteFile[Ex] size argument is 32-bit.
|
||||||
|
while (size != 0) {
|
||||||
|
overlapped.Offset = offset & 0xFFFFFFFF;
|
||||||
|
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
|
||||||
|
const DWORD want =
|
||||||
|
static_cast<DWORD>(HWY_MIN(size, uint64_t{0xFFFFFFFF}));
|
||||||
|
DWORD got;
|
||||||
|
if (!WriteFile(hFile_, bytes, want, &got, &overlapped)) {
|
||||||
|
if (GetLastError() != ERROR_HANDLE_EOF) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
offset += got;
|
||||||
|
bytes += got;
|
||||||
|
size -= got;
|
||||||
|
}
|
||||||
|
return true; // wrote everything => success
|
||||||
|
}
|
||||||
|
}; // FileWin
|
||||||
|
|
||||||
|
std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode) {
|
||||||
|
const bool is_read = mode[0] != 'w';
|
||||||
|
const DWORD flags =
|
||||||
|
FILE_ATTRIBUTE_NORMAL | (is_read ? FILE_FLAG_SEQUENTIAL_SCAN : 0);
|
||||||
|
const DWORD access = is_read ? GENERIC_READ : GENERIC_WRITE;
|
||||||
|
const DWORD share = is_read ? FILE_SHARE_READ : 0;
|
||||||
|
const DWORD create = is_read ? OPEN_EXISTING : CREATE_ALWAYS;
|
||||||
|
const HANDLE hFile = CreateFileA(filename.path.c_str(), access, share,
|
||||||
|
nullptr, create, flags, nullptr);
|
||||||
|
if (hFile == INVALID_HANDLE_VALUE) return std::unique_ptr<File>();
|
||||||
|
return std::make_unique<FileWin>(hFile);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
#endif // HWY_OS_WIN
|
||||||
|
|
@ -20,9 +20,7 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/nuq.h"
|
#include "compression/nuq.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/sfp.h"
|
#include "compression/sfp.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
||||||
|
|
@ -37,7 +35,6 @@
|
||||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE
|
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/sfp-inl.h"
|
#include "compression/sfp-inl.h"
|
||||||
#include "hwy/contrib/sort/vqsort-inl.h"
|
#include "hwy/contrib/sort/vqsort-inl.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
|
@ -107,9 +104,8 @@ class NuqClustering {
|
||||||
// Callers are responsible for ignoring lanes where last < first.
|
// Callers are responsible for ignoring lanes where last < first.
|
||||||
HWY_DASSERT(first < kGroupSize);
|
HWY_DASSERT(first < kGroupSize);
|
||||||
HWY_DASSERT(last < kGroupSize);
|
HWY_DASSERT(last < kGroupSize);
|
||||||
const size_t len = last - first + 1;
|
const int len = static_cast<int>(last) - static_cast<int>(first) + 1;
|
||||||
const hn::Vec<DF> vlen =
|
const hn::Vec<DF> vlen = hn::Iota(df, static_cast<float>(len));
|
||||||
hn::Iota(df, static_cast<float>(static_cast<int>(len)));
|
|
||||||
|
|
||||||
const hn::Vec<DF> u_lo = hn::Set(df, cumsum_[first]);
|
const hn::Vec<DF> u_lo = hn::Set(df, cumsum_[first]);
|
||||||
const hn::Vec<DF> u_lo2 = hn::Set(df, cumsum2_[first]);
|
const hn::Vec<DF> u_lo2 = hn::Set(df, cumsum2_[first]);
|
||||||
|
|
@ -207,7 +203,7 @@ class NuqClustering {
|
||||||
for (size_t num_clusters = 1; num_clusters < kClusters; ++num_clusters) {
|
for (size_t num_clusters = 1; num_clusters < kClusters; ++num_clusters) {
|
||||||
// For each batch starting at `last`, one per lane:
|
// For each batch starting at `last`, one per lane:
|
||||||
for (size_t last = 0; last < kGroupSize; last += N) {
|
for (size_t last = 0; last < kGroupSize; last += N) {
|
||||||
VF min = cc(df, 0, last);
|
VF min = hn::LoadU(df, &D(0, last));
|
||||||
VI arg = hn::Zero(di);
|
VI arg = hn::Zero(di);
|
||||||
// For each j (start of rightmost cluster):
|
// For each j (start of rightmost cluster):
|
||||||
VI vj = k1;
|
VI vj = k1;
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@
|
||||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include "compression/nuq.h"
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
@ -25,8 +27,10 @@
|
||||||
#include <algorithm> // std::shuffle
|
#include <algorithm> // std::shuffle
|
||||||
#include <random>
|
#include <random>
|
||||||
|
|
||||||
|
#include "compression/test_util.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/tests/test_util.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
|
@ -35,12 +39,7 @@
|
||||||
// clang-format on
|
// clang-format on
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
// Other headers that include Highway must come after foreach_target.h
|
// Other headers that include Highway must come after foreach_target.h
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/nuq-inl.h"
|
#include "compression/nuq-inl.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/nuq.h"
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/test_util.h"
|
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/tests/hwy_gtest.h"
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
#include "hwy/tests/test_util-inl.h"
|
#include "hwy/tests/test_util-inl.h"
|
||||||
|
|
@ -117,12 +116,16 @@ struct TestPlateaus {
|
||||||
HWY_ASSERT(indices[i] < kClusters);
|
HWY_ASSERT(indices[i] < kClusters);
|
||||||
stats.Notify(in[i], centers[indices[i]]);
|
stats.Notify(in[i], centers[indices[i]]);
|
||||||
}
|
}
|
||||||
const float pnorm = stats.PNorm();
|
// Zero error.
|
||||||
const float snr = stats.GeomeanValueDivL1();
|
HWY_ASSERT_EQ(kGroupSize, stats.NumExact());
|
||||||
fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr,
|
HWY_ASSERT_EQ(0, stats.NumSignFlip());
|
||||||
stats.MaxIndex(), stats.MaxL1());
|
HWY_ASSERT_EQ(0, stats.NumRoundedToZero());
|
||||||
HWY_ASSERT(pnorm == 0.0f);
|
HWY_ASSERT_EQ(0.0, stats.SumL1());
|
||||||
HWY_ASSERT(snr == 0.0f);
|
HWY_ASSERT_EQ(0.0f, stats.GeomeanValueDivL1());
|
||||||
|
HWY_ASSERT_EQ(0.0f, stats.WeightedAverageL1());
|
||||||
|
// Input was symmetric and zero-mean.
|
||||||
|
HWY_ASSERT(gcpp::IsInside(-0.05, 0.05, stats.Original().Mean()));
|
||||||
|
HWY_ASSERT(gcpp::IsNear(0.0, stats.Original().Skewness()));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -160,16 +163,19 @@ struct TestRamp {
|
||||||
HWY_ASSERT(indices[i] < kClusters);
|
HWY_ASSERT(indices[i] < kClusters);
|
||||||
stats.Notify(in[i], centers[indices[i]]);
|
stats.Notify(in[i], centers[indices[i]]);
|
||||||
}
|
}
|
||||||
const float pnorm = stats.PNorm();
|
|
||||||
const float snr = stats.GeomeanValueDivL1();
|
|
||||||
fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr,
|
|
||||||
stats.MaxIndex(), stats.MaxL1());
|
|
||||||
static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected");
|
|
||||||
|
|
||||||
const float expected_pnorm = kGroupSize == 128 ? 2.08E-2f : 2.1E-2f;
|
// Low error.
|
||||||
const float expected_snr = kGroupSize == 128 ? 16.9f : 17.6f;
|
HWY_ASSERT_EQ(0, stats.NumExact());
|
||||||
HWY_ASSERT(expected_pnorm <= pnorm && pnorm < 1.02f * expected_pnorm);
|
HWY_ASSERT(stats.NumSignFlip() < 10);
|
||||||
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr);
|
HWY_ASSERT_EQ(0, stats.NumRoundedToZero());
|
||||||
|
HWY_ASSERT_EQ(kGroupSize / kClusters / 4.0, stats.SumL1());
|
||||||
|
HWY_ASSERT(gcpp::IsInside(17.0, 18.0, stats.GeomeanValueDivL1()));
|
||||||
|
HWY_ASSERT(gcpp::IsInside(0.005, 0.010, stats.WeightedAverageL1()));
|
||||||
|
HWY_ASSERT(stats.L1().Max() <= 0.04f);
|
||||||
|
// Input was symmetric about 0.05.
|
||||||
|
HWY_ASSERT(gcpp::IsNear(0.05, stats.Original().Mean(), 0.01));
|
||||||
|
HWY_ASSERT(gcpp::IsNear(0.0, stats.Original().Skewness(), 1E-4));
|
||||||
|
static_assert(kGroupSize == 256, "Update expected");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -210,15 +216,16 @@ struct TestNormal {
|
||||||
HWY_ASSERT(indices[i] < kClusters);
|
HWY_ASSERT(indices[i] < kClusters);
|
||||||
stats.Notify(in[i], centers[indices[i]]);
|
stats.Notify(in[i], centers[indices[i]]);
|
||||||
}
|
}
|
||||||
const float pnorm = stats.PNorm();
|
|
||||||
const float snr = stats.GeomeanValueDivL1();
|
// Moderate error.
|
||||||
fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr,
|
HWY_ASSERT_EQ(0, stats.NumExact());
|
||||||
stats.MaxIndex(), stats.MaxL1());
|
HWY_ASSERT(stats.NumSignFlip() < kGroupSize / kClusters);
|
||||||
|
HWY_ASSERT_EQ(0, stats.NumRoundedToZero());
|
||||||
|
HWY_ASSERT(gcpp::IsInside(5.0, 6.0, stats.SumL1()));
|
||||||
|
HWY_ASSERT(gcpp::IsInside(12.7, 12.8, stats.GeomeanValueDivL1()));
|
||||||
|
HWY_ASSERT(gcpp::IsInside(0.036, 0.037, stats.WeightedAverageL1()));
|
||||||
|
HWY_ASSERT(stats.L1().Max() <= 0.10f);
|
||||||
static_assert(kGroupSize == 256, "Update expected");
|
static_assert(kGroupSize == 256, "Update expected");
|
||||||
const float expected_pnorm = 3.68E-2f;
|
|
||||||
const float expected_snr = 12.7f;
|
|
||||||
HWY_ASSERT(expected_pnorm <= pnorm && pnorm < 1.02f * expected_pnorm);
|
|
||||||
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -238,10 +245,9 @@ struct TestOffset {
|
||||||
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
|
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
|
||||||
HWY_ASSERT(in && dec1 && dec2 && nuq);
|
HWY_ASSERT(in && dec1 && dec2 && nuq);
|
||||||
|
|
||||||
std::mt19937 rng(123);
|
hwy::RandomState rng;
|
||||||
std::normal_distribution<float> dist{0.001f, 0.3f};
|
|
||||||
for (size_t i = 0; i < total; ++i) {
|
for (size_t i = 0; i < total; ++i) {
|
||||||
in[i] = dist(rng);
|
in[i] = static_cast<float>(RandomGaussian(rng));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encode + decode everything
|
// Encode + decode everything
|
||||||
|
|
@ -281,11 +287,13 @@ struct TestStream {
|
||||||
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(num));
|
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(num));
|
||||||
HWY_ASSERT(in && out && nuq);
|
HWY_ASSERT(in && out && nuq);
|
||||||
|
|
||||||
std::mt19937 rng(123);
|
hwy::RandomState rng;
|
||||||
std::normal_distribution<float> dist{0.001f, 0.3f};
|
Stats in_stats;
|
||||||
for (size_t i = 0; i < num; ++i) {
|
for (size_t i = 0; i < num; ++i) {
|
||||||
in[i] = dist(rng);
|
in[i] = static_cast<float>(RandomGaussian(rng));
|
||||||
|
in_stats.Notify(in[i]);
|
||||||
}
|
}
|
||||||
|
VerifyGaussian(in_stats);
|
||||||
|
|
||||||
ClusterBuf buf;
|
ClusterBuf buf;
|
||||||
double elapsed = hwy::HighestValue<double>();
|
double elapsed = hwy::HighestValue<double>();
|
||||||
|
|
@ -314,15 +322,16 @@ struct TestStream {
|
||||||
for (size_t i = 0; i < num; ++i) {
|
for (size_t i = 0; i < num; ++i) {
|
||||||
stats.Notify(in[i], hwy::ConvertScalarTo<float>(out[i]));
|
stats.Notify(in[i], hwy::ConvertScalarTo<float>(out[i]));
|
||||||
}
|
}
|
||||||
const float pnorm = stats.PNorm();
|
|
||||||
const float snr = stats.GeomeanValueDivL1();
|
// Moderate error.
|
||||||
fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr,
|
HWY_ASSERT_EQ(0, stats.NumExact());
|
||||||
stats.MaxIndex(), stats.MaxL1());
|
HWY_ASSERT(stats.NumSignFlip() < num / kClusters);
|
||||||
static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected");
|
HWY_ASSERT_EQ(0, stats.NumRoundedToZero());
|
||||||
const float expected_pnorm = kGroupSize == 128 ? 3.44E-2f : 3.88E-2f;
|
HWY_ASSERT(gcpp::IsInside(23.0, 24.0, stats.SumL1()));
|
||||||
const float expected_snr = kGroupSize == 128 ? 15.0f : 13.3f;
|
HWY_ASSERT(gcpp::IsInside(13.0, 13.3, stats.GeomeanValueDivL1()));
|
||||||
HWY_ASSERT(expected_pnorm <= pnorm && pnorm < 1.02f * expected_pnorm);
|
HWY_ASSERT(gcpp::IsInside(0.034, 0.035, stats.WeightedAverageL1()));
|
||||||
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr);
|
HWY_ASSERT(stats.L1().Max() <= 0.11f);
|
||||||
|
static_assert(kGroupSize == 256, "Update expected");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -351,9 +360,8 @@ struct TestDot {
|
||||||
hwy::RandomState rng;
|
hwy::RandomState rng;
|
||||||
Stats in_stats;
|
Stats in_stats;
|
||||||
for (size_t i = 0; i < num; ++i) {
|
for (size_t i = 0; i < num; ++i) {
|
||||||
const float r = static_cast<float>(RandomGaussian(rng));
|
in[i] = static_cast<float>(RandomGaussian(rng));
|
||||||
in_stats.Notify(r);
|
in_stats.Notify(in[i]);
|
||||||
in[i] = r;
|
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < num; ++i) {
|
for (size_t i = 0; i < num; ++i) {
|
||||||
const float r = static_cast<float>(RandomGaussian(rng));
|
const float r = static_cast<float>(RandomGaussian(rng));
|
||||||
|
|
@ -368,7 +376,7 @@ struct TestDot {
|
||||||
HWY_ASSERT(unused_clusters == 0);
|
HWY_ASSERT(unused_clusters == 0);
|
||||||
|
|
||||||
// Compute dot product without decompression.
|
// Compute dot product without decompression.
|
||||||
double actual = 0.0;
|
float actual = 0.0f;
|
||||||
double elapsed = hwy::HighestValue<double>();
|
double elapsed = hwy::HighestValue<double>();
|
||||||
for (size_t rep = 0; rep < 20; ++rep) {
|
for (size_t rep = 0; rep < 20; ++rep) {
|
||||||
hn::Vec<decltype(df)> sum0 = hn::Zero(df);
|
hn::Vec<decltype(df)> sum0 = hn::Zero(df);
|
||||||
|
|
@ -389,8 +397,8 @@ struct TestDot {
|
||||||
num * sizeof(in[0]) * 1E-6 / elapsed);
|
num * sizeof(in[0]) * 1E-6 / elapsed);
|
||||||
|
|
||||||
// Exact and decompressed dot products for comparison.
|
// Exact and decompressed dot products for comparison.
|
||||||
double exact = 0.0; // using original input
|
float exact = 0.0f; // using original input
|
||||||
double expected = 0.0; // using decoded NUQ
|
float expected = 0.0f; // using decoded NUQ
|
||||||
DistortionStats dec_stats;
|
DistortionStats dec_stats;
|
||||||
Stats ratios;
|
Stats ratios;
|
||||||
for (size_t i = 0; i < num; ++i) {
|
for (size_t i = 0; i < num; ++i) {
|
||||||
|
|
@ -402,24 +410,42 @@ struct TestDot {
|
||||||
ratios.Notify(exact / expected);
|
ratios.Notify(exact / expected);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
const bool isBF = sizeof(T) == 2;
|
||||||
const double dec_snr = dec_stats.GeomeanValueDivL1();
|
const double dec_snr = dec_stats.GeomeanValueDivL1();
|
||||||
|
const double dec_wl1 = dec_stats.WeightedAverageL1();
|
||||||
const double dot_snr = 1.0 / hwy::ScalarAbs(1.0 - ratios.GeometricMean());
|
const double dot_snr = 1.0 / hwy::ScalarAbs(1.0 - ratios.GeometricMean());
|
||||||
// exact and actual fluctuate due to the combination of NUQ imprecision,
|
// exact and actual fluctuate due to the combination of NUQ imprecision,
|
||||||
// and whether vec[i] is negative or positive, so this is quite loose.
|
// and whether vec[i] is negative or positive, so this is quite loose.
|
||||||
const float final_ratio = HWY_MIN(exact / actual, actual / exact);
|
const float final_ratio = HWY_MIN(exact / actual, actual / exact);
|
||||||
|
if (HWY_ONCE) {
|
||||||
fprintf(stderr, "ratios %s\n", ratios.ToString().c_str());
|
fprintf(stderr, "ratios %s\n", ratios.ToString().c_str());
|
||||||
fprintf(stderr,
|
fprintf(stderr,
|
||||||
"exact %.3f e2 %.4f actual %.4f final_ratio %.3f dec_snr %.2f "
|
"exact %.3f e2 %.4f actual %.4f final_ratio %.3f dec_snr %.2f "
|
||||||
"dot_snr %.2f\n",
|
"dot_snr %.2f dec_wl1 %.4f\n",
|
||||||
exact, expected, actual, final_ratio, dec_snr, dot_snr);
|
exact, expected, actual, final_ratio, dec_snr, dot_snr, dec_wl1);
|
||||||
|
}
|
||||||
// Final values are not too far apart.
|
// Final values are not too far apart.
|
||||||
HWY_ASSERT(0.88f <= final_ratio && final_ratio <= 1.0f);
|
HWY_ASSERT(gcpp::IsInside(0.88f, 1.0f, final_ratio));
|
||||||
// Decompressed and uncompressed dot should match exactly.
|
// Decompressed and uncompressed dot should match exactly.
|
||||||
HWY_ASSERT(hwy::ScalarAbs(expected - actual) < 1E-4f);
|
HWY_ASSERT(gcpp::IsNear(expected, actual, 1E-4f));
|
||||||
// dec[] is close to in[], but we already check that in TestStream.
|
// Geomean of ratios for each i should be very close to one.
|
||||||
HWY_ASSERT(dec_snr >= 13.0);
|
HWY_ASSERT(dot_snr >= (isBF ? 17.7 : 14.3));
|
||||||
// Geomean of ratios for each i is an approximation of the actual SNR.
|
|
||||||
HWY_ASSERT(dot_snr >= (sizeof(T) == 2 ? 17.0 : 14.0));
|
// dec[] is close to in[], but we already check that in TestStream with the
|
||||||
|
// same input distribution.
|
||||||
|
HWY_ASSERT(gcpp::IsNear(13.1, dec_snr, 0.1));
|
||||||
|
HWY_ASSERT(gcpp::IsNear(0.034, dec_wl1, 0.001));
|
||||||
|
HWY_ASSERT(gcpp::IsNear(23.5, dec_stats.SumL1(), 0.1));
|
||||||
|
HWY_ASSERT(dec_stats.NumSignFlip() < num / kClusters);
|
||||||
|
HWY_ASSERT_EQ(0, dec_stats.NumExact());
|
||||||
|
HWY_ASSERT_EQ(0, dec_stats.NumRoundedToZero());
|
||||||
|
HWY_ASSERT_EQ(0.0, dec_stats.SumL1Rounded());
|
||||||
|
// Absolute decode errors are in [0, 0.11], and somewhat right-tailed.
|
||||||
|
HWY_ASSERT(gcpp::IsInside(0.0f, 2E-5f, dec_stats.L1().Min()));
|
||||||
|
HWY_ASSERT(gcpp::IsInside(0.09f, 0.11f, dec_stats.L1().Max()));
|
||||||
|
HWY_ASSERT(gcpp::IsInside(0.02, 0.03, dec_stats.L1().Mean()));
|
||||||
|
HWY_ASSERT(gcpp::IsInside(1.0, 1.1, dec_stats.L1().Skewness()));
|
||||||
|
HWY_ASSERT(gcpp::IsInside(4.0, 5.0, dec_stats.L1().Kurtosis()));
|
||||||
static_assert(kGroupSize == 256, "Update expected*");
|
static_assert(kGroupSize == 256, "Update expected*");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/sfp.h"
|
#include "compression/sfp.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,6 @@
|
||||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/sfp.h"
|
#include "compression/sfp.h"
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
@ -27,6 +26,7 @@
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
|
|
||||||
|
#include "compression/test_util.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
|
|
@ -37,10 +37,7 @@
|
||||||
// clang-format on
|
// clang-format on
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
// Any highway.h must come after foreach_target.h
|
// Any highway.h must come after foreach_target.h
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/sfp-inl.h"
|
#include "compression/sfp-inl.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/test_util.h"
|
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/tests/hwy_gtest.h"
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
#include "hwy/tests/test_util-inl.h"
|
#include "hwy/tests/test_util-inl.h"
|
||||||
|
|
@ -307,10 +304,37 @@ struct TestEncDec {
|
||||||
sum += hwy::ConvertScalarTo<double>(hwy::ScalarAbs(in[i]));
|
sum += hwy::ConvertScalarTo<double>(hwy::ScalarAbs(in[i]));
|
||||||
stats.Notify(hwy::ConvertScalarTo<float>(in[i]), out);
|
stats.Notify(hwy::ConvertScalarTo<float>(in[i]), out);
|
||||||
}
|
}
|
||||||
const double avg = sum / num;
|
const double avg_in = sum / num;
|
||||||
fprintf(stderr, "Avg magnitude %.3E, p-norm %.3E snr %.2f @%zu = %.4E\n",
|
const double snr = stats.GeomeanValueDivL1();
|
||||||
avg, stats.PNorm(), stats.GeomeanValueDivL1(), stats.MaxIndex(),
|
const double wl1 = stats.WeightedAverageL1();
|
||||||
stats.MaxL1());
|
if (false) {
|
||||||
|
fprintf(stderr,
|
||||||
|
"Num inputs %zu, avg %.3E, exact %zu round0 %zu (sum %E) snr "
|
||||||
|
"%.2f wL1 %f\n",
|
||||||
|
num, avg_in, stats.NumExact(), stats.NumRoundedToZero(),
|
||||||
|
stats.SumL1Rounded(), snr, wl1);
|
||||||
|
}
|
||||||
|
HWY_ASSERT(stats.Original().Count() == stats.L1().Count());
|
||||||
|
// Inputs are in [-1.875, 1.875], symmetric, and heavy-tailed.
|
||||||
|
HWY_ASSERT(stats.Original().Min() == -1.875f);
|
||||||
|
HWY_ASSERT(stats.Original().Max() == 1.875f);
|
||||||
|
HWY_ASSERT(gcpp::IsInside(-1E-6, 1E-6, stats.Original().Mean()));
|
||||||
|
HWY_ASSERT(gcpp::IsInside(-1E-6, 1E-6, stats.Original().Skewness()));
|
||||||
|
HWY_ASSERT(gcpp::IsInside(80.0, 100.0, stats.Original().Kurtosis()));
|
||||||
|
// Absolute errors are in [0, 0.0625], and (heavy) right-tailed.
|
||||||
|
HWY_ASSERT(stats.L1().Min() == 0.0f);
|
||||||
|
HWY_ASSERT(stats.L1().Max() == 0.0625f);
|
||||||
|
HWY_ASSERT(gcpp::IsInside(4E-4, 5E-4, stats.L1().Mean()));
|
||||||
|
HWY_ASSERT(gcpp::IsInside(10.0, 15.0, stats.L1().Skewness()));
|
||||||
|
HWY_ASSERT(gcpp::IsInside(150.0, 200.0, stats.L1().Kurtosis()));
|
||||||
|
// SNR is low because many *tiny* numbers are rounded to zero.
|
||||||
|
HWY_ASSERT_EQ(3322, stats.NumRoundedToZero());
|
||||||
|
HWY_ASSERT(gcpp::IsInside(5E-6, 6E-6, stats.SumL1Rounded()));
|
||||||
|
HWY_ASSERT(gcpp::IsInside(1.880, 1.885, stats.SumL1()));
|
||||||
|
HWY_ASSERT_EQ(256, stats.NumExact());
|
||||||
|
HWY_ASSERT_EQ(0, stats.NumSignFlip());
|
||||||
|
HWY_ASSERT(gcpp::IsInside(2.70, 2.75, snr));
|
||||||
|
HWY_ASSERT(gcpp::IsInside(0.010, 0.011, wl1)); // = half of mean |x|.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -381,7 +405,7 @@ struct TestDot {
|
||||||
SfpCodec::Enc(d, in.get(), num, sfp.get());
|
SfpCodec::Enc(d, in.get(), num, sfp.get());
|
||||||
|
|
||||||
// Compute dot product without decompression.
|
// Compute dot product without decompression.
|
||||||
double actual = 0.0;
|
float actual = 0.0f;
|
||||||
double elapsed = hwy::HighestValue<double>();
|
double elapsed = hwy::HighestValue<double>();
|
||||||
for (size_t rep = 0; rep < 200; ++rep) {
|
for (size_t rep = 0; rep < 200; ++rep) {
|
||||||
hn::Vec<decltype(df)> sum0 = hn::Zero(df);
|
hn::Vec<decltype(df)> sum0 = hn::Zero(df);
|
||||||
|
|
@ -417,24 +441,41 @@ struct TestDot {
|
||||||
ratios.Notify(exact / expected);
|
ratios.Notify(exact / expected);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
const bool isBF = sizeof(T) == 2;
|
||||||
const double dec_snr = dec_stats.GeomeanValueDivL1();
|
const double dec_snr = dec_stats.GeomeanValueDivL1();
|
||||||
|
const double dec_wl1 = dec_stats.WeightedAverageL1();
|
||||||
const double dot_snr = 1.0 / hwy::ScalarAbs(1.0 - ratios.GeometricMean());
|
const double dot_snr = 1.0 / hwy::ScalarAbs(1.0 - ratios.GeometricMean());
|
||||||
// exact and actual fluctuate due to the combination of SFP imprecision,
|
// exact and actual fluctuate due to the combination of SFP imprecision,
|
||||||
// and whether vec[i] is negative or positive, so this is quite loose.
|
// and whether vec[i] is negative or positive, so this is quite loose.
|
||||||
const float final_ratio = HWY_MIN(exact / actual, actual / exact);
|
const float final_ratio = HWY_MIN(exact / actual, actual / exact);
|
||||||
|
if (HWY_ONCE) {
|
||||||
fprintf(stderr, "ratios %s\n", ratios.ToString().c_str());
|
fprintf(stderr, "ratios %s\n", ratios.ToString().c_str());
|
||||||
fprintf(stderr,
|
fprintf(stderr,
|
||||||
"exact %.3f e2 %.4f actual %.4f final_ratio %.3f dec_snr %.2f "
|
"exact %.3f e2 %.4f actual %.4f final_ratio %.3f dec_snr %.2f "
|
||||||
"dot_snr %.2f\n",
|
"dot_snr %.2f dec_wl1 %.5f\n",
|
||||||
exact, expected, actual, final_ratio, dec_snr, dot_snr);
|
exact, expected, actual, final_ratio, dec_snr, dot_snr, dec_wl1);
|
||||||
|
}
|
||||||
// Final values are not too far apart.
|
// Final values are not too far apart.
|
||||||
HWY_ASSERT(0.87f <= final_ratio && final_ratio <= 1.0f);
|
HWY_ASSERT(gcpp::IsInside(0.87f, 1.0f, final_ratio));
|
||||||
// Decompressed and uncompressed dot should match exactly.
|
// Decompressed and uncompressed dot should match exactly.
|
||||||
HWY_ASSERT(hwy::ScalarAbs(expected - actual) < 1E-4f);
|
HWY_ASSERT(gcpp::IsNear(expected, actual, 1E-4f));
|
||||||
// dec[] is close to in[], but we already check that in TestEncDec.
|
|
||||||
HWY_ASSERT(dec_snr >= 50.0);
|
|
||||||
// Geomean of ratios for each i should be very close to one.
|
// Geomean of ratios for each i should be very close to one.
|
||||||
HWY_ASSERT(dot_snr >= (sizeof(T) == 2 ? 70.0 : 1000.0));
|
HWY_ASSERT(dot_snr >= (isBF ? 70.0 : 1000.0));
|
||||||
|
|
||||||
|
// dec[] is close to in[]. We also check that in TestEncDec, but for much
|
||||||
|
// smaller input magnitudes.
|
||||||
|
HWY_ASSERT(gcpp::IsNear(isBF ? 51.0 : 64.0, dec_snr, 1.0));
|
||||||
|
HWY_ASSERT(gcpp::IsNear(isBF ? 0.013 : 0.012, dec_wl1, 0.001));
|
||||||
|
HWY_ASSERT(gcpp::IsNear(isBF ? 6.2 : 6.3, dec_stats.SumL1(), 0.1));
|
||||||
|
HWY_ASSERT_EQ(0, dec_stats.NumSignFlip());
|
||||||
|
HWY_ASSERT_EQ(0, dec_stats.NumRoundedToZero());
|
||||||
|
HWY_ASSERT_EQ(0.0, dec_stats.SumL1Rounded());
|
||||||
|
// Absolute decode errors are in [0, 5E-2], and somewhat right-tailed.
|
||||||
|
HWY_ASSERT(gcpp::IsInside(0.0f, 2E-6f, dec_stats.L1().Min()));
|
||||||
|
HWY_ASSERT(gcpp::IsInside(3E-2f, 5E-2f, dec_stats.L1().Max()));
|
||||||
|
HWY_ASSERT(gcpp::IsInside(4E-3, 7E-3, dec_stats.L1().Mean()));
|
||||||
|
HWY_ASSERT(gcpp::IsInside(1.8, 1.9, dec_stats.L1().Skewness()));
|
||||||
|
HWY_ASSERT(gcpp::IsInside(6.0, 7.0, dec_stats.L1().Kurtosis()));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,6 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/stats.h"
|
#include "compression/stats.h"
|
||||||
|
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
|
@ -45,7 +44,7 @@ class Bins {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Print(const char* caption) {
|
void Print(const char* caption) const {
|
||||||
fprintf(stderr, "\n%s [%zu]\n", caption, N);
|
fprintf(stderr, "\n%s [%zu]\n", caption, N);
|
||||||
size_t last_nonzero = 0;
|
size_t last_nonzero = 0;
|
||||||
for (size_t i = N - 1; i < N; --i) {
|
for (size_t i = N - 1; i < N; --i) {
|
||||||
|
|
@ -77,8 +76,8 @@ class Stats {
|
||||||
void Notify(const float x) {
|
void Notify(const float x) {
|
||||||
++n_;
|
++n_;
|
||||||
|
|
||||||
min_ = std::min(min_, x);
|
min_ = HWY_MIN(min_, x);
|
||||||
max_ = std::max(max_, x);
|
max_ = HWY_MAX(max_, x);
|
||||||
|
|
||||||
product_ *= x;
|
product_ *= x;
|
||||||
|
|
||||||
|
|
@ -119,7 +118,7 @@ class Stats {
|
||||||
// Near zero for normal distributions; if positive on a unimodal distribution,
|
// Near zero for normal distributions; if positive on a unimodal distribution,
|
||||||
// the right tail is fatter. Assumes n_ is large.
|
// the right tail is fatter. Assumes n_ is large.
|
||||||
double SampleSkewness() const {
|
double SampleSkewness() const {
|
||||||
if (std::abs(m2_) < 1E-7) return 0.0;
|
if (hwy::ScalarAbs(m2_) < 1E-7) return 0.0;
|
||||||
return m3_ * std::sqrt(static_cast<double>(n_)) / std::pow(m2_, 1.5);
|
return m3_ * std::sqrt(static_cast<double>(n_)) / std::pow(m2_, 1.5);
|
||||||
}
|
}
|
||||||
// Corrected for bias (same as Wikipedia and Minitab but not Excel).
|
// Corrected for bias (same as Wikipedia and Minitab but not Excel).
|
||||||
|
|
@ -132,7 +131,7 @@ class Stats {
|
||||||
// Near zero for normal distributions; smaller values indicate fewer/smaller
|
// Near zero for normal distributions; smaller values indicate fewer/smaller
|
||||||
// outliers and larger indicates more/larger outliers. Assumes n_ is large.
|
// outliers and larger indicates more/larger outliers. Assumes n_ is large.
|
||||||
double SampleKurtosis() const {
|
double SampleKurtosis() const {
|
||||||
if (std::abs(m2_) < 1E-7) return 0.0;
|
if (hwy::ScalarAbs(m2_) < 1E-7) return 0.0;
|
||||||
return m4_ * n_ / (m2_ * m2_);
|
return m4_ * n_ / (m2_ * m2_);
|
||||||
}
|
}
|
||||||
// Corrected for bias (same as Wikipedia and Minitab but not Excel).
|
// Corrected for bias (same as Wikipedia and Minitab but not Excel).
|
||||||
|
|
|
||||||
|
|
@ -24,9 +24,7 @@
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
||||||
// IWYU pragma: begin_exports
|
// IWYU pragma: begin_exports
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/distortion.h"
|
#include "compression/distortion.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/stats.h"
|
#include "compression/stats.h"
|
||||||
#include "hwy/tests/test_util.h" // RandomState
|
#include "hwy/tests/test_util.h" // RandomState
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
|
|
@ -51,12 +49,26 @@ HWY_INLINE double RandomGaussian(hwy::RandomState& rng) {
|
||||||
return plus_minus_1 * std::sqrt(kReps / 3.0);
|
return plus_minus_1 * std::sqrt(kReps / 3.0);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Returns true if val is inside [min, max].
|
||||||
|
template <typename T>
|
||||||
|
static inline bool IsInside(T expected_min, T expected_max, T val) {
|
||||||
|
return expected_min <= val && val <= expected_max;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static inline bool IsNear(T expected, T val, T epsilon = T{1E-6}) {
|
||||||
|
return IsInside(expected - epsilon, expected + epsilon, val);
|
||||||
|
}
|
||||||
|
|
||||||
HWY_INLINE void VerifyGaussian(Stats& stats) {
|
HWY_INLINE void VerifyGaussian(Stats& stats) {
|
||||||
const double stddev = stats.StandardDeviation();
|
// Inputs are roughly [-1, 1] and symmetric about zero.
|
||||||
HWY_ASSERT(-0.01 <= stats.Mean() && stats.Mean() <= 0.01);
|
HWY_ASSERT(IsNear(-1.0f, stats.Min(), 0.10f));
|
||||||
HWY_ASSERT(0.30 <= stddev && stddev <= 0.35);
|
HWY_ASSERT(IsNear(+1.0f, stats.Max(), 0.10f));
|
||||||
HWY_ASSERT(-1.1 <= stats.Min() && stats.Min() <= -0.9);
|
HWY_ASSERT(IsInside(-2E-3, 2E-3, stats.Mean()));
|
||||||
HWY_ASSERT(0.9 <= stats.Max() && stats.Max() <= 1.1);
|
HWY_ASSERT(IsInside(-0.15, 0.15, stats.Skewness()));
|
||||||
|
// Near-Gaussian.
|
||||||
|
HWY_ASSERT(IsInside(0.30, 0.35, stats.StandardDeviation()));
|
||||||
|
HWY_ASSERT(IsNear(3.0, stats.Kurtosis(), 0.3));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,132 @@
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "gemma/gemma.h"
|
||||||
|
#include "nlohmann/json.hpp"
|
||||||
|
#include "util/app.h"
|
||||||
|
#include "util/args.h"
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
|
using json = nlohmann::json;
|
||||||
|
|
||||||
|
class PromptArgs : public gcpp::ArgsBase<PromptArgs> {
|
||||||
|
public:
|
||||||
|
PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
|
|
||||||
|
gcpp::Path layers_output;
|
||||||
|
std::string prompt;
|
||||||
|
|
||||||
|
template <class Visitor>
|
||||||
|
void ForEach(const Visitor& visitor) {
|
||||||
|
visitor(layers_output.path, "layers_output", std::string(""),
|
||||||
|
"Path to store layers output", 2);
|
||||||
|
visitor(prompt, "prompt", std::string(""), "Prompt to the model", 2);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::pair<std::string, int> QueryModel(
|
||||||
|
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
||||||
|
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input,
|
||||||
|
gcpp::LayersOutputT* layers_output) {
|
||||||
|
std::vector<int> prompt;
|
||||||
|
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
|
||||||
|
|
||||||
|
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
||||||
|
// if needed.
|
||||||
|
prompt.insert(prompt.begin(), 2);
|
||||||
|
std::string res;
|
||||||
|
size_t total_tokens = 0;
|
||||||
|
auto accept_token = [](int) { return true; };
|
||||||
|
std::mt19937 gen;
|
||||||
|
gen.seed(42);
|
||||||
|
|
||||||
|
auto stream_token = [&res, &total_tokens, &app,
|
||||||
|
tokenizer = model.Tokenizer()](int token, float) {
|
||||||
|
++total_tokens;
|
||||||
|
std::string token_text;
|
||||||
|
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text));
|
||||||
|
res += token_text;
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
if (app.verbosity >= 2) {
|
||||||
|
std::cout << args.max_tokens << " " << args.max_generated_tokens << " "
|
||||||
|
<< args.temperature;
|
||||||
|
}
|
||||||
|
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||||
|
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
|
||||||
|
stream_token, accept_token, gen, app.verbosity, layers_output);
|
||||||
|
return {res, total_tokens};
|
||||||
|
}
|
||||||
|
|
||||||
|
class OutputJsonLogger {
|
||||||
|
public:
|
||||||
|
json json_output;
|
||||||
|
|
||||||
|
gcpp::LayersOutputT layers_output_log_f =
|
||||||
|
[this](int pos, const std::string& key, const float* values, size_t values_len) {
|
||||||
|
std::vector<float> v{values, values + values_len};
|
||||||
|
json_output[std::to_string(pos)][key] = v;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
/* Run this in the same way as gemma, p.ex.:
|
||||||
|
./debug_prompt --tokenizer tokenizer.spm --model 2b-it --weights \
|
||||||
|
2b-it-sfp.sbs --prompt "..." --layers_output [path]
|
||||||
|
*/
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
gcpp::LoaderArgs loader(argc, argv);
|
||||||
|
gcpp::InferenceArgs args(argc, argv); // inference
|
||||||
|
gcpp::AppArgs app(argc, argv);
|
||||||
|
PromptArgs prompt_args(argc, argv);
|
||||||
|
|
||||||
|
if (const char* error = loader.Validate()) {
|
||||||
|
HWY_ABORT("\nInvalid loader args: %s", error);
|
||||||
|
}
|
||||||
|
if (const char* error = args.Validate()) {
|
||||||
|
HWY_ABORT("\nInvalid inference args: %s", error);
|
||||||
|
}
|
||||||
|
const bool log_layers_output = !prompt_args.layers_output.path.empty();
|
||||||
|
OutputJsonLogger json_logger;
|
||||||
|
gcpp::LayersOutputT* layers_output =
|
||||||
|
log_layers_output ? &json_logger.layers_output_log_f : nullptr;
|
||||||
|
|
||||||
|
hwy::ThreadPool pool(app.num_threads);
|
||||||
|
// For many-core, pinning threads to cores helps.
|
||||||
|
if (app.num_threads > 10) {
|
||||||
|
gcpp::PinThreadToCore(app.num_threads - 1); // Main thread
|
||||||
|
|
||||||
|
pool.Run(0, pool.NumThreads(), [](uint64_t /*task*/, size_t thread) {
|
||||||
|
gcpp::PinThreadToCore(thread);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
|
||||||
|
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||||
|
|
||||||
|
const std::string& prompt = prompt_args.prompt;
|
||||||
|
if (prompt.empty()) {
|
||||||
|
std::cout << "Please specify --prompt" << std::endl;
|
||||||
|
return EXIT_FAILURE;
|
||||||
|
}
|
||||||
|
const auto [answer, token_count] = QueryModel(
|
||||||
|
model, args, app, kv_cache, pool, prompt, layers_output);
|
||||||
|
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
|
||||||
|
|
||||||
|
if (log_layers_output) {
|
||||||
|
std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out);
|
||||||
|
if (!output_f) {
|
||||||
|
std::cout << "Opening file failed" << std::endl;
|
||||||
|
return EXIT_FAILURE;
|
||||||
|
}
|
||||||
|
output_f << json_logger.json_output.dump();
|
||||||
|
if (!output_f) {
|
||||||
|
std::cout << "Writing to file failed" << std::endl;
|
||||||
|
return EXIT_FAILURE;
|
||||||
|
}
|
||||||
|
output_f.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
return EXIT_SUCCESS;
|
||||||
|
}
|
||||||
|
|
@ -15,12 +15,9 @@
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
#include "third_party/gemma_cpp/gemma.h"
|
||||||
#include "gemma.h"
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/app.h" // LoaderArgs
|
#include "util/app.h" // LoaderArgs
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
|
|
||||||
std::vector<int> tokenize(const std::string& prompt_string,
|
std::vector<int> tokenize(const std::string& prompt_string,
|
||||||
|
|
|
||||||
|
|
@ -8,16 +8,13 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "nlohmann/json.hpp"
|
#include "nlohmann/json.hpp"
|
||||||
// copybara:import_next_line:gemma_cpp
|
#include "gemma/gemma.h"
|
||||||
#include "gemma.h"
|
#include "util/app.h"
|
||||||
|
#include "util/args.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/app.h"
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/args.h"
|
|
||||||
|
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
|
|
||||||
|
|
@ -61,8 +58,7 @@ void LogSpeedStats(const double time_start, size_t total_tokens) {
|
||||||
|
|
||||||
std::pair<std::string, int> QueryModel(
|
std::pair<std::string, int> QueryModel(
|
||||||
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
||||||
gcpp::KVCache& kv_cache, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input) {
|
||||||
const std::string& input) {
|
|
||||||
std::vector<int> prompt;
|
std::vector<int> prompt;
|
||||||
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
|
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
|
||||||
|
|
||||||
|
|
@ -93,7 +89,7 @@ std::pair<std::string, int> QueryModel(
|
||||||
}
|
}
|
||||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||||
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
|
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
|
||||||
inner_pool, stream_token, accept_token, gen, app.verbosity);
|
stream_token, accept_token, gen, app.verbosity);
|
||||||
if (app.verbosity >= 1) {
|
if (app.verbosity >= 1) {
|
||||||
LogSpeedStats(time_start, total_tokens);
|
LogSpeedStats(time_start, total_tokens);
|
||||||
}
|
}
|
||||||
|
|
@ -134,8 +130,7 @@ std::string ReadFile(const gcpp::Path& path) {
|
||||||
|
|
||||||
int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||||
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
||||||
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
hwy::ThreadPool& pool, const std::string& golden_path) {
|
||||||
const std::string& golden_path) {
|
|
||||||
const std::vector<std::pair<std::string, std::string>> queries_answers =
|
const std::vector<std::pair<std::string, std::string>> queries_answers =
|
||||||
load_goldens(golden_path);
|
load_goldens(golden_path);
|
||||||
int correct_answers = 0;
|
int correct_answers = 0;
|
||||||
|
|
@ -143,7 +138,7 @@ int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||||
const double time_start = hwy::platform::Now();
|
const double time_start = hwy::platform::Now();
|
||||||
for (const auto& [question, expected_answer] : queries_answers) {
|
for (const auto& [question, expected_answer] : queries_answers) {
|
||||||
const auto [answer, token_count] =
|
const auto [answer, token_count] =
|
||||||
QueryModel(model, args, app, kv_cache, inner_pool, pool, question);
|
QueryModel(model, args, app, kv_cache, pool, question);
|
||||||
total_tokens += token_count;
|
total_tokens += token_count;
|
||||||
if (answer.find(expected_answer) != std::string::npos) {
|
if (answer.find(expected_answer) != std::string::npos) {
|
||||||
correct_answers++;
|
correct_answers++;
|
||||||
|
|
@ -167,14 +162,13 @@ int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||||
|
|
||||||
int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||||
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
||||||
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
hwy::ThreadPool& pool, const gcpp::Path& text) {
|
||||||
const gcpp::Path& text) {
|
|
||||||
std::string prompt("Here is some text to summarize:\n");
|
std::string prompt("Here is some text to summarize:\n");
|
||||||
prompt.append(ReadFile(text));
|
prompt.append(ReadFile(text));
|
||||||
prompt.append("\nSummarize this text.\n");
|
prompt.append("\nSummarize this text.\n");
|
||||||
const double time_start = hwy::platform::Now();
|
const double time_start = hwy::platform::Now();
|
||||||
const auto [answer, token_count] =
|
const auto [answer, token_count] =
|
||||||
QueryModel(model, args, app, kv_cache, inner_pool, pool, prompt);
|
QueryModel(model, args, app, kv_cache, pool, prompt);
|
||||||
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
|
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
|
||||||
LogSpeedStats(time_start, token_count);
|
LogSpeedStats(time_start, token_count);
|
||||||
return EXIT_SUCCESS;
|
return EXIT_SUCCESS;
|
||||||
|
|
@ -182,8 +176,8 @@ int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||||
|
|
||||||
int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
|
int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
|
||||||
gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
||||||
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
hwy::ThreadPool& pool, const gcpp::Path& text,
|
||||||
const gcpp::Path& text, size_t batch_tokens) {
|
size_t batch_tokens) {
|
||||||
std::string input = ReadFile(text);
|
std::string input = ReadFile(text);
|
||||||
std::vector<int> prompt;
|
std::vector<int> prompt;
|
||||||
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
|
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
|
||||||
|
|
@ -200,7 +194,7 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
|
||||||
auto kv_cache = CreateKVCache(model_type);
|
auto kv_cache = CreateKVCache(model_type);
|
||||||
float entropy =
|
float entropy =
|
||||||
ComputeCrossEntropy(model, num_tokens, prompt_slice, kv_cache, pool,
|
ComputeCrossEntropy(model, num_tokens, prompt_slice, kv_cache, pool,
|
||||||
inner_pool, app.verbosity);
|
app.verbosity);
|
||||||
total_entropy += entropy;
|
total_entropy += entropy;
|
||||||
LogSpeedStats(time_start, pos + num_tokens);
|
LogSpeedStats(time_start, pos + num_tokens);
|
||||||
std::string text_slice;
|
std::string text_slice;
|
||||||
|
|
@ -214,8 +208,8 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
|
||||||
|
|
||||||
int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||||
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
||||||
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
hwy::ThreadPool& pool, const gcpp::Path& json_file,
|
||||||
const gcpp::Path& json_file, size_t max_questions) {
|
size_t max_questions) {
|
||||||
std::ifstream trivia_file(json_file.path);
|
std::ifstream trivia_file(json_file.path);
|
||||||
if (!trivia_file) {
|
if (!trivia_file) {
|
||||||
std::cout << "Could not load file: " << json_file.path << "\n"
|
std::cout << "Could not load file: " << json_file.path << "\n"
|
||||||
|
|
@ -228,7 +222,7 @@ int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||||
while (std::getline(trivia_file, line)) {
|
while (std::getline(trivia_file, line)) {
|
||||||
json data = json::parse(line);
|
json data = json::parse(line);
|
||||||
const auto [answer, token_count] = QueryModel(
|
const auto [answer, token_count] = QueryModel(
|
||||||
model, args, app, kv_cache, inner_pool, pool, data["question"]);
|
model, args, app, kv_cache, pool, data["question"]);
|
||||||
std::cout << answer << "\n";
|
std::cout << answer << "\n";
|
||||||
bool correct = false;
|
bool correct = false;
|
||||||
for (const std::string expected : data["answer"]["aliases"]) {
|
for (const std::string expected : data["answer"]["aliases"]) {
|
||||||
|
|
@ -266,7 +260,6 @@ int main(int argc, char** argv) {
|
||||||
HWY_ABORT("\nInvalid inference args: %s", error);
|
HWY_ABORT("\nInvalid inference args: %s", error);
|
||||||
}
|
}
|
||||||
|
|
||||||
hwy::ThreadPool inner_pool(0);
|
|
||||||
hwy::ThreadPool pool(app.num_threads);
|
hwy::ThreadPool pool(app.num_threads);
|
||||||
// For many-core, pinning threads to cores helps.
|
// For many-core, pinning threads to cores helps.
|
||||||
if (app.num_threads > 10) {
|
if (app.num_threads > 10) {
|
||||||
|
|
@ -283,17 +276,16 @@ int main(int argc, char** argv) {
|
||||||
if (!benchmark_args.goldens.path.empty()) {
|
if (!benchmark_args.goldens.path.empty()) {
|
||||||
const std::string golden_path =
|
const std::string golden_path =
|
||||||
benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt";
|
benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt";
|
||||||
return BenchmarkGoldens(model, args, app, kv_cache, inner_pool, pool,
|
return BenchmarkGoldens(model, args, app, kv_cache, pool, golden_path);
|
||||||
golden_path);
|
|
||||||
} else if (!benchmark_args.summarize_text.path.empty()) {
|
} else if (!benchmark_args.summarize_text.path.empty()) {
|
||||||
return BenchmarkSummary(model, args, app, kv_cache, inner_pool, pool,
|
return BenchmarkSummary(model, args, app, kv_cache, pool,
|
||||||
benchmark_args.summarize_text);
|
benchmark_args.summarize_text);
|
||||||
} else if (!benchmark_args.cross_entropy.path.empty()) {
|
} else if (!benchmark_args.cross_entropy.path.empty()) {
|
||||||
return BenchmarkCrossEntropy(model, loader.ModelType(), args, app,
|
return BenchmarkCrossEntropy(model, loader.ModelType(), args, app,
|
||||||
inner_pool, pool, benchmark_args.cross_entropy,
|
pool, benchmark_args.cross_entropy,
|
||||||
benchmark_args.batch_tokens);
|
benchmark_args.batch_tokens);
|
||||||
} else if (!benchmark_args.trivia_qa.path.empty()) {
|
} else if (!benchmark_args.trivia_qa.path.empty()) {
|
||||||
return BenchmarkTriviaQA(model, args, app, kv_cache, inner_pool, pool,
|
return BenchmarkTriviaQA(model, args, app, kv_cache, pool,
|
||||||
benchmark_args.trivia_qa,
|
benchmark_args.trivia_qa,
|
||||||
benchmark_args.max_questions);
|
benchmark_args.max_questions);
|
||||||
}
|
}
|
||||||
|
|
@ -18,12 +18,8 @@
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
#include "gemma/gemma.h" // Gemma
|
||||||
#include "gemma.h" // Gemma
|
|
||||||
// copybara:end
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
// copybara:end
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -59,7 +55,7 @@ struct Args : public ArgsBase<Args> {
|
||||||
return "Missing --compressed_weights flag, a file for the compressed "
|
return "Missing --compressed_weights flag, a file for the compressed "
|
||||||
"model.";
|
"model.";
|
||||||
}
|
}
|
||||||
if (!weights.exists()) {
|
if (!weights.Exists()) {
|
||||||
return "Can't open file specified with --weights flag.";
|
return "Can't open file specified with --weights flag.";
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
@ -74,7 +70,7 @@ struct Args : public ArgsBase<Args> {
|
||||||
template <class Visitor>
|
template <class Visitor>
|
||||||
void ForEach(const Visitor& visitor) {
|
void ForEach(const Visitor& visitor) {
|
||||||
visitor(weights, "weights", Path(),
|
visitor(weights, "weights", Path(),
|
||||||
"Path name of model weights (.sbs) file.\n"
|
"Path to model weights (.bin) file.\n"
|
||||||
" Required argument.");
|
" Required argument.");
|
||||||
visitor(model_type_str, "model", std::string(),
|
visitor(model_type_str, "model", std::string(),
|
||||||
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
|
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
|
||||||
|
|
@ -84,7 +80,7 @@ struct Args : public ArgsBase<Args> {
|
||||||
"gr2b-pt = griffin 2B parameters, pretrained\n "
|
"gr2b-pt = griffin 2B parameters, pretrained\n "
|
||||||
" Required argument.");
|
" Required argument.");
|
||||||
visitor(compressed_weights, "compressed_weights", Path(),
|
visitor(compressed_weights, "compressed_weights", Path(),
|
||||||
"Path name where compressed weights file will be written.\n"
|
"Path name where compressed weights (.sbs) file will be written.\n"
|
||||||
" Required argument.");
|
" Required argument.");
|
||||||
visitor(num_threads, "num_threads",
|
visitor(num_threads, "num_threads",
|
||||||
kDefaultNumThreads, // see ChooseNumThreads
|
kDefaultNumThreads, // see ChooseNumThreads
|
||||||
|
|
@ -15,8 +15,8 @@
|
||||||
|
|
||||||
// Model configurations
|
// Model configurations
|
||||||
|
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_
|
||||||
|
|
||||||
// Allow changing pre-allocated kv cache size as a compiler flag
|
// Allow changing pre-allocated kv cache size as a compiler flag
|
||||||
#ifndef GEMMA_MAX_SEQLEN
|
#ifndef GEMMA_MAX_SEQLEN
|
||||||
|
|
@ -28,11 +28,15 @@
|
||||||
#define GEMMA_TOPK 1
|
#define GEMMA_TOPK 1
|
||||||
#endif // !GEMMA_TOPK
|
#endif // !GEMMA_TOPK
|
||||||
|
|
||||||
|
// Allow changing upper bound on threads as a compiler flag
|
||||||
|
#ifndef GEMMA_MAX_THREADS
|
||||||
|
#define GEMMA_MAX_THREADS 128
|
||||||
|
#endif // !GEMMA_MAX_THREADS
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/sfp.h"
|
#include "compression/sfp.h"
|
||||||
#include "hwy/base.h" // hwy::bfloat16_t
|
#include "hwy/base.h" // hwy::bfloat16_t
|
||||||
|
|
||||||
|
|
@ -46,6 +50,7 @@ namespace gcpp {
|
||||||
|
|
||||||
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
|
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
|
||||||
static constexpr size_t kTopK = GEMMA_TOPK;
|
static constexpr size_t kTopK = GEMMA_TOPK;
|
||||||
|
static constexpr size_t kMaxThreads = GEMMA_MAX_THREADS;
|
||||||
|
|
||||||
enum class LayerAttentionType {
|
enum class LayerAttentionType {
|
||||||
kGemma,
|
kGemma,
|
||||||
|
|
@ -62,18 +67,36 @@ constexpr std::array<LayerAttentionType, kNum> FixedLayerConfig(
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <size_t kNumLayers>
|
||||||
|
constexpr size_t NumLayersOfTypeBefore(
|
||||||
|
const std::array<LayerAttentionType, kNumLayers>& layers,
|
||||||
|
LayerAttentionType type, size_t num) {
|
||||||
|
size_t count = 0;
|
||||||
|
for (size_t i = 0; i < num; i++) {
|
||||||
|
if (layers[i] == type) count++;
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
struct ConfigGemma7B {
|
struct ConfigGemma7B {
|
||||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||||
static constexpr int kVocabSize = 256000;
|
static constexpr int kVocabSize = 256000;
|
||||||
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
|
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
|
||||||
FixedLayerConfig<28>(LayerAttentionType::kGemma);
|
FixedLayerConfig<28>(LayerAttentionType::kGemma);
|
||||||
static constexpr int kLayers = kLayerConfig.size();
|
static constexpr int kLayers = kLayerConfig.size();
|
||||||
|
static constexpr int kGemmaLayers =
|
||||||
|
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
|
||||||
|
static constexpr int kGriffinLayers =
|
||||||
|
NumLayersOfTypeBefore(kLayerConfig,
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
|
kLayers);
|
||||||
static constexpr int kModelDim = 3072;
|
static constexpr int kModelDim = 3072;
|
||||||
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
|
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
|
||||||
static constexpr int kHeads = 16;
|
static constexpr int kHeads = 16;
|
||||||
static constexpr int kKVHeads = 16; // standard MHA
|
static constexpr int kKVHeads = 16; // standard MHA
|
||||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||||
static constexpr int kTopK = gcpp::kTopK;
|
static constexpr int kTopK = gcpp::kTopK;
|
||||||
|
static constexpr bool kAbsolutePE = false;
|
||||||
|
|
||||||
// SSM config.
|
// SSM config.
|
||||||
static constexpr int kConv1dWidth = 0;
|
static constexpr int kConv1dWidth = 0;
|
||||||
|
|
@ -92,12 +115,19 @@ struct ConfigGemma2B {
|
||||||
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
|
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
|
||||||
FixedLayerConfig<18>(LayerAttentionType::kGemma);
|
FixedLayerConfig<18>(LayerAttentionType::kGemma);
|
||||||
static constexpr int kLayers = kLayerConfig.size();
|
static constexpr int kLayers = kLayerConfig.size();
|
||||||
|
static constexpr int kGemmaLayers =
|
||||||
|
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
|
||||||
|
static constexpr int kGriffinLayers =
|
||||||
|
NumLayersOfTypeBefore(kLayerConfig,
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
|
kLayers);
|
||||||
static constexpr int kModelDim = 2048;
|
static constexpr int kModelDim = 2048;
|
||||||
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
|
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
|
||||||
static constexpr int kHeads = 8;
|
static constexpr int kHeads = 8;
|
||||||
static constexpr int kKVHeads = 1;
|
static constexpr int kKVHeads = 1;
|
||||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||||
static constexpr int kTopK = gcpp::kTopK;
|
static constexpr int kTopK = gcpp::kTopK;
|
||||||
|
static constexpr bool kAbsolutePE = false;
|
||||||
|
|
||||||
// SSM config.
|
// SSM config.
|
||||||
static constexpr int kConv1dWidth = 0;
|
static constexpr int kConv1dWidth = 0;
|
||||||
|
|
@ -144,12 +174,19 @@ struct ConfigGriffin2B {
|
||||||
LayerAttentionType::kGriffinRecurrentBlock,
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
};
|
};
|
||||||
static constexpr int kLayers = kLayerConfig.size();
|
static constexpr int kLayers = kLayerConfig.size();
|
||||||
|
static constexpr int kGemmaLayers =
|
||||||
|
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
|
||||||
|
static constexpr int kGriffinLayers =
|
||||||
|
NumLayersOfTypeBefore(kLayerConfig,
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
|
kLayers);
|
||||||
static constexpr int kModelDim = 2560;
|
static constexpr int kModelDim = 2560;
|
||||||
static constexpr int kFFHiddenDim = 7680;
|
static constexpr int kFFHiddenDim = 7680;
|
||||||
static constexpr int kHeads = 10;
|
static constexpr int kHeads = 10;
|
||||||
static constexpr int kKVHeads = 1;
|
static constexpr int kKVHeads = 1;
|
||||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||||
static constexpr int kTopK = gcpp::kTopK;
|
static constexpr int kTopK = gcpp::kTopK;
|
||||||
|
static constexpr bool kAbsolutePE = false;
|
||||||
|
|
||||||
// SSM config.
|
// SSM config.
|
||||||
static constexpr int kConv1dWidth = 4;
|
static constexpr int kConv1dWidth = 4;
|
||||||
|
|
@ -164,4 +201,4 @@ struct ConfigGriffin2B {
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -13,42 +13,42 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
#include "compression/io.h" // Path
|
||||||
#include "compression/compress.h" // SfpStream/NuqStream
|
#include "gemma/configs.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "configs.h"
|
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h" // hwy::bfloat16_t
|
#include "hwy/base.h" // hwy::bfloat16_t
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/args.h" // Path
|
|
||||||
// copybara:import_next_line:sentencepiece
|
|
||||||
#include "src/sentencepiece_processor.h"
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
using GemmaWeightT = GEMMA_WEIGHT_T;
|
using GemmaWeightT = GEMMA_WEIGHT_T;
|
||||||
using EmbedderInputT = hwy::bfloat16_t;
|
using EmbedderInputT = hwy::bfloat16_t;
|
||||||
|
// Will be called for layers output with:
|
||||||
|
// - position in the tokens sequence
|
||||||
|
// - name of the data, p.ex. "tokens", "block.1", "final_norm"
|
||||||
|
// - pointer to the data array
|
||||||
|
// - size of the data array
|
||||||
|
using LayersOutputT =
|
||||||
|
std::function<void(int, const std::string&, const float*, size_t)>;
|
||||||
constexpr size_t kPrefillBatchSize = 16;
|
constexpr size_t kPrefillBatchSize = 16;
|
||||||
constexpr bool kSystemPrompt = false;
|
constexpr bool kSystemPrompt = false;
|
||||||
|
|
||||||
struct KVCache {
|
struct KVCache {
|
||||||
hwy::AlignedFreeUniquePtr<float[]>
|
hwy::AlignedFreeUniquePtr<float[]>
|
||||||
key_cache; // kSeqLen * kNumGemmaLayers * kKVHeads * kQKVDim
|
kv_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim * 2
|
||||||
hwy::AlignedFreeUniquePtr<float[]>
|
hwy::AlignedFreeUniquePtr<float[]>
|
||||||
value_cache; // kSeqLen * kNumGemmaLayers * kKVHeads * kQKVDim
|
conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kGriffinLayers
|
||||||
hwy::AlignedFreeUniquePtr<float[]>
|
hwy::AlignedFreeUniquePtr<float[]>
|
||||||
conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kNumGriffinLayers
|
rglru_cache; // kModelDim * kGriffinLayers
|
||||||
hwy::AlignedFreeUniquePtr<float[]>
|
|
||||||
rglru_cache; // kModelDim * kNumGriffinLayers
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Model variants: see configs.h for details.
|
// Model variants: see configs.h for details.
|
||||||
|
|
@ -71,6 +71,7 @@ struct GemmaInterface;
|
||||||
|
|
||||||
class GemmaTokenizer {
|
class GemmaTokenizer {
|
||||||
public:
|
public:
|
||||||
|
virtual ~GemmaTokenizer() = default;
|
||||||
virtual bool Encode(const std::string& input,
|
virtual bool Encode(const std::string& input,
|
||||||
std::vector<std::string>* pieces) const = 0;
|
std::vector<std::string>* pieces) const = 0;
|
||||||
virtual bool Encode(const std::string& input,
|
virtual bool Encode(const std::string& input,
|
||||||
|
|
@ -82,7 +83,7 @@ class GemmaTokenizer {
|
||||||
struct Gemma {
|
struct Gemma {
|
||||||
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||||
hwy::ThreadPool& pool);
|
hwy::ThreadPool& pool);
|
||||||
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
~Gemma(); // must be defined after the GemmaInterface dtor is defined.
|
||||||
const GemmaTokenizer* Tokenizer() const;
|
const GemmaTokenizer* Tokenizer() const;
|
||||||
std::unique_ptr<GemmaInterface> impl_;
|
std::unique_ptr<GemmaInterface> impl_;
|
||||||
};
|
};
|
||||||
|
|
@ -96,16 +97,17 @@ KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
|
||||||
using StreamFunc = std::function<bool(int, float)>;
|
using StreamFunc = std::function<bool(int, float)>;
|
||||||
using AcceptFunc = std::function<bool(int)>;
|
using AcceptFunc = std::function<bool(int)>;
|
||||||
|
|
||||||
|
// layers_output is optional; if set - it will be called with the activations
|
||||||
|
// output after applying each layer.
|
||||||
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
||||||
float temperature, const std::vector<int>& prompt,
|
float temperature, const std::vector<int>& prompt,
|
||||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
const StreamFunc& stream_token,
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||||
int verbosity);
|
int verbosity, LayersOutputT* layers_output = nullptr);
|
||||||
|
|
||||||
// Convenience function for the common case:
|
// Convenience function for the common case:
|
||||||
// - Bundle runtime parameters as RuntimeConfig
|
// - Bundle runtime parameters as RuntimeConfig
|
||||||
// - No threadpools within threadpools (inner_pool = dummy)
|
|
||||||
// - All tokens accepted
|
// - All tokens accepted
|
||||||
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
|
|
@ -117,11 +119,10 @@ void CompressWeights(gcpp::Model model, const Path& weights,
|
||||||
|
|
||||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
hwy::ThreadPool& pool, int verbosity);
|
||||||
int verbosity);
|
|
||||||
|
|
||||||
constexpr int EOS_ID = 1;
|
constexpr int EOS_ID = 1;
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
|
||||||
|
|
@ -13,14 +13,16 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
#include "gemma/gemma.h"
|
||||||
#include "gemma.h"
|
|
||||||
|
|
||||||
#include <thread>
|
#include <algorithm>
|
||||||
|
#include <iostream>
|
||||||
|
#include <random>
|
||||||
|
#include <string>
|
||||||
|
#include <thread> // NOLINT
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
#include "gemma/ops.h"
|
||||||
#include "ops.h"
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/tests/test_util-inl.h"
|
#include "hwy/tests/test_util-inl.h"
|
||||||
|
|
@ -34,7 +36,6 @@ class GemmaTest : public ::testing::Test {
|
||||||
: weights("./2b-it-mqa.sbs"),
|
: weights("./2b-it-mqa.sbs"),
|
||||||
tokenizer("./tokenizer.spm"),
|
tokenizer("./tokenizer.spm"),
|
||||||
pool(std::min<int>(20, (std::thread::hardware_concurrency() - 1) / 2)),
|
pool(std::min<int>(20, (std::thread::hardware_concurrency() - 1) / 2)),
|
||||||
inner_pool(0),
|
|
||||||
model_type(gcpp::Model::GEMMA_2B),
|
model_type(gcpp::Model::GEMMA_2B),
|
||||||
model(tokenizer, weights, model_type, pool) {
|
model(tokenizer, weights, model_type, pool) {
|
||||||
kv_cache = CreateKVCache(model_type);
|
kv_cache = CreateKVCache(model_type);
|
||||||
|
|
@ -58,8 +59,8 @@ class GemmaTest : public ::testing::Test {
|
||||||
gcpp::GenerateGemma(
|
gcpp::GenerateGemma(
|
||||||
model, /*max_tokens=*/3072, /*max_generated_tokens=*/2048,
|
model, /*max_tokens=*/3072, /*max_generated_tokens=*/2048,
|
||||||
/*temperature=*/1.0, prompt, /*start_pos=*/0, kv_cache, pool,
|
/*temperature=*/1.0, prompt, /*start_pos=*/0, kv_cache, pool,
|
||||||
inner_pool, stream_token,
|
stream_token, /*accept=*/[](int) { return true; }, gen,
|
||||||
/*accept=*/[](int) { return true; }, gen, /*verbosity=*/0);
|
/*verbosity=*/0);
|
||||||
std::string response_text;
|
std::string response_text;
|
||||||
HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text));
|
HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text));
|
||||||
return response_text;
|
return response_text;
|
||||||
|
|
@ -69,8 +70,7 @@ class GemmaTest : public ::testing::Test {
|
||||||
std::vector<int> prompt;
|
std::vector<int> prompt;
|
||||||
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
|
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
|
||||||
return gcpp::ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt,
|
return gcpp::ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt,
|
||||||
kv_cache, pool, inner_pool,
|
kv_cache, pool, /*verbosity=*/0) /
|
||||||
/*verbosity=*/0) /
|
|
||||||
prompt_string.size();
|
prompt_string.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -79,7 +79,7 @@ class GemmaTest : public ::testing::Test {
|
||||||
std::cout << "Question " << i + 1 << "\n\n";
|
std::cout << "Question " << i + 1 << "\n\n";
|
||||||
std::string response = GemmaReply(kQA[i][0]);
|
std::string response = GemmaReply(kQA[i][0]);
|
||||||
std::cout << response << "\n\n";
|
std::cout << response << "\n\n";
|
||||||
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos);
|
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -87,7 +87,6 @@ class GemmaTest : public ::testing::Test {
|
||||||
gcpp::Path tokenizer;
|
gcpp::Path tokenizer;
|
||||||
gcpp::KVCache kv_cache;
|
gcpp::KVCache kv_cache;
|
||||||
hwy::ThreadPool pool;
|
hwy::ThreadPool pool;
|
||||||
hwy::ThreadPool inner_pool;
|
|
||||||
gcpp::Model model_type = {};
|
gcpp::Model model_type = {};
|
||||||
gcpp::Gemma model;
|
gcpp::Gemma model;
|
||||||
};
|
};
|
||||||
|
|
@ -14,8 +14,9 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
// Include guard for non-SIMD code.
|
// Include guard for non-SIMD code.
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_OPS_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
|
|
@ -24,6 +25,7 @@
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <type_traits> // std::enable_if_t
|
#include <type_traits> // std::enable_if_t
|
||||||
|
|
||||||
|
#include "compression/compress.h" // CompressedArray
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
@ -43,7 +45,7 @@ static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_OPS_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
|
||||||
|
|
||||||
// Include guard for (potentially) SIMD code.
|
// Include guard for (potentially) SIMD code.
|
||||||
#if defined(THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
#if defined(THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||||
|
|
@ -53,7 +55,6 @@ static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
|
||||||
#define THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE
|
#define THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/compress-inl.h"
|
#include "compression/compress-inl.h"
|
||||||
#include "hwy/contrib/algo/transform-inl.h"
|
#include "hwy/contrib/algo/transform-inl.h"
|
||||||
#include "hwy/contrib/dot/dot-inl.h"
|
#include "hwy/contrib/dot/dot-inl.h"
|
||||||
|
|
@ -92,12 +93,60 @@ HWY_INLINE constexpr size_t RowsPerStrip() {
|
||||||
return kRowsPerStrip;
|
return kRowsPerStrip;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
|
||||||
|
// ops_test across instruction sets.
|
||||||
|
template <size_t kM, size_t kN, size_t kK>
|
||||||
|
HWY_INLINE void MatMul(const float* HWY_RESTRICT a, const float* HWY_RESTRICT b,
|
||||||
|
float* HWY_RESTRICT out) {
|
||||||
|
int i, j, k;
|
||||||
|
for (i = 0; i < kM; ++i) {
|
||||||
|
for (k = 0; k < kN; ++k) {
|
||||||
|
for (j = 0; j < kK; ++j) {
|
||||||
|
out[i * kK + j] += a[i * kN + k] * b[k * kK + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
|
||||||
|
const size_t size, float* HWY_RESTRICT out) {
|
||||||
|
const hn::ScalableTag<float> df;
|
||||||
|
const hn::Repartition<hwy::bfloat16_t, decltype(df)> dbf16;
|
||||||
|
|
||||||
|
HWY_DASSERT(size % hn::Lanes(dbf16) == 0);
|
||||||
|
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
|
||||||
|
|
||||||
|
for (size_t i = 0; i < size; i += hn::Lanes(dbf16)) {
|
||||||
|
const auto interleaved = hn::LoadU(dbf16, vec_aligned + i);
|
||||||
|
hn::Store(hn::PromoteEvenTo(df, interleaved), df, out + i);
|
||||||
|
hn::Store(hn::PromoteOddTo(df, interleaved), df, out + i + hn::Lanes(df));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
HWY_INLINE void ToEvenOddF32(const float* HWY_RESTRICT vec_aligned,
|
||||||
|
const size_t size, float* HWY_RESTRICT out) {
|
||||||
|
const hn::ScalableTag<float> df;
|
||||||
|
using VF = hn::Vec<decltype(df)>;
|
||||||
|
|
||||||
|
HWY_DASSERT(size % (hn::Lanes(df) * 2) == 0);
|
||||||
|
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
|
||||||
|
|
||||||
|
VF vec0, vec1;
|
||||||
|
for (size_t i = 0; i < size; i += hn::Lanes(df) * 2) {
|
||||||
|
hn::LoadInterleaved2(df, vec_aligned + i, vec0, vec1);
|
||||||
|
hn::Store(vec0, df, out + i);
|
||||||
|
hn::Store(vec1, df, out + i + hn::Lanes(df));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Simple version without tiling nor threading.
|
// Simple version without tiling nor threading.
|
||||||
|
// even_odd is precomputed for the current thread.
|
||||||
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
||||||
typename VecT, typename AddT>
|
typename VecT, typename AddT>
|
||||||
HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs,
|
HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs,
|
||||||
const VecT* HWY_RESTRICT vec_aligned,
|
const VecT* HWY_RESTRICT vec_aligned,
|
||||||
const AddT* HWY_RESTRICT add,
|
const AddT* HWY_RESTRICT add,
|
||||||
|
float* HWY_RESTRICT even_odd,
|
||||||
float* HWY_RESTRICT out) {
|
float* HWY_RESTRICT out) {
|
||||||
PROFILER_ZONE("MatVecAddLoop");
|
PROFILER_ZONE("MatVecAddLoop");
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
|
|
@ -113,12 +162,40 @@ HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16
|
||||||
|
template <bool kAdd, size_t kOuter, size_t kInner, typename VecT, typename AddT,
|
||||||
|
size_t kCapacity>
|
||||||
|
HWY_INLINE void MatVecAddLoop(
|
||||||
|
const CompressedArray<hwy::bfloat16_t, kCapacity>& mat,
|
||||||
|
const size_t mat_ofs, const VecT* HWY_RESTRICT vec_aligned,
|
||||||
|
const AddT* HWY_RESTRICT add, float* HWY_RESTRICT even_odd,
|
||||||
|
float* HWY_RESTRICT out) {
|
||||||
|
PROFILER_ZONE("MatVecAddLoop");
|
||||||
|
constexpr bool kVecIsEvenOdd = true;
|
||||||
|
|
||||||
|
const hn::ScalableTag<float> df;
|
||||||
|
ToEvenOddF32(vec_aligned, kInner, even_odd);
|
||||||
|
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) {
|
||||||
|
const size_t row_ofs = mat_ofs + idx_row * kInner;
|
||||||
|
if constexpr (kAdd) {
|
||||||
|
out[idx_row] = hwy::ConvertScalarTo<float>(add[idx_row]) +
|
||||||
|
Dot<kVecIsEvenOdd>(df, mat, row_ofs, even_odd, kInner);
|
||||||
|
} else {
|
||||||
|
out[idx_row] = Dot<kVecIsEvenOdd>(df, mat, row_ofs, even_odd, kInner);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// even_odd is precomputed for the current thread.
|
||||||
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
||||||
HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs,
|
HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs,
|
||||||
const VecT* HWY_RESTRICT vec_aligned,
|
const VecT* HWY_RESTRICT vec_aligned,
|
||||||
|
float* HWY_RESTRICT even_odd,
|
||||||
float* HWY_RESTRICT out) {
|
float* HWY_RESTRICT out) {
|
||||||
MatVecAddLoop<false, kOuter, kInner, ArrayT, VecT, VecT>(
|
MatVecAddLoop</*kAdd=*/false, kOuter, kInner>(
|
||||||
mat, mat_ofs, vec_aligned, /*add=*/nullptr, out);
|
mat, mat_ofs, vec_aligned, /*add=*/static_cast<VecT*>(nullptr), even_odd,
|
||||||
|
out);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Simple version without tiling nor threading, but two offsets/outputs.
|
// Simple version without tiling nor threading, but two offsets/outputs.
|
||||||
|
|
@ -156,7 +233,7 @@ HWY_INLINE void TwoOfsMatVecLoop(const ArrayT& mat, const size_t mat_ofs0,
|
||||||
const VecT* HWY_RESTRICT vec_aligned,
|
const VecT* HWY_RESTRICT vec_aligned,
|
||||||
float* HWY_RESTRICT out0,
|
float* HWY_RESTRICT out0,
|
||||||
float* HWY_RESTRICT out1) {
|
float* HWY_RESTRICT out1) {
|
||||||
TwoOfsMatVecAddLoop<false, kOuter, kInner, ArrayT, VecT, VecT>(
|
TwoOfsMatVecAddLoop</*kAdd=*/false, kOuter, kInner, ArrayT, VecT, VecT>(
|
||||||
mat, mat_ofs0, mat_ofs1, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr,
|
mat, mat_ofs0, mat_ofs1, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr,
|
||||||
out0, out1);
|
out0, out1);
|
||||||
}
|
}
|
||||||
|
|
@ -166,20 +243,23 @@ namespace detail {
|
||||||
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product
|
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product
|
||||||
// of row i with `vec_aligned` and add into `out[i]`. The upper-left coordinate
|
// of row i with `vec_aligned` and add into `out[i]`. The upper-left coordinate
|
||||||
// of the tile is r0, c0.
|
// of the tile is r0, c0.
|
||||||
template <class DF, typename ArrayT, typename VecT>
|
template <bool kVecEO, class DF, typename ArrayT, typename VecT>
|
||||||
HWY_INLINE void AccumulatePartialDotProducts(
|
HWY_INLINE void AccumulatePartialDotProducts(
|
||||||
DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0,
|
DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0,
|
||||||
size_t c0, size_t num_rows, size_t num_cols,
|
size_t c0, size_t num_rows, size_t num_cols,
|
||||||
const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) {
|
const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) {
|
||||||
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
|
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
|
||||||
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
|
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
|
||||||
out[idx_row] += Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
out[idx_row] +=
|
||||||
|
Dot<kVecEO>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as above, but sets out[i] to the first partial dot product +
|
// Same as AccumulatePartialDotProducts, but sets out[i] to the first partial
|
||||||
// init (if kInit), which avoids having to zero-initialize and accumulate.
|
// dot product + init (if kInit), which avoids having to zero-initialize and
|
||||||
template <bool kInit, class DF, typename ArrayT, typename VecT, typename InitT>
|
// accumulate.
|
||||||
|
template <bool kVecEO, bool kInit, class DF, typename ArrayT, typename VecT,
|
||||||
|
typename InitT>
|
||||||
HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
|
HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
|
||||||
size_t mat_ofs, size_t mat_stride,
|
size_t mat_ofs, size_t mat_stride,
|
||||||
size_t r0, size_t c0,
|
size_t r0, size_t c0,
|
||||||
|
|
@ -190,10 +270,12 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
|
||||||
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
|
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
|
||||||
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
|
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
|
||||||
if constexpr (kInit) {
|
if constexpr (kInit) {
|
||||||
out[idx_row] = hwy::ConvertScalarTo<float>(init[idx_row + r0]) +
|
out[idx_row] =
|
||||||
Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
hwy::ConvertScalarTo<float>(init[idx_row + r0]) +
|
||||||
|
Dot<kVecEO>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
||||||
} else {
|
} else {
|
||||||
out[idx_row] = Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
out[idx_row] =
|
||||||
|
Dot<kVecEO>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -202,7 +284,8 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
|
||||||
// horizontal strip of the entire matrix); the result is the full dot product
|
// horizontal strip of the entire matrix); the result is the full dot product
|
||||||
// for rows r in [r0, r0 + num_rows) + optionally the add vector, which we store
|
// for rows r in [r0, r0 + num_rows) + optionally the add vector, which we store
|
||||||
// into in out[r - r0].
|
// into in out[r - r0].
|
||||||
template <bool kAdd, class DF, typename ArrayT, typename VecT, typename AddT>
|
template <bool kVecEO, bool kAdd, class DF, typename ArrayT, typename VecT,
|
||||||
|
typename AddT>
|
||||||
HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
|
HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
|
||||||
size_t mat_ofs, size_t mat_stride,
|
size_t mat_ofs, size_t mat_stride,
|
||||||
size_t r0, size_t num_rows,
|
size_t r0, size_t num_rows,
|
||||||
|
|
@ -211,25 +294,66 @@ HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
|
||||||
float* HWY_RESTRICT out) {
|
float* HWY_RESTRICT out) {
|
||||||
// Tall and skinny: set `out` to the single dot product.
|
// Tall and skinny: set `out` to the single dot product.
|
||||||
if (mat_stride < MaxCols()) {
|
if (mat_stride < MaxCols()) {
|
||||||
SetFirstPartialDotProducts<kAdd>(df, mat, mat_ofs, mat_stride, r0, 0,
|
SetFirstPartialDotProducts<kVecEO, kAdd>(df, mat, mat_ofs, mat_stride, r0,
|
||||||
num_rows, mat_stride, vec_aligned, add,
|
0, num_rows, mat_stride,
|
||||||
out);
|
vec_aligned, add, out);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// We have at least MaxCols, so start by setting `out` to that:
|
// We have at least MaxCols, so start by setting `out` to that:
|
||||||
SetFirstPartialDotProducts<kAdd>(df, mat, mat_ofs, mat_stride, r0, 0,
|
SetFirstPartialDotProducts<kVecEO, kAdd>(df, mat, mat_ofs, mat_stride, r0, 0,
|
||||||
num_rows, MaxCols(), vec_aligned, add, out);
|
num_rows, MaxCols(), vec_aligned,
|
||||||
|
add, out);
|
||||||
// For further multiples of MaxCols, accumulate. Remainders handled below.
|
// For further multiples of MaxCols, accumulate. Remainders handled below.
|
||||||
size_t c0 = MaxCols();
|
size_t c0 = MaxCols();
|
||||||
for (; c0 <= mat_stride - MaxCols(); c0 += MaxCols()) {
|
for (; c0 <= mat_stride - MaxCols(); c0 += MaxCols()) {
|
||||||
AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows,
|
AccumulatePartialDotProducts<kVecEO>(df, mat, mat_ofs, mat_stride, r0, c0,
|
||||||
MaxCols(), vec_aligned, out);
|
num_rows, MaxCols(), vec_aligned, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (c0 < mat_stride) { // Final cols
|
if (c0 < mat_stride) { // Final cols
|
||||||
AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows,
|
AccumulatePartialDotProducts<kVecEO>(df, mat, mat_ofs, mat_stride, r0, c0,
|
||||||
mat_stride - c0, vec_aligned, out);
|
num_rows, mat_stride - c0, vec_aligned,
|
||||||
|
out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool kVecIsEvenOdd, bool kAdd, size_t kOuter, size_t kInner,
|
||||||
|
typename ArrayT, typename VecT, typename AddT>
|
||||||
|
HWY_INLINE void MatVecAddInner(const ArrayT& mat, const size_t mat_ofs,
|
||||||
|
const VecT* HWY_RESTRICT const vec_aligned,
|
||||||
|
const AddT* HWY_RESTRICT const add,
|
||||||
|
float* HWY_RESTRICT even_odd,
|
||||||
|
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
|
||||||
|
const hn::ScalableTag<float> df;
|
||||||
|
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
|
||||||
|
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
|
||||||
|
|
||||||
|
// Sanity check: each thread can write without race conditions.
|
||||||
|
if (HWY_IS_TSAN) {
|
||||||
|
pool.Run(
|
||||||
|
0, pool.NumWorkers(), [even_odd](uint64_t /*task*/, size_t thread) {
|
||||||
|
even_odd[thread * kInner] = -static_cast<float>(thread);
|
||||||
|
even_odd[thread * kInner + kInner - 1] = static_cast<float>(thread);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// For each entire strip.
|
||||||
|
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
|
||||||
|
PROFILER_ZONE("MatVec.lambda");
|
||||||
|
const size_t r0 = strip * kRowsPerStrip;
|
||||||
|
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
|
||||||
|
df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add,
|
||||||
|
out + r0);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Remaining rows
|
||||||
|
const size_t r0 = kNumStrips * kRowsPerStrip;
|
||||||
|
if (r0 < kOuter) {
|
||||||
|
PROFILER_ZONE("MatVec remainder");
|
||||||
|
const size_t num_rows = kOuter - r0;
|
||||||
|
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
|
||||||
|
df, mat, mat_ofs, kInner, r0, num_rows, vec_aligned, add, out + r0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -243,38 +367,32 @@ template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
||||||
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
||||||
const VecT* HWY_RESTRICT const vec_aligned,
|
const VecT* HWY_RESTRICT const vec_aligned,
|
||||||
const AddT* HWY_RESTRICT const add,
|
const AddT* HWY_RESTRICT const add,
|
||||||
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
|
float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out,
|
||||||
|
hwy::ThreadPool& pool) {
|
||||||
PROFILER_ZONE("MatVecAdd");
|
PROFILER_ZONE("MatVecAdd");
|
||||||
|
|
||||||
const hn::ScalableTag<float> df;
|
#if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16
|
||||||
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
|
if constexpr (CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd &&
|
||||||
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
|
hwy::IsSameEither<VecT, float, hwy::bfloat16_t>()) {
|
||||||
|
ToEvenOddF32(vec_aligned, kInner, even_odd);
|
||||||
// For each entire strip.
|
detail::MatVecAddInner</*kVecIsEvenOdd=*/true, kAdd, kOuter, kInner>(
|
||||||
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
|
mat, mat_ofs, even_odd, add, even_odd, out, pool);
|
||||||
PROFILER_ZONE("MatVec.lambda");
|
return;
|
||||||
const size_t r0 = strip * kRowsPerStrip;
|
|
||||||
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, kInner, r0,
|
|
||||||
kRowsPerStrip, vec_aligned, add,
|
|
||||||
out + r0);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Remaining rows
|
|
||||||
const size_t r0 = kNumStrips * kRowsPerStrip;
|
|
||||||
if (r0 < kOuter) {
|
|
||||||
PROFILER_ZONE("MatVec remainder");
|
|
||||||
const size_t num_rows = kOuter - r0;
|
|
||||||
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, kInner, r0,
|
|
||||||
num_rows, vec_aligned, add, out + r0);
|
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
detail::MatVecAddInner</*kVecIsEvenOdd=*/false, kAdd, kOuter, kInner>(
|
||||||
|
mat, mat_ofs, vec_aligned, add, even_odd, out, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
||||||
HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs,
|
HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs,
|
||||||
const VecT* HWY_RESTRICT const vec_aligned,
|
const VecT* HWY_RESTRICT const vec_aligned,
|
||||||
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
|
float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out,
|
||||||
MatVecAdd<false, kOuter, kInner, ArrayT, VecT, VecT>(
|
hwy::ThreadPool& pool) {
|
||||||
mat, mat_ofs, vec_aligned, /*add=*/nullptr, out, pool);
|
MatVecAdd</*kAdd=*/false, kOuter, kInner>(mat, mat_ofs, vec_aligned,
|
||||||
|
/*add=*/static_cast<VecT*>(nullptr),
|
||||||
|
even_odd, out, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class D, HWY_IF_F32_D(D)>
|
template <class D, HWY_IF_F32_D(D)>
|
||||||
|
|
@ -396,16 +514,17 @@ HWY_NOINLINE void TwoMatVecAdd(
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
|
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
|
||||||
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
|
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
|
||||||
|
constexpr bool kVecIsEvenOdd = false;
|
||||||
|
|
||||||
// For each entire strip.
|
// For each entire strip.
|
||||||
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
|
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
|
||||||
PROFILER_ZONE("TwoMatVec.lambda");
|
PROFILER_ZONE("TwoMatVec.lambda");
|
||||||
const size_t r0 = strip * kRowsPerStrip;
|
const size_t r0 = strip * kRowsPerStrip;
|
||||||
detail::FullDotProductsForStrip<kAdd>(df, mat0, mat_ofs, kInner, r0,
|
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
|
||||||
kRowsPerStrip, vec_aligned, add0,
|
df, mat0, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add0,
|
||||||
out0 + r0);
|
out0 + r0);
|
||||||
detail::FullDotProductsForStrip<kAdd>(df, mat1, mat_ofs, kInner, r0,
|
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
|
||||||
kRowsPerStrip, vec_aligned, add1,
|
df, mat1, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add1,
|
||||||
out1 + r0);
|
out1 + r0);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -414,9 +533,9 @@ HWY_NOINLINE void TwoMatVecAdd(
|
||||||
if (r0 < kOuter) {
|
if (r0 < kOuter) {
|
||||||
PROFILER_ZONE("TwoMatVec remainder");
|
PROFILER_ZONE("TwoMatVec remainder");
|
||||||
const size_t num_rows = kOuter - r0;
|
const size_t num_rows = kOuter - r0;
|
||||||
detail::FullDotProductsForStrip<kAdd>(
|
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
|
||||||
df, mat0, mat_ofs, kInner, r0, num_rows, vec_aligned, add0, out0 + r0);
|
df, mat0, mat_ofs, kInner, r0, num_rows, vec_aligned, add0, out0 + r0);
|
||||||
detail::FullDotProductsForStrip<kAdd>(
|
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
|
||||||
df, mat1, mat_ofs, kInner, r0, num_rows, vec_aligned, add1, out1 + r0);
|
df, mat1, mat_ofs, kInner, r0, num_rows, vec_aligned, add1, out1 + r0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -427,7 +546,7 @@ HWY_NOINLINE void TwoMatVec(const ArrayT& mat0, const ArrayT& mat1,
|
||||||
const VecT* HWY_RESTRICT vec_aligned,
|
const VecT* HWY_RESTRICT vec_aligned,
|
||||||
float* HWY_RESTRICT out0, float* HWY_RESTRICT out1,
|
float* HWY_RESTRICT out0, float* HWY_RESTRICT out1,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
TwoMatVecAdd<false, kOuter, kInner, ArrayT, VecT, VecT>(
|
TwoMatVecAdd</*kAdd=*/false, kOuter, kInner, ArrayT, VecT, VecT>(
|
||||||
mat0, mat1, mat_ofs, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr,
|
mat0, mat1, mat_ofs, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr,
|
||||||
out0, out1, pool);
|
out0, out1, pool);
|
||||||
}
|
}
|
||||||
|
|
@ -17,22 +17,27 @@
|
||||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "compression/compress.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#undef HWY_TARGET_INCLUDE
|
#undef HWY_TARGET_INCLUDE
|
||||||
#define HWY_TARGET_INCLUDE "ops_test.cc" //NOLINT
|
#define HWY_TARGET_INCLUDE "gemma/ops_test.cc" //NOLINT
|
||||||
// clang-format on
|
// clang-format on
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/tests/test_util-inl.h"
|
#include "hwy/tests/test_util-inl.h"
|
||||||
// After highway.h
|
// After highway.h
|
||||||
// copybara:import_next_line:gemma_cpp
|
#include "gemma/ops.h"
|
||||||
#include "ops.h"
|
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -373,15 +378,54 @@ CompressedArray<float, kOuter * kInner> GenerateMat(size_t offset) {
|
||||||
return mat;
|
return mat;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <size_t kOuter, size_t kInner>
|
||||||
|
CompressedArray<float, kOuter * kInner> GenerateZeroMat(size_t offset) {
|
||||||
|
hwy::ThreadPool pool(static_cast<size_t>(std::clamp(
|
||||||
|
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 4)));
|
||||||
|
gcpp::CompressWorkingSet ws;
|
||||||
|
CompressedArray<float, kOuter * kInner> mat;
|
||||||
|
std::array<float, kOuter * kInner> content;
|
||||||
|
|
||||||
|
pool.Run(0, kOuter, [&](const size_t i, size_t thread) {
|
||||||
|
for (size_t j = 0; j < kInner; j++) {
|
||||||
|
content[i * kInner + j] = 0.0f;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Compress(content, ws, mat, pool);
|
||||||
|
mat.set_scale(1.0f);
|
||||||
|
return mat;
|
||||||
|
}
|
||||||
|
|
||||||
template <size_t length>
|
template <size_t length>
|
||||||
hwy::AlignedFreeUniquePtr<float[]> GenerateVec(size_t offset) {
|
hwy::AlignedFreeUniquePtr<float[]> GenerateVec(size_t offset) {
|
||||||
hwy::AlignedFreeUniquePtr<float[]> vec = hwy::AllocateAligned<float>(length);
|
hwy::AlignedFreeUniquePtr<float[]> vec = hwy::AllocateAligned<float>(length);
|
||||||
|
HWY_ASSERT(vec);
|
||||||
for (size_t idx = 0; idx < length; idx++) {
|
for (size_t idx = 0; idx < length; idx++) {
|
||||||
vec[idx] = static_cast<float>(idx + offset);
|
vec[idx] = static_cast<float>(idx + offset);
|
||||||
}
|
}
|
||||||
return vec;
|
return vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A simple matrix multiplication. No optimization / tiling.
|
||||||
|
template <size_t kM, size_t kN, size_t kK>
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> SimpleMatMul(
|
||||||
|
const hwy::AlignedFreeUniquePtr<float[]>& a,
|
||||||
|
const hwy::AlignedFreeUniquePtr<float[]>& b) {
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> out = hwy::AllocateAligned<float>(kM * kK);
|
||||||
|
hwy::ZeroBytes(out.get(), kM * kK * sizeof(float));
|
||||||
|
|
||||||
|
int i, j, k;
|
||||||
|
for (i = 0; i < kM; ++i) {
|
||||||
|
for (j = 0; j < kK; ++j) {
|
||||||
|
for (k = 0; k < kN; ++k) {
|
||||||
|
out[i * kK + j] += a[i * kN + k] * b[k * kK + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
template <size_t kOuter, size_t kInner>
|
template <size_t kOuter, size_t kInner>
|
||||||
hwy::AlignedFreeUniquePtr<float[]> SimpleMatVecAdd(
|
hwy::AlignedFreeUniquePtr<float[]> SimpleMatVecAdd(
|
||||||
const CompressedArray<float, kOuter * kInner>& mat,
|
const CompressedArray<float, kOuter * kInner>& mat,
|
||||||
|
|
@ -389,8 +433,9 @@ hwy::AlignedFreeUniquePtr<float[]> SimpleMatVecAdd(
|
||||||
const hwy::AlignedFreeUniquePtr<float[]>& add) {
|
const hwy::AlignedFreeUniquePtr<float[]>& add) {
|
||||||
hwy::AlignedFreeUniquePtr<float[]> uncompressed_mat =
|
hwy::AlignedFreeUniquePtr<float[]> uncompressed_mat =
|
||||||
hwy::AllocateAligned<float>(kOuter * kInner);
|
hwy::AllocateAligned<float>(kOuter * kInner);
|
||||||
Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner);
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> out = hwy::AllocateAligned<float>(kOuter);
|
hwy::AlignedFreeUniquePtr<float[]> out = hwy::AllocateAligned<float>(kOuter);
|
||||||
|
HWY_ASSERT(uncompressed_mat && out);
|
||||||
|
Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner);
|
||||||
for (size_t idx_row = 0; idx_row < kOuter; idx_row++) {
|
for (size_t idx_row = 0; idx_row < kOuter; idx_row++) {
|
||||||
out[idx_row] = add[idx_row];
|
out[idx_row] = add[idx_row];
|
||||||
for (size_t idx_col = 0; idx_col < kInner; idx_col++) {
|
for (size_t idx_col = 0; idx_col < kInner; idx_col++) {
|
||||||
|
|
@ -412,6 +457,52 @@ void AssertClose(const hwy::AlignedFreeUniquePtr<float[]>& a,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename MatT>
|
||||||
|
void AssertClose(const hwy::AlignedFreeUniquePtr<MatT[]>& expected,
|
||||||
|
const hwy::AlignedFreeUniquePtr<MatT[]>& actual, size_t num) {
|
||||||
|
for (size_t idx = 0; idx < num; idx++) {
|
||||||
|
double expected_value = hwy::ConvertScalarTo<double>(expected[idx]);
|
||||||
|
double actual_value = hwy::ConvertScalarTo<double>(actual[idx]);
|
||||||
|
|
||||||
|
const double tolerance =
|
||||||
|
expected_value * 20 * 1.0 / (1ULL << hwy::MantissaBits<MatT>());
|
||||||
|
|
||||||
|
if (!(expected_value - tolerance <= actual_value &&
|
||||||
|
actual_value <= expected_value + tolerance)) {
|
||||||
|
fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f\n", idx,
|
||||||
|
expected_value, idx, actual_value);
|
||||||
|
HWY_ASSERT(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestMatMul() {
|
||||||
|
hwy::ThreadPool pool(0);
|
||||||
|
constexpr size_t kM = 128 * 3; // 384
|
||||||
|
constexpr size_t kK = 128 * 5; // 640
|
||||||
|
constexpr size_t kN = 128 * 6; // 768
|
||||||
|
|
||||||
|
CompressedArray<float, kM * kN> a1 = GenerateMat<kM, kN>(0);
|
||||||
|
CompressedArray<float, kN * kK> b1 = GenerateMat<kN, kK>(0);
|
||||||
|
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> a = hwy::AllocateAligned<float>(kM * kN);
|
||||||
|
Decompress(a1, 0, a.get(), kM * kN);
|
||||||
|
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kN * kK);
|
||||||
|
Decompress(b1, 0, b.get(), kN * kK);
|
||||||
|
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> expected_out1 =
|
||||||
|
SimpleMatMul<kM, kN, kK>(a, b);
|
||||||
|
|
||||||
|
CompressedArray<float, kM * kK> compressed_c = GenerateZeroMat<kM, kK>(0);
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
|
||||||
|
Decompress(compressed_c, 0, c.get(), kM * kK);
|
||||||
|
|
||||||
|
MatMul<kM, kN, kK>(a.get(), b.get(), c.get());
|
||||||
|
|
||||||
|
AssertClose(expected_out1, c, kM * kK);
|
||||||
|
}
|
||||||
|
|
||||||
void TestMatVecAdd() {
|
void TestMatVecAdd() {
|
||||||
hwy::ThreadPool pool(0);
|
hwy::ThreadPool pool(0);
|
||||||
constexpr size_t kOuter = 128 * 3;
|
constexpr size_t kOuter = 128 * 3;
|
||||||
|
|
@ -419,27 +510,15 @@ void TestMatVecAdd() {
|
||||||
CompressedArray<float, kOuter * kInner> mat = GenerateMat<kOuter, kInner>(0);
|
CompressedArray<float, kOuter * kInner> mat = GenerateMat<kOuter, kInner>(0);
|
||||||
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
|
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
|
||||||
hwy::AlignedFreeUniquePtr<float[]> add = GenerateVec<kOuter>(0);
|
hwy::AlignedFreeUniquePtr<float[]> add = GenerateVec<kOuter>(0);
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> even_odd =
|
||||||
|
hwy::AllocateAligned<float>(kInner * pool.NumWorkers());
|
||||||
hwy::AlignedFreeUniquePtr<float[]> expected_out =
|
hwy::AlignedFreeUniquePtr<float[]> expected_out =
|
||||||
SimpleMatVecAdd<kOuter, kInner>(mat, vec, add);
|
SimpleMatVecAdd<kOuter, kInner>(mat, vec, add);
|
||||||
hwy::AlignedFreeUniquePtr<float[]> actual_out =
|
hwy::AlignedFreeUniquePtr<float[]> actual_out =
|
||||||
hwy::AllocateAligned<float>(kOuter);
|
hwy::AllocateAligned<float>(kOuter);
|
||||||
MatVecAdd<true, kOuter, kInner>(mat, 0, vec.get(), add.get(),
|
HWY_ASSERT(vec && add && even_odd && expected_out && actual_out);
|
||||||
actual_out.get(), pool);
|
MatVecAdd</*kAdd=*/true, kOuter, kInner>(
|
||||||
AssertClose<kOuter>(actual_out, expected_out);
|
mat, 0, vec.get(), add.get(), even_odd.get(), actual_out.get(), pool);
|
||||||
}
|
|
||||||
|
|
||||||
void TestMatVecAddLoop() {
|
|
||||||
constexpr size_t kOuter = 128 * 3;
|
|
||||||
constexpr size_t kInner = 128 * 5;
|
|
||||||
CompressedArray<float, kOuter * kInner> mat = GenerateMat<kOuter, kInner>(0);
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> add = GenerateVec<kOuter>(0);
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> expected_out =
|
|
||||||
SimpleMatVecAdd<kOuter, kInner>(mat, vec, add);
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> actual_out =
|
|
||||||
hwy::AllocateAligned<float>(kOuter);
|
|
||||||
MatVecAddLoop<true, kOuter, kInner>(mat, 0, vec.get(), add.get(),
|
|
||||||
actual_out.get());
|
|
||||||
AssertClose<kOuter>(actual_out, expected_out);
|
AssertClose<kOuter>(actual_out, expected_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -460,6 +539,8 @@ void TestTwoMatVecAdd() {
|
||||||
hwy::AllocateAligned<float>(kOuter);
|
hwy::AllocateAligned<float>(kOuter);
|
||||||
hwy::AlignedFreeUniquePtr<float[]> actual_out1 =
|
hwy::AlignedFreeUniquePtr<float[]> actual_out1 =
|
||||||
hwy::AllocateAligned<float>(kOuter);
|
hwy::AllocateAligned<float>(kOuter);
|
||||||
|
HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 &&
|
||||||
|
expected_out1 && actual_out1);
|
||||||
TwoMatVecAdd<true, kOuter, kInner>(mat0, mat1, 0, vec.get(), add0.get(),
|
TwoMatVecAdd<true, kOuter, kInner>(mat0, mat1, 0, vec.get(), add0.get(),
|
||||||
add1.get(), actual_out0.get(),
|
add1.get(), actual_out0.get(),
|
||||||
actual_out1.get(), pool);
|
actual_out1.get(), pool);
|
||||||
|
|
@ -482,6 +563,8 @@ void TestTwoOfsMatVecAddLoop() {
|
||||||
hwy::AllocateAligned<float>(kOuter);
|
hwy::AllocateAligned<float>(kOuter);
|
||||||
hwy::AlignedFreeUniquePtr<float[]> actual_out1 =
|
hwy::AlignedFreeUniquePtr<float[]> actual_out1 =
|
||||||
hwy::AllocateAligned<float>(kOuter);
|
hwy::AllocateAligned<float>(kOuter);
|
||||||
|
HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 &&
|
||||||
|
expected_out1 && actual_out1);
|
||||||
TwoOfsMatVecAddLoop<true, kOuter, kInner>(mat, 0, 0, vec.get(), add0.get(),
|
TwoOfsMatVecAddLoop<true, kOuter, kInner>(mat, 0, 0, vec.get(), add0.get(),
|
||||||
add1.get(), actual_out0.get(),
|
add1.get(), actual_out0.get(),
|
||||||
actual_out1.get());
|
actual_out1.get());
|
||||||
|
|
@ -521,8 +604,8 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
||||||
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatMul);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAddLoop);
|
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);
|
||||||
|
|
@ -23,20 +23,16 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// Placeholder for internal header, do not modify.
|
// Placeholder for internal header, do not modify.
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
#include "gemma/gemma.h" // Gemma
|
||||||
#include "gemma.h" // Gemma
|
#include "util/app.h"
|
||||||
|
#include "util/args.h" // HasHelp
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/per_target.h"
|
#include "hwy/per_target.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/app.h"
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/args.h" // HasHelp
|
|
||||||
|
|
||||||
static constexpr bool kVerboseLogTokens = false;
|
static constexpr bool kVerboseLogTokens = false;
|
||||||
|
|
||||||
|
|
@ -98,11 +94,10 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
|
||||||
|
|
||||||
void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool,
|
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
hwy::ThreadPool& inner_pool, const InferenceArgs& args,
|
const InferenceArgs& args, int verbosity,
|
||||||
int verbosity, const gcpp::AcceptFunc& accept_token,
|
const gcpp::AcceptFunc& accept_token, std::string& eot_line) {
|
||||||
std::string& eot_line) {
|
|
||||||
PROFILER_ZONE("Gen.misc");
|
PROFILER_ZONE("Gen.misc");
|
||||||
int abs_pos = 0; // absolute token index over all turns
|
size_t abs_pos = 0; // absolute token index over all turns
|
||||||
int current_pos = 0; // token index within the current turn
|
int current_pos = 0; // token index within the current turn
|
||||||
int prompt_size{};
|
int prompt_size{};
|
||||||
|
|
||||||
|
|
@ -185,7 +180,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
// For instruction-tuned models: add control tokens.
|
// For instruction-tuned models: add control tokens.
|
||||||
prompt_string = "<start_of_turn>user\n" + prompt_string +
|
prompt_string = "<start_of_turn>user\n" + prompt_string +
|
||||||
"<end_of_turn>\n<start_of_turn>model\n";
|
"<end_of_turn>\n<start_of_turn>model\n";
|
||||||
if (abs_pos > 0) {
|
if (abs_pos != 0) {
|
||||||
// Prepend "<end_of_turn>" token if this is a multi-turn dialogue
|
// Prepend "<end_of_turn>" token if this is a multi-turn dialogue
|
||||||
// continuation.
|
// continuation.
|
||||||
prompt_string = "<end_of_turn>\n" + prompt_string;
|
prompt_string = "<end_of_turn>\n" + prompt_string;
|
||||||
|
|
@ -213,7 +208,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
|
|
||||||
const double time_start = hwy::platform::Now();
|
const double time_start = hwy::platform::Now();
|
||||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||||
args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool,
|
args.temperature, prompt, abs_pos, kv_cache, pool,
|
||||||
stream_token, accept_token, gen, verbosity);
|
stream_token, accept_token, gen, verbosity);
|
||||||
const double time_end = hwy::platform::Now();
|
const double time_end = hwy::platform::Now();
|
||||||
const double tok_sec = current_pos / (time_end - time_start);
|
const double tok_sec = current_pos / (time_end - time_start);
|
||||||
|
|
@ -233,7 +228,6 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
PROFILER_ZONE("Run.misc");
|
PROFILER_ZONE("Run.misc");
|
||||||
|
|
||||||
hwy::ThreadPool inner_pool(0);
|
|
||||||
hwy::ThreadPool pool(app.num_threads);
|
hwy::ThreadPool pool(app.num_threads);
|
||||||
// For many-core, pinning threads to cores helps.
|
// For many-core, pinning threads to cores helps.
|
||||||
if (app.num_threads > 10) {
|
if (app.num_threads > 10) {
|
||||||
|
|
@ -275,8 +269,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ReplGemma(
|
ReplGemma(
|
||||||
model, loader.ModelTraining(), kv_cache, pool, inner_pool, inference,
|
model, loader.ModelTraining(), kv_cache, pool, inference, app.verbosity,
|
||||||
app.verbosity,
|
|
||||||
/*accept_token=*/[](int) { return true; }, app.eot_line);
|
/*accept_token=*/[](int) { return true; }, app.eot_line);
|
||||||
}
|
}
|
||||||
|
|
||||||
28
util/app.h
28
util/app.h
|
|
@ -18,7 +18,6 @@
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
||||||
|
|
||||||
#include <iterator>
|
|
||||||
#if HWY_OS_LINUX
|
#if HWY_OS_LINUX
|
||||||
#include <sched.h>
|
#include <sched.h>
|
||||||
|
|
||||||
|
|
@ -32,13 +31,11 @@
|
||||||
#include <algorithm> // std::clamp
|
#include <algorithm> // std::clamp
|
||||||
#include <thread> // NOLINT>
|
#include <thread> // NOLINT>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
#include "compression/io.h" // Path
|
||||||
#include "configs.h"
|
#include "gemma/configs.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
#include "gemma/gemma.h"
|
||||||
#include "gemma.h"
|
|
||||||
#include "hwy/base.h" // HWY_ASSERT
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
|
#include "hwy/base.h" // HWY_ASSERT
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -49,6 +46,10 @@ static inline const char* CompiledConfig() {
|
||||||
return "msan";
|
return "msan";
|
||||||
} else if (HWY_IS_TSAN) {
|
} else if (HWY_IS_TSAN) {
|
||||||
return "tsan";
|
return "tsan";
|
||||||
|
#if defined(HWY_IS_HWASAN)
|
||||||
|
} else if (HWY_IS_HWASAN) {
|
||||||
|
return "hwasan";
|
||||||
|
#endif
|
||||||
#if defined(HWY_IS_UBSAN)
|
#if defined(HWY_IS_UBSAN)
|
||||||
} else if (HWY_IS_UBSAN) {
|
} else if (HWY_IS_UBSAN) {
|
||||||
return "ubsan";
|
return "ubsan";
|
||||||
|
|
@ -84,8 +85,7 @@ class AppArgs : public ArgsBase<AppArgs> {
|
||||||
void ChooseNumThreads() {
|
void ChooseNumThreads() {
|
||||||
if (num_threads == kDefaultNumThreads) {
|
if (num_threads == kDefaultNumThreads) {
|
||||||
// This is a rough heuristic, replace with something better in the future.
|
// This is a rough heuristic, replace with something better in the future.
|
||||||
num_threads = static_cast<size_t>(std::clamp(
|
num_threads = GetSupportedThreadCount();
|
||||||
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -95,6 +95,12 @@ class AppArgs : public ArgsBase<AppArgs> {
|
||||||
ChooseNumThreads();
|
ChooseNumThreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline size_t GetSupportedThreadCount() {
|
||||||
|
return static_cast<size_t>(
|
||||||
|
std::clamp(static_cast<int>(std::thread::hardware_concurrency()) - 2, 1,
|
||||||
|
HWY_MIN(static_cast<int>(kMaxThreads), 18)));
|
||||||
|
}
|
||||||
|
|
||||||
Path log; // output
|
Path log; // output
|
||||||
int verbosity;
|
int verbosity;
|
||||||
size_t num_threads;
|
size_t num_threads;
|
||||||
|
|
@ -137,7 +143,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
if (tokenizer.path.empty()) {
|
if (tokenizer.path.empty()) {
|
||||||
return "Missing --tokenizer flag, a file for the tokenizer is required.";
|
return "Missing --tokenizer flag, a file for the tokenizer is required.";
|
||||||
}
|
}
|
||||||
if (!tokenizer.exists()) {
|
if (!tokenizer.Exists()) {
|
||||||
return "Can't open file specified with --tokenizer flag.";
|
return "Can't open file specified with --tokenizer flag.";
|
||||||
}
|
}
|
||||||
if (!compressed_weights.path.empty()) {
|
if (!compressed_weights.path.empty()) {
|
||||||
|
|
@ -152,7 +158,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
if (weights.path.empty()) {
|
if (weights.path.empty()) {
|
||||||
return "Missing --weights flag, a file for the model weights.";
|
return "Missing --weights flag, a file for the model weights.";
|
||||||
}
|
}
|
||||||
if (!weights.exists()) {
|
if (!weights.Exists()) {
|
||||||
return "Can't open file specified with --weights flag.";
|
return "Can't open file specified with --weights flag.";
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
||||||
39
util/args.h
39
util/args.h
|
|
@ -23,48 +23,11 @@
|
||||||
#include <algorithm> // std::transform
|
#include <algorithm> // std::transform
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "compression/io.h"
|
||||||
#include "hwy/base.h" // HWY_ABORT
|
#include "hwy/base.h" // HWY_ABORT
|
||||||
|
|
||||||
#if defined(_WIN32)
|
|
||||||
#include <io.h>
|
|
||||||
#define F_OK 0
|
|
||||||
#define access _access
|
|
||||||
#else
|
|
||||||
#include <unistd.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// Wrapper for strings representing a path name. Differentiates vs. arbitrary
|
|
||||||
// strings and supports shortening for display purposes.
|
|
||||||
struct Path {
|
|
||||||
Path() {}
|
|
||||||
explicit Path(const char* p) : path(p) {}
|
|
||||||
|
|
||||||
Path& operator=(const char* other) {
|
|
||||||
path = other;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string Shortened() const {
|
|
||||||
constexpr size_t kMaxLen = 48;
|
|
||||||
constexpr size_t kCutPoint = kMaxLen / 2 - 5;
|
|
||||||
if (path.size() > kMaxLen) {
|
|
||||||
return std::string(begin(path), begin(path) + kCutPoint) + " ... " +
|
|
||||||
std::string(end(path) - kCutPoint, end(path));
|
|
||||||
}
|
|
||||||
if (path.empty()) return "[no path specified]";
|
|
||||||
return path;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Beware, TOCTOU.
|
|
||||||
bool exists() const {
|
|
||||||
return (access(path.c_str(), F_OK) == 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string path;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Args is a class that provides a ForEach member function which visits each of
|
// Args is a class that provides a ForEach member function which visits each of
|
||||||
// its member variables. ArgsBase provides functions called by Args to
|
// its member variables. ArgsBase provides functions called by Args to
|
||||||
// initialize values to their defaults (passed as an argument to the visitor),
|
// initialize values to their defaults (passed as an argument to the visitor),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue