Compare commits

...

57 Commits

Author SHA1 Message Date
Wang Xinping 2c038e1285 work with cmake install 2024-05-03 23:44:12 +08:00
Copybara-Service 8ed22e52bf Merge pull request #177 from szabadka:gemma2
PiperOrigin-RevId: 630388843
2024-05-03 07:52:27 -07:00
Zoltan Szabadka 19017fdb6d Fix expression in DASSERT() 2024-05-03 13:54:20 +00:00
Phil Culliton 28ca001d5e Matmul and test functions
PiperOrigin-RevId: 630373984
2024-05-03 06:39:36 -07:00
Zoltan Szabadka 429eb78512 Remove unused vars. 2024-05-03 13:37:17 +00:00
Zoltan Szabadka 3d72f17261 Use more parallelism in attention block in prefill mode.
Move the loop over the tokens inside the attention block and
then create kHeads * num_tokens threads.

This helps the multi-threaded speed only in case of the 2b gemma
model, but to be consistent we move the loop over the tokens inside
the griffin recurrent layer and the FFW layer as well. This is
also a preparation for using the MatMul operation later.

Benchmark results (summarization with 1600 tokens for prefill
and essay writing with 500 tokens for generation):

```
                   Prefill speed
Num threads      BEFORE       AFTER
32               61.76 t/s    65.08 t/s
64               89.46 t/s    98.62 t/s
```
2024-05-03 13:23:07 +00:00
Copybara-Service 6eeef2e2d9 Merge pull request #166 from samkaufman:deinterleave-vecs
PiperOrigin-RevId: 630360778
2024-05-03 05:23:31 -07:00
Copybara-Service 2a71333c8a Merge pull request #176 from szabadka:gemma3
PiperOrigin-RevId: 630131001
2024-05-02 11:41:05 -07:00
Zoltan Szabadka 9a2682d544 Use more parallelism in the QKV projections of the MHA block.
We compute all three projections with one MatVec and then copy
the kv part to the cache.

Benchmark results for 7b-it model that uses MHA blocks (summarization with
1600 tokens for prefill and essay writing with 500 tokens for generation):

```
                   Prefill speed                Generation speed
Num threads      BEFORE       AFTER            BEFORE       AFTER
32               13.75 t/s    14.80 t/s       9.22 t/s     9.77 t/s
64               19.89 t/s    24.83 t/s      12.46 t/s    13.66 t/s
```
2024-05-02 13:46:45 +00:00
Copybara-Service bafb8382f8 Merge pull request #175 from szabadka:gemma2
PiperOrigin-RevId: 630044058
2024-05-02 06:27:15 -07:00
Zoltan Szabadka 0afa480d90 Use more parallelism in the final output of the attention block.
We use MatVec instead of MatVecLoop for the per-head dense layers,
because we can parallelize more on the rows of the matrix than
on the number of heads. This will be even more efficient after
we rearrange the weights and can have a single MatVec operation.

Benchmark results (summarization with 1600 tokens for prefill
and essay writing with 500 tokens for generation):

```
                   Prefill speed                Generation speed
Num threads      BEFORE       AFTER            BEFORE       AFTER
32               58.24 t/s    61.79 t/s      32.11 t/s    32.62 t/s
64               83.62 t/s    92.00 t/s      41.10 t/s    41.80 t/s
```
2024-05-02 09:30:07 +00:00
Sam Kaufman 4a6173d929 Remove unused vars. 2024-05-02 00:41:44 -07:00
Sam Kaufman 564937ede6 Merge branch 'dev' into deinterleave-vecs 2024-04-30 16:23:04 -07:00
Sam Kaufman 2829ef17ad Check for HWY_NATIVE_DOT_BF16. 2024-04-30 15:19:28 -07:00
Sam Kaufman 59ebecce22 Fix: specialized MatVecAdd was never called. 2024-04-30 15:17:27 -07:00
Jan Wassenberg 12fb2f05cf Add per-thread even_odd storage for #166.
Also inline ProjQ and ProjKV lambdas,
add missing includes/deps for ops_test.

PiperOrigin-RevId: 629460608
2024-04-30 10:42:23 -07:00
Copybara-Service 8f04a8346d Merge pull request #172 from szabadka:gemma2
PiperOrigin-RevId: 629438917
2024-04-30 09:33:38 -07:00
Zoltan Szabadka f8ccb8e37c Fix kv offset computation for MHA config. 2024-04-30 16:19:14 +00:00
Copybara-Service 374fd7478a Merge pull request #170 from szabadka:gemma2
PiperOrigin-RevId: 629408279
2024-04-30 07:40:30 -07:00
Zoltan Szabadka afaca4efa8 Use more parallelism in the QKV projections in MQA mode.
Instead of MatVecLoop, we use MatVec and we combine k and v
into one 2 * kQKVDim long vector so that K and V projections
can be combined into one MatVec operation.

Benchmark results (summarization with 1600 tokens for prefill
and essay writing with 500 tokens for generation):

```
                   Prefill speed                Generation speed
Num threads      BEFORE       AFTER            BEFORE       AFTER
4                 9.81 t/s     9.96 t/s       8.39 t/s     8.46 t/s
18               31.50 t/s    36.67 t/s      23.10 t/s    25.83 t/s
32               45.36 t/s    58.91 t/s      27.60 t/s    31.25 t/s
64               57.72 t/s    80.64 t/s      35.40 t/s    39.76 t/s
```
2024-04-30 13:10:14 +00:00
Copybara-Service befe9fb07e Merge pull request #167 from szabadka:gemma2
PiperOrigin-RevId: 629325219
2024-04-30 01:00:37 -07:00
Sam Kaufman 6a78a23f4c Abstracted some MatVecAdd spec. dupes. 2024-04-29 16:23:38 -07:00
Sam Kaufman f608337fef Remove Bf16ToF32EO and use PromoteEvenTo and PromoteOddTo. 2024-04-29 14:13:07 -07:00
Sam Kaufman aa0b113214 (VecT*) to static_cast<VecT*>. 2024-04-29 12:53:47 -07:00
Sam Kaufman 5cb63346aa supports_eo -> kSupportsEvenOdd 2024-04-29 12:51:35 -07:00
Zoltan Szabadka 27117cc39f Simplify threading: remove the use of inner_pool.
We only used inner_pool in the prefill FFW function, and there we
can achieve sufficient parallelism on the rows of the matrix-vector
multiplications.

Benchmark results on a 1600-token summarization task:

```
               Prefill speed
Num threads    BEFORE         AFTER
4               9.24 t/s       9.76 t/s
18             31.41 t/s      31.16 t/s
32             31.41 t/s      45.13 t/s
64             31.03 t/s      57.85 t/s
```
2024-04-29 16:07:30 +00:00
Paul Chang 1d18c5a129 Improve documentation for compress_weights flags
PiperOrigin-RevId: 629053191
2024-04-29 06:49:50 -07:00
Sam Kaufman 0816a1070d Even-odd layout MatVecs for bf16 weights. 2024-04-28 20:09:25 -07:00
Jan Wassenberg 7a12e29027 Add error-checking for py binding, add missing include+hwasan check
PiperOrigin-RevId: 628453112
2024-04-26 10:59:41 -07:00
Paul Chang e8f59bb411 Fix underflow in NUQ ClusterCost()
PiperOrigin-RevId: 628137904
2024-04-25 11:28:51 -07:00
Phil Culliton 9e0ac5de34 Update Clif wrapper to work with latest gemma.cpp and add simple example
PiperOrigin-RevId: 628134201
2024-04-25 11:17:16 -07:00
Paul Chang 2d4de6b08b Support absolute positional embeddings from vanilla transformer
PiperOrigin-RevId: 628100831
2024-04-25 09:32:14 -07:00
Paul Chang 75eca87039 Simplify prefill early-exit (originally Merge #156)
PiperOrigin-RevId: 627788524
2024-04-24 11:11:42 -07:00
Copybara-Service b27d8d6b92 Merge pull request #156 from zeerd:dev
PiperOrigin-RevId: 627706909
2024-04-24 06:19:14 -07:00
Charles Chan ea45d7c4d7 Use lambda to split function and Make stream_token can break prefill, too 2024-04-23 22:55:01 +08:00
Paul Chang e8d29792ac New token validity assertions, improve prompt truncation warning
PiperOrigin-RevId: 627376194
2024-04-23 07:05:59 -07:00
Jan Wassenberg 3bf22abb22 Fix sign comparison warnings
PiperOrigin-RevId: 627299902
2024-04-23 01:16:51 -07:00
Jan Wassenberg ca971ef50f Document weight conversion
PiperOrigin-RevId: 626957718
2024-04-22 01:58:30 -07:00
Jan Wassenberg e9a0caed87 Further improve IO, enable multiple backends without -D.
Move Path into io.h and use for opening files.
Removes dependency of gemma_lib on args.
Separate Windows codepath instead of emulating POSIX functions.

Plus lint fixes.

PiperOrigin-RevId: 626279004
2024-04-19 00:40:29 -07:00
Paul Chang 38f1ea9b80 Eliminate redundant copies of TokenString()
Move this function outside of HWY_NAMESPACE since it doesn't need to be
optimized for any particular architecture.

PiperOrigin-RevId: 626098641
2024-04-18 11:31:50 -07:00
Jan Wassenberg a8ceb75f43 Improved IO abstraction layer
Move to unique_ptr-like File class.
Move `if OS_WIN` into wrapper functions.
exists -> Exists.

PiperOrigin-RevId: 625923056
2024-04-17 23:15:07 -07:00
Jan Wassenberg a939b5fc9f Update distortion.h to weighted average, add distortion_test.
More thorough checks in sfp_test and nuq_test.
nuq_test: use deterministic input generator.

PiperOrigin-RevId: 625602019
2024-04-17 01:44:19 -07:00
Copybara-Service 05e7e2b2bb Merge pull request #145 from atorero:dev
PiperOrigin-RevId: 624221085
2024-04-12 10:27:18 -07:00
Andrey Mikhaylov 4ef3da733a Fixed minor things and added comments. 2024-04-12 15:39:16 +00:00
Andrey Mikhaylov 2c5706f159 Add comments regarding layers output usage. 2024-04-12 15:39:16 +00:00
Andrey Mikhaylov 03284d752e Added layers output functionality to gemma and a binary debug_output to save the outputs to a json file. 2024-04-12 15:39:16 +00:00
Copybara-Service 342e998cb6 Merge pull request #142 from ufownl:refactor/data_structures
PiperOrigin-RevId: 623503486
2024-04-10 08:35:18 -07:00
RangerUFO e541707caa Rename the fields of Griffin weights 2024-04-10 21:04:31 +08:00
RangerUFO 4e960d67f6 Fix typos 2024-04-10 20:38:18 +08:00
RangerUFO 809bd0709d Refactor data structures to reduce memory usage 2024-04-10 19:35:23 +08:00
Jan Wassenberg 54120a5571 Mention Makefile contributed by @jart
PiperOrigin-RevId: 623436818
2024-04-10 03:21:10 -07:00
Jan Wassenberg 881eeffe0a Lint fixes: strcat, includes, arg naming
PiperOrigin-RevId: 623435210
2024-04-10 03:12:41 -07:00
Copybara-Service da91f4c4be Merge pull request #137 from zond:main
PiperOrigin-RevId: 623255639
2024-04-09 12:57:57 -07:00
Copybara-Service 827fec1904 Merge pull request #139 from ufownl:feature/public_layers
PiperOrigin-RevId: 623254705
2024-04-09 12:54:23 -07:00
RangerUFO 2099b37732 Change `NumGemmaLayers` and `NumGriffinLayers` to constants in configs 2024-04-09 20:44:41 +08:00
Jan Wassenberg a982ec1287 Move code to gemma/ so we can remove error-prone copybara: comments.
Also fix includes and Lint warnings.

PiperOrigin-RevId: 623127487
2024-04-09 04:45:42 -07:00
zond 9ca662dc14
Clarified README
Made it more visible that the recurrent weights are at a different Kaggle page.
2024-04-09 09:58:47 +02:00
39 changed files with 2041 additions and 1087 deletions

6
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,6 @@
{
"cmake.configureOnOpen": false,
"files.associations": {
"array": "cpp"
}
}

View File

@ -22,9 +22,7 @@ exports_files(["LICENSE"])
cc_library( cc_library(
name = "ops", name = "ops",
hdrs = [ hdrs = ["gemma/ops.h"],
"ops.h",
],
deps = [ deps = [
"//compression:compress", "//compression:compress",
"@hwy//:algo", "@hwy//:algo",
@ -41,42 +39,34 @@ cc_library(
cc_test( cc_test(
name = "ops_test", name = "ops_test",
size = "small", size = "small",
srcs = ["ops_test.cc"], srcs = ["gemma/ops_test.cc"],
local_defines = ["HWY_IS_TEST"], local_defines = ["HWY_IS_TEST"],
# for test_suite. # for test_suite.
tags = ["hwy_ops_test"], tags = ["hwy_ops_test"],
deps = [ deps = [
":ops", ":ops",
"@googletest//:gtest_main", "@googletest//:gtest_main",
"//compression:compress",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//:hwy_test_util", "@hwy//:hwy_test_util",
], "@hwy//:thread_pool",
)
cc_library(
name = "args",
hdrs = [
"util/args.h",
],
deps = [
"@hwy//:hwy",
], ],
) )
cc_library( cc_library(
name = "gemma_lib", name = "gemma_lib",
srcs = [ srcs = [
"gemma.cc", "gemma/gemma.cc",
], ],
hdrs = [ hdrs = [
"configs.h", "gemma/configs.h",
"gemma.h", "gemma/gemma.h",
], ],
deps = [ deps = [
":args",
":ops", ":ops",
# "//base", # "//base",
"//compression:compress", "//compression:compress",
"//compression:io",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//:matvec", "@hwy//:matvec",
"@hwy//:nanobenchmark", # timer "@hwy//:nanobenchmark", # timer
@ -86,9 +76,29 @@ cc_library(
], ],
) )
cc_library(
name = "args",
hdrs = ["util/args.h"],
deps = [
"//compression:io",
"@hwy//:hwy",
],
)
cc_library(
name = "app",
hdrs = ["util/app.h"],
deps = [
":args",
":gemma_lib",
"//compression:io",
"@hwy//:hwy",
],
)
cc_test( cc_test(
name = "gemma_test", name = "gemma_test",
srcs = ["gemma_test.cc"], srcs = ["gemma/gemma_test.cc"],
# Requires model files # Requires model files
tags = [ tags = [
"local", "local",
@ -105,23 +115,9 @@ cc_test(
], ],
) )
cc_library(
name = "app",
hdrs = [
"util/app.h",
],
deps = [
":args",
":gemma_lib",
"@hwy//:hwy",
],
)
cc_binary( cc_binary(
name = "gemma", name = "gemma",
srcs = [ srcs = ["gemma/run.cc"],
"run.cc",
],
deps = [ deps = [
":app", ":app",
":args", ":args",
@ -137,9 +133,7 @@ cc_binary(
cc_binary( cc_binary(
name = "compress_weights", name = "compress_weights",
srcs = [ srcs = ["gemma/compress_weights.cc"],
"compress_weights.cc",
],
deps = [ deps = [
":args", ":args",
":gemma_lib", ":gemma_lib",
@ -154,8 +148,25 @@ cc_binary(
cc_binary( cc_binary(
name = "benchmark", name = "benchmark",
srcs = ["gemma/benchmark.cc"],
deps = [
":app",
":args",
":gemma_lib",
# "//base",
"//compression:compress",
"@hwy//:hwy",
"@hwy//:nanobenchmark",
"@hwy//:profiler",
"@hwy//:thread_pool",
"@nlohmann_json//:json",
],
)
cc_binary(
name = "debug_prompt",
srcs = [ srcs = [
"benchmark.cc", "debug_prompt.cc",
], ],
deps = [ deps = [
":app", ":app",

View File

@ -34,16 +34,22 @@ FetchContent_Declare(json GIT_REPOSITORY https://github.com/nlohmann/json.git GI
FetchContent_MakeAvailable(json) FetchContent_MakeAvailable(json)
set(SOURCES set(SOURCES
gemma.cc
compression/blob_store.cc compression/blob_store.cc
compression/blob_store.h compression/blob_store.h
compression/compress.h compression/compress.h
compression/compress-inl.h compression/compress-inl.h
compression/io_win.cc
compression/io.cc
compression/io.h
compression/nuq.h compression/nuq.h
compression/nuq-inl.h compression/nuq-inl.h
compression/sfp.h compression/sfp.h
compression/sfp-inl.h compression/sfp-inl.h
compression/test_util.h compression/test_util.h
gemma/configs.h
gemma/gemma.cc
gemma/gemma.h
gemma/ops.h
util/app.h util/app.h
util/args.h util/args.h
) )
@ -72,26 +78,31 @@ set_property(TARGET libgemma PROPERTY CXX_STANDARD 17)
set_target_properties(libgemma PROPERTIES PREFIX "") set_target_properties(libgemma PROPERTIES PREFIX "")
set_property(TARGET libgemma PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET libgemma PROPERTY POSITION_INDEPENDENT_CODE ON)
target_include_directories(libgemma PUBLIC ./) target_include_directories(libgemma PUBLIC ./)
target_link_libraries(libgemma hwy hwy_contrib sentencepiece) target_link_libraries(libgemma hwy hwy_contrib sentencepiece-static)
target_include_directories(libgemma PUBLIC ${sentencepiece_SOURCE_DIR}) target_include_directories(libgemma PUBLIC ${sentencepiece_SOURCE_DIR})
target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>) target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>) target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
install(TARGETS libgemma DESTINATION lib)
# Executable Target # Executable Target
add_executable(gemma run.cc) add_executable(gemma gemma/run.cc)
target_link_libraries(gemma libgemma hwy hwy_contrib) target_link_libraries(gemma libgemma hwy hwy_contrib)
install(TARGETS gemma DESTINATION bin)
add_executable(benchmark benchmark.cc) add_executable(benchmark gemma/benchmark.cc)
target_link_libraries(benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json) target_link_libraries(benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
add_executable(debug_prompt debug_prompt.cc)
target_link_libraries(debug_prompt libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
## Tests ## Tests
set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests") set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests")
if (GEMMA_ENABLE_TESTS) if (GEMMA_ENABLE_TESTS)
set(GEMMA_TEST_FILES set(GEMMA_TEST_FILES
ops_test.cc gemma/ops_test.cc
gemma_test.cc gemma/gemma_test.cc
) )
foreach (TESTFILE IN LISTS GEMMA_TEST_FILES) foreach (TESTFILE IN LISTS GEMMA_TEST_FILES)
@ -112,5 +123,5 @@ endif() # GEMMA_ENABLE_TESTS
## Tools ## Tools
add_executable(compress_weights compress_weights.cc) add_executable(compress_weights gemma/compress_weights.cc)
target_link_libraries(compress_weights libgemma hwy hwy_contrib) target_link_libraries(compress_weights libgemma hwy hwy_contrib)

View File

@ -83,6 +83,23 @@ A `.clang-format` configuration is provided with our defaults, please run source
files through `clang-format` (or a formatter that produces equivalent behavior) files through `clang-format` (or a formatter that produces equivalent behavior)
before finalizing PR for submission. before finalizing PR for submission.
## Converting weights
We use a stripped down binary blob (.sbs) artifact to accelerate weight loading
in C++. These files can be downloaded directly from Kaggle and HuggingFace. You
can also convert Pytorch or Keras checkpoints to .sbs, but most end users should
not have to do this.
If starting with Keras, first run this script to convert to Pytorch:
https://github.com/keras-team/keras-nlp/blob/master/tools/gemma/export_gemma_to_torch_xla.py
From Pytorch, use the following script to generate uncompressed weights:
https://github.com/google/gemma.cpp/blob/dev/util/convert_weights.py
Then run gemma/compress_weights.cc (Bazel target :compress_weights), specifying
the resulting file as `--weights` and the desired .sbs name as the
`--compressed_weights`.
## Compile-Time Flags (Advanced) ## Compile-Time Flags (Advanced)
There are several compile-time flags to be aware of (note these may or may not There are several compile-time flags to be aware of (note these may or may not
@ -169,9 +186,9 @@ inference path of the Gemma model.
The sentencepiece library we depend on requires some additional work to build The sentencepiece library we depend on requires some additional work to build
with the Bazel build system. First, it does not export its BUILD file, so we with the Bazel build system. First, it does not export its BUILD file, so we
provide `bazel/sentencepiece.bazel`. Second, it ships with a vendored subset of provide `bazel/sentencepiece.bazel`. Second, it ships with a vendored subset of
the Abseil library. `bazel/com_google_sentencepiece.patch` changes the code to the Abseil library. `bazel/sentencepiece.patch` changes the code to support
support Abseil as a standalone dependency without third_party/ prefixes, similar Abseil as a standalone dependency without third_party/ prefixes, similar to the
to the transforms we apply to Gemma via Copybara. transforms we apply to Gemma via Copybara.
## Discord ## Discord

View File

@ -33,7 +33,7 @@ http_archive(
strip_prefix = "sentencepiece-0.1.96", strip_prefix = "sentencepiece-0.1.96",
urls = ["https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"], urls = ["https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"],
build_file = "@//bazel:sentencepiece.bazel", build_file = "@//bazel:sentencepiece.bazel",
patches = ["@//bazel:com_google_sentencepiece.patch"], patches = ["@//bazel:sentencepiece.patch"],
patch_args = ["-p1"], patch_args = ["-p1"],
) )

View File

@ -206,6 +206,12 @@ bazel build -c opt --cxxopt=-std=c++20 :gemma
If the build is successful, you should now have a `gemma` executable in the `bazel-bin/` directory. If the build is successful, you should now have a `gemma` executable in the `bazel-bin/` directory.
#### Make
If you prefer Makefiles, @jart has made one available here:
https://github.com/jart/gemma3/blob/main/Makefile
### Step 4: Run ### Step 4: Run
You can now run `gemma` from inside the `build/` directory. You can now run `gemma` from inside the `build/` directory.
@ -252,7 +258,7 @@ here provide a C++ implementation of this model based on the paper.
To use the recurrent version of Gemma included in this repository, build the To use the recurrent version of Gemma included in this repository, build the
gemma binary as noted above in Step 3. Download the compressed weights and gemma binary as noted above in Step 3. Download the compressed weights and
tokenizer from tokenizer from the RecurrentGemma
[Kaggle](https://www.kaggle.com/models/google/recurrentgemma/gemmaCpp) as in [Kaggle](https://www.kaggle.com/models/google/recurrentgemma/gemmaCpp) as in
Step 1, and run the binary as follows: Step 1, and run the binary as follows:

View File

@ -1,4 +1,4 @@
# Required for referencing bazel:com_google_sentencepiece.patch # Required for referencing bazel:sentencepiece.patch
package( package(
default_applicable_licenses = ["//:license"], default_applicable_licenses = ["//:license"],
default_visibility = ["//:__subpackages__"], default_visibility = ["//:__subpackages__"],

View File

@ -11,14 +11,24 @@ package(
) )
cc_library( cc_library(
name = "blob_store", name = "io",
srcs = [ srcs = [
"blob_store.cc", "io.cc",
], # Placeholder for io backend, do not remove
hdrs = [
"blob_store.h",
], ],
hdrs = ["io.h"],
deps = [ deps = [
# Placeholder for io deps, do not remove
"@hwy//:hwy",
],
)
cc_library(
name = "blob_store",
srcs = ["blob_store.cc"],
hdrs = ["blob_store.h"],
deps = [
":io",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//:thread_pool", "@hwy//:thread_pool",
], ],
@ -39,7 +49,23 @@ cc_library(
name = "distortion", name = "distortion",
hdrs = ["distortion.h"], hdrs = ["distortion.h"],
deps = [ deps = [
":stats",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//hwy/contrib/sort:vqsort",
],
)
cc_test(
name = "distortion_test",
size = "small",
srcs = ["distortion_test.cc"],
deps = [
":distortion",
":test_util",
"@googletest//:gtest_main",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:nanobenchmark", # Unpredictable1
], ],
) )
@ -56,12 +82,8 @@ cc_library(
cc_library( cc_library(
name = "sfp", name = "sfp",
hdrs = [ hdrs = ["sfp.h"],
"sfp.h", textual_hdrs = ["sfp-inl.h"],
],
textual_hdrs = [
"sfp-inl.h",
],
deps = [ deps = [
"@hwy//:hwy", "@hwy//:hwy",
], ],
@ -88,12 +110,8 @@ cc_test(
cc_library( cc_library(
name = "nuq", name = "nuq",
hdrs = [ hdrs = ["nuq.h"],
"nuq.h", textual_hdrs = ["nuq-inl.h"],
],
textual_hdrs = [
"nuq-inl.h",
],
deps = [ deps = [
":sfp", ":sfp",
"@hwy//:hwy", "@hwy//:hwy",
@ -134,6 +152,7 @@ cc_library(
deps = [ deps = [
":blob_store", ":blob_store",
":distortion", ":distortion",
":io",
":nuq", ":nuq",
":sfp", ":sfp",
":stats", ":stats",
@ -146,9 +165,7 @@ cc_library(
# For internal experimentation # For internal experimentation
cc_library( cc_library(
name = "analyze", name = "analyze",
textual_hdrs = [ textual_hdrs = ["analyze.h"],
"analyze.h",
],
deps = [ deps = [
":distortion", ":distortion",
":nuq", ":nuq",

View File

@ -26,11 +26,8 @@
#include <cstdlib> // std::abs #include <cstdlib> // std::abs
#include <vector> #include <vector>
// copybara:import_next_line:gemma_cpp
#include "compression/distortion.h" #include "compression/distortion.h"
// copybara:import_next_line:gemma_cpp
#include "compression/nuq.h" #include "compression/nuq.h"
// copybara:import_next_line:gemma_cpp
#include "compression/stats.h" #include "compression/stats.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
@ -46,9 +43,7 @@
#define THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE #define THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
#endif #endif
// copybara:import_next_line:gemma_cpp
#include "compression/nuq-inl.h" #include "compression/nuq-inl.h"
// copybara:import_next_line:gemma_cpp
#include "compression/sfp-inl.h" #include "compression/sfp-inl.h"
#include "hwy/contrib/sort/vqsort-inl.h" #include "hwy/contrib/sort/vqsort-inl.h"
#include "hwy/highway.h" #include "hwy/highway.h"

View File

@ -13,89 +13,21 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Request POSIX 2008, including `pread()` and `posix_fadvise()`.
#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700
#undef _XOPEN_SOURCE
#define _XOPEN_SOURCE 700
#endif
#if !defined(_POSIX_C_SOURCE) || _POSIX_C_SOURCE < 200809
#define _POSIX_C_SOURCE 200809
#endif
// Make `off_t` 64-bit even on 32-bit systems. Works for Android >= r15c.
#undef _FILE_OFFSET_BITS
#define _FILE_OFFSET_BITS 64
// copybara:import_next_line:gemma_cpp
#include "compression/blob_store.h" #include "compression/blob_store.h"
#include <fcntl.h> // open #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
#include <sys/stat.h> // O_RDONLY
#if HWY_OS_WIN
#include <fileapi.h>
#include <io.h> // read, write, close
#else
#include <unistd.h> // read, write, close
#endif
#include <atomic> #include <atomic>
#include <memory>
#include <vector> #include <vector>
#include "compression/io.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/detect_compiler_arch.h" #include "hwy/detect_compiler_arch.h"
namespace {
#if HWY_OS_WIN
// pread is not supported on Windows
static int64_t pread(int fd, void* buf, uint64_t size, uint64_t offset) {
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
if (file == INVALID_HANDLE_VALUE) {
return -1;
}
OVERLAPPED overlapped = {0};
overlapped.Offset = offset & 0xFFFFFFFF;
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
DWORD bytes_read;
if (!ReadFile(file, buf, size, &bytes_read, &overlapped)) {
if (GetLastError() != ERROR_HANDLE_EOF) {
return -1;
}
}
return bytes_read;
}
// pwrite is not supported on Windows
static int64_t pwrite(int fd, const void* buf, uint64_t size, uint64_t offset) {
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
if (file == INVALID_HANDLE_VALUE) {
return -1;
}
OVERLAPPED overlapped = {0};
overlapped.Offset = offset & 0xFFFFFFFF;
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
DWORD bytes_written;
if (!WriteFile(file, buf, size, &bytes_written, &overlapped)) {
if (GetLastError() != ERROR_HANDLE_EOF) {
return -1;
}
}
return bytes_written;
}
#endif
} // namespace
namespace gcpp { namespace gcpp {
hwy::uint128_t MakeKey(const char* string) { hwy::uint128_t MakeKey(const char* string) {
@ -132,61 +64,6 @@ void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data,
} }
} // namespace } // namespace
struct IO {
// Returns size in bytes or 0.
static uint64_t FileSize(const char* filename) {
int fd = open(filename, O_RDONLY);
if (fd < 0) {
return 0;
}
#if HWY_OS_WIN
const int64_t size = _lseeki64(fd, 0, SEEK_END);
HWY_ASSERT(close(fd) != -1);
if (size < 0) {
return 0;
}
#else
static_assert(sizeof(off_t) == 8, "64-bit off_t required");
const off_t size = lseek(fd, 0, SEEK_END);
HWY_ASSERT(close(fd) != -1);
if (size == static_cast<off_t>(-1)) {
return 0;
}
#endif
return static_cast<uint64_t>(size);
}
static bool Read(int fd, uint64_t offset, uint64_t size, void* to) {
uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
uint64_t pos = 0;
for (;;) {
// pread seems to be faster than lseek + read when parallelized.
const auto bytes_read = pread(fd, bytes + pos, size - pos, offset + pos);
if (bytes_read <= 0) break;
pos += bytes_read;
HWY_ASSERT(pos <= size);
if (pos == size) break;
}
return pos == size; // success if managed to read desired size
}
static bool Write(const void* from, uint64_t size, uint64_t offset, int fd) {
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(from);
uint64_t pos = 0;
for (;;) {
const auto bytes_written =
pwrite(fd, bytes + pos, size - pos, offset + pos);
if (bytes_written <= 0) break;
pos += bytes_written;
HWY_ASSERT(pos <= size);
if (pos == size) break;
}
return pos == size; // success if managed to write desired size
}
}; // IO
static_assert(HWY_IS_LITTLE_ENDIAN, "Assumes little endian"); static_assert(HWY_IS_LITTLE_ENDIAN, "Assumes little endian");
// On-disk representation (little-endian). // On-disk representation (little-endian).
@ -323,26 +200,13 @@ class BlobStore {
}; };
#pragma pack(pop) #pragma pack(pop)
BlobError BlobReader::Open(const char* filename) { BlobError BlobReader::Open(const Path& filename) {
#if HWY_OS_WIN file_ = OpenFileOrNull(filename, "r");
DWORD flags = FILE_ATTRIBUTE_NORMAL | FILE_FLAG_SEQUENTIAL_SCAN; if (!file_) return __LINE__;
HANDLE file = CreateFileA(filename, GENERIC_READ, FILE_SHARE_READ, nullptr,
OPEN_EXISTING, flags, nullptr);
if (file == INVALID_HANDLE_VALUE) return __LINE__;
fd_ = _open_osfhandle(reinterpret_cast<intptr_t>(file), _O_RDONLY);
#else
fd_ = open(filename, O_RDONLY);
#endif
if (fd_ < 0) return __LINE__;
#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
// Doubles the readahead window, which seems slightly faster when cached.
(void)posix_fadvise(fd_, 0, 0, POSIX_FADV_SEQUENTIAL);
#endif
// Read first part of header to get actual size. // Read first part of header to get actual size.
BlobStore bs; BlobStore bs;
if (!IO::Read(fd_, 0, sizeof(bs), &bs)) return __LINE__; if (!file_->Read(0, sizeof(bs), &bs)) return __LINE__;
const size_t padded_size = bs.PaddedHeaderSize(); const size_t padded_size = bs.PaddedHeaderSize();
HWY_ASSERT(padded_size >= sizeof(bs)); HWY_ASSERT(padded_size >= sizeof(bs));
@ -354,18 +218,11 @@ BlobError BlobReader::Open(const char* filename) {
hwy::CopySameSize(&bs, blob_store_.get()); hwy::CopySameSize(&bs, blob_store_.get());
// Read the rest of the header, but not the full file. // Read the rest of the header, but not the full file.
uint8_t* bytes = reinterpret_cast<uint8_t*>(blob_store_.get()); uint8_t* bytes = reinterpret_cast<uint8_t*>(blob_store_.get());
if (!IO::Read(fd_, sizeof(bs), padded_size - sizeof(bs), if (!file_->Read(sizeof(bs), padded_size - sizeof(bs), bytes + sizeof(bs))) {
bytes + sizeof(bs))) {
return __LINE__; return __LINE__;
} }
return blob_store_->CheckValidity(IO::FileSize(filename)); return blob_store_->CheckValidity(file_->FileSize());
}
BlobReader::~BlobReader() {
if (fd_ >= 0) {
HWY_ASSERT(close(fd_) != -1);
}
} }
BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) { BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) {
@ -392,13 +249,13 @@ BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) {
// between consecutive runs. // between consecutive runs.
// - memory-mapped I/O is less predictable and adds noise to measurements. // - memory-mapped I/O is less predictable and adds noise to measurements.
BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) { BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
const int fd = fd_; File* pfile = file_.get(); // not owned
const auto& requests = requests_; const auto& requests = requests_;
std::atomic_flag err = ATOMIC_FLAG_INIT; std::atomic_flag err = ATOMIC_FLAG_INIT;
// >5x speedup from parallel reads when cached. // >5x speedup from parallel reads when cached.
pool.Run(0, requests.size(), pool.Run(0, requests.size(),
[fd, &requests, &err](uint64_t i, size_t /*thread*/) { [pfile, &requests, &err](uint64_t i, size_t /*thread*/) {
if (!IO::Read(fd, requests[i].offset, requests[i].size, if (!pfile->Read(requests[i].offset, requests[i].size,
requests[i].data)) { requests[i].data)) {
err.test_and_set(); err.test_and_set();
} }
@ -407,8 +264,7 @@ BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
return 0; return 0;
} }
BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
const char* filename) const {
HWY_ASSERT(keys_.size() == blobs_.size()); HWY_ASSERT(keys_.size() == blobs_.size());
// Concatenate blobs in memory. // Concatenate blobs in memory.
@ -419,26 +275,18 @@ BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool,
keys_.data(), blobs_.data(), keys_.size(), bs.get()); keys_.data(), blobs_.data(), keys_.size(), bs.get());
// Create/replace existing file. // Create/replace existing file.
#if HWY_OS_WIN std::unique_ptr<File> file = OpenFileOrNull(filename, "w+");
DWORD flags = FILE_ATTRIBUTE_NORMAL; if (!file) return __LINE__;
HANDLE file = CreateFileA(filename, GENERIC_WRITE, 0, nullptr, CREATE_ALWAYS, File* pfile = file.get(); // not owned
flags, nullptr);
if (file == INVALID_HANDLE_VALUE) return __LINE__;
const int fd = _open_osfhandle(reinterpret_cast<intptr_t>(file), _O_WRONLY);
#else
const int fd = open(filename, O_CREAT | O_RDWR | O_TRUNC, 0644);
#endif
if (fd < 0) return __LINE__;
std::atomic_flag err = ATOMIC_FLAG_INIT; std::atomic_flag err = ATOMIC_FLAG_INIT;
pool.Run(0, requests.size(), pool.Run(0, requests.size(),
[fd, &requests, &err](uint64_t i, size_t /*thread*/) { [pfile, &requests, &err](uint64_t i, size_t /*thread*/) {
if (!IO::Write(requests[i].data, requests[i].size, if (!pfile->Write(requests[i].data, requests[i].size,
requests[i].offset, fd)) { requests[i].offset)) {
err.test_and_set(); err.test_and_set();
} }
}); });
HWY_ASSERT(close(fd) != -1);
if (err.test_and_set()) return __LINE__; if (err.test_and_set()) return __LINE__;
return 0; return 0;
} }

View File

@ -19,8 +19,10 @@
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include <memory>
#include <vector> #include <vector>
#include "compression/io.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::uint128_t #include "hwy/base.h" // hwy::uint128_t
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
@ -59,10 +61,10 @@ struct BlobIO {
class BlobReader { class BlobReader {
public: public:
BlobReader() { requests_.reserve(500); } BlobReader() { requests_.reserve(500); }
~BlobReader(); ~BlobReader() = default;
// Opens `filename` and reads its header. // Opens `filename` and reads its header.
BlobError Open(const char* filename); BlobError Open(const Path& filename);
// Enqueues read requests if `key` is found and its size matches `size`. // Enqueues read requests if `key` is found and its size matches `size`.
BlobError Enqueue(hwy::uint128_t key, void* data, size_t size); BlobError Enqueue(hwy::uint128_t key, void* data, size_t size);
@ -73,7 +75,7 @@ class BlobReader {
private: private:
BlobStorePtr blob_store_; // holds header, not the entire file BlobStorePtr blob_store_; // holds header, not the entire file
std::vector<BlobIO> requests_; std::vector<BlobIO> requests_;
int fd_ = 0; std::unique_ptr<File> file_;
}; };
class BlobWriter { class BlobWriter {
@ -84,7 +86,7 @@ class BlobWriter {
} }
// Stores all blobs to disk in the given order with padding for alignment. // Stores all blobs to disk in the given order with padding for alignment.
BlobError WriteAll(hwy::ThreadPool& pool, const char* filename) const; BlobError WriteAll(hwy::ThreadPool& pool, const Path& filename);
private: private:
std::vector<hwy::uint128_t> keys_; std::vector<hwy::uint128_t> keys_;

View File

@ -23,11 +23,8 @@
#include <array> #include <array>
// copybara:import_next_line:gemma_cpp
#include "compression/blob_store.h" #include "compression/blob_store.h"
// copybara:import_next_line:gemma_cpp
#include "compression/compress.h" #include "compression/compress.h"
// copybara:import_next_line:gemma_cpp
#include "compression/distortion.h" #include "compression/distortion.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
@ -44,9 +41,7 @@
#define THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE #define THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE
#endif #endif
// copybara:import_next_line:gemma_cpp
#include "compression/nuq-inl.h" #include "compression/nuq-inl.h"
// copybara:import_next_line:gemma_cpp
#include "compression/sfp-inl.h" #include "compression/sfp-inl.h"
#include "hwy/contrib/dot/dot-inl.h" #include "hwy/contrib/dot/dot-inl.h"
#include "hwy/highway.h" #include "hwy/highway.h"
@ -63,6 +58,7 @@ struct CompressTraits {};
template <> template <>
struct CompressTraits<float> { struct CompressTraits<float> {
using MatT = float; using MatT = float;
static constexpr bool kSupportsEvenOdd = false;
template <class DF, HWY_IF_F32_D(DF)> template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in, static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
@ -116,6 +112,7 @@ struct CompressTraits<float> {
template <> template <>
struct CompressTraits<hwy::bfloat16_t> { struct CompressTraits<hwy::bfloat16_t> {
using MatT = hwy::bfloat16_t; using MatT = hwy::bfloat16_t;
static constexpr bool kSupportsEvenOdd = true;
template <class DF, HWY_IF_F32_D(DF)> template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in, static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
@ -224,11 +221,59 @@ struct CompressTraits<hwy::bfloat16_t> {
// bf16*bf16. // bf16*bf16.
return hn::Dot::Compute<kAssumptions>(d_vec, vec_aligned, in + in_ofs, num); return hn::Dot::Compute<kAssumptions>(d_vec, vec_aligned, in + in_ofs, num);
} }
// Computes the dot product of an even-odd deinterleaved, f32 `vec_aligned`
// and a column- major matrix `in`. `vec_aligned` should be aligned and
// alternate even-indexed `hn::Lanes(df32)` elements followed by odd-indexed
// `hn::Lanes(df32)` elements.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE float DotEO(
const DF df32, const hwy::bfloat16_t* HWY_RESTRICT in, size_t in_ofs,
const float* HWY_RESTRICT vec_aligned, size_t num) {
HWY_DASSERT(num >= (hn::Lanes(df32) * 2) && (num % (hn::Lanes(df32) * 2)) == 0);
HWY_DASSERT((in_ofs % (hn::Lanes(df32) * 2)) == 0);
HWY_DASSERT(hn::IsAligned(df32, vec_aligned));
const hn::Repartition<hwy::bfloat16_t, DF> dbf16;
using VF32 = decltype(Zero(df32));
const size_t N = Lanes(dbf16);
VF32 sum0 = Zero(df32);
VF32 sum1 = Zero(df32);
VF32 sum2 = Zero(df32);
VF32 sum3 = Zero(df32);
const hn::RebindToUnsigned<decltype(df32)> du32;
using VU32 = hn::VFromD<decltype(du32)>;
const VU32 odd = Set(du32, 0xFFFF0000u);
for (size_t i = 0; i < num; /* i += 2 * N */) {
const auto interleaved0 = hn::LoadU(dbf16, in + in_ofs + i);
const VF32 ae0 = Load(df32, vec_aligned + i);
const VF32 ao0 = Load(df32, vec_aligned + i + (N / 2));
sum0 = hn::MulAdd(ae0, hn::PromoteEvenTo(df32, interleaved0), sum0);
sum1 = hn::MulAdd(ao0, hn::PromoteOddTo(df32, interleaved0), sum1);
i += N;
const auto interleaved1 = hn::LoadU(dbf16, in + in_ofs + i);
const VF32 ae1 = Load(df32, vec_aligned + i);
const VF32 ao1 = Load(df32, vec_aligned + i + (N / 2));
sum2 = hn::MulAdd(ae1, hn::PromoteEvenTo(df32, interleaved1), sum2);
sum3 = hn::MulAdd(ao1, hn::PromoteOddTo(df32, interleaved1), sum3);
i += N;
}
sum0 = Add(sum0, sum1);
sum2 = Add(sum2, sum3);
sum0 = Add(sum0, sum2);
return ReduceSum(df32, sum0);
}
}; };
template <> template <>
struct CompressTraits<SfpStream> { struct CompressTraits<SfpStream> {
using MatT = SfpStream; using MatT = SfpStream;
static constexpr bool kSupportsEvenOdd = false;
template <class DF, HWY_IF_F32_D(DF)> template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* in, size_t num, static HWY_INLINE void Compress(DF df, const float* in, size_t num,
@ -278,6 +323,7 @@ struct CompressTraits<SfpStream> {
template <> template <>
struct CompressTraits<NuqStream> { struct CompressTraits<NuqStream> {
using MatT = NuqStream; using MatT = NuqStream;
static constexpr bool kSupportsEvenOdd = false;
template <class DF, HWY_IF_F32_D(DF)> template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* in, size_t num, static HWY_INLINE void Compress(DF df, const float* in, size_t num,
@ -430,16 +476,22 @@ HWY_INLINE float Dot(DF df, const ArrayT& compressed, size_t compressed_ofs,
} }
// Returns dot product with `vec_aligned` of length `num`. // Returns dot product with `vec_aligned` of length `num`.
template <class DF, typename MatT, size_t kCapacity, typename VecT> template <bool kVecEO, class DF, typename MatT, size_t kCapacity, typename VecT>
HWY_INLINE float Dot(DF df, const CompressedArray<MatT, kCapacity>& compressed, HWY_INLINE float Dot(DF df, const CompressedArray<MatT, kCapacity>& compressed,
size_t compressed_ofs, const VecT* vec_aligned, size_t compressed_ofs, const VecT* vec_aligned,
size_t num) { size_t num) {
HWY_DASSERT(compressed_ofs + num <= compressed.size()); HWY_DASSERT(compressed_ofs + num <= compressed.size());
HWY_DASSERT(hn::IsAligned(df, vec_aligned)); HWY_DASSERT(hn::IsAligned(df, vec_aligned));
using Traits = CompressTraits<MatT>; using Traits = CompressTraits<MatT>;
return (compressed.scale() * Traits::Dot(df, compressed.size(), float dot_result;
compressed.data(), compressed_ofs, if constexpr (kVecEO) {
vec_aligned, num)); dot_result = Traits::DotEO(df, compressed.data(), compressed_ofs,
vec_aligned, num);
} else {
dot_result = Traits::Dot(df, compressed.size(), compressed.data(),
compressed_ofs, vec_aligned, num);
}
return compressed.scale() * dot_result;
} }
// Callback used by ForeachTensor. // Callback used by ForeachTensor.
@ -464,11 +516,11 @@ class Compressor {
} }
} }
void WriteAll(hwy::ThreadPool& pool, const char* blob_filename) { void WriteAll(hwy::ThreadPool& pool, const Path& blob_filename) {
const BlobError err = writer_.WriteAll(pool, blob_filename); const BlobError err = writer_.WriteAll(pool, blob_filename);
if (err != 0) { if (err != 0) {
fprintf(stderr, "Failed to write blobs to %s (error %d)\n", blob_filename, fprintf(stderr, "Failed to write blobs to %s (error %d)\n",
err); blob_filename.path.c_str(), err);
} }
} }

View File

@ -27,19 +27,15 @@
#include <vector> #include <vector>
// IWYU pragma: begin_exports // IWYU pragma: begin_exports
// copybara:import_next_line:gemma_cpp
#include "compression/blob_store.h" #include "compression/blob_store.h"
// copybara:import_next_line:gemma_cpp #include "compression/io.h"
#include "compression/nuq.h" #include "compression/nuq.h"
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h" #include "compression/sfp.h"
// IWYU pragma: end_exports // IWYU pragma: end_exports
// copybara:import_next_line:gemma_cpp
#include "compression/distortion.h" #include "compression/distortion.h"
#include "hwy/base.h" // hwy::bfloat16_t #include "hwy/base.h" // hwy::bfloat16_t
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#if COMPRESS_STATS #if COMPRESS_STATS
// copybara:import_next_line:gemma_cpp
#include "compression/stats.h" #include "compression/stats.h"
#endif #endif
@ -171,13 +167,13 @@ hwy::uint128_t CacheKey(const char* name) {
class CacheLoader { class CacheLoader {
public: public:
explicit CacheLoader(const char* blob_filename) { explicit CacheLoader(const Path& blob_filename) {
err_ = reader_.Open(blob_filename); err_ = reader_.Open(blob_filename);
if (err_ != 0) { if (err_ != 0) {
fprintf(stderr, fprintf(stderr,
"Cached compressed weights does not exist yet (code %d), " "Cached compressed weights does not exist yet (code %d), "
"compressing weights and creating file: %s.\n", "compressing weights and creating file: %s.\n",
err_, blob_filename); err_, blob_filename.path.c_str());
} }
} }

View File

@ -15,85 +15,214 @@
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_ #ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_ #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_
#include <math.h> // pow #include <math.h> // pow
#include <stddef.h> #include <stddef.h>
#include <stdio.h>
#include <vector>
#include "compression/stats.h"
#include "hwy/aligned_allocator.h" // HWY_ALIGNMENT
#include "hwy/base.h" // ScalarAbs #include "hwy/base.h" // ScalarAbs
#include "hwy/contrib/sort/vqsort.h"
namespace gcpp { namespace gcpp {
// Returns `sum` and `err` such that `sum + err` is exactly equal to `a + b`,
// despite floating-point rounding. `sum` is already the best estimate, so do
// not actually add `err` to it. Knuth98/Moller65. Unlike Fast2Sum [Dekker71],
// this does not require any relative ordering of the exponents of a and b.
template <typename T>
static inline T TwoSum(T a, T b, T& err) {
const T sum = a + b;
const T a2 = sum - b;
const T b2 = sum - a2;
const T err_a = a - a2;
const T err_b = b - b2;
err = err_a + err_b;
return sum;
}
// Accumulates numbers with about twice the precision of T using 7 * n FLOPS.
// Rump/Ogita/Oishi08, Algorithm 6.11 in Handbook of Floating-Point Arithmetic.
template <typename T>
class CascadedSummation {
public:
void Notify(T t) {
T err;
sum_ = TwoSum(sum_, t, err);
sum_err_ += err;
}
void Assimilate(const CascadedSummation& other) {
Notify(other.sum_);
sum_err_ += other.sum_err_;
}
// Allows users to observe how much difference the extra precision made.
T Err() const { return sum_err_; }
// Returns the sum of all `t` passed to `Notify`.
T Total() const { return sum_ + sum_err_; }
private:
T sum_ = T{0};
T sum_err_ = T{0};
};
// Summarizes the error of a distortion (e.g. quantization) applied to a series
// of numbers.
// Users should check all four resulting metrics (NumExact, NumRoundedToZero,
// GeomeanValueDivL1, WeightedAverageL1) because each covers different aspects.
class DistortionStats { class DistortionStats {
public: public:
void Notify(float original, float distorted) { void Notify(float original, float distorted) {
(void)padding_; // prevent unused member warning (void)padding_; // prevent unused member warning
const double l1 = hwy::ScalarAbs(original - distorted); const bool rounded_to_zero = (original != 0.0f) && (distorted == 0.0f);
// We expect original == 0 is not distorted (can be exactly represented).
HWY_ASSERT(original != 0.0f || distorted == 0.0f);
if (l1 > max_l1_) { s_original_.Notify(original);
max_l1_ = l1; const float l1f = hwy::ScalarAbs(original - distorted);
max_idx_ = n_; const double l1 = static_cast<double>(l1f);
s_l1_.Notify(l1f);
b_l1_.Notify(HWY_MIN(99, static_cast<int>(l1f * 1E4)));
if (l1f != 0.0f) {
l1_.push_back(l1f);
}
sum_l1_.Notify(l1f);
if (rounded_to_zero) sum_l1_rounded_.Notify(l1f);
// Event counts
{
n_ += 1;
// Rounding (small) negative numbers to 0 does not influence dot products
// as much as an actual sign flip, so do not count them.
n_sign_flip_ +=
((original < 0.0f) != (distorted < 0.0f)) && !rounded_to_zero;
n_exact_ += (l1f == 0.0f);
n_rounded_to_zero += rounded_to_zero;
} }
const double pow3 = l1 * l1 * l1; // Signal to noise ratio (Shannon's channel capacity, NOT the L2-based and
sum_pow3_ += pow3; // logarithmic PSNR) to estimate the ratios of original to the L1 norm.
sum_pow6_ += pow3 * pow3; if (l1f != 0.0) { // prevent division by zero
n_ += 1; const double snr =
1.0 + static_cast<double>(hwy::ScalarAbs(original)) / l1;
// Avoid division by zero, which happens when there is no error. NumExact() // For numerical purposes (prevents overflow). A hierarchical geomean
// reports the number of times this happens.
if (l1 != 0.0) {
const double rel = 1.0 + hwy::ScalarAbs(original) / l1;
// Logarithm is required to prevent overflow. A hierarchical geomean
// could also work, but that is more complex and not necessarily better. // could also work, but that is more complex and not necessarily better.
sum_log_rel_ += log(rel); // We will return exp() of the arithmetic mean.
num_rel_ += 1; sum_log_snr_ += log(snr);
num_snr_ += 1;
} }
} }
void Assimilate(const DistortionStats& other) { void Assimilate(const DistortionStats& other) {
if (other.max_l1_ > max_l1_) { s_original_.Assimilate(other.s_original_);
max_l1_ = other.max_l1_; s_l1_.Assimilate(other.s_l1_);
max_idx_ = other.max_idx_; b_l1_.Assimilate(other.b_l1_);
} sum_l1_.Assimilate(other.sum_l1_);
sum_l1_rounded_.Assimilate(other.sum_l1_rounded_);
l1_.insert(l1_.end(), other.l1_.begin(), other.l1_.end());
sum_pow3_ += other.sum_pow3_;
sum_pow6_ += other.sum_pow6_;
n_ += other.n_; n_ += other.n_;
n_sign_flip_ += other.n_sign_flip_;
n_exact_ += other.n_exact_;
n_rounded_to_zero += other.n_rounded_to_zero;
sum_log_rel_ += other.sum_log_rel_; sum_log_snr_ += other.sum_log_snr_;
num_rel_ += other.num_rel_; num_snr_ += other.num_snr_;
} }
size_t NumExact() const { return n_ - num_rel_; } size_t NumExact() const { return n_exact_; }
size_t NumSignFlip() const { return n_sign_flip_; }
size_t NumRoundedToZero() const { return n_rounded_to_zero; }
// Total absolute error.
double SumL1() const { return sum_l1_.Total(); }
// Total absolute error for numbers that were rounded to zero.
double SumL1Rounded() const { return sum_l1_rounded_.Total(); }
// Returns geomean of 1 + S/N (Shannon channel capacity). This is computed via
// the ratio of input magnitude to nonzero L1 norms. Higher is better.
double GeomeanValueDivL1() const { double GeomeanValueDivL1() const {
if (num_rel_ == 0) return 0.0; if (num_snr_ == 0) return 0.0;
return exp(sum_log_rel_ / static_cast<double>(num_rel_)); return exp(sum_log_snr_ / static_cast<double>(num_snr_));
} }
double PNorm() const { // Returns weighted average of nonzero L1 norms. Those further from the median
// p-norms are a compromise between max-norm (penalizes the largest error // L1 norm are much more heavily weighted, such that this behaves more like
// without dilution, but does not notice any other errors) and L1 (all // the L-infinity norm, but still includes all differences, not just the max.
// errors contribute, but large errors are diluted by smaller ones). // Lower is better, magnitude depends on the input magnitude.
const double norm3 = pow(sum_pow3_ / static_cast<double>(n_), 1.0 / 3); double WeightedAverageL1() const {
const double norm6 = pow(sum_pow6_ / static_cast<double>(n_), 1.0 / 6); if (l1_.empty()) return 0.0f; // all exact
return 0.5 * (norm3 + norm6);
std::vector<float> weights(l1_); // copy so we can modify
const float median = [&weights]() {
const size_t mid = weights.size() / 2;
// We just want the median; partial sort is faster if available (v1.2).
#if HWY_MAJOR > 1 || HWY_MINOR >= 2
hwy::VQSelect(weights.data(), weights.size(), mid, hwy::SortAscending());
#else
hwy::VQSort(weights.data(), weights.size(), hwy::SortAscending());
#endif
return weights[mid];
}();
weights = l1_; // restore original order
// Replace with distance from median (might have too few samples for mode).
float max_abs = -1.0f;
for (float& d : weights) {
d = hwy::ScalarAbs(d - median);
max_abs = HWY_MAX(max_abs, d);
}
HWY_ASSERT(max_abs >= 0.0f);
// All equal - return the distance value to prevent division by zero.
if (max_abs == 0.0f) return median;
// Normalize to max difference and exponentiate.
const double inv_max = 1.0 / static_cast<double>(max_abs);
double sum_weights = 0.0;
for (float& w : weights) {
const double normalized = static_cast<double>(w) * inv_max;
const double amplified = exp(4.0 * normalized * normalized);
sum_weights += amplified;
w = static_cast<float>(amplified);
}
// At least 1.0 per weight, plus more for at least one weight because we
// verified via max_abs that not all are equal.
HWY_ASSERT(sum_weights > static_cast<double>(weights.size()));
// Return weighted average.
double weighted_sum = 0.0;
for (size_t i = 0; i < weights.size(); ++i) {
weighted_sum += l1_[i] * weights[i];
}
return weighted_sum / sum_weights;
} }
size_t MaxIndex() const { return max_idx_; } Stats& L1() { return s_l1_; }
double MaxL1() const { return max_l1_; } Stats& Original() { return s_original_; }
private: private:
Stats s_original_;
Stats s_l1_;
Bins<100> b_l1_;
CascadedSummation<double> sum_l1_; // all
CascadedSummation<double> sum_l1_rounded_; // only if rounded_to_zero
std::vector<float> l1_;
// Event counts
size_t n_ = 0; size_t n_ = 0;
size_t max_idx_ = 0; // index that had l1 = max_l1_. size_t n_sign_flip_ = 0;
double max_l1_ = -1.0; size_t n_exact_ = 0;
size_t n_rounded_to_zero = 0;
double sum_pow3_ = 0.0; double sum_log_snr_ = 0.0;
double sum_pow6_ = 0.0; size_t num_snr_ = 0;
double sum_log_rel_ = 0.0; uint8_t padding_[HWY_ALIGNMENT]; // prevents false sharing
size_t num_rel_ = 0;
double padding_; // prevents false sharing
}; };
} // namespace gcpp } // namespace gcpp

View File

@ -0,0 +1,99 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "compression/distortion.h"
#include <stdio.h>
#include "compression/test_util.h"
#include "hwy/nanobenchmark.h"
#include "hwy/tests/hwy_gtest.h"
#include "hwy/tests/test_util-inl.h" // HWY_ASSERT_EQ
namespace gcpp {
namespace {
#if !HWY_TEST_STANDALONE
class DistortionTest : public testing::Test {};
#endif
TEST(DistortionTest, TestCascadedSummation) {
CascadedSummation<double> cs;
// Example from Priest92. Exact sum is 2.
const double kHuge = 9007199254740992.0 * hwy::Unpredictable1(); // 2^53
const double kNeg = -4503599627370495.0 * hwy::Unpredictable1(); // -(2^52-1)
const double kIn[6] = {kHuge, kHuge - 2.0, kNeg, kNeg, kNeg, kNeg};
for (double in : kIn) {
cs.Notify(in);
}
HWY_ASSERT_EQ(2.0, cs.Total());
}
// Number of exact and rounded-to-zero matches expectations.
TEST(DistortionTest, TestCounts) {
// Arbitrary positive/negative original, zero distorted.
DistortionStats stats;
for (size_t i = 1; i < 10; ++i) {
stats.Notify(i / 100.0f, 0.0f);
stats.Notify(i / -100.0f, 0.0f);
}
HWY_ASSERT(stats.NumExact() == 0);
HWY_ASSERT(stats.NumRoundedToZero() == 18);
// Add some exact (same):
size_t num_exact = 0;
for (float x = 0.0f; x <= 1.5f; x += 0.25f) {
stats.Notify(x, x);
stats.Notify(-x, -x);
num_exact += 2;
}
HWY_ASSERT_EQ(num_exact, stats.NumExact());
HWY_ASSERT(stats.NumRoundedToZero() == 18); // unchanged
}
// Few large differences are diluted in SNR but not WeightedAverageL1.
TEST(DistortionTest, TestDilution) {
DistortionStats stats;
for (size_t i = 0; i < 100; ++i) {
stats.Notify(0.998f, 0.999f); // small
}
HWY_ASSERT(IsInside(900.0, 1000.0, stats.GeomeanValueDivL1()));
// All-equal WeightedSum is exact.
HWY_ASSERT(IsNear(0.001, stats.WeightedAverageL1()));
// Now add a large difference:
stats.Notify(1.875f - 0.0625f, 1.875f); // max magnitude, 3-bit mantissa
// .. WeightedAverageL1 is closer to it.
HWY_ASSERT(IsInside(0.020, 0.025, stats.WeightedAverageL1()));
// Add a small and large difference:
stats.Notify((1.75f - 0.125f) / 1024, 1.75f / 1024); // small, 2-bit mantissa
stats.Notify(-1.875f + 0.0625f, -1.875f); // larger negative
// .. SNR is still barely affected.
HWY_ASSERT(IsInside(890.0, 900.0, stats.GeomeanValueDivL1()));
// .. WeightedAverageL1 is higher after another large error.
HWY_ASSERT(IsInside(0.030, 0.035, stats.WeightedAverageL1()));
// With these inputs, none are exact nor round to zero.
HWY_ASSERT(stats.NumExact() == 0);
HWY_ASSERT(stats.NumRoundedToZero() == 0);
HWY_ASSERT_EQ(0.0, stats.SumL1Rounded());
HWY_ASSERT(IsInside(0.220, 0.23, stats.SumL1()));
}
} // namespace
} // namespace gcpp
HWY_TEST_MAIN();

121
compression/io.cc Normal file
View File

@ -0,0 +1,121 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Safe to be first, does not include POSIX headers.
#include "hwy/detect_compiler_arch.h"
// Only compile this file on non-Windows; it replaces io_win.cc. It is easier to
// check this in source code because we support multiple build systems.
#if !HWY_OS_WIN
// Request POSIX 2008, including `pread()` and `posix_fadvise()`.
#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700
#undef _XOPEN_SOURCE
#define _XOPEN_SOURCE 700
#endif
#if !defined(_POSIX_C_SOURCE) || _POSIX_C_SOURCE < 200809
#define _POSIX_C_SOURCE 200809
#endif
// Make `off_t` 64-bit even on 32-bit systems. Works for Android >= r15c.
#undef _FILE_OFFSET_BITS
#define _FILE_OFFSET_BITS 64
#include <fcntl.h> // open
#include <stddef.h>
#include <stdint.h>
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
#include <sys/stat.h> // O_RDONLY
#include <unistd.h> // read, write, close
#include <memory>
#include "compression/io.h"
#include "hwy/base.h" // HWY_ASSERT
namespace gcpp {
class FilePosix : public File {
int fd_ = 0;
public:
explicit FilePosix(int fd) : fd_(fd) { HWY_ASSERT(fd > 0); }
~FilePosix() override {
if (fd_ != 0) {
HWY_ASSERT(close(fd_) != -1);
}
}
uint64_t FileSize() const override {
static_assert(sizeof(off_t) == 8, "64-bit off_t required");
const off_t size = lseek(fd_, 0, SEEK_END);
if (size < 0) {
return 0;
}
return static_cast<uint64_t>(size);
}
bool Read(uint64_t offset, uint64_t size, void* to) const override {
uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
uint64_t pos = 0;
for (;;) {
// pread seems to be faster than lseek + read when parallelized.
const auto bytes_read = pread(fd_, bytes + pos, size - pos, offset + pos);
if (bytes_read <= 0) break;
pos += bytes_read;
HWY_ASSERT(pos <= size);
if (pos == size) break;
}
return pos == size; // success if managed to read desired size
}
bool Write(const void* from, uint64_t size, uint64_t offset) override {
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(from);
uint64_t pos = 0;
for (;;) {
const auto bytes_written =
pwrite(fd_, bytes + pos, size - pos, offset + pos);
if (bytes_written <= 0) break;
pos += bytes_written;
HWY_ASSERT(pos <= size);
if (pos == size) break;
}
return pos == size; // success if managed to write desired size
}
}; // FilePosix
HWY_MAYBE_UNUSED extern std::unique_ptr<File> OpenFileGoogle(
const Path& filename, const char* mode);
std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode) {
std::unique_ptr<File> file; // OpenFileGoogle omitted
if (file) return file;
const bool is_read = mode[0] != 'w';
const int flags = is_read ? O_RDONLY : O_CREAT | O_RDWR | O_TRUNC;
const int fd = open(filename.path.c_str(), flags, 0644);
if (fd < 0) return file;
#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
if (is_read) {
// Doubles the readahead window, which seems slightly faster when cached.
(void)posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL);
}
#endif
return std::make_unique<FilePosix>(fd);
}
} // namespace gcpp
#endif // !HWY_OS_WIN

88
compression/io.h Normal file
View File

@ -0,0 +1,88 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_
#include <stddef.h>
#include <stdint.h>
#include <memory>
#include <string>
#include <utility> // std::move
namespace gcpp {
// Forward-declare to break the circular dependency: OpenFileOrNull returns
// File and has a Path argument, and Path::Exists calls OpenFileOrNull. We
// prefer to define Exists inline because there are multiple io*.cc files.
struct Path;
// Abstract base class enables multiple I/O backends in the same binary.
class File {
public:
File() = default;
virtual ~File() = default;
// Noncopyable.
File(const File& other) = delete;
const File& operator=(const File& other) = delete;
// Returns size in bytes or 0.
virtual uint64_t FileSize() const = 0;
// Returns true if all the requested bytes were read.
virtual bool Read(uint64_t offset, uint64_t size, void* to) const = 0;
// Returns true if all the requested bytes were written.
virtual bool Write(const void* from, uint64_t size, uint64_t offset) = 0;
};
// Returns nullptr on failure. `mode` is either "r" or "w+". This is not just
// named 'OpenFile' to avoid a conflict with Windows.h #define.
std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode);
// Wrapper for strings representing a path name. Differentiates vs. arbitrary
// strings and supports shortening for display purposes.
struct Path {
Path() {}
explicit Path(const char* p) : path(p) {}
explicit Path(std::string p) : path(std::move(p)) {}
Path& operator=(const char* other) {
path = other;
return *this;
}
std::string Shortened() const {
constexpr size_t kMaxLen = 48;
constexpr size_t kCutPoint = kMaxLen / 2 - 5;
if (path.size() > kMaxLen) {
return std::string(begin(path), begin(path) + kCutPoint) + " ... " +
std::string(end(path) - kCutPoint, end(path));
}
if (path.empty()) return "[no path specified]";
return path;
}
// Returns whether the file existed when this was called.
bool Exists() const { return !!OpenFileOrNull(*this, "r"); }
std::string path;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_

115
compression/io_win.cc Normal file
View File

@ -0,0 +1,115 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "hwy/detect_compiler_arch.h"
// Only compile this file on Windows; it replaces io.cc. It is easier to check
// this in source code because we support multiple build systems.
#if HWY_OS_WIN
#include <stddef.h>
#include <stdint.h>
#include "compression/io.h"
#include "hwy/base.h" // HWY_ASSERT
#ifndef WIN32_LEAN_AND_MEAN
#define WIN32_LEAN_AND_MEAN
#endif
#ifndef VC_EXTRALEAN
#define VC_EXTRALEAN
#endif
#include <Windows.h>
namespace gcpp {
class FileWin : public File {
HANDLE hFile_ = INVALID_HANDLE_VALUE;
public:
FileWin(HANDLE hFile) : hFile_(hFile) {
HWY_ASSERT(hFile != INVALID_HANDLE_VALUE);
}
~FileWin() override {
if (hFile_ != INVALID_HANDLE_VALUE) {
HWY_ASSERT(CloseHandle(hFile_) != 0);
}
}
uint64_t FileSize() const override {
DWORD hi;
const DWORD lo = GetFileSize(hFile_, &hi);
if (lo == INVALID_FILE_SIZE) return 0;
return (static_cast<uint64_t>(hi) << 32) | lo;
}
bool Read(uint64_t offset, uint64_t size, void* to) const override {
uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
OVERLAPPED overlapped = {0};
// Loop is required because ReadFile[Ex] size argument is 32-bit.
while (size != 0) {
overlapped.Offset = offset & 0xFFFFFFFF;
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
const DWORD want =
static_cast<DWORD>(HWY_MIN(size, uint64_t{0xFFFFFFFF}));
DWORD got;
if (!ReadFile(hFile_, bytes, want, &got, &overlapped)) {
if (GetLastError() != ERROR_HANDLE_EOF) {
return false;
}
}
offset += got;
bytes += got;
size -= got;
}
return true; // read everything => success
}
bool Write(const void* from, uint64_t size, uint64_t offset) override {
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(from);
OVERLAPPED overlapped = {0};
// Loop is required because WriteFile[Ex] size argument is 32-bit.
while (size != 0) {
overlapped.Offset = offset & 0xFFFFFFFF;
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
const DWORD want =
static_cast<DWORD>(HWY_MIN(size, uint64_t{0xFFFFFFFF}));
DWORD got;
if (!WriteFile(hFile_, bytes, want, &got, &overlapped)) {
if (GetLastError() != ERROR_HANDLE_EOF) {
return false;
}
}
offset += got;
bytes += got;
size -= got;
}
return true; // wrote everything => success
}
}; // FileWin
std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode) {
const bool is_read = mode[0] != 'w';
const DWORD flags =
FILE_ATTRIBUTE_NORMAL | (is_read ? FILE_FLAG_SEQUENTIAL_SCAN : 0);
const DWORD access = is_read ? GENERIC_READ : GENERIC_WRITE;
const DWORD share = is_read ? FILE_SHARE_READ : 0;
const DWORD create = is_read ? OPEN_EXISTING : CREATE_ALWAYS;
const HANDLE hFile = CreateFileA(filename.path.c_str(), access, share,
nullptr, create, flags, nullptr);
if (hFile == INVALID_HANDLE_VALUE) return std::unique_ptr<File>();
return std::make_unique<FileWin>(hFile);
}
} // namespace gcpp
#endif // HWY_OS_WIN

View File

@ -20,9 +20,7 @@
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
// copybara:import_next_line:gemma_cpp
#include "compression/nuq.h" #include "compression/nuq.h"
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h" #include "compression/sfp.h"
#include "hwy/base.h" #include "hwy/base.h"
@ -37,7 +35,6 @@
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE
#endif #endif
// copybara:import_next_line:gemma_cpp
#include "compression/sfp-inl.h" #include "compression/sfp-inl.h"
#include "hwy/contrib/sort/vqsort-inl.h" #include "hwy/contrib/sort/vqsort-inl.h"
#include "hwy/highway.h" #include "hwy/highway.h"
@ -107,9 +104,8 @@ class NuqClustering {
// Callers are responsible for ignoring lanes where last < first. // Callers are responsible for ignoring lanes where last < first.
HWY_DASSERT(first < kGroupSize); HWY_DASSERT(first < kGroupSize);
HWY_DASSERT(last < kGroupSize); HWY_DASSERT(last < kGroupSize);
const size_t len = last - first + 1; const int len = static_cast<int>(last) - static_cast<int>(first) + 1;
const hn::Vec<DF> vlen = const hn::Vec<DF> vlen = hn::Iota(df, static_cast<float>(len));
hn::Iota(df, static_cast<float>(static_cast<int>(len)));
const hn::Vec<DF> u_lo = hn::Set(df, cumsum_[first]); const hn::Vec<DF> u_lo = hn::Set(df, cumsum_[first]);
const hn::Vec<DF> u_lo2 = hn::Set(df, cumsum2_[first]); const hn::Vec<DF> u_lo2 = hn::Set(df, cumsum2_[first]);
@ -207,7 +203,7 @@ class NuqClustering {
for (size_t num_clusters = 1; num_clusters < kClusters; ++num_clusters) { for (size_t num_clusters = 1; num_clusters < kClusters; ++num_clusters) {
// For each batch starting at `last`, one per lane: // For each batch starting at `last`, one per lane:
for (size_t last = 0; last < kGroupSize; last += N) { for (size_t last = 0; last < kGroupSize; last += N) {
VF min = cc(df, 0, last); VF min = hn::LoadU(df, &D(0, last));
VI arg = hn::Zero(di); VI arg = hn::Zero(di);
// For each j (start of rightmost cluster): // For each j (start of rightmost cluster):
VI vj = k1; VI vj = k1;

View File

@ -18,6 +18,8 @@
#define HWY_DISABLED_TARGETS HWY_SCALAR #define HWY_DISABLED_TARGETS HWY_SCALAR
#endif #endif
#include "compression/nuq.h"
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
@ -25,8 +27,10 @@
#include <algorithm> // std::shuffle #include <algorithm> // std::shuffle
#include <random> #include <random>
#include "compression/test_util.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/tests/test_util.h"
#include "hwy/timer.h" #include "hwy/timer.h"
// clang-format off // clang-format off
@ -35,12 +39,7 @@
// clang-format on // clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
// Other headers that include Highway must come after foreach_target.h // Other headers that include Highway must come after foreach_target.h
// copybara:import_next_line:gemma_cpp
#include "compression/nuq-inl.h" #include "compression/nuq-inl.h"
// copybara:import_next_line:gemma_cpp
#include "compression/nuq.h"
// copybara:import_next_line:gemma_cpp
#include "compression/test_util.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"
#include "hwy/tests/test_util-inl.h" #include "hwy/tests/test_util-inl.h"
@ -117,12 +116,16 @@ struct TestPlateaus {
HWY_ASSERT(indices[i] < kClusters); HWY_ASSERT(indices[i] < kClusters);
stats.Notify(in[i], centers[indices[i]]); stats.Notify(in[i], centers[indices[i]]);
} }
const float pnorm = stats.PNorm(); // Zero error.
const float snr = stats.GeomeanValueDivL1(); HWY_ASSERT_EQ(kGroupSize, stats.NumExact());
fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr, HWY_ASSERT_EQ(0, stats.NumSignFlip());
stats.MaxIndex(), stats.MaxL1()); HWY_ASSERT_EQ(0, stats.NumRoundedToZero());
HWY_ASSERT(pnorm == 0.0f); HWY_ASSERT_EQ(0.0, stats.SumL1());
HWY_ASSERT(snr == 0.0f); HWY_ASSERT_EQ(0.0f, stats.GeomeanValueDivL1());
HWY_ASSERT_EQ(0.0f, stats.WeightedAverageL1());
// Input was symmetric and zero-mean.
HWY_ASSERT(gcpp::IsInside(-0.05, 0.05, stats.Original().Mean()));
HWY_ASSERT(gcpp::IsNear(0.0, stats.Original().Skewness()));
} }
}; };
@ -160,16 +163,19 @@ struct TestRamp {
HWY_ASSERT(indices[i] < kClusters); HWY_ASSERT(indices[i] < kClusters);
stats.Notify(in[i], centers[indices[i]]); stats.Notify(in[i], centers[indices[i]]);
} }
const float pnorm = stats.PNorm();
const float snr = stats.GeomeanValueDivL1();
fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr,
stats.MaxIndex(), stats.MaxL1());
static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected");
const float expected_pnorm = kGroupSize == 128 ? 2.08E-2f : 2.1E-2f; // Low error.
const float expected_snr = kGroupSize == 128 ? 16.9f : 17.6f; HWY_ASSERT_EQ(0, stats.NumExact());
HWY_ASSERT(expected_pnorm <= pnorm && pnorm < 1.02f * expected_pnorm); HWY_ASSERT(stats.NumSignFlip() < 10);
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr); HWY_ASSERT_EQ(0, stats.NumRoundedToZero());
HWY_ASSERT_EQ(kGroupSize / kClusters / 4.0, stats.SumL1());
HWY_ASSERT(gcpp::IsInside(17.0, 18.0, stats.GeomeanValueDivL1()));
HWY_ASSERT(gcpp::IsInside(0.005, 0.010, stats.WeightedAverageL1()));
HWY_ASSERT(stats.L1().Max() <= 0.04f);
// Input was symmetric about 0.05.
HWY_ASSERT(gcpp::IsNear(0.05, stats.Original().Mean(), 0.01));
HWY_ASSERT(gcpp::IsNear(0.0, stats.Original().Skewness(), 1E-4));
static_assert(kGroupSize == 256, "Update expected");
} }
}; };
@ -210,15 +216,16 @@ struct TestNormal {
HWY_ASSERT(indices[i] < kClusters); HWY_ASSERT(indices[i] < kClusters);
stats.Notify(in[i], centers[indices[i]]); stats.Notify(in[i], centers[indices[i]]);
} }
const float pnorm = stats.PNorm();
const float snr = stats.GeomeanValueDivL1(); // Moderate error.
fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr, HWY_ASSERT_EQ(0, stats.NumExact());
stats.MaxIndex(), stats.MaxL1()); HWY_ASSERT(stats.NumSignFlip() < kGroupSize / kClusters);
HWY_ASSERT_EQ(0, stats.NumRoundedToZero());
HWY_ASSERT(gcpp::IsInside(5.0, 6.0, stats.SumL1()));
HWY_ASSERT(gcpp::IsInside(12.7, 12.8, stats.GeomeanValueDivL1()));
HWY_ASSERT(gcpp::IsInside(0.036, 0.037, stats.WeightedAverageL1()));
HWY_ASSERT(stats.L1().Max() <= 0.10f);
static_assert(kGroupSize == 256, "Update expected"); static_assert(kGroupSize == 256, "Update expected");
const float expected_pnorm = 3.68E-2f;
const float expected_snr = 12.7f;
HWY_ASSERT(expected_pnorm <= pnorm && pnorm < 1.02f * expected_pnorm);
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr);
} }
}; };
@ -238,10 +245,9 @@ struct TestOffset {
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total)); auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
HWY_ASSERT(in && dec1 && dec2 && nuq); HWY_ASSERT(in && dec1 && dec2 && nuq);
std::mt19937 rng(123); hwy::RandomState rng;
std::normal_distribution<float> dist{0.001f, 0.3f};
for (size_t i = 0; i < total; ++i) { for (size_t i = 0; i < total; ++i) {
in[i] = dist(rng); in[i] = static_cast<float>(RandomGaussian(rng));
} }
// Encode + decode everything // Encode + decode everything
@ -281,11 +287,13 @@ struct TestStream {
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(num)); auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(num));
HWY_ASSERT(in && out && nuq); HWY_ASSERT(in && out && nuq);
std::mt19937 rng(123); hwy::RandomState rng;
std::normal_distribution<float> dist{0.001f, 0.3f}; Stats in_stats;
for (size_t i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
in[i] = dist(rng); in[i] = static_cast<float>(RandomGaussian(rng));
in_stats.Notify(in[i]);
} }
VerifyGaussian(in_stats);
ClusterBuf buf; ClusterBuf buf;
double elapsed = hwy::HighestValue<double>(); double elapsed = hwy::HighestValue<double>();
@ -314,15 +322,16 @@ struct TestStream {
for (size_t i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
stats.Notify(in[i], hwy::ConvertScalarTo<float>(out[i])); stats.Notify(in[i], hwy::ConvertScalarTo<float>(out[i]));
} }
const float pnorm = stats.PNorm();
const float snr = stats.GeomeanValueDivL1(); // Moderate error.
fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr, HWY_ASSERT_EQ(0, stats.NumExact());
stats.MaxIndex(), stats.MaxL1()); HWY_ASSERT(stats.NumSignFlip() < num / kClusters);
static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected"); HWY_ASSERT_EQ(0, stats.NumRoundedToZero());
const float expected_pnorm = kGroupSize == 128 ? 3.44E-2f : 3.88E-2f; HWY_ASSERT(gcpp::IsInside(23.0, 24.0, stats.SumL1()));
const float expected_snr = kGroupSize == 128 ? 15.0f : 13.3f; HWY_ASSERT(gcpp::IsInside(13.0, 13.3, stats.GeomeanValueDivL1()));
HWY_ASSERT(expected_pnorm <= pnorm && pnorm < 1.02f * expected_pnorm); HWY_ASSERT(gcpp::IsInside(0.034, 0.035, stats.WeightedAverageL1()));
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr); HWY_ASSERT(stats.L1().Max() <= 0.11f);
static_assert(kGroupSize == 256, "Update expected");
} }
}; };
@ -351,9 +360,8 @@ struct TestDot {
hwy::RandomState rng; hwy::RandomState rng;
Stats in_stats; Stats in_stats;
for (size_t i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
const float r = static_cast<float>(RandomGaussian(rng)); in[i] = static_cast<float>(RandomGaussian(rng));
in_stats.Notify(r); in_stats.Notify(in[i]);
in[i] = r;
} }
for (size_t i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
const float r = static_cast<float>(RandomGaussian(rng)); const float r = static_cast<float>(RandomGaussian(rng));
@ -368,7 +376,7 @@ struct TestDot {
HWY_ASSERT(unused_clusters == 0); HWY_ASSERT(unused_clusters == 0);
// Compute dot product without decompression. // Compute dot product without decompression.
double actual = 0.0; float actual = 0.0f;
double elapsed = hwy::HighestValue<double>(); double elapsed = hwy::HighestValue<double>();
for (size_t rep = 0; rep < 20; ++rep) { for (size_t rep = 0; rep < 20; ++rep) {
hn::Vec<decltype(df)> sum0 = hn::Zero(df); hn::Vec<decltype(df)> sum0 = hn::Zero(df);
@ -389,8 +397,8 @@ struct TestDot {
num * sizeof(in[0]) * 1E-6 / elapsed); num * sizeof(in[0]) * 1E-6 / elapsed);
// Exact and decompressed dot products for comparison. // Exact and decompressed dot products for comparison.
double exact = 0.0; // using original input float exact = 0.0f; // using original input
double expected = 0.0; // using decoded NUQ float expected = 0.0f; // using decoded NUQ
DistortionStats dec_stats; DistortionStats dec_stats;
Stats ratios; Stats ratios;
for (size_t i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
@ -402,24 +410,42 @@ struct TestDot {
ratios.Notify(exact / expected); ratios.Notify(exact / expected);
} }
} }
const bool isBF = sizeof(T) == 2;
const double dec_snr = dec_stats.GeomeanValueDivL1(); const double dec_snr = dec_stats.GeomeanValueDivL1();
const double dec_wl1 = dec_stats.WeightedAverageL1();
const double dot_snr = 1.0 / hwy::ScalarAbs(1.0 - ratios.GeometricMean()); const double dot_snr = 1.0 / hwy::ScalarAbs(1.0 - ratios.GeometricMean());
// exact and actual fluctuate due to the combination of NUQ imprecision, // exact and actual fluctuate due to the combination of NUQ imprecision,
// and whether vec[i] is negative or positive, so this is quite loose. // and whether vec[i] is negative or positive, so this is quite loose.
const float final_ratio = HWY_MIN(exact / actual, actual / exact); const float final_ratio = HWY_MIN(exact / actual, actual / exact);
if (HWY_ONCE) {
fprintf(stderr, "ratios %s\n", ratios.ToString().c_str()); fprintf(stderr, "ratios %s\n", ratios.ToString().c_str());
fprintf(stderr, fprintf(stderr,
"exact %.3f e2 %.4f actual %.4f final_ratio %.3f dec_snr %.2f " "exact %.3f e2 %.4f actual %.4f final_ratio %.3f dec_snr %.2f "
"dot_snr %.2f\n", "dot_snr %.2f dec_wl1 %.4f\n",
exact, expected, actual, final_ratio, dec_snr, dot_snr); exact, expected, actual, final_ratio, dec_snr, dot_snr, dec_wl1);
}
// Final values are not too far apart. // Final values are not too far apart.
HWY_ASSERT(0.88f <= final_ratio && final_ratio <= 1.0f); HWY_ASSERT(gcpp::IsInside(0.88f, 1.0f, final_ratio));
// Decompressed and uncompressed dot should match exactly. // Decompressed and uncompressed dot should match exactly.
HWY_ASSERT(hwy::ScalarAbs(expected - actual) < 1E-4f); HWY_ASSERT(gcpp::IsNear(expected, actual, 1E-4f));
// dec[] is close to in[], but we already check that in TestStream. // Geomean of ratios for each i should be very close to one.
HWY_ASSERT(dec_snr >= 13.0); HWY_ASSERT(dot_snr >= (isBF ? 17.7 : 14.3));
// Geomean of ratios for each i is an approximation of the actual SNR.
HWY_ASSERT(dot_snr >= (sizeof(T) == 2 ? 17.0 : 14.0)); // dec[] is close to in[], but we already check that in TestStream with the
// same input distribution.
HWY_ASSERT(gcpp::IsNear(13.1, dec_snr, 0.1));
HWY_ASSERT(gcpp::IsNear(0.034, dec_wl1, 0.001));
HWY_ASSERT(gcpp::IsNear(23.5, dec_stats.SumL1(), 0.1));
HWY_ASSERT(dec_stats.NumSignFlip() < num / kClusters);
HWY_ASSERT_EQ(0, dec_stats.NumExact());
HWY_ASSERT_EQ(0, dec_stats.NumRoundedToZero());
HWY_ASSERT_EQ(0.0, dec_stats.SumL1Rounded());
// Absolute decode errors are in [0, 0.11], and somewhat right-tailed.
HWY_ASSERT(gcpp::IsInside(0.0f, 2E-5f, dec_stats.L1().Min()));
HWY_ASSERT(gcpp::IsInside(0.09f, 0.11f, dec_stats.L1().Max()));
HWY_ASSERT(gcpp::IsInside(0.02, 0.03, dec_stats.L1().Mean()));
HWY_ASSERT(gcpp::IsInside(1.0, 1.1, dec_stats.L1().Skewness()));
HWY_ASSERT(gcpp::IsInside(4.0, 5.0, dec_stats.L1().Kurtosis()));
static_assert(kGroupSize == 256, "Update expected*"); static_assert(kGroupSize == 256, "Update expected*");
} }
}; };

View File

@ -20,7 +20,6 @@
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h" #include "compression/sfp.h"
#include "hwy/base.h" #include "hwy/base.h"

View File

@ -18,7 +18,6 @@
#define HWY_DISABLED_TARGETS HWY_SCALAR #define HWY_DISABLED_TARGETS HWY_SCALAR
#endif #endif
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h" #include "compression/sfp.h"
#include <stddef.h> #include <stddef.h>
@ -27,6 +26,7 @@
#include <set> #include <set>
#include "compression/test_util.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/timer.h" #include "hwy/timer.h"
@ -37,10 +37,7 @@
// clang-format on // clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
// Any highway.h must come after foreach_target.h // Any highway.h must come after foreach_target.h
// copybara:import_next_line:gemma_cpp
#include "compression/sfp-inl.h" #include "compression/sfp-inl.h"
// copybara:import_next_line:gemma_cpp
#include "compression/test_util.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"
#include "hwy/tests/test_util-inl.h" #include "hwy/tests/test_util-inl.h"
@ -307,10 +304,37 @@ struct TestEncDec {
sum += hwy::ConvertScalarTo<double>(hwy::ScalarAbs(in[i])); sum += hwy::ConvertScalarTo<double>(hwy::ScalarAbs(in[i]));
stats.Notify(hwy::ConvertScalarTo<float>(in[i]), out); stats.Notify(hwy::ConvertScalarTo<float>(in[i]), out);
} }
const double avg = sum / num; const double avg_in = sum / num;
fprintf(stderr, "Avg magnitude %.3E, p-norm %.3E snr %.2f @%zu = %.4E\n", const double snr = stats.GeomeanValueDivL1();
avg, stats.PNorm(), stats.GeomeanValueDivL1(), stats.MaxIndex(), const double wl1 = stats.WeightedAverageL1();
stats.MaxL1()); if (false) {
fprintf(stderr,
"Num inputs %zu, avg %.3E, exact %zu round0 %zu (sum %E) snr "
"%.2f wL1 %f\n",
num, avg_in, stats.NumExact(), stats.NumRoundedToZero(),
stats.SumL1Rounded(), snr, wl1);
}
HWY_ASSERT(stats.Original().Count() == stats.L1().Count());
// Inputs are in [-1.875, 1.875], symmetric, and heavy-tailed.
HWY_ASSERT(stats.Original().Min() == -1.875f);
HWY_ASSERT(stats.Original().Max() == 1.875f);
HWY_ASSERT(gcpp::IsInside(-1E-6, 1E-6, stats.Original().Mean()));
HWY_ASSERT(gcpp::IsInside(-1E-6, 1E-6, stats.Original().Skewness()));
HWY_ASSERT(gcpp::IsInside(80.0, 100.0, stats.Original().Kurtosis()));
// Absolute errors are in [0, 0.0625], and (heavy) right-tailed.
HWY_ASSERT(stats.L1().Min() == 0.0f);
HWY_ASSERT(stats.L1().Max() == 0.0625f);
HWY_ASSERT(gcpp::IsInside(4E-4, 5E-4, stats.L1().Mean()));
HWY_ASSERT(gcpp::IsInside(10.0, 15.0, stats.L1().Skewness()));
HWY_ASSERT(gcpp::IsInside(150.0, 200.0, stats.L1().Kurtosis()));
// SNR is low because many *tiny* numbers are rounded to zero.
HWY_ASSERT_EQ(3322, stats.NumRoundedToZero());
HWY_ASSERT(gcpp::IsInside(5E-6, 6E-6, stats.SumL1Rounded()));
HWY_ASSERT(gcpp::IsInside(1.880, 1.885, stats.SumL1()));
HWY_ASSERT_EQ(256, stats.NumExact());
HWY_ASSERT_EQ(0, stats.NumSignFlip());
HWY_ASSERT(gcpp::IsInside(2.70, 2.75, snr));
HWY_ASSERT(gcpp::IsInside(0.010, 0.011, wl1)); // = half of mean |x|.
} }
} }
}; };
@ -381,7 +405,7 @@ struct TestDot {
SfpCodec::Enc(d, in.get(), num, sfp.get()); SfpCodec::Enc(d, in.get(), num, sfp.get());
// Compute dot product without decompression. // Compute dot product without decompression.
double actual = 0.0; float actual = 0.0f;
double elapsed = hwy::HighestValue<double>(); double elapsed = hwy::HighestValue<double>();
for (size_t rep = 0; rep < 200; ++rep) { for (size_t rep = 0; rep < 200; ++rep) {
hn::Vec<decltype(df)> sum0 = hn::Zero(df); hn::Vec<decltype(df)> sum0 = hn::Zero(df);
@ -417,24 +441,41 @@ struct TestDot {
ratios.Notify(exact / expected); ratios.Notify(exact / expected);
} }
} }
const bool isBF = sizeof(T) == 2;
const double dec_snr = dec_stats.GeomeanValueDivL1(); const double dec_snr = dec_stats.GeomeanValueDivL1();
const double dec_wl1 = dec_stats.WeightedAverageL1();
const double dot_snr = 1.0 / hwy::ScalarAbs(1.0 - ratios.GeometricMean()); const double dot_snr = 1.0 / hwy::ScalarAbs(1.0 - ratios.GeometricMean());
// exact and actual fluctuate due to the combination of SFP imprecision, // exact and actual fluctuate due to the combination of SFP imprecision,
// and whether vec[i] is negative or positive, so this is quite loose. // and whether vec[i] is negative or positive, so this is quite loose.
const float final_ratio = HWY_MIN(exact / actual, actual / exact); const float final_ratio = HWY_MIN(exact / actual, actual / exact);
if (HWY_ONCE) {
fprintf(stderr, "ratios %s\n", ratios.ToString().c_str()); fprintf(stderr, "ratios %s\n", ratios.ToString().c_str());
fprintf(stderr, fprintf(stderr,
"exact %.3f e2 %.4f actual %.4f final_ratio %.3f dec_snr %.2f " "exact %.3f e2 %.4f actual %.4f final_ratio %.3f dec_snr %.2f "
"dot_snr %.2f\n", "dot_snr %.2f dec_wl1 %.5f\n",
exact, expected, actual, final_ratio, dec_snr, dot_snr); exact, expected, actual, final_ratio, dec_snr, dot_snr, dec_wl1);
}
// Final values are not too far apart. // Final values are not too far apart.
HWY_ASSERT(0.87f <= final_ratio && final_ratio <= 1.0f); HWY_ASSERT(gcpp::IsInside(0.87f, 1.0f, final_ratio));
// Decompressed and uncompressed dot should match exactly. // Decompressed and uncompressed dot should match exactly.
HWY_ASSERT(hwy::ScalarAbs(expected - actual) < 1E-4f); HWY_ASSERT(gcpp::IsNear(expected, actual, 1E-4f));
// dec[] is close to in[], but we already check that in TestEncDec.
HWY_ASSERT(dec_snr >= 50.0);
// Geomean of ratios for each i should be very close to one. // Geomean of ratios for each i should be very close to one.
HWY_ASSERT(dot_snr >= (sizeof(T) == 2 ? 70.0 : 1000.0)); HWY_ASSERT(dot_snr >= (isBF ? 70.0 : 1000.0));
// dec[] is close to in[]. We also check that in TestEncDec, but for much
// smaller input magnitudes.
HWY_ASSERT(gcpp::IsNear(isBF ? 51.0 : 64.0, dec_snr, 1.0));
HWY_ASSERT(gcpp::IsNear(isBF ? 0.013 : 0.012, dec_wl1, 0.001));
HWY_ASSERT(gcpp::IsNear(isBF ? 6.2 : 6.3, dec_stats.SumL1(), 0.1));
HWY_ASSERT_EQ(0, dec_stats.NumSignFlip());
HWY_ASSERT_EQ(0, dec_stats.NumRoundedToZero());
HWY_ASSERT_EQ(0.0, dec_stats.SumL1Rounded());
// Absolute decode errors are in [0, 5E-2], and somewhat right-tailed.
HWY_ASSERT(gcpp::IsInside(0.0f, 2E-6f, dec_stats.L1().Min()));
HWY_ASSERT(gcpp::IsInside(3E-2f, 5E-2f, dec_stats.L1().Max()));
HWY_ASSERT(gcpp::IsInside(4E-3, 7E-3, dec_stats.L1().Mean()));
HWY_ASSERT(gcpp::IsInside(1.8, 1.9, dec_stats.L1().Skewness()));
HWY_ASSERT(gcpp::IsInside(6.0, 7.0, dec_stats.L1().Kurtosis()));
} }
}; };

View File

@ -13,7 +13,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// copybara:import_next_line:gemma_cpp
#include "compression/stats.h" #include "compression/stats.h"
#include <stdio.h> #include <stdio.h>

View File

@ -19,7 +19,6 @@
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
#include <algorithm>
#include <cmath> #include <cmath>
#include <string> #include <string>
@ -45,7 +44,7 @@ class Bins {
} }
} }
void Print(const char* caption) { void Print(const char* caption) const {
fprintf(stderr, "\n%s [%zu]\n", caption, N); fprintf(stderr, "\n%s [%zu]\n", caption, N);
size_t last_nonzero = 0; size_t last_nonzero = 0;
for (size_t i = N - 1; i < N; --i) { for (size_t i = N - 1; i < N; --i) {
@ -77,8 +76,8 @@ class Stats {
void Notify(const float x) { void Notify(const float x) {
++n_; ++n_;
min_ = std::min(min_, x); min_ = HWY_MIN(min_, x);
max_ = std::max(max_, x); max_ = HWY_MAX(max_, x);
product_ *= x; product_ *= x;
@ -119,7 +118,7 @@ class Stats {
// Near zero for normal distributions; if positive on a unimodal distribution, // Near zero for normal distributions; if positive on a unimodal distribution,
// the right tail is fatter. Assumes n_ is large. // the right tail is fatter. Assumes n_ is large.
double SampleSkewness() const { double SampleSkewness() const {
if (std::abs(m2_) < 1E-7) return 0.0; if (hwy::ScalarAbs(m2_) < 1E-7) return 0.0;
return m3_ * std::sqrt(static_cast<double>(n_)) / std::pow(m2_, 1.5); return m3_ * std::sqrt(static_cast<double>(n_)) / std::pow(m2_, 1.5);
} }
// Corrected for bias (same as Wikipedia and Minitab but not Excel). // Corrected for bias (same as Wikipedia and Minitab but not Excel).
@ -132,7 +131,7 @@ class Stats {
// Near zero for normal distributions; smaller values indicate fewer/smaller // Near zero for normal distributions; smaller values indicate fewer/smaller
// outliers and larger indicates more/larger outliers. Assumes n_ is large. // outliers and larger indicates more/larger outliers. Assumes n_ is large.
double SampleKurtosis() const { double SampleKurtosis() const {
if (std::abs(m2_) < 1E-7) return 0.0; if (hwy::ScalarAbs(m2_) < 1E-7) return 0.0;
return m4_ * n_ / (m2_ * m2_); return m4_ * n_ / (m2_ * m2_);
} }
// Corrected for bias (same as Wikipedia and Minitab but not Excel). // Corrected for bias (same as Wikipedia and Minitab but not Excel).

View File

@ -24,9 +24,7 @@
#include "hwy/base.h" #include "hwy/base.h"
// IWYU pragma: begin_exports // IWYU pragma: begin_exports
// copybara:import_next_line:gemma_cpp
#include "compression/distortion.h" #include "compression/distortion.h"
// copybara:import_next_line:gemma_cpp
#include "compression/stats.h" #include "compression/stats.h"
#include "hwy/tests/test_util.h" // RandomState #include "hwy/tests/test_util.h" // RandomState
// IWYU pragma: end_exports // IWYU pragma: end_exports
@ -51,12 +49,26 @@ HWY_INLINE double RandomGaussian(hwy::RandomState& rng) {
return plus_minus_1 * std::sqrt(kReps / 3.0); return plus_minus_1 * std::sqrt(kReps / 3.0);
}; };
// Returns true if val is inside [min, max].
template <typename T>
static inline bool IsInside(T expected_min, T expected_max, T val) {
return expected_min <= val && val <= expected_max;
}
template <typename T>
static inline bool IsNear(T expected, T val, T epsilon = T{1E-6}) {
return IsInside(expected - epsilon, expected + epsilon, val);
}
HWY_INLINE void VerifyGaussian(Stats& stats) { HWY_INLINE void VerifyGaussian(Stats& stats) {
const double stddev = stats.StandardDeviation(); // Inputs are roughly [-1, 1] and symmetric about zero.
HWY_ASSERT(-0.01 <= stats.Mean() && stats.Mean() <= 0.01); HWY_ASSERT(IsNear(-1.0f, stats.Min(), 0.10f));
HWY_ASSERT(0.30 <= stddev && stddev <= 0.35); HWY_ASSERT(IsNear(+1.0f, stats.Max(), 0.10f));
HWY_ASSERT(-1.1 <= stats.Min() && stats.Min() <= -0.9); HWY_ASSERT(IsInside(-2E-3, 2E-3, stats.Mean()));
HWY_ASSERT(0.9 <= stats.Max() && stats.Max() <= 1.1); HWY_ASSERT(IsInside(-0.15, 0.15, stats.Skewness()));
// Near-Gaussian.
HWY_ASSERT(IsInside(0.30, 0.35, stats.StandardDeviation()));
HWY_ASSERT(IsNear(3.0, stats.Kurtosis(), 0.3));
} }
} // namespace gcpp } // namespace gcpp

132
debug_prompt.cc Normal file
View File

@ -0,0 +1,132 @@
#include <fstream>
#include <iostream>
#include <string>
#include "gemma/gemma.h"
#include "nlohmann/json.hpp"
#include "util/app.h"
#include "util/args.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
using json = nlohmann::json;
class PromptArgs : public gcpp::ArgsBase<PromptArgs> {
public:
PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
gcpp::Path layers_output;
std::string prompt;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(layers_output.path, "layers_output", std::string(""),
"Path to store layers output", 2);
visitor(prompt, "prompt", std::string(""), "Prompt to the model", 2);
}
};
std::pair<std::string, int> QueryModel(
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input,
gcpp::LayersOutputT* layers_output) {
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
prompt.insert(prompt.begin(), 2);
std::string res;
size_t total_tokens = 0;
auto accept_token = [](int) { return true; };
std::mt19937 gen;
gen.seed(42);
auto stream_token = [&res, &total_tokens, &app,
tokenizer = model.Tokenizer()](int token, float) {
++total_tokens;
std::string token_text;
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text));
res += token_text;
return true;
};
if (app.verbosity >= 2) {
std::cout << args.max_tokens << " " << args.max_generated_tokens << " "
<< args.temperature;
}
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
stream_token, accept_token, gen, app.verbosity, layers_output);
return {res, total_tokens};
}
class OutputJsonLogger {
public:
json json_output;
gcpp::LayersOutputT layers_output_log_f =
[this](int pos, const std::string& key, const float* values, size_t values_len) {
std::vector<float> v{values, values + values_len};
json_output[std::to_string(pos)][key] = v;
};
};
/* Run this in the same way as gemma, p.ex.:
./debug_prompt --tokenizer tokenizer.spm --model 2b-it --weights \
2b-it-sfp.sbs --prompt "..." --layers_output [path]
*/
int main(int argc, char** argv) {
gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs args(argc, argv); // inference
gcpp::AppArgs app(argc, argv);
PromptArgs prompt_args(argc, argv);
if (const char* error = loader.Validate()) {
HWY_ABORT("\nInvalid loader args: %s", error);
}
if (const char* error = args.Validate()) {
HWY_ABORT("\nInvalid inference args: %s", error);
}
const bool log_layers_output = !prompt_args.layers_output.path.empty();
OutputJsonLogger json_logger;
gcpp::LayersOutputT* layers_output =
log_layers_output ? &json_logger.layers_output_log_f : nullptr;
hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning threads to cores helps.
if (app.num_threads > 10) {
gcpp::PinThreadToCore(app.num_threads - 1); // Main thread
pool.Run(0, pool.NumThreads(), [](uint64_t /*task*/, size_t thread) {
gcpp::PinThreadToCore(thread);
});
}
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
auto kv_cache = CreateKVCache(loader.ModelType());
const std::string& prompt = prompt_args.prompt;
if (prompt.empty()) {
std::cout << "Please specify --prompt" << std::endl;
return EXIT_FAILURE;
}
const auto [answer, token_count] = QueryModel(
model, args, app, kv_cache, pool, prompt, layers_output);
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
if (log_layers_output) {
std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out);
if (!output_f) {
std::cout << "Opening file failed" << std::endl;
return EXIT_FAILURE;
}
output_f << json_logger.json_output.dump();
if (!output_f) {
std::cout << "Writing to file failed" << std::endl;
return EXIT_FAILURE;
}
output_f.close();
}
return EXIT_SUCCESS;
}

View File

@ -15,12 +15,9 @@
#include <iostream> #include <iostream>
// copybara:import_next_line:gemma_cpp #include "third_party/gemma_cpp/gemma.h"
#include "gemma.h"
// copybara:import_next_line:gemma_cpp
#include "util/app.h" // LoaderArgs #include "util/app.h" // LoaderArgs
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" #include "util/args.h"
std::vector<int> tokenize(const std::string& prompt_string, std::vector<int> tokenize(const std::string& prompt_string,

View File

@ -8,16 +8,13 @@
#include <vector> #include <vector>
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
// copybara:import_next_line:gemma_cpp #include "gemma/gemma.h"
#include "gemma.h" #include "util/app.h"
#include "util/args.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/timer.h" #include "hwy/timer.h"
// copybara:import_next_line:gemma_cpp
#include "util/app.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h"
using json = nlohmann::json; using json = nlohmann::json;
@ -61,8 +58,7 @@ void LogSpeedStats(const double time_start, size_t total_tokens) {
std::pair<std::string, int> QueryModel( std::pair<std::string, int> QueryModel(
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app, gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
gcpp::KVCache& kv_cache, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input) {
const std::string& input) {
std::vector<int> prompt; std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt)); HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
@ -93,7 +89,7 @@ std::pair<std::string, int> QueryModel(
} }
GenerateGemma(model, args.max_tokens, args.max_generated_tokens, GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool, args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
inner_pool, stream_token, accept_token, gen, app.verbosity); stream_token, accept_token, gen, app.verbosity);
if (app.verbosity >= 1) { if (app.verbosity >= 1) {
LogSpeedStats(time_start, total_tokens); LogSpeedStats(time_start, total_tokens);
} }
@ -134,8 +130,7 @@ std::string ReadFile(const gcpp::Path& path) {
int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args, int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
gcpp::AppArgs& app, gcpp::KVCache& kv_cache, gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& pool, const std::string& golden_path) {
const std::string& golden_path) {
const std::vector<std::pair<std::string, std::string>> queries_answers = const std::vector<std::pair<std::string, std::string>> queries_answers =
load_goldens(golden_path); load_goldens(golden_path);
int correct_answers = 0; int correct_answers = 0;
@ -143,7 +138,7 @@ int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
const double time_start = hwy::platform::Now(); const double time_start = hwy::platform::Now();
for (const auto& [question, expected_answer] : queries_answers) { for (const auto& [question, expected_answer] : queries_answers) {
const auto [answer, token_count] = const auto [answer, token_count] =
QueryModel(model, args, app, kv_cache, inner_pool, pool, question); QueryModel(model, args, app, kv_cache, pool, question);
total_tokens += token_count; total_tokens += token_count;
if (answer.find(expected_answer) != std::string::npos) { if (answer.find(expected_answer) != std::string::npos) {
correct_answers++; correct_answers++;
@ -167,14 +162,13 @@ int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args, int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args,
gcpp::AppArgs& app, gcpp::KVCache& kv_cache, gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& pool, const gcpp::Path& text) {
const gcpp::Path& text) {
std::string prompt("Here is some text to summarize:\n"); std::string prompt("Here is some text to summarize:\n");
prompt.append(ReadFile(text)); prompt.append(ReadFile(text));
prompt.append("\nSummarize this text.\n"); prompt.append("\nSummarize this text.\n");
const double time_start = hwy::platform::Now(); const double time_start = hwy::platform::Now();
const auto [answer, token_count] = const auto [answer, token_count] =
QueryModel(model, args, app, kv_cache, inner_pool, pool, prompt); QueryModel(model, args, app, kv_cache, pool, prompt);
std::cout << answer.substr(prompt.size()) << "\n" << std::flush; std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
LogSpeedStats(time_start, token_count); LogSpeedStats(time_start, token_count);
return EXIT_SUCCESS; return EXIT_SUCCESS;
@ -182,8 +176,8 @@ int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args,
int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type, int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
gcpp::InferenceArgs& args, gcpp::AppArgs& app, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& pool, const gcpp::Path& text,
const gcpp::Path& text, size_t batch_tokens) { size_t batch_tokens) {
std::string input = ReadFile(text); std::string input = ReadFile(text);
std::vector<int> prompt; std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt)); HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
@ -200,7 +194,7 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
auto kv_cache = CreateKVCache(model_type); auto kv_cache = CreateKVCache(model_type);
float entropy = float entropy =
ComputeCrossEntropy(model, num_tokens, prompt_slice, kv_cache, pool, ComputeCrossEntropy(model, num_tokens, prompt_slice, kv_cache, pool,
inner_pool, app.verbosity); app.verbosity);
total_entropy += entropy; total_entropy += entropy;
LogSpeedStats(time_start, pos + num_tokens); LogSpeedStats(time_start, pos + num_tokens);
std::string text_slice; std::string text_slice;
@ -214,8 +208,8 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args, int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
gcpp::AppArgs& app, gcpp::KVCache& kv_cache, gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& pool, const gcpp::Path& json_file,
const gcpp::Path& json_file, size_t max_questions) { size_t max_questions) {
std::ifstream trivia_file(json_file.path); std::ifstream trivia_file(json_file.path);
if (!trivia_file) { if (!trivia_file) {
std::cout << "Could not load file: " << json_file.path << "\n" std::cout << "Could not load file: " << json_file.path << "\n"
@ -228,7 +222,7 @@ int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
while (std::getline(trivia_file, line)) { while (std::getline(trivia_file, line)) {
json data = json::parse(line); json data = json::parse(line);
const auto [answer, token_count] = QueryModel( const auto [answer, token_count] = QueryModel(
model, args, app, kv_cache, inner_pool, pool, data["question"]); model, args, app, kv_cache, pool, data["question"]);
std::cout << answer << "\n"; std::cout << answer << "\n";
bool correct = false; bool correct = false;
for (const std::string expected : data["answer"]["aliases"]) { for (const std::string expected : data["answer"]["aliases"]) {
@ -266,7 +260,6 @@ int main(int argc, char** argv) {
HWY_ABORT("\nInvalid inference args: %s", error); HWY_ABORT("\nInvalid inference args: %s", error);
} }
hwy::ThreadPool inner_pool(0);
hwy::ThreadPool pool(app.num_threads); hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning threads to cores helps. // For many-core, pinning threads to cores helps.
if (app.num_threads > 10) { if (app.num_threads > 10) {
@ -283,17 +276,16 @@ int main(int argc, char** argv) {
if (!benchmark_args.goldens.path.empty()) { if (!benchmark_args.goldens.path.empty()) {
const std::string golden_path = const std::string golden_path =
benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt"; benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt";
return BenchmarkGoldens(model, args, app, kv_cache, inner_pool, pool, return BenchmarkGoldens(model, args, app, kv_cache, pool, golden_path);
golden_path);
} else if (!benchmark_args.summarize_text.path.empty()) { } else if (!benchmark_args.summarize_text.path.empty()) {
return BenchmarkSummary(model, args, app, kv_cache, inner_pool, pool, return BenchmarkSummary(model, args, app, kv_cache, pool,
benchmark_args.summarize_text); benchmark_args.summarize_text);
} else if (!benchmark_args.cross_entropy.path.empty()) { } else if (!benchmark_args.cross_entropy.path.empty()) {
return BenchmarkCrossEntropy(model, loader.ModelType(), args, app, return BenchmarkCrossEntropy(model, loader.ModelType(), args, app,
inner_pool, pool, benchmark_args.cross_entropy, pool, benchmark_args.cross_entropy,
benchmark_args.batch_tokens); benchmark_args.batch_tokens);
} else if (!benchmark_args.trivia_qa.path.empty()) { } else if (!benchmark_args.trivia_qa.path.empty()) {
return BenchmarkTriviaQA(model, args, app, kv_cache, inner_pool, pool, return BenchmarkTriviaQA(model, args, app, kv_cache, pool,
benchmark_args.trivia_qa, benchmark_args.trivia_qa,
benchmark_args.max_questions); benchmark_args.max_questions);
} }

View File

@ -18,12 +18,8 @@
#include <iostream> #include <iostream>
#include <string> #include <string>
// copybara:import_next_line:gemma_cpp #include "gemma/gemma.h" // Gemma
#include "gemma.h" // Gemma
// copybara:end
// copybara:import_next_line:gemma_cpp
#include "util/args.h" #include "util/args.h"
// copybara:end
namespace gcpp { namespace gcpp {
@ -59,7 +55,7 @@ struct Args : public ArgsBase<Args> {
return "Missing --compressed_weights flag, a file for the compressed " return "Missing --compressed_weights flag, a file for the compressed "
"model."; "model.";
} }
if (!weights.exists()) { if (!weights.Exists()) {
return "Can't open file specified with --weights flag."; return "Can't open file specified with --weights flag.";
} }
return nullptr; return nullptr;
@ -74,7 +70,7 @@ struct Args : public ArgsBase<Args> {
template <class Visitor> template <class Visitor>
void ForEach(const Visitor& visitor) { void ForEach(const Visitor& visitor) {
visitor(weights, "weights", Path(), visitor(weights, "weights", Path(),
"Path name of model weights (.sbs) file.\n" "Path to model weights (.bin) file.\n"
" Required argument."); " Required argument.");
visitor(model_type_str, "model", std::string(), visitor(model_type_str, "model", std::string(),
"Model type\n 2b-it = 2B parameters, instruction-tuned\n " "Model type\n 2b-it = 2B parameters, instruction-tuned\n "
@ -84,7 +80,7 @@ struct Args : public ArgsBase<Args> {
"gr2b-pt = griffin 2B parameters, pretrained\n " "gr2b-pt = griffin 2B parameters, pretrained\n "
" Required argument."); " Required argument.");
visitor(compressed_weights, "compressed_weights", Path(), visitor(compressed_weights, "compressed_weights", Path(),
"Path name where compressed weights file will be written.\n" "Path name where compressed weights (.sbs) file will be written.\n"
" Required argument."); " Required argument.");
visitor(num_threads, "num_threads", visitor(num_threads, "num_threads",
kDefaultNumThreads, // see ChooseNumThreads kDefaultNumThreads, // see ChooseNumThreads

View File

@ -15,8 +15,8 @@
// Model configurations // Model configurations
#ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_
#define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_
// Allow changing pre-allocated kv cache size as a compiler flag // Allow changing pre-allocated kv cache size as a compiler flag
#ifndef GEMMA_MAX_SEQLEN #ifndef GEMMA_MAX_SEQLEN
@ -28,11 +28,15 @@
#define GEMMA_TOPK 1 #define GEMMA_TOPK 1
#endif // !GEMMA_TOPK #endif // !GEMMA_TOPK
// Allow changing upper bound on threads as a compiler flag
#ifndef GEMMA_MAX_THREADS
#define GEMMA_MAX_THREADS 128
#endif // !GEMMA_MAX_THREADS
#include <stddef.h> #include <stddef.h>
#include <array> #include <array>
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h" #include "compression/sfp.h"
#include "hwy/base.h" // hwy::bfloat16_t #include "hwy/base.h" // hwy::bfloat16_t
@ -46,6 +50,7 @@ namespace gcpp {
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN; static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
static constexpr size_t kTopK = GEMMA_TOPK; static constexpr size_t kTopK = GEMMA_TOPK;
static constexpr size_t kMaxThreads = GEMMA_MAX_THREADS;
enum class LayerAttentionType { enum class LayerAttentionType {
kGemma, kGemma,
@ -62,18 +67,36 @@ constexpr std::array<LayerAttentionType, kNum> FixedLayerConfig(
return config; return config;
} }
template <size_t kNumLayers>
constexpr size_t NumLayersOfTypeBefore(
const std::array<LayerAttentionType, kNumLayers>& layers,
LayerAttentionType type, size_t num) {
size_t count = 0;
for (size_t i = 0; i < num; i++) {
if (layers[i] == type) count++;
}
return count;
}
struct ConfigGemma7B { struct ConfigGemma7B {
static constexpr int kSeqLen = gcpp::kSeqLen; static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256000; static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 28> kLayerConfig = static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
FixedLayerConfig<28>(LayerAttentionType::kGemma); FixedLayerConfig<28>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size(); static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers =
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
static constexpr int kGriffinLayers =
NumLayersOfTypeBefore(kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
kLayers);
static constexpr int kModelDim = 3072; static constexpr int kModelDim = 3072;
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
static constexpr int kHeads = 16; static constexpr int kHeads = 16;
static constexpr int kKVHeads = 16; // standard MHA static constexpr int kKVHeads = 16; // standard MHA
static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK; static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
// SSM config. // SSM config.
static constexpr int kConv1dWidth = 0; static constexpr int kConv1dWidth = 0;
@ -92,12 +115,19 @@ struct ConfigGemma2B {
static constexpr std::array<LayerAttentionType, 18> kLayerConfig = static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
FixedLayerConfig<18>(LayerAttentionType::kGemma); FixedLayerConfig<18>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size(); static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers =
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
static constexpr int kGriffinLayers =
NumLayersOfTypeBefore(kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
kLayers);
static constexpr int kModelDim = 2048; static constexpr int kModelDim = 2048;
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
static constexpr int kHeads = 8; static constexpr int kHeads = 8;
static constexpr int kKVHeads = 1; static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK; static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
// SSM config. // SSM config.
static constexpr int kConv1dWidth = 0; static constexpr int kConv1dWidth = 0;
@ -144,12 +174,19 @@ struct ConfigGriffin2B {
LayerAttentionType::kGriffinRecurrentBlock, LayerAttentionType::kGriffinRecurrentBlock,
}; };
static constexpr int kLayers = kLayerConfig.size(); static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers =
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
static constexpr int kGriffinLayers =
NumLayersOfTypeBefore(kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
kLayers);
static constexpr int kModelDim = 2560; static constexpr int kModelDim = 2560;
static constexpr int kFFHiddenDim = 7680; static constexpr int kFFHiddenDim = 7680;
static constexpr int kHeads = 10; static constexpr int kHeads = 10;
static constexpr int kKVHeads = 1; static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK; static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
// SSM config. // SSM config.
static constexpr int kConv1dWidth = 4; static constexpr int kConv1dWidth = 4;
@ -164,4 +201,4 @@ struct ConfigGriffin2B {
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_

File diff suppressed because it is too large Load Diff

View File

@ -13,42 +13,42 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <random> #include <random>
#include <string>
#include <vector> #include <vector>
// copybara:import_next_line:gemma_cpp #include "compression/io.h" // Path
#include "compression/compress.h" // SfpStream/NuqStream #include "gemma/configs.h"
// copybara:import_next_line:gemma_cpp
#include "configs.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::bfloat16_t #include "hwy/base.h" // hwy::bfloat16_t
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // Path
// copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h"
namespace gcpp { namespace gcpp {
using GemmaWeightT = GEMMA_WEIGHT_T; using GemmaWeightT = GEMMA_WEIGHT_T;
using EmbedderInputT = hwy::bfloat16_t; using EmbedderInputT = hwy::bfloat16_t;
// Will be called for layers output with:
// - position in the tokens sequence
// - name of the data, p.ex. "tokens", "block.1", "final_norm"
// - pointer to the data array
// - size of the data array
using LayersOutputT =
std::function<void(int, const std::string&, const float*, size_t)>;
constexpr size_t kPrefillBatchSize = 16; constexpr size_t kPrefillBatchSize = 16;
constexpr bool kSystemPrompt = false; constexpr bool kSystemPrompt = false;
struct KVCache { struct KVCache {
hwy::AlignedFreeUniquePtr<float[]> hwy::AlignedFreeUniquePtr<float[]>
key_cache; // kSeqLen * kNumGemmaLayers * kKVHeads * kQKVDim kv_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim * 2
hwy::AlignedFreeUniquePtr<float[]> hwy::AlignedFreeUniquePtr<float[]>
value_cache; // kSeqLen * kNumGemmaLayers * kKVHeads * kQKVDim conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kGriffinLayers
hwy::AlignedFreeUniquePtr<float[]> hwy::AlignedFreeUniquePtr<float[]>
conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kNumGriffinLayers rglru_cache; // kModelDim * kGriffinLayers
hwy::AlignedFreeUniquePtr<float[]>
rglru_cache; // kModelDim * kNumGriffinLayers
}; };
// Model variants: see configs.h for details. // Model variants: see configs.h for details.
@ -71,6 +71,7 @@ struct GemmaInterface;
class GemmaTokenizer { class GemmaTokenizer {
public: public:
virtual ~GemmaTokenizer() = default;
virtual bool Encode(const std::string& input, virtual bool Encode(const std::string& input,
std::vector<std::string>* pieces) const = 0; std::vector<std::string>* pieces) const = 0;
virtual bool Encode(const std::string& input, virtual bool Encode(const std::string& input,
@ -82,7 +83,7 @@ class GemmaTokenizer {
struct Gemma { struct Gemma {
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
hwy::ThreadPool& pool); hwy::ThreadPool& pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined. ~Gemma(); // must be defined after the GemmaInterface dtor is defined.
const GemmaTokenizer* Tokenizer() const; const GemmaTokenizer* Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_; std::unique_ptr<GemmaInterface> impl_;
}; };
@ -96,16 +97,17 @@ KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
using StreamFunc = std::function<bool(int, float)>; using StreamFunc = std::function<bool(int, float)>;
using AcceptFunc = std::function<bool(int)>; using AcceptFunc = std::function<bool(int)>;
// layers_output is optional; if set - it will be called with the activations
// output after applying each layer.
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt, float temperature, const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity); int verbosity, LayersOutputT* layers_output = nullptr);
// Convenience function for the common case: // Convenience function for the common case:
// - Bundle runtime parameters as RuntimeConfig // - Bundle runtime parameters as RuntimeConfig
// - No threadpools within threadpools (inner_pool = dummy)
// - All tokens accepted // - All tokens accepted
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
@ -117,11 +119,10 @@ void CompressWeights(gcpp::Model model, const Path& weights,
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
const std::vector<int>& prompt, KVCache& kv_cache, const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, int verbosity);
int verbosity);
constexpr int EOS_ID = 1; constexpr int EOS_ID = 1;
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_

View File

@ -13,14 +13,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// copybara:import_next_line:gemma_cpp #include "gemma/gemma.h"
#include "gemma.h"
#include <thread> #include <algorithm>
#include <iostream>
#include <random>
#include <string>
#include <thread> // NOLINT
#include <vector>
// copybara:import_next_line:gemma_cpp #include "gemma/ops.h"
#include "ops.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" #include "util/args.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/tests/test_util-inl.h" #include "hwy/tests/test_util-inl.h"
@ -34,7 +36,6 @@ class GemmaTest : public ::testing::Test {
: weights("./2b-it-mqa.sbs"), : weights("./2b-it-mqa.sbs"),
tokenizer("./tokenizer.spm"), tokenizer("./tokenizer.spm"),
pool(std::min<int>(20, (std::thread::hardware_concurrency() - 1) / 2)), pool(std::min<int>(20, (std::thread::hardware_concurrency() - 1) / 2)),
inner_pool(0),
model_type(gcpp::Model::GEMMA_2B), model_type(gcpp::Model::GEMMA_2B),
model(tokenizer, weights, model_type, pool) { model(tokenizer, weights, model_type, pool) {
kv_cache = CreateKVCache(model_type); kv_cache = CreateKVCache(model_type);
@ -58,8 +59,8 @@ class GemmaTest : public ::testing::Test {
gcpp::GenerateGemma( gcpp::GenerateGemma(
model, /*max_tokens=*/3072, /*max_generated_tokens=*/2048, model, /*max_tokens=*/3072, /*max_generated_tokens=*/2048,
/*temperature=*/1.0, prompt, /*start_pos=*/0, kv_cache, pool, /*temperature=*/1.0, prompt, /*start_pos=*/0, kv_cache, pool,
inner_pool, stream_token, stream_token, /*accept=*/[](int) { return true; }, gen,
/*accept=*/[](int) { return true; }, gen, /*verbosity=*/0); /*verbosity=*/0);
std::string response_text; std::string response_text;
HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text)); HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text));
return response_text; return response_text;
@ -69,8 +70,7 @@ class GemmaTest : public ::testing::Test {
std::vector<int> prompt; std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt)); HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
return gcpp::ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt, return gcpp::ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt,
kv_cache, pool, inner_pool, kv_cache, pool, /*verbosity=*/0) /
/*verbosity=*/0) /
prompt_string.size(); prompt_string.size();
} }
@ -79,7 +79,7 @@ class GemmaTest : public ::testing::Test {
std::cout << "Question " << i + 1 << "\n\n"; std::cout << "Question " << i + 1 << "\n\n";
std::string response = GemmaReply(kQA[i][0]); std::string response = GemmaReply(kQA[i][0]);
std::cout << response << "\n\n"; std::cout << response << "\n\n";
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
} }
} }
@ -87,7 +87,6 @@ class GemmaTest : public ::testing::Test {
gcpp::Path tokenizer; gcpp::Path tokenizer;
gcpp::KVCache kv_cache; gcpp::KVCache kv_cache;
hwy::ThreadPool pool; hwy::ThreadPool pool;
hwy::ThreadPool inner_pool;
gcpp::Model model_type = {}; gcpp::Model model_type = {};
gcpp::Gemma model; gcpp::Gemma model;
}; };

View File

@ -14,8 +14,9 @@
// limitations under the License. // limitations under the License.
// Include guard for non-SIMD code. // Include guard for non-SIMD code.
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_H_ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
#define THIRD_PARTY_GEMMA_CPP_OPS_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
@ -24,6 +25,7 @@
#include <random> #include <random>
#include <type_traits> // std::enable_if_t #include <type_traits> // std::enable_if_t
#include "compression/compress.h" // CompressedArray
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
@ -43,7 +45,7 @@ static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_OPS_H_ #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
// Include guard for (potentially) SIMD code. // Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE) == defined(HWY_TARGET_TOGGLE) #if defined(THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE) == defined(HWY_TARGET_TOGGLE)
@ -53,7 +55,6 @@ static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
#define THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE #define THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE
#endif #endif
// copybara:import_next_line:gemma_cpp
#include "compression/compress-inl.h" #include "compression/compress-inl.h"
#include "hwy/contrib/algo/transform-inl.h" #include "hwy/contrib/algo/transform-inl.h"
#include "hwy/contrib/dot/dot-inl.h" #include "hwy/contrib/dot/dot-inl.h"
@ -92,12 +93,60 @@ HWY_INLINE constexpr size_t RowsPerStrip() {
return kRowsPerStrip; return kRowsPerStrip;
} }
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
// ops_test across instruction sets.
template <size_t kM, size_t kN, size_t kK>
HWY_INLINE void MatMul(const float* HWY_RESTRICT a, const float* HWY_RESTRICT b,
float* HWY_RESTRICT out) {
int i, j, k;
for (i = 0; i < kM; ++i) {
for (k = 0; k < kN; ++k) {
for (j = 0; j < kK; ++j) {
out[i * kK + j] += a[i * kN + k] * b[k * kK + j];
}
}
}
}
HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
const size_t size, float* HWY_RESTRICT out) {
const hn::ScalableTag<float> df;
const hn::Repartition<hwy::bfloat16_t, decltype(df)> dbf16;
HWY_DASSERT(size % hn::Lanes(dbf16) == 0);
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
for (size_t i = 0; i < size; i += hn::Lanes(dbf16)) {
const auto interleaved = hn::LoadU(dbf16, vec_aligned + i);
hn::Store(hn::PromoteEvenTo(df, interleaved), df, out + i);
hn::Store(hn::PromoteOddTo(df, interleaved), df, out + i + hn::Lanes(df));
}
}
HWY_INLINE void ToEvenOddF32(const float* HWY_RESTRICT vec_aligned,
const size_t size, float* HWY_RESTRICT out) {
const hn::ScalableTag<float> df;
using VF = hn::Vec<decltype(df)>;
HWY_DASSERT(size % (hn::Lanes(df) * 2) == 0);
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
VF vec0, vec1;
for (size_t i = 0; i < size; i += hn::Lanes(df) * 2) {
hn::LoadInterleaved2(df, vec_aligned + i, vec0, vec1);
hn::Store(vec0, df, out + i);
hn::Store(vec1, df, out + i + hn::Lanes(df));
}
}
// Simple version without tiling nor threading. // Simple version without tiling nor threading.
// even_odd is precomputed for the current thread.
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT, template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
typename VecT, typename AddT> typename VecT, typename AddT>
HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs, HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT vec_aligned, const VecT* HWY_RESTRICT vec_aligned,
const AddT* HWY_RESTRICT add, const AddT* HWY_RESTRICT add,
float* HWY_RESTRICT even_odd,
float* HWY_RESTRICT out) { float* HWY_RESTRICT out) {
PROFILER_ZONE("MatVecAddLoop"); PROFILER_ZONE("MatVecAddLoop");
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
@ -113,12 +162,40 @@ HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs,
} }
} }
#if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16
template <bool kAdd, size_t kOuter, size_t kInner, typename VecT, typename AddT,
size_t kCapacity>
HWY_INLINE void MatVecAddLoop(
const CompressedArray<hwy::bfloat16_t, kCapacity>& mat,
const size_t mat_ofs, const VecT* HWY_RESTRICT vec_aligned,
const AddT* HWY_RESTRICT add, float* HWY_RESTRICT even_odd,
float* HWY_RESTRICT out) {
PROFILER_ZONE("MatVecAddLoop");
constexpr bool kVecIsEvenOdd = true;
const hn::ScalableTag<float> df;
ToEvenOddF32(vec_aligned, kInner, even_odd);
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) {
const size_t row_ofs = mat_ofs + idx_row * kInner;
if constexpr (kAdd) {
out[idx_row] = hwy::ConvertScalarTo<float>(add[idx_row]) +
Dot<kVecIsEvenOdd>(df, mat, row_ofs, even_odd, kInner);
} else {
out[idx_row] = Dot<kVecIsEvenOdd>(df, mat, row_ofs, even_odd, kInner);
}
}
}
#endif
// even_odd is precomputed for the current thread.
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT> template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs, HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT vec_aligned, const VecT* HWY_RESTRICT vec_aligned,
float* HWY_RESTRICT even_odd,
float* HWY_RESTRICT out) { float* HWY_RESTRICT out) {
MatVecAddLoop<false, kOuter, kInner, ArrayT, VecT, VecT>( MatVecAddLoop</*kAdd=*/false, kOuter, kInner>(
mat, mat_ofs, vec_aligned, /*add=*/nullptr, out); mat, mat_ofs, vec_aligned, /*add=*/static_cast<VecT*>(nullptr), even_odd,
out);
} }
// Simple version without tiling nor threading, but two offsets/outputs. // Simple version without tiling nor threading, but two offsets/outputs.
@ -156,7 +233,7 @@ HWY_INLINE void TwoOfsMatVecLoop(const ArrayT& mat, const size_t mat_ofs0,
const VecT* HWY_RESTRICT vec_aligned, const VecT* HWY_RESTRICT vec_aligned,
float* HWY_RESTRICT out0, float* HWY_RESTRICT out0,
float* HWY_RESTRICT out1) { float* HWY_RESTRICT out1) {
TwoOfsMatVecAddLoop<false, kOuter, kInner, ArrayT, VecT, VecT>( TwoOfsMatVecAddLoop</*kAdd=*/false, kOuter, kInner, ArrayT, VecT, VecT>(
mat, mat_ofs0, mat_ofs1, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr, mat, mat_ofs0, mat_ofs1, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr,
out0, out1); out0, out1);
} }
@ -166,20 +243,23 @@ namespace detail {
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product // For each i = [0, num_rows), compute partial (length `num_cols`) dot product
// of row i with `vec_aligned` and add into `out[i]`. The upper-left coordinate // of row i with `vec_aligned` and add into `out[i]`. The upper-left coordinate
// of the tile is r0, c0. // of the tile is r0, c0.
template <class DF, typename ArrayT, typename VecT> template <bool kVecEO, class DF, typename ArrayT, typename VecT>
HWY_INLINE void AccumulatePartialDotProducts( HWY_INLINE void AccumulatePartialDotProducts(
DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0, DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0,
size_t c0, size_t num_rows, size_t num_cols, size_t c0, size_t num_rows, size_t num_cols,
const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) { const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) {
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) { for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride; const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
out[idx_row] += Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); out[idx_row] +=
Dot<kVecEO>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
} }
} }
// Same as above, but sets out[i] to the first partial dot product + // Same as AccumulatePartialDotProducts, but sets out[i] to the first partial
// init (if kInit), which avoids having to zero-initialize and accumulate. // dot product + init (if kInit), which avoids having to zero-initialize and
template <bool kInit, class DF, typename ArrayT, typename VecT, typename InitT> // accumulate.
template <bool kVecEO, bool kInit, class DF, typename ArrayT, typename VecT,
typename InitT>
HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat, HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
size_t mat_ofs, size_t mat_stride, size_t mat_ofs, size_t mat_stride,
size_t r0, size_t c0, size_t r0, size_t c0,
@ -190,10 +270,12 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) { for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride; const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
if constexpr (kInit) { if constexpr (kInit) {
out[idx_row] = hwy::ConvertScalarTo<float>(init[idx_row + r0]) + out[idx_row] =
Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); hwy::ConvertScalarTo<float>(init[idx_row + r0]) +
Dot<kVecEO>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
} else { } else {
out[idx_row] = Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); out[idx_row] =
Dot<kVecEO>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
} }
} }
} }
@ -202,7 +284,8 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
// horizontal strip of the entire matrix); the result is the full dot product // horizontal strip of the entire matrix); the result is the full dot product
// for rows r in [r0, r0 + num_rows) + optionally the add vector, which we store // for rows r in [r0, r0 + num_rows) + optionally the add vector, which we store
// into in out[r - r0]. // into in out[r - r0].
template <bool kAdd, class DF, typename ArrayT, typename VecT, typename AddT> template <bool kVecEO, bool kAdd, class DF, typename ArrayT, typename VecT,
typename AddT>
HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat, HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
size_t mat_ofs, size_t mat_stride, size_t mat_ofs, size_t mat_stride,
size_t r0, size_t num_rows, size_t r0, size_t num_rows,
@ -211,25 +294,66 @@ HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
float* HWY_RESTRICT out) { float* HWY_RESTRICT out) {
// Tall and skinny: set `out` to the single dot product. // Tall and skinny: set `out` to the single dot product.
if (mat_stride < MaxCols()) { if (mat_stride < MaxCols()) {
SetFirstPartialDotProducts<kAdd>(df, mat, mat_ofs, mat_stride, r0, 0, SetFirstPartialDotProducts<kVecEO, kAdd>(df, mat, mat_ofs, mat_stride, r0,
num_rows, mat_stride, vec_aligned, add, 0, num_rows, mat_stride,
out); vec_aligned, add, out);
return; return;
} }
// We have at least MaxCols, so start by setting `out` to that: // We have at least MaxCols, so start by setting `out` to that:
SetFirstPartialDotProducts<kAdd>(df, mat, mat_ofs, mat_stride, r0, 0, SetFirstPartialDotProducts<kVecEO, kAdd>(df, mat, mat_ofs, mat_stride, r0, 0,
num_rows, MaxCols(), vec_aligned, add, out); num_rows, MaxCols(), vec_aligned,
add, out);
// For further multiples of MaxCols, accumulate. Remainders handled below. // For further multiples of MaxCols, accumulate. Remainders handled below.
size_t c0 = MaxCols(); size_t c0 = MaxCols();
for (; c0 <= mat_stride - MaxCols(); c0 += MaxCols()) { for (; c0 <= mat_stride - MaxCols(); c0 += MaxCols()) {
AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows, AccumulatePartialDotProducts<kVecEO>(df, mat, mat_ofs, mat_stride, r0, c0,
MaxCols(), vec_aligned, out); num_rows, MaxCols(), vec_aligned, out);
} }
if (c0 < mat_stride) { // Final cols if (c0 < mat_stride) { // Final cols
AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows, AccumulatePartialDotProducts<kVecEO>(df, mat, mat_ofs, mat_stride, r0, c0,
mat_stride - c0, vec_aligned, out); num_rows, mat_stride - c0, vec_aligned,
out);
}
}
template <bool kVecIsEvenOdd, bool kAdd, size_t kOuter, size_t kInner,
typename ArrayT, typename VecT, typename AddT>
HWY_INLINE void MatVecAddInner(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT const vec_aligned,
const AddT* HWY_RESTRICT const add,
float* HWY_RESTRICT even_odd,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
const hn::ScalableTag<float> df;
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
// Sanity check: each thread can write without race conditions.
if (HWY_IS_TSAN) {
pool.Run(
0, pool.NumWorkers(), [even_odd](uint64_t /*task*/, size_t thread) {
even_odd[thread * kInner] = -static_cast<float>(thread);
even_odd[thread * kInner + kInner - 1] = static_cast<float>(thread);
});
}
// For each entire strip.
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
PROFILER_ZONE("MatVec.lambda");
const size_t r0 = strip * kRowsPerStrip;
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add,
out + r0);
});
// Remaining rows
const size_t r0 = kNumStrips * kRowsPerStrip;
if (r0 < kOuter) {
PROFILER_ZONE("MatVec remainder");
const size_t num_rows = kOuter - r0;
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
df, mat, mat_ofs, kInner, r0, num_rows, vec_aligned, add, out + r0);
} }
} }
@ -243,38 +367,32 @@ template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT const vec_aligned, const VecT* HWY_RESTRICT const vec_aligned,
const AddT* HWY_RESTRICT const add, const AddT* HWY_RESTRICT const add,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) { float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out,
hwy::ThreadPool& pool) {
PROFILER_ZONE("MatVecAdd"); PROFILER_ZONE("MatVecAdd");
const hn::ScalableTag<float> df; #if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>(); if constexpr (CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd &&
constexpr size_t kNumStrips = kOuter / kRowsPerStrip; hwy::IsSameEither<VecT, float, hwy::bfloat16_t>()) {
ToEvenOddF32(vec_aligned, kInner, even_odd);
// For each entire strip. detail::MatVecAddInner</*kVecIsEvenOdd=*/true, kAdd, kOuter, kInner>(
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { mat, mat_ofs, even_odd, add, even_odd, out, pool);
PROFILER_ZONE("MatVec.lambda"); return;
const size_t r0 = strip * kRowsPerStrip;
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, kInner, r0,
kRowsPerStrip, vec_aligned, add,
out + r0);
});
// Remaining rows
const size_t r0 = kNumStrips * kRowsPerStrip;
if (r0 < kOuter) {
PROFILER_ZONE("MatVec remainder");
const size_t num_rows = kOuter - r0;
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, kInner, r0,
num_rows, vec_aligned, add, out + r0);
} }
#endif
detail::MatVecAddInner</*kVecIsEvenOdd=*/false, kAdd, kOuter, kInner>(
mat, mat_ofs, vec_aligned, add, even_odd, out, pool);
} }
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT> template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs, HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT const vec_aligned, const VecT* HWY_RESTRICT const vec_aligned,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) { float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out,
MatVecAdd<false, kOuter, kInner, ArrayT, VecT, VecT>( hwy::ThreadPool& pool) {
mat, mat_ofs, vec_aligned, /*add=*/nullptr, out, pool); MatVecAdd</*kAdd=*/false, kOuter, kInner>(mat, mat_ofs, vec_aligned,
/*add=*/static_cast<VecT*>(nullptr),
even_odd, out, pool);
} }
template <class D, HWY_IF_F32_D(D)> template <class D, HWY_IF_F32_D(D)>
@ -396,16 +514,17 @@ HWY_NOINLINE void TwoMatVecAdd(
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>(); constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
constexpr size_t kNumStrips = kOuter / kRowsPerStrip; constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
constexpr bool kVecIsEvenOdd = false;
// For each entire strip. // For each entire strip.
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
PROFILER_ZONE("TwoMatVec.lambda"); PROFILER_ZONE("TwoMatVec.lambda");
const size_t r0 = strip * kRowsPerStrip; const size_t r0 = strip * kRowsPerStrip;
detail::FullDotProductsForStrip<kAdd>(df, mat0, mat_ofs, kInner, r0, detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
kRowsPerStrip, vec_aligned, add0, df, mat0, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add0,
out0 + r0); out0 + r0);
detail::FullDotProductsForStrip<kAdd>(df, mat1, mat_ofs, kInner, r0, detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
kRowsPerStrip, vec_aligned, add1, df, mat1, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add1,
out1 + r0); out1 + r0);
}); });
@ -414,9 +533,9 @@ HWY_NOINLINE void TwoMatVecAdd(
if (r0 < kOuter) { if (r0 < kOuter) {
PROFILER_ZONE("TwoMatVec remainder"); PROFILER_ZONE("TwoMatVec remainder");
const size_t num_rows = kOuter - r0; const size_t num_rows = kOuter - r0;
detail::FullDotProductsForStrip<kAdd>( detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
df, mat0, mat_ofs, kInner, r0, num_rows, vec_aligned, add0, out0 + r0); df, mat0, mat_ofs, kInner, r0, num_rows, vec_aligned, add0, out0 + r0);
detail::FullDotProductsForStrip<kAdd>( detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
df, mat1, mat_ofs, kInner, r0, num_rows, vec_aligned, add1, out1 + r0); df, mat1, mat_ofs, kInner, r0, num_rows, vec_aligned, add1, out1 + r0);
} }
} }
@ -427,7 +546,7 @@ HWY_NOINLINE void TwoMatVec(const ArrayT& mat0, const ArrayT& mat1,
const VecT* HWY_RESTRICT vec_aligned, const VecT* HWY_RESTRICT vec_aligned,
float* HWY_RESTRICT out0, float* HWY_RESTRICT out1, float* HWY_RESTRICT out0, float* HWY_RESTRICT out1,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
TwoMatVecAdd<false, kOuter, kInner, ArrayT, VecT, VecT>( TwoMatVecAdd</*kAdd=*/false, kOuter, kInner, ArrayT, VecT, VecT>(
mat0, mat1, mat_ofs, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr, mat0, mat1, mat_ofs, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr,
out0, out1, pool); out0, out1, pool);
} }

View File

@ -17,22 +17,27 @@
#define HWY_DISABLED_TARGETS HWY_SCALAR #define HWY_DISABLED_TARGETS HWY_SCALAR
#endif #endif
#include <stddef.h>
#include <algorithm>
#include <array> #include <array>
#include <random> #include <random>
#include <vector>
#include "compression/compress.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// clang-format off // clang-format off
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "ops_test.cc" //NOLINT #define HWY_TARGET_INCLUDE "gemma/ops_test.cc" //NOLINT
// clang-format on // clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/tests/test_util-inl.h" #include "hwy/tests/test_util-inl.h"
// After highway.h // After highway.h
// copybara:import_next_line:gemma_cpp #include "gemma/ops.h"
#include "ops.h"
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
@ -373,15 +378,54 @@ CompressedArray<float, kOuter * kInner> GenerateMat(size_t offset) {
return mat; return mat;
} }
template <size_t kOuter, size_t kInner>
CompressedArray<float, kOuter * kInner> GenerateZeroMat(size_t offset) {
hwy::ThreadPool pool(static_cast<size_t>(std::clamp(
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 4)));
gcpp::CompressWorkingSet ws;
CompressedArray<float, kOuter * kInner> mat;
std::array<float, kOuter * kInner> content;
pool.Run(0, kOuter, [&](const size_t i, size_t thread) {
for (size_t j = 0; j < kInner; j++) {
content[i * kInner + j] = 0.0f;
}
});
Compress(content, ws, mat, pool);
mat.set_scale(1.0f);
return mat;
}
template <size_t length> template <size_t length>
hwy::AlignedFreeUniquePtr<float[]> GenerateVec(size_t offset) { hwy::AlignedFreeUniquePtr<float[]> GenerateVec(size_t offset) {
hwy::AlignedFreeUniquePtr<float[]> vec = hwy::AllocateAligned<float>(length); hwy::AlignedFreeUniquePtr<float[]> vec = hwy::AllocateAligned<float>(length);
HWY_ASSERT(vec);
for (size_t idx = 0; idx < length; idx++) { for (size_t idx = 0; idx < length; idx++) {
vec[idx] = static_cast<float>(idx + offset); vec[idx] = static_cast<float>(idx + offset);
} }
return vec; return vec;
} }
// A simple matrix multiplication. No optimization / tiling.
template <size_t kM, size_t kN, size_t kK>
hwy::AlignedFreeUniquePtr<float[]> SimpleMatMul(
const hwy::AlignedFreeUniquePtr<float[]>& a,
const hwy::AlignedFreeUniquePtr<float[]>& b) {
hwy::AlignedFreeUniquePtr<float[]> out = hwy::AllocateAligned<float>(kM * kK);
hwy::ZeroBytes(out.get(), kM * kK * sizeof(float));
int i, j, k;
for (i = 0; i < kM; ++i) {
for (j = 0; j < kK; ++j) {
for (k = 0; k < kN; ++k) {
out[i * kK + j] += a[i * kN + k] * b[k * kK + j];
}
}
}
return out;
}
template <size_t kOuter, size_t kInner> template <size_t kOuter, size_t kInner>
hwy::AlignedFreeUniquePtr<float[]> SimpleMatVecAdd( hwy::AlignedFreeUniquePtr<float[]> SimpleMatVecAdd(
const CompressedArray<float, kOuter * kInner>& mat, const CompressedArray<float, kOuter * kInner>& mat,
@ -389,8 +433,9 @@ hwy::AlignedFreeUniquePtr<float[]> SimpleMatVecAdd(
const hwy::AlignedFreeUniquePtr<float[]>& add) { const hwy::AlignedFreeUniquePtr<float[]>& add) {
hwy::AlignedFreeUniquePtr<float[]> uncompressed_mat = hwy::AlignedFreeUniquePtr<float[]> uncompressed_mat =
hwy::AllocateAligned<float>(kOuter * kInner); hwy::AllocateAligned<float>(kOuter * kInner);
Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner);
hwy::AlignedFreeUniquePtr<float[]> out = hwy::AllocateAligned<float>(kOuter); hwy::AlignedFreeUniquePtr<float[]> out = hwy::AllocateAligned<float>(kOuter);
HWY_ASSERT(uncompressed_mat && out);
Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner);
for (size_t idx_row = 0; idx_row < kOuter; idx_row++) { for (size_t idx_row = 0; idx_row < kOuter; idx_row++) {
out[idx_row] = add[idx_row]; out[idx_row] = add[idx_row];
for (size_t idx_col = 0; idx_col < kInner; idx_col++) { for (size_t idx_col = 0; idx_col < kInner; idx_col++) {
@ -412,6 +457,52 @@ void AssertClose(const hwy::AlignedFreeUniquePtr<float[]>& a,
} }
} }
template <typename MatT>
void AssertClose(const hwy::AlignedFreeUniquePtr<MatT[]>& expected,
const hwy::AlignedFreeUniquePtr<MatT[]>& actual, size_t num) {
for (size_t idx = 0; idx < num; idx++) {
double expected_value = hwy::ConvertScalarTo<double>(expected[idx]);
double actual_value = hwy::ConvertScalarTo<double>(actual[idx]);
const double tolerance =
expected_value * 20 * 1.0 / (1ULL << hwy::MantissaBits<MatT>());
if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) {
fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f\n", idx,
expected_value, idx, actual_value);
HWY_ASSERT(0);
}
}
}
void TestMatMul() {
hwy::ThreadPool pool(0);
constexpr size_t kM = 128 * 3; // 384
constexpr size_t kK = 128 * 5; // 640
constexpr size_t kN = 128 * 6; // 768
CompressedArray<float, kM * kN> a1 = GenerateMat<kM, kN>(0);
CompressedArray<float, kN * kK> b1 = GenerateMat<kN, kK>(0);
hwy::AlignedFreeUniquePtr<float[]> a = hwy::AllocateAligned<float>(kM * kN);
Decompress(a1, 0, a.get(), kM * kN);
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kN * kK);
Decompress(b1, 0, b.get(), kN * kK);
hwy::AlignedFreeUniquePtr<float[]> expected_out1 =
SimpleMatMul<kM, kN, kK>(a, b);
CompressedArray<float, kM * kK> compressed_c = GenerateZeroMat<kM, kK>(0);
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
Decompress(compressed_c, 0, c.get(), kM * kK);
MatMul<kM, kN, kK>(a.get(), b.get(), c.get());
AssertClose(expected_out1, c, kM * kK);
}
void TestMatVecAdd() { void TestMatVecAdd() {
hwy::ThreadPool pool(0); hwy::ThreadPool pool(0);
constexpr size_t kOuter = 128 * 3; constexpr size_t kOuter = 128 * 3;
@ -419,27 +510,15 @@ void TestMatVecAdd() {
CompressedArray<float, kOuter * kInner> mat = GenerateMat<kOuter, kInner>(0); CompressedArray<float, kOuter * kInner> mat = GenerateMat<kOuter, kInner>(0);
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0); hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
hwy::AlignedFreeUniquePtr<float[]> add = GenerateVec<kOuter>(0); hwy::AlignedFreeUniquePtr<float[]> add = GenerateVec<kOuter>(0);
hwy::AlignedFreeUniquePtr<float[]> even_odd =
hwy::AllocateAligned<float>(kInner * pool.NumWorkers());
hwy::AlignedFreeUniquePtr<float[]> expected_out = hwy::AlignedFreeUniquePtr<float[]> expected_out =
SimpleMatVecAdd<kOuter, kInner>(mat, vec, add); SimpleMatVecAdd<kOuter, kInner>(mat, vec, add);
hwy::AlignedFreeUniquePtr<float[]> actual_out = hwy::AlignedFreeUniquePtr<float[]> actual_out =
hwy::AllocateAligned<float>(kOuter); hwy::AllocateAligned<float>(kOuter);
MatVecAdd<true, kOuter, kInner>(mat, 0, vec.get(), add.get(), HWY_ASSERT(vec && add && even_odd && expected_out && actual_out);
actual_out.get(), pool); MatVecAdd</*kAdd=*/true, kOuter, kInner>(
AssertClose<kOuter>(actual_out, expected_out); mat, 0, vec.get(), add.get(), even_odd.get(), actual_out.get(), pool);
}
void TestMatVecAddLoop() {
constexpr size_t kOuter = 128 * 3;
constexpr size_t kInner = 128 * 5;
CompressedArray<float, kOuter * kInner> mat = GenerateMat<kOuter, kInner>(0);
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
hwy::AlignedFreeUniquePtr<float[]> add = GenerateVec<kOuter>(0);
hwy::AlignedFreeUniquePtr<float[]> expected_out =
SimpleMatVecAdd<kOuter, kInner>(mat, vec, add);
hwy::AlignedFreeUniquePtr<float[]> actual_out =
hwy::AllocateAligned<float>(kOuter);
MatVecAddLoop<true, kOuter, kInner>(mat, 0, vec.get(), add.get(),
actual_out.get());
AssertClose<kOuter>(actual_out, expected_out); AssertClose<kOuter>(actual_out, expected_out);
} }
@ -460,6 +539,8 @@ void TestTwoMatVecAdd() {
hwy::AllocateAligned<float>(kOuter); hwy::AllocateAligned<float>(kOuter);
hwy::AlignedFreeUniquePtr<float[]> actual_out1 = hwy::AlignedFreeUniquePtr<float[]> actual_out1 =
hwy::AllocateAligned<float>(kOuter); hwy::AllocateAligned<float>(kOuter);
HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 &&
expected_out1 && actual_out1);
TwoMatVecAdd<true, kOuter, kInner>(mat0, mat1, 0, vec.get(), add0.get(), TwoMatVecAdd<true, kOuter, kInner>(mat0, mat1, 0, vec.get(), add0.get(),
add1.get(), actual_out0.get(), add1.get(), actual_out0.get(),
actual_out1.get(), pool); actual_out1.get(), pool);
@ -482,6 +563,8 @@ void TestTwoOfsMatVecAddLoop() {
hwy::AllocateAligned<float>(kOuter); hwy::AllocateAligned<float>(kOuter);
hwy::AlignedFreeUniquePtr<float[]> actual_out1 = hwy::AlignedFreeUniquePtr<float[]> actual_out1 =
hwy::AllocateAligned<float>(kOuter); hwy::AllocateAligned<float>(kOuter);
HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 &&
expected_out1 && actual_out1);
TwoOfsMatVecAddLoop<true, kOuter, kInner>(mat, 0, 0, vec.get(), add0.get(), TwoOfsMatVecAddLoop<true, kOuter, kInner>(mat, 0, 0, vec.get(), add0.get(),
add1.get(), actual_out0.get(), add1.get(), actual_out0.get(),
actual_out1.get()); actual_out1.get());
@ -521,8 +604,8 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatMul);
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAddLoop);
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop); HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop);
HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid); HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);

View File

@ -23,20 +23,16 @@
#include <vector> #include <vector>
// Placeholder for internal header, do not modify. // Placeholder for internal header, do not modify.
// copybara:import_next_line:gemma_cpp
#include "compression/compress.h" #include "compression/compress.h"
// copybara:import_next_line:gemma_cpp #include "gemma/gemma.h" // Gemma
#include "gemma.h" // Gemma #include "util/app.h"
#include "util/args.h" // HasHelp
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/per_target.h" #include "hwy/per_target.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
#include "hwy/timer.h" #include "hwy/timer.h"
// copybara:import_next_line:gemma_cpp
#include "util/app.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // HasHelp
static constexpr bool kVerboseLogTokens = false; static constexpr bool kVerboseLogTokens = false;
@ -98,11 +94,10 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
void ReplGemma(gcpp::Gemma& model, ModelTraining training, void ReplGemma(gcpp::Gemma& model, ModelTraining training,
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, gcpp::KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const InferenceArgs& args, const InferenceArgs& args, int verbosity,
int verbosity, const gcpp::AcceptFunc& accept_token, const gcpp::AcceptFunc& accept_token, std::string& eot_line) {
std::string& eot_line) {
PROFILER_ZONE("Gen.misc"); PROFILER_ZONE("Gen.misc");
int abs_pos = 0; // absolute token index over all turns size_t abs_pos = 0; // absolute token index over all turns
int current_pos = 0; // token index within the current turn int current_pos = 0; // token index within the current turn
int prompt_size{}; int prompt_size{};
@ -185,7 +180,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
// For instruction-tuned models: add control tokens. // For instruction-tuned models: add control tokens.
prompt_string = "<start_of_turn>user\n" + prompt_string + prompt_string = "<start_of_turn>user\n" + prompt_string +
"<end_of_turn>\n<start_of_turn>model\n"; "<end_of_turn>\n<start_of_turn>model\n";
if (abs_pos > 0) { if (abs_pos != 0) {
// Prepend "<end_of_turn>" token if this is a multi-turn dialogue // Prepend "<end_of_turn>" token if this is a multi-turn dialogue
// continuation. // continuation.
prompt_string = "<end_of_turn>\n" + prompt_string; prompt_string = "<end_of_turn>\n" + prompt_string;
@ -213,7 +208,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
const double time_start = hwy::platform::Now(); const double time_start = hwy::platform::Now();
GenerateGemma(model, args.max_tokens, args.max_generated_tokens, GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool, args.temperature, prompt, abs_pos, kv_cache, pool,
stream_token, accept_token, gen, verbosity); stream_token, accept_token, gen, verbosity);
const double time_end = hwy::platform::Now(); const double time_end = hwy::platform::Now();
const double tok_sec = current_pos / (time_end - time_start); const double tok_sec = current_pos / (time_end - time_start);
@ -233,7 +228,6 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
PROFILER_ZONE("Run.misc"); PROFILER_ZONE("Run.misc");
hwy::ThreadPool inner_pool(0);
hwy::ThreadPool pool(app.num_threads); hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning threads to cores helps. // For many-core, pinning threads to cores helps.
if (app.num_threads > 10) { if (app.num_threads > 10) {
@ -275,8 +269,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
} }
ReplGemma( ReplGemma(
model, loader.ModelTraining(), kv_cache, pool, inner_pool, inference, model, loader.ModelTraining(), kv_cache, pool, inference, app.verbosity,
app.verbosity,
/*accept_token=*/[](int) { return true; }, app.eot_line); /*accept_token=*/[](int) { return true; }, app.eot_line);
} }

View File

@ -18,7 +18,6 @@
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ #ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ #define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
#include <iterator>
#if HWY_OS_LINUX #if HWY_OS_LINUX
#include <sched.h> #include <sched.h>
@ -32,13 +31,11 @@
#include <algorithm> // std::clamp #include <algorithm> // std::clamp
#include <thread> // NOLINT> #include <thread> // NOLINT>
// copybara:import_next_line:gemma_cpp #include "compression/io.h" // Path
#include "configs.h" #include "gemma/configs.h"
// copybara:import_next_line:gemma_cpp #include "gemma/gemma.h"
#include "gemma.h"
#include "hwy/base.h" // HWY_ASSERT
// copybara:import_next_line:gemma_cpp
#include "util/args.h" #include "util/args.h"
#include "hwy/base.h" // HWY_ASSERT
namespace gcpp { namespace gcpp {
@ -49,6 +46,10 @@ static inline const char* CompiledConfig() {
return "msan"; return "msan";
} else if (HWY_IS_TSAN) { } else if (HWY_IS_TSAN) {
return "tsan"; return "tsan";
#if defined(HWY_IS_HWASAN)
} else if (HWY_IS_HWASAN) {
return "hwasan";
#endif
#if defined(HWY_IS_UBSAN) #if defined(HWY_IS_UBSAN)
} else if (HWY_IS_UBSAN) { } else if (HWY_IS_UBSAN) {
return "ubsan"; return "ubsan";
@ -84,8 +85,7 @@ class AppArgs : public ArgsBase<AppArgs> {
void ChooseNumThreads() { void ChooseNumThreads() {
if (num_threads == kDefaultNumThreads) { if (num_threads == kDefaultNumThreads) {
// This is a rough heuristic, replace with something better in the future. // This is a rough heuristic, replace with something better in the future.
num_threads = static_cast<size_t>(std::clamp( num_threads = GetSupportedThreadCount();
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
} }
} }
@ -95,6 +95,12 @@ class AppArgs : public ArgsBase<AppArgs> {
ChooseNumThreads(); ChooseNumThreads();
} }
static inline size_t GetSupportedThreadCount() {
return static_cast<size_t>(
std::clamp(static_cast<int>(std::thread::hardware_concurrency()) - 2, 1,
HWY_MIN(static_cast<int>(kMaxThreads), 18)));
}
Path log; // output Path log; // output
int verbosity; int verbosity;
size_t num_threads; size_t num_threads;
@ -137,7 +143,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
if (tokenizer.path.empty()) { if (tokenizer.path.empty()) {
return "Missing --tokenizer flag, a file for the tokenizer is required."; return "Missing --tokenizer flag, a file for the tokenizer is required.";
} }
if (!tokenizer.exists()) { if (!tokenizer.Exists()) {
return "Can't open file specified with --tokenizer flag."; return "Can't open file specified with --tokenizer flag.";
} }
if (!compressed_weights.path.empty()) { if (!compressed_weights.path.empty()) {
@ -152,7 +158,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
if (weights.path.empty()) { if (weights.path.empty()) {
return "Missing --weights flag, a file for the model weights."; return "Missing --weights flag, a file for the model weights.";
} }
if (!weights.exists()) { if (!weights.Exists()) {
return "Can't open file specified with --weights flag."; return "Can't open file specified with --weights flag.";
} }
return nullptr; return nullptr;

View File

@ -23,48 +23,11 @@
#include <algorithm> // std::transform #include <algorithm> // std::transform
#include <string> #include <string>
#include "compression/io.h"
#include "hwy/base.h" // HWY_ABORT #include "hwy/base.h" // HWY_ABORT
#if defined(_WIN32)
#include <io.h>
#define F_OK 0
#define access _access
#else
#include <unistd.h>
#endif
namespace gcpp { namespace gcpp {
// Wrapper for strings representing a path name. Differentiates vs. arbitrary
// strings and supports shortening for display purposes.
struct Path {
Path() {}
explicit Path(const char* p) : path(p) {}
Path& operator=(const char* other) {
path = other;
return *this;
}
std::string Shortened() const {
constexpr size_t kMaxLen = 48;
constexpr size_t kCutPoint = kMaxLen / 2 - 5;
if (path.size() > kMaxLen) {
return std::string(begin(path), begin(path) + kCutPoint) + " ... " +
std::string(end(path) - kCutPoint, end(path));
}
if (path.empty()) return "[no path specified]";
return path;
}
// Beware, TOCTOU.
bool exists() const {
return (access(path.c_str(), F_OK) == 0);
}
std::string path;
};
// Args is a class that provides a ForEach member function which visits each of // Args is a class that provides a ForEach member function which visits each of
// its member variables. ArgsBase provides functions called by Args to // its member variables. ArgsBase provides functions called by Args to
// initialize values to their defaults (passed as an argument to the visitor), // initialize values to their defaults (passed as an argument to the visitor),