Compare commits

...

57 Commits

Author SHA1 Message Date
Wang Xinping 2c038e1285 work with cmake install 2024-05-03 23:44:12 +08:00
Copybara-Service 8ed22e52bf Merge pull request #177 from szabadka:gemma2
PiperOrigin-RevId: 630388843
2024-05-03 07:52:27 -07:00
Zoltan Szabadka 19017fdb6d Fix expression in DASSERT() 2024-05-03 13:54:20 +00:00
Phil Culliton 28ca001d5e Matmul and test functions
PiperOrigin-RevId: 630373984
2024-05-03 06:39:36 -07:00
Zoltan Szabadka 429eb78512 Remove unused vars. 2024-05-03 13:37:17 +00:00
Zoltan Szabadka 3d72f17261 Use more parallelism in attention block in prefill mode.
Move the loop over the tokens inside the attention block and
then create kHeads * num_tokens threads.

This helps the multi-threaded speed only in case of the 2b gemma
model, but to be consistent we move the loop over the tokens inside
the griffin recurrent layer and the FFW layer as well. This is
also a preparation for using the MatMul operation later.

Benchmark results (summarization with 1600 tokens for prefill
and essay writing with 500 tokens for generation):

```
                   Prefill speed
Num threads      BEFORE       AFTER
32               61.76 t/s    65.08 t/s
64               89.46 t/s    98.62 t/s
```
2024-05-03 13:23:07 +00:00
Copybara-Service 6eeef2e2d9 Merge pull request #166 from samkaufman:deinterleave-vecs
PiperOrigin-RevId: 630360778
2024-05-03 05:23:31 -07:00
Copybara-Service 2a71333c8a Merge pull request #176 from szabadka:gemma3
PiperOrigin-RevId: 630131001
2024-05-02 11:41:05 -07:00
Zoltan Szabadka 9a2682d544 Use more parallelism in the QKV projections of the MHA block.
We compute all three projections with one MatVec and then copy
the kv part to the cache.

Benchmark results for 7b-it model that uses MHA blocks (summarization with
1600 tokens for prefill and essay writing with 500 tokens for generation):

```
                   Prefill speed                Generation speed
Num threads      BEFORE       AFTER            BEFORE       AFTER
32               13.75 t/s    14.80 t/s       9.22 t/s     9.77 t/s
64               19.89 t/s    24.83 t/s      12.46 t/s    13.66 t/s
```
2024-05-02 13:46:45 +00:00
Copybara-Service bafb8382f8 Merge pull request #175 from szabadka:gemma2
PiperOrigin-RevId: 630044058
2024-05-02 06:27:15 -07:00
Zoltan Szabadka 0afa480d90 Use more parallelism in the final output of the attention block.
We use MatVec instead of MatVecLoop for the per-head dense layers,
because we can parallelize more on the rows of the matrix than
on the number of heads. This will be even more efficient after
we rearrange the weights and can have a single MatVec operation.

Benchmark results (summarization with 1600 tokens for prefill
and essay writing with 500 tokens for generation):

```
                   Prefill speed                Generation speed
Num threads      BEFORE       AFTER            BEFORE       AFTER
32               58.24 t/s    61.79 t/s      32.11 t/s    32.62 t/s
64               83.62 t/s    92.00 t/s      41.10 t/s    41.80 t/s
```
2024-05-02 09:30:07 +00:00
Sam Kaufman 4a6173d929 Remove unused vars. 2024-05-02 00:41:44 -07:00
Sam Kaufman 564937ede6 Merge branch 'dev' into deinterleave-vecs 2024-04-30 16:23:04 -07:00
Sam Kaufman 2829ef17ad Check for HWY_NATIVE_DOT_BF16. 2024-04-30 15:19:28 -07:00
Sam Kaufman 59ebecce22 Fix: specialized MatVecAdd was never called. 2024-04-30 15:17:27 -07:00
Jan Wassenberg 12fb2f05cf Add per-thread even_odd storage for #166.
Also inline ProjQ and ProjKV lambdas,
add missing includes/deps for ops_test.

PiperOrigin-RevId: 629460608
2024-04-30 10:42:23 -07:00
Copybara-Service 8f04a8346d Merge pull request #172 from szabadka:gemma2
PiperOrigin-RevId: 629438917
2024-04-30 09:33:38 -07:00
Zoltan Szabadka f8ccb8e37c Fix kv offset computation for MHA config. 2024-04-30 16:19:14 +00:00
Copybara-Service 374fd7478a Merge pull request #170 from szabadka:gemma2
PiperOrigin-RevId: 629408279
2024-04-30 07:40:30 -07:00
Zoltan Szabadka afaca4efa8 Use more parallelism in the QKV projections in MQA mode.
Instead of MatVecLoop, we use MatVec and we combine k and v
into one 2 * kQKVDim long vector so that K and V projections
can be combined into one MatVec operation.

Benchmark results (summarization with 1600 tokens for prefill
and essay writing with 500 tokens for generation):

```
                   Prefill speed                Generation speed
Num threads      BEFORE       AFTER            BEFORE       AFTER
4                 9.81 t/s     9.96 t/s       8.39 t/s     8.46 t/s
18               31.50 t/s    36.67 t/s      23.10 t/s    25.83 t/s
32               45.36 t/s    58.91 t/s      27.60 t/s    31.25 t/s
64               57.72 t/s    80.64 t/s      35.40 t/s    39.76 t/s
```
2024-04-30 13:10:14 +00:00
Copybara-Service befe9fb07e Merge pull request #167 from szabadka:gemma2
PiperOrigin-RevId: 629325219
2024-04-30 01:00:37 -07:00
Sam Kaufman 6a78a23f4c Abstracted some MatVecAdd spec. dupes. 2024-04-29 16:23:38 -07:00
Sam Kaufman f608337fef Remove Bf16ToF32EO and use PromoteEvenTo and PromoteOddTo. 2024-04-29 14:13:07 -07:00
Sam Kaufman aa0b113214 (VecT*) to static_cast<VecT*>. 2024-04-29 12:53:47 -07:00
Sam Kaufman 5cb63346aa supports_eo -> kSupportsEvenOdd 2024-04-29 12:51:35 -07:00
Zoltan Szabadka 27117cc39f Simplify threading: remove the use of inner_pool.
We only used inner_pool in the prefill FFW function, and there we
can achieve sufficient parallelism on the rows of the matrix-vector
multiplications.

Benchmark results on a 1600-token summarization task:

```
               Prefill speed
Num threads    BEFORE         AFTER
4               9.24 t/s       9.76 t/s
18             31.41 t/s      31.16 t/s
32             31.41 t/s      45.13 t/s
64             31.03 t/s      57.85 t/s
```
2024-04-29 16:07:30 +00:00
Paul Chang 1d18c5a129 Improve documentation for compress_weights flags
PiperOrigin-RevId: 629053191
2024-04-29 06:49:50 -07:00
Sam Kaufman 0816a1070d Even-odd layout MatVecs for bf16 weights. 2024-04-28 20:09:25 -07:00
Jan Wassenberg 7a12e29027 Add error-checking for py binding, add missing include+hwasan check
PiperOrigin-RevId: 628453112
2024-04-26 10:59:41 -07:00
Paul Chang e8f59bb411 Fix underflow in NUQ ClusterCost()
PiperOrigin-RevId: 628137904
2024-04-25 11:28:51 -07:00
Phil Culliton 9e0ac5de34 Update Clif wrapper to work with latest gemma.cpp and add simple example
PiperOrigin-RevId: 628134201
2024-04-25 11:17:16 -07:00
Paul Chang 2d4de6b08b Support absolute positional embeddings from vanilla transformer
PiperOrigin-RevId: 628100831
2024-04-25 09:32:14 -07:00
Paul Chang 75eca87039 Simplify prefill early-exit (originally Merge #156)
PiperOrigin-RevId: 627788524
2024-04-24 11:11:42 -07:00
Copybara-Service b27d8d6b92 Merge pull request #156 from zeerd:dev
PiperOrigin-RevId: 627706909
2024-04-24 06:19:14 -07:00
Charles Chan ea45d7c4d7 Use lambda to split function and Make stream_token can break prefill, too 2024-04-23 22:55:01 +08:00
Paul Chang e8d29792ac New token validity assertions, improve prompt truncation warning
PiperOrigin-RevId: 627376194
2024-04-23 07:05:59 -07:00
Jan Wassenberg 3bf22abb22 Fix sign comparison warnings
PiperOrigin-RevId: 627299902
2024-04-23 01:16:51 -07:00
Jan Wassenberg ca971ef50f Document weight conversion
PiperOrigin-RevId: 626957718
2024-04-22 01:58:30 -07:00
Jan Wassenberg e9a0caed87 Further improve IO, enable multiple backends without -D.
Move Path into io.h and use for opening files.
Removes dependency of gemma_lib on args.
Separate Windows codepath instead of emulating POSIX functions.

Plus lint fixes.

PiperOrigin-RevId: 626279004
2024-04-19 00:40:29 -07:00
Paul Chang 38f1ea9b80 Eliminate redundant copies of TokenString()
Move this function outside of HWY_NAMESPACE since it doesn't need to be
optimized for any particular architecture.

PiperOrigin-RevId: 626098641
2024-04-18 11:31:50 -07:00
Jan Wassenberg a8ceb75f43 Improved IO abstraction layer
Move to unique_ptr-like File class.
Move `if OS_WIN` into wrapper functions.
exists -> Exists.

PiperOrigin-RevId: 625923056
2024-04-17 23:15:07 -07:00
Jan Wassenberg a939b5fc9f Update distortion.h to weighted average, add distortion_test.
More thorough checks in sfp_test and nuq_test.
nuq_test: use deterministic input generator.

PiperOrigin-RevId: 625602019
2024-04-17 01:44:19 -07:00
Copybara-Service 05e7e2b2bb Merge pull request #145 from atorero:dev
PiperOrigin-RevId: 624221085
2024-04-12 10:27:18 -07:00
Andrey Mikhaylov 4ef3da733a Fixed minor things and added comments. 2024-04-12 15:39:16 +00:00
Andrey Mikhaylov 2c5706f159 Add comments regarding layers output usage. 2024-04-12 15:39:16 +00:00
Andrey Mikhaylov 03284d752e Added layers output functionality to gemma and a binary debug_output to save the outputs to a json file. 2024-04-12 15:39:16 +00:00
Copybara-Service 342e998cb6 Merge pull request #142 from ufownl:refactor/data_structures
PiperOrigin-RevId: 623503486
2024-04-10 08:35:18 -07:00
RangerUFO e541707caa Rename the fields of Griffin weights 2024-04-10 21:04:31 +08:00
RangerUFO 4e960d67f6 Fix typos 2024-04-10 20:38:18 +08:00
RangerUFO 809bd0709d Refactor data structures to reduce memory usage 2024-04-10 19:35:23 +08:00
Jan Wassenberg 54120a5571 Mention Makefile contributed by @jart
PiperOrigin-RevId: 623436818
2024-04-10 03:21:10 -07:00
Jan Wassenberg 881eeffe0a Lint fixes: strcat, includes, arg naming
PiperOrigin-RevId: 623435210
2024-04-10 03:12:41 -07:00
Copybara-Service da91f4c4be Merge pull request #137 from zond:main
PiperOrigin-RevId: 623255639
2024-04-09 12:57:57 -07:00
Copybara-Service 827fec1904 Merge pull request #139 from ufownl:feature/public_layers
PiperOrigin-RevId: 623254705
2024-04-09 12:54:23 -07:00
RangerUFO 2099b37732 Change `NumGemmaLayers` and `NumGriffinLayers` to constants in configs 2024-04-09 20:44:41 +08:00
Jan Wassenberg a982ec1287 Move code to gemma/ so we can remove error-prone copybara: comments.
Also fix includes and Lint warnings.

PiperOrigin-RevId: 623127487
2024-04-09 04:45:42 -07:00
zond 9ca662dc14
Clarified README
Made it more visible that the recurrent weights are at a different Kaggle page.
2024-04-09 09:58:47 +02:00
39 changed files with 2041 additions and 1087 deletions

6
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,6 @@
{
"cmake.configureOnOpen": false,
"files.associations": {
"array": "cpp"
}
}

View File

@ -22,9 +22,7 @@ exports_files(["LICENSE"])
cc_library(
name = "ops",
hdrs = [
"ops.h",
],
hdrs = ["gemma/ops.h"],
deps = [
"//compression:compress",
"@hwy//:algo",
@ -41,42 +39,34 @@ cc_library(
cc_test(
name = "ops_test",
size = "small",
srcs = ["ops_test.cc"],
srcs = ["gemma/ops_test.cc"],
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
deps = [
":ops",
"@googletest//:gtest_main",
"//compression:compress",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
],
)
cc_library(
name = "args",
hdrs = [
"util/args.h",
],
deps = [
"@hwy//:hwy",
"@hwy//:thread_pool",
],
)
cc_library(
name = "gemma_lib",
srcs = [
"gemma.cc",
"gemma/gemma.cc",
],
hdrs = [
"configs.h",
"gemma.h",
"gemma/configs.h",
"gemma/gemma.h",
],
deps = [
":args",
":ops",
# "//base",
"//compression:compress",
"//compression:io",
"@hwy//:hwy",
"@hwy//:matvec",
"@hwy//:nanobenchmark", # timer
@ -86,9 +76,29 @@ cc_library(
],
)
cc_library(
name = "args",
hdrs = ["util/args.h"],
deps = [
"//compression:io",
"@hwy//:hwy",
],
)
cc_library(
name = "app",
hdrs = ["util/app.h"],
deps = [
":args",
":gemma_lib",
"//compression:io",
"@hwy//:hwy",
],
)
cc_test(
name = "gemma_test",
srcs = ["gemma_test.cc"],
srcs = ["gemma/gemma_test.cc"],
# Requires model files
tags = [
"local",
@ -105,23 +115,9 @@ cc_test(
],
)
cc_library(
name = "app",
hdrs = [
"util/app.h",
],
deps = [
":args",
":gemma_lib",
"@hwy//:hwy",
],
)
cc_binary(
name = "gemma",
srcs = [
"run.cc",
],
srcs = ["gemma/run.cc"],
deps = [
":app",
":args",
@ -137,9 +133,7 @@ cc_binary(
cc_binary(
name = "compress_weights",
srcs = [
"compress_weights.cc",
],
srcs = ["gemma/compress_weights.cc"],
deps = [
":args",
":gemma_lib",
@ -154,8 +148,25 @@ cc_binary(
cc_binary(
name = "benchmark",
srcs = ["gemma/benchmark.cc"],
deps = [
":app",
":args",
":gemma_lib",
# "//base",
"//compression:compress",
"@hwy//:hwy",
"@hwy//:nanobenchmark",
"@hwy//:profiler",
"@hwy//:thread_pool",
"@nlohmann_json//:json",
],
)
cc_binary(
name = "debug_prompt",
srcs = [
"benchmark.cc",
"debug_prompt.cc",
],
deps = [
":app",

View File

@ -34,16 +34,22 @@ FetchContent_Declare(json GIT_REPOSITORY https://github.com/nlohmann/json.git GI
FetchContent_MakeAvailable(json)
set(SOURCES
gemma.cc
compression/blob_store.cc
compression/blob_store.h
compression/compress.h
compression/compress-inl.h
compression/io_win.cc
compression/io.cc
compression/io.h
compression/nuq.h
compression/nuq-inl.h
compression/sfp.h
compression/sfp-inl.h
compression/test_util.h
gemma/configs.h
gemma/gemma.cc
gemma/gemma.h
gemma/ops.h
util/app.h
util/args.h
)
@ -72,26 +78,31 @@ set_property(TARGET libgemma PROPERTY CXX_STANDARD 17)
set_target_properties(libgemma PROPERTIES PREFIX "")
set_property(TARGET libgemma PROPERTY POSITION_INDEPENDENT_CODE ON)
target_include_directories(libgemma PUBLIC ./)
target_link_libraries(libgemma hwy hwy_contrib sentencepiece)
target_link_libraries(libgemma hwy hwy_contrib sentencepiece-static)
target_include_directories(libgemma PUBLIC ${sentencepiece_SOURCE_DIR})
target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
install(TARGETS libgemma DESTINATION lib)
# Executable Target
add_executable(gemma run.cc)
add_executable(gemma gemma/run.cc)
target_link_libraries(gemma libgemma hwy hwy_contrib)
install(TARGETS gemma DESTINATION bin)
add_executable(benchmark benchmark.cc)
add_executable(benchmark gemma/benchmark.cc)
target_link_libraries(benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
add_executable(debug_prompt debug_prompt.cc)
target_link_libraries(debug_prompt libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
## Tests
set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests")
if (GEMMA_ENABLE_TESTS)
set(GEMMA_TEST_FILES
ops_test.cc
gemma_test.cc
gemma/ops_test.cc
gemma/gemma_test.cc
)
foreach (TESTFILE IN LISTS GEMMA_TEST_FILES)
@ -112,5 +123,5 @@ endif() # GEMMA_ENABLE_TESTS
## Tools
add_executable(compress_weights compress_weights.cc)
add_executable(compress_weights gemma/compress_weights.cc)
target_link_libraries(compress_weights libgemma hwy hwy_contrib)

View File

@ -83,6 +83,23 @@ A `.clang-format` configuration is provided with our defaults, please run source
files through `clang-format` (or a formatter that produces equivalent behavior)
before finalizing PR for submission.
## Converting weights
We use a stripped down binary blob (.sbs) artifact to accelerate weight loading
in C++. These files can be downloaded directly from Kaggle and HuggingFace. You
can also convert Pytorch or Keras checkpoints to .sbs, but most end users should
not have to do this.
If starting with Keras, first run this script to convert to Pytorch:
https://github.com/keras-team/keras-nlp/blob/master/tools/gemma/export_gemma_to_torch_xla.py
From Pytorch, use the following script to generate uncompressed weights:
https://github.com/google/gemma.cpp/blob/dev/util/convert_weights.py
Then run gemma/compress_weights.cc (Bazel target :compress_weights), specifying
the resulting file as `--weights` and the desired .sbs name as the
`--compressed_weights`.
## Compile-Time Flags (Advanced)
There are several compile-time flags to be aware of (note these may or may not
@ -169,9 +186,9 @@ inference path of the Gemma model.
The sentencepiece library we depend on requires some additional work to build
with the Bazel build system. First, it does not export its BUILD file, so we
provide `bazel/sentencepiece.bazel`. Second, it ships with a vendored subset of
the Abseil library. `bazel/com_google_sentencepiece.patch` changes the code to
support Abseil as a standalone dependency without third_party/ prefixes, similar
to the transforms we apply to Gemma via Copybara.
the Abseil library. `bazel/sentencepiece.patch` changes the code to support
Abseil as a standalone dependency without third_party/ prefixes, similar to the
transforms we apply to Gemma via Copybara.
## Discord

View File

@ -33,7 +33,7 @@ http_archive(
strip_prefix = "sentencepiece-0.1.96",
urls = ["https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"],
build_file = "@//bazel:sentencepiece.bazel",
patches = ["@//bazel:com_google_sentencepiece.patch"],
patches = ["@//bazel:sentencepiece.patch"],
patch_args = ["-p1"],
)

View File

@ -206,6 +206,12 @@ bazel build -c opt --cxxopt=-std=c++20 :gemma
If the build is successful, you should now have a `gemma` executable in the `bazel-bin/` directory.
#### Make
If you prefer Makefiles, @jart has made one available here:
https://github.com/jart/gemma3/blob/main/Makefile
### Step 4: Run
You can now run `gemma` from inside the `build/` directory.
@ -252,7 +258,7 @@ here provide a C++ implementation of this model based on the paper.
To use the recurrent version of Gemma included in this repository, build the
gemma binary as noted above in Step 3. Download the compressed weights and
tokenizer from
tokenizer from the RecurrentGemma
[Kaggle](https://www.kaggle.com/models/google/recurrentgemma/gemmaCpp) as in
Step 1, and run the binary as follows:

View File

@ -1,4 +1,4 @@
# Required for referencing bazel:com_google_sentencepiece.patch
# Required for referencing bazel:sentencepiece.patch
package(
default_applicable_licenses = ["//:license"],
default_visibility = ["//:__subpackages__"],

View File

@ -11,14 +11,24 @@ package(
)
cc_library(
name = "blob_store",
name = "io",
srcs = [
"blob_store.cc",
],
hdrs = [
"blob_store.h",
"io.cc",
# Placeholder for io backend, do not remove
],
hdrs = ["io.h"],
deps = [
# Placeholder for io deps, do not remove
"@hwy//:hwy",
],
)
cc_library(
name = "blob_store",
srcs = ["blob_store.cc"],
hdrs = ["blob_store.h"],
deps = [
":io",
"@hwy//:hwy",
"@hwy//:thread_pool",
],
@ -39,7 +49,23 @@ cc_library(
name = "distortion",
hdrs = ["distortion.h"],
deps = [
":stats",
"@hwy//:hwy",
"@hwy//hwy/contrib/sort:vqsort",
],
)
cc_test(
name = "distortion_test",
size = "small",
srcs = ["distortion_test.cc"],
deps = [
":distortion",
":test_util",
"@googletest//:gtest_main",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:nanobenchmark", # Unpredictable1
],
)
@ -56,12 +82,8 @@ cc_library(
cc_library(
name = "sfp",
hdrs = [
"sfp.h",
],
textual_hdrs = [
"sfp-inl.h",
],
hdrs = ["sfp.h"],
textual_hdrs = ["sfp-inl.h"],
deps = [
"@hwy//:hwy",
],
@ -88,12 +110,8 @@ cc_test(
cc_library(
name = "nuq",
hdrs = [
"nuq.h",
],
textual_hdrs = [
"nuq-inl.h",
],
hdrs = ["nuq.h"],
textual_hdrs = ["nuq-inl.h"],
deps = [
":sfp",
"@hwy//:hwy",
@ -134,6 +152,7 @@ cc_library(
deps = [
":blob_store",
":distortion",
":io",
":nuq",
":sfp",
":stats",
@ -146,9 +165,7 @@ cc_library(
# For internal experimentation
cc_library(
name = "analyze",
textual_hdrs = [
"analyze.h",
],
textual_hdrs = ["analyze.h"],
deps = [
":distortion",
":nuq",

View File

@ -26,11 +26,8 @@
#include <cstdlib> // std::abs
#include <vector>
// copybara:import_next_line:gemma_cpp
#include "compression/distortion.h"
// copybara:import_next_line:gemma_cpp
#include "compression/nuq.h"
// copybara:import_next_line:gemma_cpp
#include "compression/stats.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -46,9 +43,7 @@
#define THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
#endif
// copybara:import_next_line:gemma_cpp
#include "compression/nuq-inl.h"
// copybara:import_next_line:gemma_cpp
#include "compression/sfp-inl.h"
#include "hwy/contrib/sort/vqsort-inl.h"
#include "hwy/highway.h"

View File

@ -13,89 +13,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Request POSIX 2008, including `pread()` and `posix_fadvise()`.
#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700
#undef _XOPEN_SOURCE
#define _XOPEN_SOURCE 700
#endif
#if !defined(_POSIX_C_SOURCE) || _POSIX_C_SOURCE < 200809
#define _POSIX_C_SOURCE 200809
#endif
// Make `off_t` 64-bit even on 32-bit systems. Works for Android >= r15c.
#undef _FILE_OFFSET_BITS
#define _FILE_OFFSET_BITS 64
// copybara:import_next_line:gemma_cpp
#include "compression/blob_store.h"
#include <fcntl.h> // open
#include <stddef.h>
#include <stdint.h>
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
#include <sys/stat.h> // O_RDONLY
#if HWY_OS_WIN
#include <fileapi.h>
#include <io.h> // read, write, close
#else
#include <unistd.h> // read, write, close
#endif
#include <atomic>
#include <memory>
#include <vector>
#include "compression/io.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/detect_compiler_arch.h"
namespace {
#if HWY_OS_WIN
// pread is not supported on Windows
static int64_t pread(int fd, void* buf, uint64_t size, uint64_t offset) {
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
if (file == INVALID_HANDLE_VALUE) {
return -1;
}
OVERLAPPED overlapped = {0};
overlapped.Offset = offset & 0xFFFFFFFF;
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
DWORD bytes_read;
if (!ReadFile(file, buf, size, &bytes_read, &overlapped)) {
if (GetLastError() != ERROR_HANDLE_EOF) {
return -1;
}
}
return bytes_read;
}
// pwrite is not supported on Windows
static int64_t pwrite(int fd, const void* buf, uint64_t size, uint64_t offset) {
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
if (file == INVALID_HANDLE_VALUE) {
return -1;
}
OVERLAPPED overlapped = {0};
overlapped.Offset = offset & 0xFFFFFFFF;
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
DWORD bytes_written;
if (!WriteFile(file, buf, size, &bytes_written, &overlapped)) {
if (GetLastError() != ERROR_HANDLE_EOF) {
return -1;
}
}
return bytes_written;
}
#endif
} // namespace
namespace gcpp {
hwy::uint128_t MakeKey(const char* string) {
@ -132,61 +64,6 @@ void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data,
}
} // namespace
struct IO {
// Returns size in bytes or 0.
static uint64_t FileSize(const char* filename) {
int fd = open(filename, O_RDONLY);
if (fd < 0) {
return 0;
}
#if HWY_OS_WIN
const int64_t size = _lseeki64(fd, 0, SEEK_END);
HWY_ASSERT(close(fd) != -1);
if (size < 0) {
return 0;
}
#else
static_assert(sizeof(off_t) == 8, "64-bit off_t required");
const off_t size = lseek(fd, 0, SEEK_END);
HWY_ASSERT(close(fd) != -1);
if (size == static_cast<off_t>(-1)) {
return 0;
}
#endif
return static_cast<uint64_t>(size);
}
static bool Read(int fd, uint64_t offset, uint64_t size, void* to) {
uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
uint64_t pos = 0;
for (;;) {
// pread seems to be faster than lseek + read when parallelized.
const auto bytes_read = pread(fd, bytes + pos, size - pos, offset + pos);
if (bytes_read <= 0) break;
pos += bytes_read;
HWY_ASSERT(pos <= size);
if (pos == size) break;
}
return pos == size; // success if managed to read desired size
}
static bool Write(const void* from, uint64_t size, uint64_t offset, int fd) {
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(from);
uint64_t pos = 0;
for (;;) {
const auto bytes_written =
pwrite(fd, bytes + pos, size - pos, offset + pos);
if (bytes_written <= 0) break;
pos += bytes_written;
HWY_ASSERT(pos <= size);
if (pos == size) break;
}
return pos == size; // success if managed to write desired size
}
}; // IO
static_assert(HWY_IS_LITTLE_ENDIAN, "Assumes little endian");
// On-disk representation (little-endian).
@ -323,26 +200,13 @@ class BlobStore {
};
#pragma pack(pop)
BlobError BlobReader::Open(const char* filename) {
#if HWY_OS_WIN
DWORD flags = FILE_ATTRIBUTE_NORMAL | FILE_FLAG_SEQUENTIAL_SCAN;
HANDLE file = CreateFileA(filename, GENERIC_READ, FILE_SHARE_READ, nullptr,
OPEN_EXISTING, flags, nullptr);
if (file == INVALID_HANDLE_VALUE) return __LINE__;
fd_ = _open_osfhandle(reinterpret_cast<intptr_t>(file), _O_RDONLY);
#else
fd_ = open(filename, O_RDONLY);
#endif
if (fd_ < 0) return __LINE__;
#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
// Doubles the readahead window, which seems slightly faster when cached.
(void)posix_fadvise(fd_, 0, 0, POSIX_FADV_SEQUENTIAL);
#endif
BlobError BlobReader::Open(const Path& filename) {
file_ = OpenFileOrNull(filename, "r");
if (!file_) return __LINE__;
// Read first part of header to get actual size.
BlobStore bs;
if (!IO::Read(fd_, 0, sizeof(bs), &bs)) return __LINE__;
if (!file_->Read(0, sizeof(bs), &bs)) return __LINE__;
const size_t padded_size = bs.PaddedHeaderSize();
HWY_ASSERT(padded_size >= sizeof(bs));
@ -354,18 +218,11 @@ BlobError BlobReader::Open(const char* filename) {
hwy::CopySameSize(&bs, blob_store_.get());
// Read the rest of the header, but not the full file.
uint8_t* bytes = reinterpret_cast<uint8_t*>(blob_store_.get());
if (!IO::Read(fd_, sizeof(bs), padded_size - sizeof(bs),
bytes + sizeof(bs))) {
if (!file_->Read(sizeof(bs), padded_size - sizeof(bs), bytes + sizeof(bs))) {
return __LINE__;
}
return blob_store_->CheckValidity(IO::FileSize(filename));
}
BlobReader::~BlobReader() {
if (fd_ >= 0) {
HWY_ASSERT(close(fd_) != -1);
}
return blob_store_->CheckValidity(file_->FileSize());
}
BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) {
@ -392,14 +249,14 @@ BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) {
// between consecutive runs.
// - memory-mapped I/O is less predictable and adds noise to measurements.
BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
const int fd = fd_;
File* pfile = file_.get(); // not owned
const auto& requests = requests_;
std::atomic_flag err = ATOMIC_FLAG_INIT;
// >5x speedup from parallel reads when cached.
pool.Run(0, requests.size(),
[fd, &requests, &err](uint64_t i, size_t /*thread*/) {
if (!IO::Read(fd, requests[i].offset, requests[i].size,
requests[i].data)) {
[pfile, &requests, &err](uint64_t i, size_t /*thread*/) {
if (!pfile->Read(requests[i].offset, requests[i].size,
requests[i].data)) {
err.test_and_set();
}
});
@ -407,8 +264,7 @@ BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
return 0;
}
BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool,
const char* filename) const {
BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
HWY_ASSERT(keys_.size() == blobs_.size());
// Concatenate blobs in memory.
@ -419,26 +275,18 @@ BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool,
keys_.data(), blobs_.data(), keys_.size(), bs.get());
// Create/replace existing file.
#if HWY_OS_WIN
DWORD flags = FILE_ATTRIBUTE_NORMAL;
HANDLE file = CreateFileA(filename, GENERIC_WRITE, 0, nullptr, CREATE_ALWAYS,
flags, nullptr);
if (file == INVALID_HANDLE_VALUE) return __LINE__;
const int fd = _open_osfhandle(reinterpret_cast<intptr_t>(file), _O_WRONLY);
#else
const int fd = open(filename, O_CREAT | O_RDWR | O_TRUNC, 0644);
#endif
if (fd < 0) return __LINE__;
std::unique_ptr<File> file = OpenFileOrNull(filename, "w+");
if (!file) return __LINE__;
File* pfile = file.get(); // not owned
std::atomic_flag err = ATOMIC_FLAG_INIT;
pool.Run(0, requests.size(),
[fd, &requests, &err](uint64_t i, size_t /*thread*/) {
if (!IO::Write(requests[i].data, requests[i].size,
requests[i].offset, fd)) {
[pfile, &requests, &err](uint64_t i, size_t /*thread*/) {
if (!pfile->Write(requests[i].data, requests[i].size,
requests[i].offset)) {
err.test_and_set();
}
});
HWY_ASSERT(close(fd) != -1);
if (err.test_and_set()) return __LINE__;
return 0;
}

View File

@ -19,8 +19,10 @@
#include <stddef.h>
#include <stdint.h>
#include <memory>
#include <vector>
#include "compression/io.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::uint128_t
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -59,10 +61,10 @@ struct BlobIO {
class BlobReader {
public:
BlobReader() { requests_.reserve(500); }
~BlobReader();
~BlobReader() = default;
// Opens `filename` and reads its header.
BlobError Open(const char* filename);
BlobError Open(const Path& filename);
// Enqueues read requests if `key` is found and its size matches `size`.
BlobError Enqueue(hwy::uint128_t key, void* data, size_t size);
@ -73,7 +75,7 @@ class BlobReader {
private:
BlobStorePtr blob_store_; // holds header, not the entire file
std::vector<BlobIO> requests_;
int fd_ = 0;
std::unique_ptr<File> file_;
};
class BlobWriter {
@ -84,7 +86,7 @@ class BlobWriter {
}
// Stores all blobs to disk in the given order with padding for alignment.
BlobError WriteAll(hwy::ThreadPool& pool, const char* filename) const;
BlobError WriteAll(hwy::ThreadPool& pool, const Path& filename);
private:
std::vector<hwy::uint128_t> keys_;

View File

@ -23,11 +23,8 @@
#include <array>
// copybara:import_next_line:gemma_cpp
#include "compression/blob_store.h"
// copybara:import_next_line:gemma_cpp
#include "compression/compress.h"
// copybara:import_next_line:gemma_cpp
#include "compression/distortion.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
@ -44,9 +41,7 @@
#define THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE
#endif
// copybara:import_next_line:gemma_cpp
#include "compression/nuq-inl.h"
// copybara:import_next_line:gemma_cpp
#include "compression/sfp-inl.h"
#include "hwy/contrib/dot/dot-inl.h"
#include "hwy/highway.h"
@ -63,6 +58,7 @@ struct CompressTraits {};
template <>
struct CompressTraits<float> {
using MatT = float;
static constexpr bool kSupportsEvenOdd = false;
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
@ -116,6 +112,7 @@ struct CompressTraits<float> {
template <>
struct CompressTraits<hwy::bfloat16_t> {
using MatT = hwy::bfloat16_t;
static constexpr bool kSupportsEvenOdd = true;
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
@ -224,11 +221,59 @@ struct CompressTraits<hwy::bfloat16_t> {
// bf16*bf16.
return hn::Dot::Compute<kAssumptions>(d_vec, vec_aligned, in + in_ofs, num);
}
// Computes the dot product of an even-odd deinterleaved, f32 `vec_aligned`
// and a column- major matrix `in`. `vec_aligned` should be aligned and
// alternate even-indexed `hn::Lanes(df32)` elements followed by odd-indexed
// `hn::Lanes(df32)` elements.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE float DotEO(
const DF df32, const hwy::bfloat16_t* HWY_RESTRICT in, size_t in_ofs,
const float* HWY_RESTRICT vec_aligned, size_t num) {
HWY_DASSERT(num >= (hn::Lanes(df32) * 2) && (num % (hn::Lanes(df32) * 2)) == 0);
HWY_DASSERT((in_ofs % (hn::Lanes(df32) * 2)) == 0);
HWY_DASSERT(hn::IsAligned(df32, vec_aligned));
const hn::Repartition<hwy::bfloat16_t, DF> dbf16;
using VF32 = decltype(Zero(df32));
const size_t N = Lanes(dbf16);
VF32 sum0 = Zero(df32);
VF32 sum1 = Zero(df32);
VF32 sum2 = Zero(df32);
VF32 sum3 = Zero(df32);
const hn::RebindToUnsigned<decltype(df32)> du32;
using VU32 = hn::VFromD<decltype(du32)>;
const VU32 odd = Set(du32, 0xFFFF0000u);
for (size_t i = 0; i < num; /* i += 2 * N */) {
const auto interleaved0 = hn::LoadU(dbf16, in + in_ofs + i);
const VF32 ae0 = Load(df32, vec_aligned + i);
const VF32 ao0 = Load(df32, vec_aligned + i + (N / 2));
sum0 = hn::MulAdd(ae0, hn::PromoteEvenTo(df32, interleaved0), sum0);
sum1 = hn::MulAdd(ao0, hn::PromoteOddTo(df32, interleaved0), sum1);
i += N;
const auto interleaved1 = hn::LoadU(dbf16, in + in_ofs + i);
const VF32 ae1 = Load(df32, vec_aligned + i);
const VF32 ao1 = Load(df32, vec_aligned + i + (N / 2));
sum2 = hn::MulAdd(ae1, hn::PromoteEvenTo(df32, interleaved1), sum2);
sum3 = hn::MulAdd(ao1, hn::PromoteOddTo(df32, interleaved1), sum3);
i += N;
}
sum0 = Add(sum0, sum1);
sum2 = Add(sum2, sum3);
sum0 = Add(sum0, sum2);
return ReduceSum(df32, sum0);
}
};
template <>
struct CompressTraits<SfpStream> {
using MatT = SfpStream;
static constexpr bool kSupportsEvenOdd = false;
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* in, size_t num,
@ -278,6 +323,7 @@ struct CompressTraits<SfpStream> {
template <>
struct CompressTraits<NuqStream> {
using MatT = NuqStream;
static constexpr bool kSupportsEvenOdd = false;
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* in, size_t num,
@ -430,16 +476,22 @@ HWY_INLINE float Dot(DF df, const ArrayT& compressed, size_t compressed_ofs,
}
// Returns dot product with `vec_aligned` of length `num`.
template <class DF, typename MatT, size_t kCapacity, typename VecT>
template <bool kVecEO, class DF, typename MatT, size_t kCapacity, typename VecT>
HWY_INLINE float Dot(DF df, const CompressedArray<MatT, kCapacity>& compressed,
size_t compressed_ofs, const VecT* vec_aligned,
size_t num) {
HWY_DASSERT(compressed_ofs + num <= compressed.size());
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
using Traits = CompressTraits<MatT>;
return (compressed.scale() * Traits::Dot(df, compressed.size(),
compressed.data(), compressed_ofs,
vec_aligned, num));
float dot_result;
if constexpr (kVecEO) {
dot_result = Traits::DotEO(df, compressed.data(), compressed_ofs,
vec_aligned, num);
} else {
dot_result = Traits::Dot(df, compressed.size(), compressed.data(),
compressed_ofs, vec_aligned, num);
}
return compressed.scale() * dot_result;
}
// Callback used by ForeachTensor.
@ -464,11 +516,11 @@ class Compressor {
}
}
void WriteAll(hwy::ThreadPool& pool, const char* blob_filename) {
void WriteAll(hwy::ThreadPool& pool, const Path& blob_filename) {
const BlobError err = writer_.WriteAll(pool, blob_filename);
if (err != 0) {
fprintf(stderr, "Failed to write blobs to %s (error %d)\n", blob_filename,
err);
fprintf(stderr, "Failed to write blobs to %s (error %d)\n",
blob_filename.path.c_str(), err);
}
}

View File

@ -27,19 +27,15 @@
#include <vector>
// IWYU pragma: begin_exports
// copybara:import_next_line:gemma_cpp
#include "compression/blob_store.h"
// copybara:import_next_line:gemma_cpp
#include "compression/io.h"
#include "compression/nuq.h"
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h"
// IWYU pragma: end_exports
// copybara:import_next_line:gemma_cpp
#include "compression/distortion.h"
#include "hwy/base.h" // hwy::bfloat16_t
#include "hwy/contrib/thread_pool/thread_pool.h"
#if COMPRESS_STATS
// copybara:import_next_line:gemma_cpp
#include "compression/stats.h"
#endif
@ -171,13 +167,13 @@ hwy::uint128_t CacheKey(const char* name) {
class CacheLoader {
public:
explicit CacheLoader(const char* blob_filename) {
explicit CacheLoader(const Path& blob_filename) {
err_ = reader_.Open(blob_filename);
if (err_ != 0) {
fprintf(stderr,
"Cached compressed weights does not exist yet (code %d), "
"compressing weights and creating file: %s.\n",
err_, blob_filename);
err_, blob_filename.path.c_str());
}
}

View File

@ -15,85 +15,214 @@
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_
#include <math.h> // pow
#include <stddef.h>
#include <stdio.h>
#include "hwy/base.h" // ScalarAbs
#include <vector>
#include "compression/stats.h"
#include "hwy/aligned_allocator.h" // HWY_ALIGNMENT
#include "hwy/base.h" // ScalarAbs
#include "hwy/contrib/sort/vqsort.h"
namespace gcpp {
// Returns `sum` and `err` such that `sum + err` is exactly equal to `a + b`,
// despite floating-point rounding. `sum` is already the best estimate, so do
// not actually add `err` to it. Knuth98/Moller65. Unlike Fast2Sum [Dekker71],
// this does not require any relative ordering of the exponents of a and b.
template <typename T>
static inline T TwoSum(T a, T b, T& err) {
const T sum = a + b;
const T a2 = sum - b;
const T b2 = sum - a2;
const T err_a = a - a2;
const T err_b = b - b2;
err = err_a + err_b;
return sum;
}
// Accumulates numbers with about twice the precision of T using 7 * n FLOPS.
// Rump/Ogita/Oishi08, Algorithm 6.11 in Handbook of Floating-Point Arithmetic.
template <typename T>
class CascadedSummation {
public:
void Notify(T t) {
T err;
sum_ = TwoSum(sum_, t, err);
sum_err_ += err;
}
void Assimilate(const CascadedSummation& other) {
Notify(other.sum_);
sum_err_ += other.sum_err_;
}
// Allows users to observe how much difference the extra precision made.
T Err() const { return sum_err_; }
// Returns the sum of all `t` passed to `Notify`.
T Total() const { return sum_ + sum_err_; }
private:
T sum_ = T{0};
T sum_err_ = T{0};
};
// Summarizes the error of a distortion (e.g. quantization) applied to a series
// of numbers.
// Users should check all four resulting metrics (NumExact, NumRoundedToZero,
// GeomeanValueDivL1, WeightedAverageL1) because each covers different aspects.
class DistortionStats {
public:
void Notify(float original, float distorted) {
(void)padding_; // prevent unused member warning
const double l1 = hwy::ScalarAbs(original - distorted);
const bool rounded_to_zero = (original != 0.0f) && (distorted == 0.0f);
// We expect original == 0 is not distorted (can be exactly represented).
HWY_ASSERT(original != 0.0f || distorted == 0.0f);
if (l1 > max_l1_) {
max_l1_ = l1;
max_idx_ = n_;
s_original_.Notify(original);
const float l1f = hwy::ScalarAbs(original - distorted);
const double l1 = static_cast<double>(l1f);
s_l1_.Notify(l1f);
b_l1_.Notify(HWY_MIN(99, static_cast<int>(l1f * 1E4)));
if (l1f != 0.0f) {
l1_.push_back(l1f);
}
sum_l1_.Notify(l1f);
if (rounded_to_zero) sum_l1_rounded_.Notify(l1f);
// Event counts
{
n_ += 1;
// Rounding (small) negative numbers to 0 does not influence dot products
// as much as an actual sign flip, so do not count them.
n_sign_flip_ +=
((original < 0.0f) != (distorted < 0.0f)) && !rounded_to_zero;
n_exact_ += (l1f == 0.0f);
n_rounded_to_zero += rounded_to_zero;
}
const double pow3 = l1 * l1 * l1;
sum_pow3_ += pow3;
sum_pow6_ += pow3 * pow3;
n_ += 1;
// Avoid division by zero, which happens when there is no error. NumExact()
// reports the number of times this happens.
if (l1 != 0.0) {
const double rel = 1.0 + hwy::ScalarAbs(original) / l1;
// Logarithm is required to prevent overflow. A hierarchical geomean
// Signal to noise ratio (Shannon's channel capacity, NOT the L2-based and
// logarithmic PSNR) to estimate the ratios of original to the L1 norm.
if (l1f != 0.0) { // prevent division by zero
const double snr =
1.0 + static_cast<double>(hwy::ScalarAbs(original)) / l1;
// For numerical purposes (prevents overflow). A hierarchical geomean
// could also work, but that is more complex and not necessarily better.
sum_log_rel_ += log(rel);
num_rel_ += 1;
// We will return exp() of the arithmetic mean.
sum_log_snr_ += log(snr);
num_snr_ += 1;
}
}
void Assimilate(const DistortionStats& other) {
if (other.max_l1_ > max_l1_) {
max_l1_ = other.max_l1_;
max_idx_ = other.max_idx_;
}
s_original_.Assimilate(other.s_original_);
s_l1_.Assimilate(other.s_l1_);
b_l1_.Assimilate(other.b_l1_);
sum_l1_.Assimilate(other.sum_l1_);
sum_l1_rounded_.Assimilate(other.sum_l1_rounded_);
l1_.insert(l1_.end(), other.l1_.begin(), other.l1_.end());
sum_pow3_ += other.sum_pow3_;
sum_pow6_ += other.sum_pow6_;
n_ += other.n_;
n_sign_flip_ += other.n_sign_flip_;
n_exact_ += other.n_exact_;
n_rounded_to_zero += other.n_rounded_to_zero;
sum_log_rel_ += other.sum_log_rel_;
num_rel_ += other.num_rel_;
sum_log_snr_ += other.sum_log_snr_;
num_snr_ += other.num_snr_;
}
size_t NumExact() const { return n_ - num_rel_; }
size_t NumExact() const { return n_exact_; }
size_t NumSignFlip() const { return n_sign_flip_; }
size_t NumRoundedToZero() const { return n_rounded_to_zero; }
// Total absolute error.
double SumL1() const { return sum_l1_.Total(); }
// Total absolute error for numbers that were rounded to zero.
double SumL1Rounded() const { return sum_l1_rounded_.Total(); }
// Returns geomean of 1 + S/N (Shannon channel capacity). This is computed via
// the ratio of input magnitude to nonzero L1 norms. Higher is better.
double GeomeanValueDivL1() const {
if (num_rel_ == 0) return 0.0;
return exp(sum_log_rel_ / static_cast<double>(num_rel_));
if (num_snr_ == 0) return 0.0;
return exp(sum_log_snr_ / static_cast<double>(num_snr_));
}
double PNorm() const {
// p-norms are a compromise between max-norm (penalizes the largest error
// without dilution, but does not notice any other errors) and L1 (all
// errors contribute, but large errors are diluted by smaller ones).
const double norm3 = pow(sum_pow3_ / static_cast<double>(n_), 1.0 / 3);
const double norm6 = pow(sum_pow6_ / static_cast<double>(n_), 1.0 / 6);
return 0.5 * (norm3 + norm6);
// Returns weighted average of nonzero L1 norms. Those further from the median
// L1 norm are much more heavily weighted, such that this behaves more like
// the L-infinity norm, but still includes all differences, not just the max.
// Lower is better, magnitude depends on the input magnitude.
double WeightedAverageL1() const {
if (l1_.empty()) return 0.0f; // all exact
std::vector<float> weights(l1_); // copy so we can modify
const float median = [&weights]() {
const size_t mid = weights.size() / 2;
// We just want the median; partial sort is faster if available (v1.2).
#if HWY_MAJOR > 1 || HWY_MINOR >= 2
hwy::VQSelect(weights.data(), weights.size(), mid, hwy::SortAscending());
#else
hwy::VQSort(weights.data(), weights.size(), hwy::SortAscending());
#endif
return weights[mid];
}();
weights = l1_; // restore original order
// Replace with distance from median (might have too few samples for mode).
float max_abs = -1.0f;
for (float& d : weights) {
d = hwy::ScalarAbs(d - median);
max_abs = HWY_MAX(max_abs, d);
}
HWY_ASSERT(max_abs >= 0.0f);
// All equal - return the distance value to prevent division by zero.
if (max_abs == 0.0f) return median;
// Normalize to max difference and exponentiate.
const double inv_max = 1.0 / static_cast<double>(max_abs);
double sum_weights = 0.0;
for (float& w : weights) {
const double normalized = static_cast<double>(w) * inv_max;
const double amplified = exp(4.0 * normalized * normalized);
sum_weights += amplified;
w = static_cast<float>(amplified);
}
// At least 1.0 per weight, plus more for at least one weight because we
// verified via max_abs that not all are equal.
HWY_ASSERT(sum_weights > static_cast<double>(weights.size()));
// Return weighted average.
double weighted_sum = 0.0;
for (size_t i = 0; i < weights.size(); ++i) {
weighted_sum += l1_[i] * weights[i];
}
return weighted_sum / sum_weights;
}
size_t MaxIndex() const { return max_idx_; }
double MaxL1() const { return max_l1_; }
Stats& L1() { return s_l1_; }
Stats& Original() { return s_original_; }
private:
Stats s_original_;
Stats s_l1_;
Bins<100> b_l1_;
CascadedSummation<double> sum_l1_; // all
CascadedSummation<double> sum_l1_rounded_; // only if rounded_to_zero
std::vector<float> l1_;
// Event counts
size_t n_ = 0;
size_t max_idx_ = 0; // index that had l1 = max_l1_.
double max_l1_ = -1.0;
size_t n_sign_flip_ = 0;
size_t n_exact_ = 0;
size_t n_rounded_to_zero = 0;
double sum_pow3_ = 0.0;
double sum_pow6_ = 0.0;
double sum_log_snr_ = 0.0;
size_t num_snr_ = 0;
double sum_log_rel_ = 0.0;
size_t num_rel_ = 0;
double padding_; // prevents false sharing
uint8_t padding_[HWY_ALIGNMENT]; // prevents false sharing
};
} // namespace gcpp

View File

@ -0,0 +1,99 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "compression/distortion.h"
#include <stdio.h>
#include "compression/test_util.h"
#include "hwy/nanobenchmark.h"
#include "hwy/tests/hwy_gtest.h"
#include "hwy/tests/test_util-inl.h" // HWY_ASSERT_EQ
namespace gcpp {
namespace {
#if !HWY_TEST_STANDALONE
class DistortionTest : public testing::Test {};
#endif
TEST(DistortionTest, TestCascadedSummation) {
CascadedSummation<double> cs;
// Example from Priest92. Exact sum is 2.
const double kHuge = 9007199254740992.0 * hwy::Unpredictable1(); // 2^53
const double kNeg = -4503599627370495.0 * hwy::Unpredictable1(); // -(2^52-1)
const double kIn[6] = {kHuge, kHuge - 2.0, kNeg, kNeg, kNeg, kNeg};
for (double in : kIn) {
cs.Notify(in);
}
HWY_ASSERT_EQ(2.0, cs.Total());
}
// Number of exact and rounded-to-zero matches expectations.
TEST(DistortionTest, TestCounts) {
// Arbitrary positive/negative original, zero distorted.
DistortionStats stats;
for (size_t i = 1; i < 10; ++i) {
stats.Notify(i / 100.0f, 0.0f);
stats.Notify(i / -100.0f, 0.0f);
}
HWY_ASSERT(stats.NumExact() == 0);
HWY_ASSERT(stats.NumRoundedToZero() == 18);
// Add some exact (same):
size_t num_exact = 0;
for (float x = 0.0f; x <= 1.5f; x += 0.25f) {
stats.Notify(x, x);
stats.Notify(-x, -x);
num_exact += 2;
}
HWY_ASSERT_EQ(num_exact, stats.NumExact());
HWY_ASSERT(stats.NumRoundedToZero() == 18); // unchanged
}
// Few large differences are diluted in SNR but not WeightedAverageL1.
TEST(DistortionTest, TestDilution) {
DistortionStats stats;
for (size_t i = 0; i < 100; ++i) {
stats.Notify(0.998f, 0.999f); // small
}
HWY_ASSERT(IsInside(900.0, 1000.0, stats.GeomeanValueDivL1()));
// All-equal WeightedSum is exact.
HWY_ASSERT(IsNear(0.001, stats.WeightedAverageL1()));
// Now add a large difference:
stats.Notify(1.875f - 0.0625f, 1.875f); // max magnitude, 3-bit mantissa
// .. WeightedAverageL1 is closer to it.
HWY_ASSERT(IsInside(0.020, 0.025, stats.WeightedAverageL1()));
// Add a small and large difference:
stats.Notify((1.75f - 0.125f) / 1024, 1.75f / 1024); // small, 2-bit mantissa
stats.Notify(-1.875f + 0.0625f, -1.875f); // larger negative
// .. SNR is still barely affected.
HWY_ASSERT(IsInside(890.0, 900.0, stats.GeomeanValueDivL1()));
// .. WeightedAverageL1 is higher after another large error.
HWY_ASSERT(IsInside(0.030, 0.035, stats.WeightedAverageL1()));
// With these inputs, none are exact nor round to zero.
HWY_ASSERT(stats.NumExact() == 0);
HWY_ASSERT(stats.NumRoundedToZero() == 0);
HWY_ASSERT_EQ(0.0, stats.SumL1Rounded());
HWY_ASSERT(IsInside(0.220, 0.23, stats.SumL1()));
}
} // namespace
} // namespace gcpp
HWY_TEST_MAIN();

121
compression/io.cc Normal file
View File

@ -0,0 +1,121 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Safe to be first, does not include POSIX headers.
#include "hwy/detect_compiler_arch.h"
// Only compile this file on non-Windows; it replaces io_win.cc. It is easier to
// check this in source code because we support multiple build systems.
#if !HWY_OS_WIN
// Request POSIX 2008, including `pread()` and `posix_fadvise()`.
#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700
#undef _XOPEN_SOURCE
#define _XOPEN_SOURCE 700
#endif
#if !defined(_POSIX_C_SOURCE) || _POSIX_C_SOURCE < 200809
#define _POSIX_C_SOURCE 200809
#endif
// Make `off_t` 64-bit even on 32-bit systems. Works for Android >= r15c.
#undef _FILE_OFFSET_BITS
#define _FILE_OFFSET_BITS 64
#include <fcntl.h> // open
#include <stddef.h>
#include <stdint.h>
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
#include <sys/stat.h> // O_RDONLY
#include <unistd.h> // read, write, close
#include <memory>
#include "compression/io.h"
#include "hwy/base.h" // HWY_ASSERT
namespace gcpp {
class FilePosix : public File {
int fd_ = 0;
public:
explicit FilePosix(int fd) : fd_(fd) { HWY_ASSERT(fd > 0); }
~FilePosix() override {
if (fd_ != 0) {
HWY_ASSERT(close(fd_) != -1);
}
}
uint64_t FileSize() const override {
static_assert(sizeof(off_t) == 8, "64-bit off_t required");
const off_t size = lseek(fd_, 0, SEEK_END);
if (size < 0) {
return 0;
}
return static_cast<uint64_t>(size);
}
bool Read(uint64_t offset, uint64_t size, void* to) const override {
uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
uint64_t pos = 0;
for (;;) {
// pread seems to be faster than lseek + read when parallelized.
const auto bytes_read = pread(fd_, bytes + pos, size - pos, offset + pos);
if (bytes_read <= 0) break;
pos += bytes_read;
HWY_ASSERT(pos <= size);
if (pos == size) break;
}
return pos == size; // success if managed to read desired size
}
bool Write(const void* from, uint64_t size, uint64_t offset) override {
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(from);
uint64_t pos = 0;
for (;;) {
const auto bytes_written =
pwrite(fd_, bytes + pos, size - pos, offset + pos);
if (bytes_written <= 0) break;
pos += bytes_written;
HWY_ASSERT(pos <= size);
if (pos == size) break;
}
return pos == size; // success if managed to write desired size
}
}; // FilePosix
HWY_MAYBE_UNUSED extern std::unique_ptr<File> OpenFileGoogle(
const Path& filename, const char* mode);
std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode) {
std::unique_ptr<File> file; // OpenFileGoogle omitted
if (file) return file;
const bool is_read = mode[0] != 'w';
const int flags = is_read ? O_RDONLY : O_CREAT | O_RDWR | O_TRUNC;
const int fd = open(filename.path.c_str(), flags, 0644);
if (fd < 0) return file;
#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
if (is_read) {
// Doubles the readahead window, which seems slightly faster when cached.
(void)posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL);
}
#endif
return std::make_unique<FilePosix>(fd);
}
} // namespace gcpp
#endif // !HWY_OS_WIN

88
compression/io.h Normal file
View File

@ -0,0 +1,88 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_
#include <stddef.h>
#include <stdint.h>
#include <memory>
#include <string>
#include <utility> // std::move
namespace gcpp {
// Forward-declare to break the circular dependency: OpenFileOrNull returns
// File and has a Path argument, and Path::Exists calls OpenFileOrNull. We
// prefer to define Exists inline because there are multiple io*.cc files.
struct Path;
// Abstract base class enables multiple I/O backends in the same binary.
class File {
public:
File() = default;
virtual ~File() = default;
// Noncopyable.
File(const File& other) = delete;
const File& operator=(const File& other) = delete;
// Returns size in bytes or 0.
virtual uint64_t FileSize() const = 0;
// Returns true if all the requested bytes were read.
virtual bool Read(uint64_t offset, uint64_t size, void* to) const = 0;
// Returns true if all the requested bytes were written.
virtual bool Write(const void* from, uint64_t size, uint64_t offset) = 0;
};
// Returns nullptr on failure. `mode` is either "r" or "w+". This is not just
// named 'OpenFile' to avoid a conflict with Windows.h #define.
std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode);
// Wrapper for strings representing a path name. Differentiates vs. arbitrary
// strings and supports shortening for display purposes.
struct Path {
Path() {}
explicit Path(const char* p) : path(p) {}
explicit Path(std::string p) : path(std::move(p)) {}
Path& operator=(const char* other) {
path = other;
return *this;
}
std::string Shortened() const {
constexpr size_t kMaxLen = 48;
constexpr size_t kCutPoint = kMaxLen / 2 - 5;
if (path.size() > kMaxLen) {
return std::string(begin(path), begin(path) + kCutPoint) + " ... " +
std::string(end(path) - kCutPoint, end(path));
}
if (path.empty()) return "[no path specified]";
return path;
}
// Returns whether the file existed when this was called.
bool Exists() const { return !!OpenFileOrNull(*this, "r"); }
std::string path;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_

115
compression/io_win.cc Normal file
View File

@ -0,0 +1,115 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "hwy/detect_compiler_arch.h"
// Only compile this file on Windows; it replaces io.cc. It is easier to check
// this in source code because we support multiple build systems.
#if HWY_OS_WIN
#include <stddef.h>
#include <stdint.h>
#include "compression/io.h"
#include "hwy/base.h" // HWY_ASSERT
#ifndef WIN32_LEAN_AND_MEAN
#define WIN32_LEAN_AND_MEAN
#endif
#ifndef VC_EXTRALEAN
#define VC_EXTRALEAN
#endif
#include <Windows.h>
namespace gcpp {
class FileWin : public File {
HANDLE hFile_ = INVALID_HANDLE_VALUE;
public:
FileWin(HANDLE hFile) : hFile_(hFile) {
HWY_ASSERT(hFile != INVALID_HANDLE_VALUE);
}
~FileWin() override {
if (hFile_ != INVALID_HANDLE_VALUE) {
HWY_ASSERT(CloseHandle(hFile_) != 0);
}
}
uint64_t FileSize() const override {
DWORD hi;
const DWORD lo = GetFileSize(hFile_, &hi);
if (lo == INVALID_FILE_SIZE) return 0;
return (static_cast<uint64_t>(hi) << 32) | lo;
}
bool Read(uint64_t offset, uint64_t size, void* to) const override {
uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
OVERLAPPED overlapped = {0};
// Loop is required because ReadFile[Ex] size argument is 32-bit.
while (size != 0) {
overlapped.Offset = offset & 0xFFFFFFFF;
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
const DWORD want =
static_cast<DWORD>(HWY_MIN(size, uint64_t{0xFFFFFFFF}));
DWORD got;
if (!ReadFile(hFile_, bytes, want, &got, &overlapped)) {
if (GetLastError() != ERROR_HANDLE_EOF) {
return false;
}
}
offset += got;
bytes += got;
size -= got;
}
return true; // read everything => success
}
bool Write(const void* from, uint64_t size, uint64_t offset) override {
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(from);
OVERLAPPED overlapped = {0};
// Loop is required because WriteFile[Ex] size argument is 32-bit.
while (size != 0) {
overlapped.Offset = offset & 0xFFFFFFFF;
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
const DWORD want =
static_cast<DWORD>(HWY_MIN(size, uint64_t{0xFFFFFFFF}));
DWORD got;
if (!WriteFile(hFile_, bytes, want, &got, &overlapped)) {
if (GetLastError() != ERROR_HANDLE_EOF) {
return false;
}
}
offset += got;
bytes += got;
size -= got;
}
return true; // wrote everything => success
}
}; // FileWin
std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode) {
const bool is_read = mode[0] != 'w';
const DWORD flags =
FILE_ATTRIBUTE_NORMAL | (is_read ? FILE_FLAG_SEQUENTIAL_SCAN : 0);
const DWORD access = is_read ? GENERIC_READ : GENERIC_WRITE;
const DWORD share = is_read ? FILE_SHARE_READ : 0;
const DWORD create = is_read ? OPEN_EXISTING : CREATE_ALWAYS;
const HANDLE hFile = CreateFileA(filename.path.c_str(), access, share,
nullptr, create, flags, nullptr);
if (hFile == INVALID_HANDLE_VALUE) return std::unique_ptr<File>();
return std::make_unique<FileWin>(hFile);
}
} // namespace gcpp
#endif // HWY_OS_WIN

View File

@ -20,9 +20,7 @@
#include <stddef.h>
#include <stdint.h>
// copybara:import_next_line:gemma_cpp
#include "compression/nuq.h"
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h"
#include "hwy/base.h"
@ -37,7 +35,6 @@
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE
#endif
// copybara:import_next_line:gemma_cpp
#include "compression/sfp-inl.h"
#include "hwy/contrib/sort/vqsort-inl.h"
#include "hwy/highway.h"
@ -107,9 +104,8 @@ class NuqClustering {
// Callers are responsible for ignoring lanes where last < first.
HWY_DASSERT(first < kGroupSize);
HWY_DASSERT(last < kGroupSize);
const size_t len = last - first + 1;
const hn::Vec<DF> vlen =
hn::Iota(df, static_cast<float>(static_cast<int>(len)));
const int len = static_cast<int>(last) - static_cast<int>(first) + 1;
const hn::Vec<DF> vlen = hn::Iota(df, static_cast<float>(len));
const hn::Vec<DF> u_lo = hn::Set(df, cumsum_[first]);
const hn::Vec<DF> u_lo2 = hn::Set(df, cumsum2_[first]);
@ -207,7 +203,7 @@ class NuqClustering {
for (size_t num_clusters = 1; num_clusters < kClusters; ++num_clusters) {
// For each batch starting at `last`, one per lane:
for (size_t last = 0; last < kGroupSize; last += N) {
VF min = cc(df, 0, last);
VF min = hn::LoadU(df, &D(0, last));
VI arg = hn::Zero(di);
// For each j (start of rightmost cluster):
VI vj = k1;

View File

@ -18,6 +18,8 @@
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
#include "compression/nuq.h"
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
@ -25,8 +27,10 @@
#include <algorithm> // std::shuffle
#include <random>
#include "compression/test_util.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/tests/test_util.h"
#include "hwy/timer.h"
// clang-format off
@ -35,12 +39,7 @@
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
// Other headers that include Highway must come after foreach_target.h
// copybara:import_next_line:gemma_cpp
#include "compression/nuq-inl.h"
// copybara:import_next_line:gemma_cpp
#include "compression/nuq.h"
// copybara:import_next_line:gemma_cpp
#include "compression/test_util.h"
#include "hwy/highway.h"
#include "hwy/tests/hwy_gtest.h"
#include "hwy/tests/test_util-inl.h"
@ -117,12 +116,16 @@ struct TestPlateaus {
HWY_ASSERT(indices[i] < kClusters);
stats.Notify(in[i], centers[indices[i]]);
}
const float pnorm = stats.PNorm();
const float snr = stats.GeomeanValueDivL1();
fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr,
stats.MaxIndex(), stats.MaxL1());
HWY_ASSERT(pnorm == 0.0f);
HWY_ASSERT(snr == 0.0f);
// Zero error.
HWY_ASSERT_EQ(kGroupSize, stats.NumExact());
HWY_ASSERT_EQ(0, stats.NumSignFlip());
HWY_ASSERT_EQ(0, stats.NumRoundedToZero());
HWY_ASSERT_EQ(0.0, stats.SumL1());
HWY_ASSERT_EQ(0.0f, stats.GeomeanValueDivL1());
HWY_ASSERT_EQ(0.0f, stats.WeightedAverageL1());
// Input was symmetric and zero-mean.
HWY_ASSERT(gcpp::IsInside(-0.05, 0.05, stats.Original().Mean()));
HWY_ASSERT(gcpp::IsNear(0.0, stats.Original().Skewness()));
}
};
@ -160,16 +163,19 @@ struct TestRamp {
HWY_ASSERT(indices[i] < kClusters);
stats.Notify(in[i], centers[indices[i]]);
}
const float pnorm = stats.PNorm();
const float snr = stats.GeomeanValueDivL1();
fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr,
stats.MaxIndex(), stats.MaxL1());
static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected");
const float expected_pnorm = kGroupSize == 128 ? 2.08E-2f : 2.1E-2f;
const float expected_snr = kGroupSize == 128 ? 16.9f : 17.6f;
HWY_ASSERT(expected_pnorm <= pnorm && pnorm < 1.02f * expected_pnorm);
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr);
// Low error.
HWY_ASSERT_EQ(0, stats.NumExact());
HWY_ASSERT(stats.NumSignFlip() < 10);
HWY_ASSERT_EQ(0, stats.NumRoundedToZero());
HWY_ASSERT_EQ(kGroupSize / kClusters / 4.0, stats.SumL1());
HWY_ASSERT(gcpp::IsInside(17.0, 18.0, stats.GeomeanValueDivL1()));
HWY_ASSERT(gcpp::IsInside(0.005, 0.010, stats.WeightedAverageL1()));
HWY_ASSERT(stats.L1().Max() <= 0.04f);
// Input was symmetric about 0.05.
HWY_ASSERT(gcpp::IsNear(0.05, stats.Original().Mean(), 0.01));
HWY_ASSERT(gcpp::IsNear(0.0, stats.Original().Skewness(), 1E-4));
static_assert(kGroupSize == 256, "Update expected");
}
};
@ -210,15 +216,16 @@ struct TestNormal {
HWY_ASSERT(indices[i] < kClusters);
stats.Notify(in[i], centers[indices[i]]);
}
const float pnorm = stats.PNorm();
const float snr = stats.GeomeanValueDivL1();
fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr,
stats.MaxIndex(), stats.MaxL1());
// Moderate error.
HWY_ASSERT_EQ(0, stats.NumExact());
HWY_ASSERT(stats.NumSignFlip() < kGroupSize / kClusters);
HWY_ASSERT_EQ(0, stats.NumRoundedToZero());
HWY_ASSERT(gcpp::IsInside(5.0, 6.0, stats.SumL1()));
HWY_ASSERT(gcpp::IsInside(12.7, 12.8, stats.GeomeanValueDivL1()));
HWY_ASSERT(gcpp::IsInside(0.036, 0.037, stats.WeightedAverageL1()));
HWY_ASSERT(stats.L1().Max() <= 0.10f);
static_assert(kGroupSize == 256, "Update expected");
const float expected_pnorm = 3.68E-2f;
const float expected_snr = 12.7f;
HWY_ASSERT(expected_pnorm <= pnorm && pnorm < 1.02f * expected_pnorm);
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr);
}
};
@ -238,10 +245,9 @@ struct TestOffset {
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
HWY_ASSERT(in && dec1 && dec2 && nuq);
std::mt19937 rng(123);
std::normal_distribution<float> dist{0.001f, 0.3f};
hwy::RandomState rng;
for (size_t i = 0; i < total; ++i) {
in[i] = dist(rng);
in[i] = static_cast<float>(RandomGaussian(rng));
}
// Encode + decode everything
@ -281,11 +287,13 @@ struct TestStream {
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(num));
HWY_ASSERT(in && out && nuq);
std::mt19937 rng(123);
std::normal_distribution<float> dist{0.001f, 0.3f};
hwy::RandomState rng;
Stats in_stats;
for (size_t i = 0; i < num; ++i) {
in[i] = dist(rng);
in[i] = static_cast<float>(RandomGaussian(rng));
in_stats.Notify(in[i]);
}
VerifyGaussian(in_stats);
ClusterBuf buf;
double elapsed = hwy::HighestValue<double>();
@ -314,15 +322,16 @@ struct TestStream {
for (size_t i = 0; i < num; ++i) {
stats.Notify(in[i], hwy::ConvertScalarTo<float>(out[i]));
}
const float pnorm = stats.PNorm();
const float snr = stats.GeomeanValueDivL1();
fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr,
stats.MaxIndex(), stats.MaxL1());
static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected");
const float expected_pnorm = kGroupSize == 128 ? 3.44E-2f : 3.88E-2f;
const float expected_snr = kGroupSize == 128 ? 15.0f : 13.3f;
HWY_ASSERT(expected_pnorm <= pnorm && pnorm < 1.02f * expected_pnorm);
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr);
// Moderate error.
HWY_ASSERT_EQ(0, stats.NumExact());
HWY_ASSERT(stats.NumSignFlip() < num / kClusters);
HWY_ASSERT_EQ(0, stats.NumRoundedToZero());
HWY_ASSERT(gcpp::IsInside(23.0, 24.0, stats.SumL1()));
HWY_ASSERT(gcpp::IsInside(13.0, 13.3, stats.GeomeanValueDivL1()));
HWY_ASSERT(gcpp::IsInside(0.034, 0.035, stats.WeightedAverageL1()));
HWY_ASSERT(stats.L1().Max() <= 0.11f);
static_assert(kGroupSize == 256, "Update expected");
}
};
@ -351,9 +360,8 @@ struct TestDot {
hwy::RandomState rng;
Stats in_stats;
for (size_t i = 0; i < num; ++i) {
const float r = static_cast<float>(RandomGaussian(rng));
in_stats.Notify(r);
in[i] = r;
in[i] = static_cast<float>(RandomGaussian(rng));
in_stats.Notify(in[i]);
}
for (size_t i = 0; i < num; ++i) {
const float r = static_cast<float>(RandomGaussian(rng));
@ -368,7 +376,7 @@ struct TestDot {
HWY_ASSERT(unused_clusters == 0);
// Compute dot product without decompression.
double actual = 0.0;
float actual = 0.0f;
double elapsed = hwy::HighestValue<double>();
for (size_t rep = 0; rep < 20; ++rep) {
hn::Vec<decltype(df)> sum0 = hn::Zero(df);
@ -389,8 +397,8 @@ struct TestDot {
num * sizeof(in[0]) * 1E-6 / elapsed);
// Exact and decompressed dot products for comparison.
double exact = 0.0; // using original input
double expected = 0.0; // using decoded NUQ
float exact = 0.0f; // using original input
float expected = 0.0f; // using decoded NUQ
DistortionStats dec_stats;
Stats ratios;
for (size_t i = 0; i < num; ++i) {
@ -402,24 +410,42 @@ struct TestDot {
ratios.Notify(exact / expected);
}
}
const bool isBF = sizeof(T) == 2;
const double dec_snr = dec_stats.GeomeanValueDivL1();
const double dec_wl1 = dec_stats.WeightedAverageL1();
const double dot_snr = 1.0 / hwy::ScalarAbs(1.0 - ratios.GeometricMean());
// exact and actual fluctuate due to the combination of NUQ imprecision,
// and whether vec[i] is negative or positive, so this is quite loose.
const float final_ratio = HWY_MIN(exact / actual, actual / exact);
fprintf(stderr, "ratios %s\n", ratios.ToString().c_str());
fprintf(stderr,
"exact %.3f e2 %.4f actual %.4f final_ratio %.3f dec_snr %.2f "
"dot_snr %.2f\n",
exact, expected, actual, final_ratio, dec_snr, dot_snr);
if (HWY_ONCE) {
fprintf(stderr, "ratios %s\n", ratios.ToString().c_str());
fprintf(stderr,
"exact %.3f e2 %.4f actual %.4f final_ratio %.3f dec_snr %.2f "
"dot_snr %.2f dec_wl1 %.4f\n",
exact, expected, actual, final_ratio, dec_snr, dot_snr, dec_wl1);
}
// Final values are not too far apart.
HWY_ASSERT(0.88f <= final_ratio && final_ratio <= 1.0f);
HWY_ASSERT(gcpp::IsInside(0.88f, 1.0f, final_ratio));
// Decompressed and uncompressed dot should match exactly.
HWY_ASSERT(hwy::ScalarAbs(expected - actual) < 1E-4f);
// dec[] is close to in[], but we already check that in TestStream.
HWY_ASSERT(dec_snr >= 13.0);
// Geomean of ratios for each i is an approximation of the actual SNR.
HWY_ASSERT(dot_snr >= (sizeof(T) == 2 ? 17.0 : 14.0));
HWY_ASSERT(gcpp::IsNear(expected, actual, 1E-4f));
// Geomean of ratios for each i should be very close to one.
HWY_ASSERT(dot_snr >= (isBF ? 17.7 : 14.3));
// dec[] is close to in[], but we already check that in TestStream with the
// same input distribution.
HWY_ASSERT(gcpp::IsNear(13.1, dec_snr, 0.1));
HWY_ASSERT(gcpp::IsNear(0.034, dec_wl1, 0.001));
HWY_ASSERT(gcpp::IsNear(23.5, dec_stats.SumL1(), 0.1));
HWY_ASSERT(dec_stats.NumSignFlip() < num / kClusters);
HWY_ASSERT_EQ(0, dec_stats.NumExact());
HWY_ASSERT_EQ(0, dec_stats.NumRoundedToZero());
HWY_ASSERT_EQ(0.0, dec_stats.SumL1Rounded());
// Absolute decode errors are in [0, 0.11], and somewhat right-tailed.
HWY_ASSERT(gcpp::IsInside(0.0f, 2E-5f, dec_stats.L1().Min()));
HWY_ASSERT(gcpp::IsInside(0.09f, 0.11f, dec_stats.L1().Max()));
HWY_ASSERT(gcpp::IsInside(0.02, 0.03, dec_stats.L1().Mean()));
HWY_ASSERT(gcpp::IsInside(1.0, 1.1, dec_stats.L1().Skewness()));
HWY_ASSERT(gcpp::IsInside(4.0, 5.0, dec_stats.L1().Kurtosis()));
static_assert(kGroupSize == 256, "Update expected*");
}
};

View File

@ -20,7 +20,6 @@
#include <stddef.h>
#include <stdint.h>
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h"
#include "hwy/base.h"

View File

@ -18,7 +18,6 @@
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h"
#include <stddef.h>
@ -27,6 +26,7 @@
#include <set>
#include "compression/test_util.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/timer.h"
@ -37,10 +37,7 @@
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
// Any highway.h must come after foreach_target.h
// copybara:import_next_line:gemma_cpp
#include "compression/sfp-inl.h"
// copybara:import_next_line:gemma_cpp
#include "compression/test_util.h"
#include "hwy/highway.h"
#include "hwy/tests/hwy_gtest.h"
#include "hwy/tests/test_util-inl.h"
@ -307,10 +304,37 @@ struct TestEncDec {
sum += hwy::ConvertScalarTo<double>(hwy::ScalarAbs(in[i]));
stats.Notify(hwy::ConvertScalarTo<float>(in[i]), out);
}
const double avg = sum / num;
fprintf(stderr, "Avg magnitude %.3E, p-norm %.3E snr %.2f @%zu = %.4E\n",
avg, stats.PNorm(), stats.GeomeanValueDivL1(), stats.MaxIndex(),
stats.MaxL1());
const double avg_in = sum / num;
const double snr = stats.GeomeanValueDivL1();
const double wl1 = stats.WeightedAverageL1();
if (false) {
fprintf(stderr,
"Num inputs %zu, avg %.3E, exact %zu round0 %zu (sum %E) snr "
"%.2f wL1 %f\n",
num, avg_in, stats.NumExact(), stats.NumRoundedToZero(),
stats.SumL1Rounded(), snr, wl1);
}
HWY_ASSERT(stats.Original().Count() == stats.L1().Count());
// Inputs are in [-1.875, 1.875], symmetric, and heavy-tailed.
HWY_ASSERT(stats.Original().Min() == -1.875f);
HWY_ASSERT(stats.Original().Max() == 1.875f);
HWY_ASSERT(gcpp::IsInside(-1E-6, 1E-6, stats.Original().Mean()));
HWY_ASSERT(gcpp::IsInside(-1E-6, 1E-6, stats.Original().Skewness()));
HWY_ASSERT(gcpp::IsInside(80.0, 100.0, stats.Original().Kurtosis()));
// Absolute errors are in [0, 0.0625], and (heavy) right-tailed.
HWY_ASSERT(stats.L1().Min() == 0.0f);
HWY_ASSERT(stats.L1().Max() == 0.0625f);
HWY_ASSERT(gcpp::IsInside(4E-4, 5E-4, stats.L1().Mean()));
HWY_ASSERT(gcpp::IsInside(10.0, 15.0, stats.L1().Skewness()));
HWY_ASSERT(gcpp::IsInside(150.0, 200.0, stats.L1().Kurtosis()));
// SNR is low because many *tiny* numbers are rounded to zero.
HWY_ASSERT_EQ(3322, stats.NumRoundedToZero());
HWY_ASSERT(gcpp::IsInside(5E-6, 6E-6, stats.SumL1Rounded()));
HWY_ASSERT(gcpp::IsInside(1.880, 1.885, stats.SumL1()));
HWY_ASSERT_EQ(256, stats.NumExact());
HWY_ASSERT_EQ(0, stats.NumSignFlip());
HWY_ASSERT(gcpp::IsInside(2.70, 2.75, snr));
HWY_ASSERT(gcpp::IsInside(0.010, 0.011, wl1)); // = half of mean |x|.
}
}
};
@ -381,7 +405,7 @@ struct TestDot {
SfpCodec::Enc(d, in.get(), num, sfp.get());
// Compute dot product without decompression.
double actual = 0.0;
float actual = 0.0f;
double elapsed = hwy::HighestValue<double>();
for (size_t rep = 0; rep < 200; ++rep) {
hn::Vec<decltype(df)> sum0 = hn::Zero(df);
@ -417,24 +441,41 @@ struct TestDot {
ratios.Notify(exact / expected);
}
}
const bool isBF = sizeof(T) == 2;
const double dec_snr = dec_stats.GeomeanValueDivL1();
const double dec_wl1 = dec_stats.WeightedAverageL1();
const double dot_snr = 1.0 / hwy::ScalarAbs(1.0 - ratios.GeometricMean());
// exact and actual fluctuate due to the combination of SFP imprecision,
// and whether vec[i] is negative or positive, so this is quite loose.
const float final_ratio = HWY_MIN(exact / actual, actual / exact);
fprintf(stderr, "ratios %s\n", ratios.ToString().c_str());
fprintf(stderr,
"exact %.3f e2 %.4f actual %.4f final_ratio %.3f dec_snr %.2f "
"dot_snr %.2f\n",
exact, expected, actual, final_ratio, dec_snr, dot_snr);
if (HWY_ONCE) {
fprintf(stderr, "ratios %s\n", ratios.ToString().c_str());
fprintf(stderr,
"exact %.3f e2 %.4f actual %.4f final_ratio %.3f dec_snr %.2f "
"dot_snr %.2f dec_wl1 %.5f\n",
exact, expected, actual, final_ratio, dec_snr, dot_snr, dec_wl1);
}
// Final values are not too far apart.
HWY_ASSERT(0.87f <= final_ratio && final_ratio <= 1.0f);
HWY_ASSERT(gcpp::IsInside(0.87f, 1.0f, final_ratio));
// Decompressed and uncompressed dot should match exactly.
HWY_ASSERT(hwy::ScalarAbs(expected - actual) < 1E-4f);
// dec[] is close to in[], but we already check that in TestEncDec.
HWY_ASSERT(dec_snr >= 50.0);
HWY_ASSERT(gcpp::IsNear(expected, actual, 1E-4f));
// Geomean of ratios for each i should be very close to one.
HWY_ASSERT(dot_snr >= (sizeof(T) == 2 ? 70.0 : 1000.0));
HWY_ASSERT(dot_snr >= (isBF ? 70.0 : 1000.0));
// dec[] is close to in[]. We also check that in TestEncDec, but for much
// smaller input magnitudes.
HWY_ASSERT(gcpp::IsNear(isBF ? 51.0 : 64.0, dec_snr, 1.0));
HWY_ASSERT(gcpp::IsNear(isBF ? 0.013 : 0.012, dec_wl1, 0.001));
HWY_ASSERT(gcpp::IsNear(isBF ? 6.2 : 6.3, dec_stats.SumL1(), 0.1));
HWY_ASSERT_EQ(0, dec_stats.NumSignFlip());
HWY_ASSERT_EQ(0, dec_stats.NumRoundedToZero());
HWY_ASSERT_EQ(0.0, dec_stats.SumL1Rounded());
// Absolute decode errors are in [0, 5E-2], and somewhat right-tailed.
HWY_ASSERT(gcpp::IsInside(0.0f, 2E-6f, dec_stats.L1().Min()));
HWY_ASSERT(gcpp::IsInside(3E-2f, 5E-2f, dec_stats.L1().Max()));
HWY_ASSERT(gcpp::IsInside(4E-3, 7E-3, dec_stats.L1().Mean()));
HWY_ASSERT(gcpp::IsInside(1.8, 1.9, dec_stats.L1().Skewness()));
HWY_ASSERT(gcpp::IsInside(6.0, 7.0, dec_stats.L1().Kurtosis()));
}
};

View File

@ -13,7 +13,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// copybara:import_next_line:gemma_cpp
#include "compression/stats.h"
#include <stdio.h>

View File

@ -19,7 +19,6 @@
#include <stdint.h>
#include <stdio.h>
#include <algorithm>
#include <cmath>
#include <string>
@ -45,7 +44,7 @@ class Bins {
}
}
void Print(const char* caption) {
void Print(const char* caption) const {
fprintf(stderr, "\n%s [%zu]\n", caption, N);
size_t last_nonzero = 0;
for (size_t i = N - 1; i < N; --i) {
@ -77,8 +76,8 @@ class Stats {
void Notify(const float x) {
++n_;
min_ = std::min(min_, x);
max_ = std::max(max_, x);
min_ = HWY_MIN(min_, x);
max_ = HWY_MAX(max_, x);
product_ *= x;
@ -119,7 +118,7 @@ class Stats {
// Near zero for normal distributions; if positive on a unimodal distribution,
// the right tail is fatter. Assumes n_ is large.
double SampleSkewness() const {
if (std::abs(m2_) < 1E-7) return 0.0;
if (hwy::ScalarAbs(m2_) < 1E-7) return 0.0;
return m3_ * std::sqrt(static_cast<double>(n_)) / std::pow(m2_, 1.5);
}
// Corrected for bias (same as Wikipedia and Minitab but not Excel).
@ -132,7 +131,7 @@ class Stats {
// Near zero for normal distributions; smaller values indicate fewer/smaller
// outliers and larger indicates more/larger outliers. Assumes n_ is large.
double SampleKurtosis() const {
if (std::abs(m2_) < 1E-7) return 0.0;
if (hwy::ScalarAbs(m2_) < 1E-7) return 0.0;
return m4_ * n_ / (m2_ * m2_);
}
// Corrected for bias (same as Wikipedia and Minitab but not Excel).

View File

@ -24,9 +24,7 @@
#include "hwy/base.h"
// IWYU pragma: begin_exports
// copybara:import_next_line:gemma_cpp
#include "compression/distortion.h"
// copybara:import_next_line:gemma_cpp
#include "compression/stats.h"
#include "hwy/tests/test_util.h" // RandomState
// IWYU pragma: end_exports
@ -51,12 +49,26 @@ HWY_INLINE double RandomGaussian(hwy::RandomState& rng) {
return plus_minus_1 * std::sqrt(kReps / 3.0);
};
// Returns true if val is inside [min, max].
template <typename T>
static inline bool IsInside(T expected_min, T expected_max, T val) {
return expected_min <= val && val <= expected_max;
}
template <typename T>
static inline bool IsNear(T expected, T val, T epsilon = T{1E-6}) {
return IsInside(expected - epsilon, expected + epsilon, val);
}
HWY_INLINE void VerifyGaussian(Stats& stats) {
const double stddev = stats.StandardDeviation();
HWY_ASSERT(-0.01 <= stats.Mean() && stats.Mean() <= 0.01);
HWY_ASSERT(0.30 <= stddev && stddev <= 0.35);
HWY_ASSERT(-1.1 <= stats.Min() && stats.Min() <= -0.9);
HWY_ASSERT(0.9 <= stats.Max() && stats.Max() <= 1.1);
// Inputs are roughly [-1, 1] and symmetric about zero.
HWY_ASSERT(IsNear(-1.0f, stats.Min(), 0.10f));
HWY_ASSERT(IsNear(+1.0f, stats.Max(), 0.10f));
HWY_ASSERT(IsInside(-2E-3, 2E-3, stats.Mean()));
HWY_ASSERT(IsInside(-0.15, 0.15, stats.Skewness()));
// Near-Gaussian.
HWY_ASSERT(IsInside(0.30, 0.35, stats.StandardDeviation()));
HWY_ASSERT(IsNear(3.0, stats.Kurtosis(), 0.3));
}
} // namespace gcpp

132
debug_prompt.cc Normal file
View File

@ -0,0 +1,132 @@
#include <fstream>
#include <iostream>
#include <string>
#include "gemma/gemma.h"
#include "nlohmann/json.hpp"
#include "util/app.h"
#include "util/args.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
using json = nlohmann::json;
class PromptArgs : public gcpp::ArgsBase<PromptArgs> {
public:
PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
gcpp::Path layers_output;
std::string prompt;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(layers_output.path, "layers_output", std::string(""),
"Path to store layers output", 2);
visitor(prompt, "prompt", std::string(""), "Prompt to the model", 2);
}
};
std::pair<std::string, int> QueryModel(
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input,
gcpp::LayersOutputT* layers_output) {
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
prompt.insert(prompt.begin(), 2);
std::string res;
size_t total_tokens = 0;
auto accept_token = [](int) { return true; };
std::mt19937 gen;
gen.seed(42);
auto stream_token = [&res, &total_tokens, &app,
tokenizer = model.Tokenizer()](int token, float) {
++total_tokens;
std::string token_text;
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text));
res += token_text;
return true;
};
if (app.verbosity >= 2) {
std::cout << args.max_tokens << " " << args.max_generated_tokens << " "
<< args.temperature;
}
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
stream_token, accept_token, gen, app.verbosity, layers_output);
return {res, total_tokens};
}
class OutputJsonLogger {
public:
json json_output;
gcpp::LayersOutputT layers_output_log_f =
[this](int pos, const std::string& key, const float* values, size_t values_len) {
std::vector<float> v{values, values + values_len};
json_output[std::to_string(pos)][key] = v;
};
};
/* Run this in the same way as gemma, p.ex.:
./debug_prompt --tokenizer tokenizer.spm --model 2b-it --weights \
2b-it-sfp.sbs --prompt "..." --layers_output [path]
*/
int main(int argc, char** argv) {
gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs args(argc, argv); // inference
gcpp::AppArgs app(argc, argv);
PromptArgs prompt_args(argc, argv);
if (const char* error = loader.Validate()) {
HWY_ABORT("\nInvalid loader args: %s", error);
}
if (const char* error = args.Validate()) {
HWY_ABORT("\nInvalid inference args: %s", error);
}
const bool log_layers_output = !prompt_args.layers_output.path.empty();
OutputJsonLogger json_logger;
gcpp::LayersOutputT* layers_output =
log_layers_output ? &json_logger.layers_output_log_f : nullptr;
hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning threads to cores helps.
if (app.num_threads > 10) {
gcpp::PinThreadToCore(app.num_threads - 1); // Main thread
pool.Run(0, pool.NumThreads(), [](uint64_t /*task*/, size_t thread) {
gcpp::PinThreadToCore(thread);
});
}
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
auto kv_cache = CreateKVCache(loader.ModelType());
const std::string& prompt = prompt_args.prompt;
if (prompt.empty()) {
std::cout << "Please specify --prompt" << std::endl;
return EXIT_FAILURE;
}
const auto [answer, token_count] = QueryModel(
model, args, app, kv_cache, pool, prompt, layers_output);
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
if (log_layers_output) {
std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out);
if (!output_f) {
std::cout << "Opening file failed" << std::endl;
return EXIT_FAILURE;
}
output_f << json_logger.json_output.dump();
if (!output_f) {
std::cout << "Writing to file failed" << std::endl;
return EXIT_FAILURE;
}
output_f.close();
}
return EXIT_SUCCESS;
}

View File

@ -15,12 +15,9 @@
#include <iostream>
// copybara:import_next_line:gemma_cpp
#include "gemma.h"
// copybara:import_next_line:gemma_cpp
#include "third_party/gemma_cpp/gemma.h"
#include "util/app.h" // LoaderArgs
#include "hwy/contrib/thread_pool/thread_pool.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h"
std::vector<int> tokenize(const std::string& prompt_string,

View File

@ -8,16 +8,13 @@
#include <vector>
#include "nlohmann/json.hpp"
// copybara:import_next_line:gemma_cpp
#include "gemma.h"
#include "gemma/gemma.h"
#include "util/app.h"
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
#include "hwy/timer.h"
// copybara:import_next_line:gemma_cpp
#include "util/app.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h"
using json = nlohmann::json;
@ -61,8 +58,7 @@ void LogSpeedStats(const double time_start, size_t total_tokens) {
std::pair<std::string, int> QueryModel(
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
gcpp::KVCache& kv_cache, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
const std::string& input) {
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input) {
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
@ -93,7 +89,7 @@ std::pair<std::string, int> QueryModel(
}
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
inner_pool, stream_token, accept_token, gen, app.verbosity);
stream_token, accept_token, gen, app.verbosity);
if (app.verbosity >= 1) {
LogSpeedStats(time_start, total_tokens);
}
@ -134,8 +130,7 @@ std::string ReadFile(const gcpp::Path& path) {
int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
const std::string& golden_path) {
hwy::ThreadPool& pool, const std::string& golden_path) {
const std::vector<std::pair<std::string, std::string>> queries_answers =
load_goldens(golden_path);
int correct_answers = 0;
@ -143,7 +138,7 @@ int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
const double time_start = hwy::platform::Now();
for (const auto& [question, expected_answer] : queries_answers) {
const auto [answer, token_count] =
QueryModel(model, args, app, kv_cache, inner_pool, pool, question);
QueryModel(model, args, app, kv_cache, pool, question);
total_tokens += token_count;
if (answer.find(expected_answer) != std::string::npos) {
correct_answers++;
@ -167,14 +162,13 @@ int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args,
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
const gcpp::Path& text) {
hwy::ThreadPool& pool, const gcpp::Path& text) {
std::string prompt("Here is some text to summarize:\n");
prompt.append(ReadFile(text));
prompt.append("\nSummarize this text.\n");
const double time_start = hwy::platform::Now();
const auto [answer, token_count] =
QueryModel(model, args, app, kv_cache, inner_pool, pool, prompt);
QueryModel(model, args, app, kv_cache, pool, prompt);
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
LogSpeedStats(time_start, token_count);
return EXIT_SUCCESS;
@ -182,8 +176,8 @@ int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args,
int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
gcpp::InferenceArgs& args, gcpp::AppArgs& app,
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
const gcpp::Path& text, size_t batch_tokens) {
hwy::ThreadPool& pool, const gcpp::Path& text,
size_t batch_tokens) {
std::string input = ReadFile(text);
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
@ -200,7 +194,7 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
auto kv_cache = CreateKVCache(model_type);
float entropy =
ComputeCrossEntropy(model, num_tokens, prompt_slice, kv_cache, pool,
inner_pool, app.verbosity);
app.verbosity);
total_entropy += entropy;
LogSpeedStats(time_start, pos + num_tokens);
std::string text_slice;
@ -214,8 +208,8 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
const gcpp::Path& json_file, size_t max_questions) {
hwy::ThreadPool& pool, const gcpp::Path& json_file,
size_t max_questions) {
std::ifstream trivia_file(json_file.path);
if (!trivia_file) {
std::cout << "Could not load file: " << json_file.path << "\n"
@ -228,7 +222,7 @@ int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
while (std::getline(trivia_file, line)) {
json data = json::parse(line);
const auto [answer, token_count] = QueryModel(
model, args, app, kv_cache, inner_pool, pool, data["question"]);
model, args, app, kv_cache, pool, data["question"]);
std::cout << answer << "\n";
bool correct = false;
for (const std::string expected : data["answer"]["aliases"]) {
@ -266,7 +260,6 @@ int main(int argc, char** argv) {
HWY_ABORT("\nInvalid inference args: %s", error);
}
hwy::ThreadPool inner_pool(0);
hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning threads to cores helps.
if (app.num_threads > 10) {
@ -283,17 +276,16 @@ int main(int argc, char** argv) {
if (!benchmark_args.goldens.path.empty()) {
const std::string golden_path =
benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt";
return BenchmarkGoldens(model, args, app, kv_cache, inner_pool, pool,
golden_path);
return BenchmarkGoldens(model, args, app, kv_cache, pool, golden_path);
} else if (!benchmark_args.summarize_text.path.empty()) {
return BenchmarkSummary(model, args, app, kv_cache, inner_pool, pool,
return BenchmarkSummary(model, args, app, kv_cache, pool,
benchmark_args.summarize_text);
} else if (!benchmark_args.cross_entropy.path.empty()) {
return BenchmarkCrossEntropy(model, loader.ModelType(), args, app,
inner_pool, pool, benchmark_args.cross_entropy,
pool, benchmark_args.cross_entropy,
benchmark_args.batch_tokens);
} else if (!benchmark_args.trivia_qa.path.empty()) {
return BenchmarkTriviaQA(model, args, app, kv_cache, inner_pool, pool,
return BenchmarkTriviaQA(model, args, app, kv_cache, pool,
benchmark_args.trivia_qa,
benchmark_args.max_questions);
}

View File

@ -18,12 +18,8 @@
#include <iostream>
#include <string>
// copybara:import_next_line:gemma_cpp
#include "gemma.h" // Gemma
// copybara:end
// copybara:import_next_line:gemma_cpp
#include "gemma/gemma.h" // Gemma
#include "util/args.h"
// copybara:end
namespace gcpp {
@ -59,7 +55,7 @@ struct Args : public ArgsBase<Args> {
return "Missing --compressed_weights flag, a file for the compressed "
"model.";
}
if (!weights.exists()) {
if (!weights.Exists()) {
return "Can't open file specified with --weights flag.";
}
return nullptr;
@ -74,7 +70,7 @@ struct Args : public ArgsBase<Args> {
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(weights, "weights", Path(),
"Path name of model weights (.sbs) file.\n"
"Path to model weights (.bin) file.\n"
" Required argument.");
visitor(model_type_str, "model", std::string(),
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
@ -84,7 +80,7 @@ struct Args : public ArgsBase<Args> {
"gr2b-pt = griffin 2B parameters, pretrained\n "
" Required argument.");
visitor(compressed_weights, "compressed_weights", Path(),
"Path name where compressed weights file will be written.\n"
"Path name where compressed weights (.sbs) file will be written.\n"
" Required argument.");
visitor(num_threads, "num_threads",
kDefaultNumThreads, // see ChooseNumThreads

View File

@ -15,8 +15,8 @@
// Model configurations
#ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
#define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_
// Allow changing pre-allocated kv cache size as a compiler flag
#ifndef GEMMA_MAX_SEQLEN
@ -28,11 +28,15 @@
#define GEMMA_TOPK 1
#endif // !GEMMA_TOPK
// Allow changing upper bound on threads as a compiler flag
#ifndef GEMMA_MAX_THREADS
#define GEMMA_MAX_THREADS 128
#endif // !GEMMA_MAX_THREADS
#include <stddef.h>
#include <array>
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h"
#include "hwy/base.h" // hwy::bfloat16_t
@ -46,6 +50,7 @@ namespace gcpp {
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
static constexpr size_t kTopK = GEMMA_TOPK;
static constexpr size_t kMaxThreads = GEMMA_MAX_THREADS;
enum class LayerAttentionType {
kGemma,
@ -62,18 +67,36 @@ constexpr std::array<LayerAttentionType, kNum> FixedLayerConfig(
return config;
}
template <size_t kNumLayers>
constexpr size_t NumLayersOfTypeBefore(
const std::array<LayerAttentionType, kNumLayers>& layers,
LayerAttentionType type, size_t num) {
size_t count = 0;
for (size_t i = 0; i < num; i++) {
if (layers[i] == type) count++;
}
return count;
}
struct ConfigGemma7B {
static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
FixedLayerConfig<28>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers =
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
static constexpr int kGriffinLayers =
NumLayersOfTypeBefore(kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
kLayers);
static constexpr int kModelDim = 3072;
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
static constexpr int kHeads = 16;
static constexpr int kKVHeads = 16; // standard MHA
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
// SSM config.
static constexpr int kConv1dWidth = 0;
@ -92,12 +115,19 @@ struct ConfigGemma2B {
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
FixedLayerConfig<18>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers =
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
static constexpr int kGriffinLayers =
NumLayersOfTypeBefore(kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
kLayers);
static constexpr int kModelDim = 2048;
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
static constexpr int kHeads = 8;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
// SSM config.
static constexpr int kConv1dWidth = 0;
@ -144,12 +174,19 @@ struct ConfigGriffin2B {
LayerAttentionType::kGriffinRecurrentBlock,
};
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers =
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
static constexpr int kGriffinLayers =
NumLayersOfTypeBefore(kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
kLayers);
static constexpr int kModelDim = 2560;
static constexpr int kFFHiddenDim = 7680;
static constexpr int kHeads = 10;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
// SSM config.
static constexpr int kConv1dWidth = 4;
@ -164,4 +201,4 @@ struct ConfigGriffin2B {
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_

File diff suppressed because it is too large Load Diff

View File

@ -13,42 +13,42 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_H_
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
#include <functional>
#include <memory>
#include <random>
#include <string>
#include <vector>
// copybara:import_next_line:gemma_cpp
#include "compression/compress.h" // SfpStream/NuqStream
// copybara:import_next_line:gemma_cpp
#include "configs.h"
#include "compression/io.h" // Path
#include "gemma/configs.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::bfloat16_t
#include "hwy/contrib/thread_pool/thread_pool.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // Path
// copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h"
namespace gcpp {
using GemmaWeightT = GEMMA_WEIGHT_T;
using EmbedderInputT = hwy::bfloat16_t;
// Will be called for layers output with:
// - position in the tokens sequence
// - name of the data, p.ex. "tokens", "block.1", "final_norm"
// - pointer to the data array
// - size of the data array
using LayersOutputT =
std::function<void(int, const std::string&, const float*, size_t)>;
constexpr size_t kPrefillBatchSize = 16;
constexpr bool kSystemPrompt = false;
struct KVCache {
hwy::AlignedFreeUniquePtr<float[]>
key_cache; // kSeqLen * kNumGemmaLayers * kKVHeads * kQKVDim
kv_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim * 2
hwy::AlignedFreeUniquePtr<float[]>
value_cache; // kSeqLen * kNumGemmaLayers * kKVHeads * kQKVDim
conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kGriffinLayers
hwy::AlignedFreeUniquePtr<float[]>
conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kNumGriffinLayers
hwy::AlignedFreeUniquePtr<float[]>
rglru_cache; // kModelDim * kNumGriffinLayers
rglru_cache; // kModelDim * kGriffinLayers
};
// Model variants: see configs.h for details.
@ -71,6 +71,7 @@ struct GemmaInterface;
class GemmaTokenizer {
public:
virtual ~GemmaTokenizer() = default;
virtual bool Encode(const std::string& input,
std::vector<std::string>* pieces) const = 0;
virtual bool Encode(const std::string& input,
@ -82,7 +83,7 @@ class GemmaTokenizer {
struct Gemma {
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
hwy::ThreadPool& pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
~Gemma(); // must be defined after the GemmaInterface dtor is defined.
const GemmaTokenizer* Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_;
};
@ -96,16 +97,17 @@ KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
using StreamFunc = std::function<bool(int, float)>;
using AcceptFunc = std::function<bool(int)>;
// layers_output is optional; if set - it will be called with the activations
// output after applying each layer.
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity);
int verbosity, LayersOutputT* layers_output = nullptr);
// Convenience function for the common case:
// - Bundle runtime parameters as RuntimeConfig
// - No threadpools within threadpools (inner_pool = dummy)
// - All tokens accepted
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
const std::vector<int>& prompt, size_t start_pos,
@ -117,11 +119,10 @@ void CompressWeights(gcpp::Model model, const Path& weights,
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
int verbosity);
hwy::ThreadPool& pool, int verbosity);
constexpr int EOS_ID = 1;
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_

View File

@ -13,14 +13,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// copybara:import_next_line:gemma_cpp
#include "gemma.h"
#include "gemma/gemma.h"
#include <thread>
#include <algorithm>
#include <iostream>
#include <random>
#include <string>
#include <thread> // NOLINT
#include <vector>
// copybara:import_next_line:gemma_cpp
#include "ops.h"
// copybara:import_next_line:gemma_cpp
#include "gemma/ops.h"
#include "util/args.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/tests/test_util-inl.h"
@ -34,7 +36,6 @@ class GemmaTest : public ::testing::Test {
: weights("./2b-it-mqa.sbs"),
tokenizer("./tokenizer.spm"),
pool(std::min<int>(20, (std::thread::hardware_concurrency() - 1) / 2)),
inner_pool(0),
model_type(gcpp::Model::GEMMA_2B),
model(tokenizer, weights, model_type, pool) {
kv_cache = CreateKVCache(model_type);
@ -58,8 +59,8 @@ class GemmaTest : public ::testing::Test {
gcpp::GenerateGemma(
model, /*max_tokens=*/3072, /*max_generated_tokens=*/2048,
/*temperature=*/1.0, prompt, /*start_pos=*/0, kv_cache, pool,
inner_pool, stream_token,
/*accept=*/[](int) { return true; }, gen, /*verbosity=*/0);
stream_token, /*accept=*/[](int) { return true; }, gen,
/*verbosity=*/0);
std::string response_text;
HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text));
return response_text;
@ -69,8 +70,7 @@ class GemmaTest : public ::testing::Test {
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
return gcpp::ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt,
kv_cache, pool, inner_pool,
/*verbosity=*/0) /
kv_cache, pool, /*verbosity=*/0) /
prompt_string.size();
}
@ -79,7 +79,7 @@ class GemmaTest : public ::testing::Test {
std::cout << "Question " << i + 1 << "\n\n";
std::string response = GemmaReply(kQA[i][0]);
std::cout << response << "\n\n";
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos);
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
}
}
@ -87,7 +87,6 @@ class GemmaTest : public ::testing::Test {
gcpp::Path tokenizer;
gcpp::KVCache kv_cache;
hwy::ThreadPool pool;
hwy::ThreadPool inner_pool;
gcpp::Model model_type = {};
gcpp::Gemma model;
};

View File

@ -14,8 +14,9 @@
// limitations under the License.
// Include guard for non-SIMD code.
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_H_
#define THIRD_PARTY_GEMMA_CPP_OPS_H_
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
#include <stddef.h>
#include <stdint.h>
@ -24,6 +25,7 @@
#include <random>
#include <type_traits> // std::enable_if_t
#include "compression/compress.h" // CompressedArray
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
@ -43,7 +45,7 @@ static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_OPS_H_
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
// Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE) == defined(HWY_TARGET_TOGGLE)
@ -53,7 +55,6 @@ static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
#define THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE
#endif
// copybara:import_next_line:gemma_cpp
#include "compression/compress-inl.h"
#include "hwy/contrib/algo/transform-inl.h"
#include "hwy/contrib/dot/dot-inl.h"
@ -92,12 +93,60 @@ HWY_INLINE constexpr size_t RowsPerStrip() {
return kRowsPerStrip;
}
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
// ops_test across instruction sets.
template <size_t kM, size_t kN, size_t kK>
HWY_INLINE void MatMul(const float* HWY_RESTRICT a, const float* HWY_RESTRICT b,
float* HWY_RESTRICT out) {
int i, j, k;
for (i = 0; i < kM; ++i) {
for (k = 0; k < kN; ++k) {
for (j = 0; j < kK; ++j) {
out[i * kK + j] += a[i * kN + k] * b[k * kK + j];
}
}
}
}
HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
const size_t size, float* HWY_RESTRICT out) {
const hn::ScalableTag<float> df;
const hn::Repartition<hwy::bfloat16_t, decltype(df)> dbf16;
HWY_DASSERT(size % hn::Lanes(dbf16) == 0);
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
for (size_t i = 0; i < size; i += hn::Lanes(dbf16)) {
const auto interleaved = hn::LoadU(dbf16, vec_aligned + i);
hn::Store(hn::PromoteEvenTo(df, interleaved), df, out + i);
hn::Store(hn::PromoteOddTo(df, interleaved), df, out + i + hn::Lanes(df));
}
}
HWY_INLINE void ToEvenOddF32(const float* HWY_RESTRICT vec_aligned,
const size_t size, float* HWY_RESTRICT out) {
const hn::ScalableTag<float> df;
using VF = hn::Vec<decltype(df)>;
HWY_DASSERT(size % (hn::Lanes(df) * 2) == 0);
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
VF vec0, vec1;
for (size_t i = 0; i < size; i += hn::Lanes(df) * 2) {
hn::LoadInterleaved2(df, vec_aligned + i, vec0, vec1);
hn::Store(vec0, df, out + i);
hn::Store(vec1, df, out + i + hn::Lanes(df));
}
}
// Simple version without tiling nor threading.
// even_odd is precomputed for the current thread.
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
typename VecT, typename AddT>
HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT vec_aligned,
const AddT* HWY_RESTRICT add,
float* HWY_RESTRICT even_odd,
float* HWY_RESTRICT out) {
PROFILER_ZONE("MatVecAddLoop");
const hn::ScalableTag<float> df;
@ -113,12 +162,40 @@ HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs,
}
}
#if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16
template <bool kAdd, size_t kOuter, size_t kInner, typename VecT, typename AddT,
size_t kCapacity>
HWY_INLINE void MatVecAddLoop(
const CompressedArray<hwy::bfloat16_t, kCapacity>& mat,
const size_t mat_ofs, const VecT* HWY_RESTRICT vec_aligned,
const AddT* HWY_RESTRICT add, float* HWY_RESTRICT even_odd,
float* HWY_RESTRICT out) {
PROFILER_ZONE("MatVecAddLoop");
constexpr bool kVecIsEvenOdd = true;
const hn::ScalableTag<float> df;
ToEvenOddF32(vec_aligned, kInner, even_odd);
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) {
const size_t row_ofs = mat_ofs + idx_row * kInner;
if constexpr (kAdd) {
out[idx_row] = hwy::ConvertScalarTo<float>(add[idx_row]) +
Dot<kVecIsEvenOdd>(df, mat, row_ofs, even_odd, kInner);
} else {
out[idx_row] = Dot<kVecIsEvenOdd>(df, mat, row_ofs, even_odd, kInner);
}
}
}
#endif
// even_odd is precomputed for the current thread.
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT vec_aligned,
float* HWY_RESTRICT even_odd,
float* HWY_RESTRICT out) {
MatVecAddLoop<false, kOuter, kInner, ArrayT, VecT, VecT>(
mat, mat_ofs, vec_aligned, /*add=*/nullptr, out);
MatVecAddLoop</*kAdd=*/false, kOuter, kInner>(
mat, mat_ofs, vec_aligned, /*add=*/static_cast<VecT*>(nullptr), even_odd,
out);
}
// Simple version without tiling nor threading, but two offsets/outputs.
@ -156,7 +233,7 @@ HWY_INLINE void TwoOfsMatVecLoop(const ArrayT& mat, const size_t mat_ofs0,
const VecT* HWY_RESTRICT vec_aligned,
float* HWY_RESTRICT out0,
float* HWY_RESTRICT out1) {
TwoOfsMatVecAddLoop<false, kOuter, kInner, ArrayT, VecT, VecT>(
TwoOfsMatVecAddLoop</*kAdd=*/false, kOuter, kInner, ArrayT, VecT, VecT>(
mat, mat_ofs0, mat_ofs1, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr,
out0, out1);
}
@ -166,20 +243,23 @@ namespace detail {
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product
// of row i with `vec_aligned` and add into `out[i]`. The upper-left coordinate
// of the tile is r0, c0.
template <class DF, typename ArrayT, typename VecT>
template <bool kVecEO, class DF, typename ArrayT, typename VecT>
HWY_INLINE void AccumulatePartialDotProducts(
DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0,
size_t c0, size_t num_rows, size_t num_cols,
const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) {
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
out[idx_row] += Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
out[idx_row] +=
Dot<kVecEO>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
}
}
// Same as above, but sets out[i] to the first partial dot product +
// init (if kInit), which avoids having to zero-initialize and accumulate.
template <bool kInit, class DF, typename ArrayT, typename VecT, typename InitT>
// Same as AccumulatePartialDotProducts, but sets out[i] to the first partial
// dot product + init (if kInit), which avoids having to zero-initialize and
// accumulate.
template <bool kVecEO, bool kInit, class DF, typename ArrayT, typename VecT,
typename InitT>
HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
size_t mat_ofs, size_t mat_stride,
size_t r0, size_t c0,
@ -190,10 +270,12 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
if constexpr (kInit) {
out[idx_row] = hwy::ConvertScalarTo<float>(init[idx_row + r0]) +
Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
out[idx_row] =
hwy::ConvertScalarTo<float>(init[idx_row + r0]) +
Dot<kVecEO>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
} else {
out[idx_row] = Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
out[idx_row] =
Dot<kVecEO>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
}
}
}
@ -202,7 +284,8 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
// horizontal strip of the entire matrix); the result is the full dot product
// for rows r in [r0, r0 + num_rows) + optionally the add vector, which we store
// into in out[r - r0].
template <bool kAdd, class DF, typename ArrayT, typename VecT, typename AddT>
template <bool kVecEO, bool kAdd, class DF, typename ArrayT, typename VecT,
typename AddT>
HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
size_t mat_ofs, size_t mat_stride,
size_t r0, size_t num_rows,
@ -211,25 +294,66 @@ HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
float* HWY_RESTRICT out) {
// Tall and skinny: set `out` to the single dot product.
if (mat_stride < MaxCols()) {
SetFirstPartialDotProducts<kAdd>(df, mat, mat_ofs, mat_stride, r0, 0,
num_rows, mat_stride, vec_aligned, add,
out);
SetFirstPartialDotProducts<kVecEO, kAdd>(df, mat, mat_ofs, mat_stride, r0,
0, num_rows, mat_stride,
vec_aligned, add, out);
return;
}
// We have at least MaxCols, so start by setting `out` to that:
SetFirstPartialDotProducts<kAdd>(df, mat, mat_ofs, mat_stride, r0, 0,
num_rows, MaxCols(), vec_aligned, add, out);
SetFirstPartialDotProducts<kVecEO, kAdd>(df, mat, mat_ofs, mat_stride, r0, 0,
num_rows, MaxCols(), vec_aligned,
add, out);
// For further multiples of MaxCols, accumulate. Remainders handled below.
size_t c0 = MaxCols();
for (; c0 <= mat_stride - MaxCols(); c0 += MaxCols()) {
AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows,
MaxCols(), vec_aligned, out);
AccumulatePartialDotProducts<kVecEO>(df, mat, mat_ofs, mat_stride, r0, c0,
num_rows, MaxCols(), vec_aligned, out);
}
if (c0 < mat_stride) { // Final cols
AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows,
mat_stride - c0, vec_aligned, out);
AccumulatePartialDotProducts<kVecEO>(df, mat, mat_ofs, mat_stride, r0, c0,
num_rows, mat_stride - c0, vec_aligned,
out);
}
}
template <bool kVecIsEvenOdd, bool kAdd, size_t kOuter, size_t kInner,
typename ArrayT, typename VecT, typename AddT>
HWY_INLINE void MatVecAddInner(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT const vec_aligned,
const AddT* HWY_RESTRICT const add,
float* HWY_RESTRICT even_odd,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
const hn::ScalableTag<float> df;
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
// Sanity check: each thread can write without race conditions.
if (HWY_IS_TSAN) {
pool.Run(
0, pool.NumWorkers(), [even_odd](uint64_t /*task*/, size_t thread) {
even_odd[thread * kInner] = -static_cast<float>(thread);
even_odd[thread * kInner + kInner - 1] = static_cast<float>(thread);
});
}
// For each entire strip.
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
PROFILER_ZONE("MatVec.lambda");
const size_t r0 = strip * kRowsPerStrip;
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add,
out + r0);
});
// Remaining rows
const size_t r0 = kNumStrips * kRowsPerStrip;
if (r0 < kOuter) {
PROFILER_ZONE("MatVec remainder");
const size_t num_rows = kOuter - r0;
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
df, mat, mat_ofs, kInner, r0, num_rows, vec_aligned, add, out + r0);
}
}
@ -243,38 +367,32 @@ template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT const vec_aligned,
const AddT* HWY_RESTRICT const add,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out,
hwy::ThreadPool& pool) {
PROFILER_ZONE("MatVecAdd");
const hn::ScalableTag<float> df;
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
// For each entire strip.
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
PROFILER_ZONE("MatVec.lambda");
const size_t r0 = strip * kRowsPerStrip;
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, kInner, r0,
kRowsPerStrip, vec_aligned, add,
out + r0);
});
// Remaining rows
const size_t r0 = kNumStrips * kRowsPerStrip;
if (r0 < kOuter) {
PROFILER_ZONE("MatVec remainder");
const size_t num_rows = kOuter - r0;
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, kInner, r0,
num_rows, vec_aligned, add, out + r0);
#if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16
if constexpr (CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd &&
hwy::IsSameEither<VecT, float, hwy::bfloat16_t>()) {
ToEvenOddF32(vec_aligned, kInner, even_odd);
detail::MatVecAddInner</*kVecIsEvenOdd=*/true, kAdd, kOuter, kInner>(
mat, mat_ofs, even_odd, add, even_odd, out, pool);
return;
}
#endif
detail::MatVecAddInner</*kVecIsEvenOdd=*/false, kAdd, kOuter, kInner>(
mat, mat_ofs, vec_aligned, add, even_odd, out, pool);
}
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT const vec_aligned,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
MatVecAdd<false, kOuter, kInner, ArrayT, VecT, VecT>(
mat, mat_ofs, vec_aligned, /*add=*/nullptr, out, pool);
float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out,
hwy::ThreadPool& pool) {
MatVecAdd</*kAdd=*/false, kOuter, kInner>(mat, mat_ofs, vec_aligned,
/*add=*/static_cast<VecT*>(nullptr),
even_odd, out, pool);
}
template <class D, HWY_IF_F32_D(D)>
@ -396,17 +514,18 @@ HWY_NOINLINE void TwoMatVecAdd(
const hn::ScalableTag<float> df;
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
constexpr bool kVecIsEvenOdd = false;
// For each entire strip.
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
PROFILER_ZONE("TwoMatVec.lambda");
const size_t r0 = strip * kRowsPerStrip;
detail::FullDotProductsForStrip<kAdd>(df, mat0, mat_ofs, kInner, r0,
kRowsPerStrip, vec_aligned, add0,
out0 + r0);
detail::FullDotProductsForStrip<kAdd>(df, mat1, mat_ofs, kInner, r0,
kRowsPerStrip, vec_aligned, add1,
out1 + r0);
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
df, mat0, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add0,
out0 + r0);
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
df, mat1, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add1,
out1 + r0);
});
// Remaining rows
@ -414,9 +533,9 @@ HWY_NOINLINE void TwoMatVecAdd(
if (r0 < kOuter) {
PROFILER_ZONE("TwoMatVec remainder");
const size_t num_rows = kOuter - r0;
detail::FullDotProductsForStrip<kAdd>(
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
df, mat0, mat_ofs, kInner, r0, num_rows, vec_aligned, add0, out0 + r0);
detail::FullDotProductsForStrip<kAdd>(
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
df, mat1, mat_ofs, kInner, r0, num_rows, vec_aligned, add1, out1 + r0);
}
}
@ -427,7 +546,7 @@ HWY_NOINLINE void TwoMatVec(const ArrayT& mat0, const ArrayT& mat1,
const VecT* HWY_RESTRICT vec_aligned,
float* HWY_RESTRICT out0, float* HWY_RESTRICT out1,
hwy::ThreadPool& pool) {
TwoMatVecAdd<false, kOuter, kInner, ArrayT, VecT, VecT>(
TwoMatVecAdd</*kAdd=*/false, kOuter, kInner, ArrayT, VecT, VecT>(
mat0, mat1, mat_ofs, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr,
out0, out1, pool);
}

View File

@ -17,22 +17,27 @@
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
#include <stddef.h>
#include <algorithm>
#include <array>
#include <random>
#include <vector>
#include "compression/compress.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "ops_test.cc" //NOLINT
#define HWY_TARGET_INCLUDE "gemma/ops_test.cc" //NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
#include "hwy/tests/test_util-inl.h"
// After highway.h
// copybara:import_next_line:gemma_cpp
#include "ops.h"
#include "gemma/ops.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
@ -373,15 +378,54 @@ CompressedArray<float, kOuter * kInner> GenerateMat(size_t offset) {
return mat;
}
template <size_t kOuter, size_t kInner>
CompressedArray<float, kOuter * kInner> GenerateZeroMat(size_t offset) {
hwy::ThreadPool pool(static_cast<size_t>(std::clamp(
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 4)));
gcpp::CompressWorkingSet ws;
CompressedArray<float, kOuter * kInner> mat;
std::array<float, kOuter * kInner> content;
pool.Run(0, kOuter, [&](const size_t i, size_t thread) {
for (size_t j = 0; j < kInner; j++) {
content[i * kInner + j] = 0.0f;
}
});
Compress(content, ws, mat, pool);
mat.set_scale(1.0f);
return mat;
}
template <size_t length>
hwy::AlignedFreeUniquePtr<float[]> GenerateVec(size_t offset) {
hwy::AlignedFreeUniquePtr<float[]> vec = hwy::AllocateAligned<float>(length);
HWY_ASSERT(vec);
for (size_t idx = 0; idx < length; idx++) {
vec[idx] = static_cast<float>(idx + offset);
}
return vec;
}
// A simple matrix multiplication. No optimization / tiling.
template <size_t kM, size_t kN, size_t kK>
hwy::AlignedFreeUniquePtr<float[]> SimpleMatMul(
const hwy::AlignedFreeUniquePtr<float[]>& a,
const hwy::AlignedFreeUniquePtr<float[]>& b) {
hwy::AlignedFreeUniquePtr<float[]> out = hwy::AllocateAligned<float>(kM * kK);
hwy::ZeroBytes(out.get(), kM * kK * sizeof(float));
int i, j, k;
for (i = 0; i < kM; ++i) {
for (j = 0; j < kK; ++j) {
for (k = 0; k < kN; ++k) {
out[i * kK + j] += a[i * kN + k] * b[k * kK + j];
}
}
}
return out;
}
template <size_t kOuter, size_t kInner>
hwy::AlignedFreeUniquePtr<float[]> SimpleMatVecAdd(
const CompressedArray<float, kOuter * kInner>& mat,
@ -389,8 +433,9 @@ hwy::AlignedFreeUniquePtr<float[]> SimpleMatVecAdd(
const hwy::AlignedFreeUniquePtr<float[]>& add) {
hwy::AlignedFreeUniquePtr<float[]> uncompressed_mat =
hwy::AllocateAligned<float>(kOuter * kInner);
Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner);
hwy::AlignedFreeUniquePtr<float[]> out = hwy::AllocateAligned<float>(kOuter);
HWY_ASSERT(uncompressed_mat && out);
Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner);
for (size_t idx_row = 0; idx_row < kOuter; idx_row++) {
out[idx_row] = add[idx_row];
for (size_t idx_col = 0; idx_col < kInner; idx_col++) {
@ -412,6 +457,52 @@ void AssertClose(const hwy::AlignedFreeUniquePtr<float[]>& a,
}
}
template <typename MatT>
void AssertClose(const hwy::AlignedFreeUniquePtr<MatT[]>& expected,
const hwy::AlignedFreeUniquePtr<MatT[]>& actual, size_t num) {
for (size_t idx = 0; idx < num; idx++) {
double expected_value = hwy::ConvertScalarTo<double>(expected[idx]);
double actual_value = hwy::ConvertScalarTo<double>(actual[idx]);
const double tolerance =
expected_value * 20 * 1.0 / (1ULL << hwy::MantissaBits<MatT>());
if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) {
fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f\n", idx,
expected_value, idx, actual_value);
HWY_ASSERT(0);
}
}
}
void TestMatMul() {
hwy::ThreadPool pool(0);
constexpr size_t kM = 128 * 3; // 384
constexpr size_t kK = 128 * 5; // 640
constexpr size_t kN = 128 * 6; // 768
CompressedArray<float, kM * kN> a1 = GenerateMat<kM, kN>(0);
CompressedArray<float, kN * kK> b1 = GenerateMat<kN, kK>(0);
hwy::AlignedFreeUniquePtr<float[]> a = hwy::AllocateAligned<float>(kM * kN);
Decompress(a1, 0, a.get(), kM * kN);
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kN * kK);
Decompress(b1, 0, b.get(), kN * kK);
hwy::AlignedFreeUniquePtr<float[]> expected_out1 =
SimpleMatMul<kM, kN, kK>(a, b);
CompressedArray<float, kM * kK> compressed_c = GenerateZeroMat<kM, kK>(0);
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
Decompress(compressed_c, 0, c.get(), kM * kK);
MatMul<kM, kN, kK>(a.get(), b.get(), c.get());
AssertClose(expected_out1, c, kM * kK);
}
void TestMatVecAdd() {
hwy::ThreadPool pool(0);
constexpr size_t kOuter = 128 * 3;
@ -419,27 +510,15 @@ void TestMatVecAdd() {
CompressedArray<float, kOuter * kInner> mat = GenerateMat<kOuter, kInner>(0);
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
hwy::AlignedFreeUniquePtr<float[]> add = GenerateVec<kOuter>(0);
hwy::AlignedFreeUniquePtr<float[]> even_odd =
hwy::AllocateAligned<float>(kInner * pool.NumWorkers());
hwy::AlignedFreeUniquePtr<float[]> expected_out =
SimpleMatVecAdd<kOuter, kInner>(mat, vec, add);
hwy::AlignedFreeUniquePtr<float[]> actual_out =
hwy::AllocateAligned<float>(kOuter);
MatVecAdd<true, kOuter, kInner>(mat, 0, vec.get(), add.get(),
actual_out.get(), pool);
AssertClose<kOuter>(actual_out, expected_out);
}
void TestMatVecAddLoop() {
constexpr size_t kOuter = 128 * 3;
constexpr size_t kInner = 128 * 5;
CompressedArray<float, kOuter * kInner> mat = GenerateMat<kOuter, kInner>(0);
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
hwy::AlignedFreeUniquePtr<float[]> add = GenerateVec<kOuter>(0);
hwy::AlignedFreeUniquePtr<float[]> expected_out =
SimpleMatVecAdd<kOuter, kInner>(mat, vec, add);
hwy::AlignedFreeUniquePtr<float[]> actual_out =
hwy::AllocateAligned<float>(kOuter);
MatVecAddLoop<true, kOuter, kInner>(mat, 0, vec.get(), add.get(),
actual_out.get());
HWY_ASSERT(vec && add && even_odd && expected_out && actual_out);
MatVecAdd</*kAdd=*/true, kOuter, kInner>(
mat, 0, vec.get(), add.get(), even_odd.get(), actual_out.get(), pool);
AssertClose<kOuter>(actual_out, expected_out);
}
@ -460,6 +539,8 @@ void TestTwoMatVecAdd() {
hwy::AllocateAligned<float>(kOuter);
hwy::AlignedFreeUniquePtr<float[]> actual_out1 =
hwy::AllocateAligned<float>(kOuter);
HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 &&
expected_out1 && actual_out1);
TwoMatVecAdd<true, kOuter, kInner>(mat0, mat1, 0, vec.get(), add0.get(),
add1.get(), actual_out0.get(),
actual_out1.get(), pool);
@ -482,6 +563,8 @@ void TestTwoOfsMatVecAddLoop() {
hwy::AllocateAligned<float>(kOuter);
hwy::AlignedFreeUniquePtr<float[]> actual_out1 =
hwy::AllocateAligned<float>(kOuter);
HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 &&
expected_out1 && actual_out1);
TwoOfsMatVecAddLoop<true, kOuter, kInner>(mat, 0, 0, vec.get(), add0.get(),
add1.get(), actual_out0.get(),
actual_out1.get());
@ -521,8 +604,8 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatMul);
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAddLoop);
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop);
HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);

View File

@ -23,20 +23,16 @@
#include <vector>
// Placeholder for internal header, do not modify.
// copybara:import_next_line:gemma_cpp
#include "compression/compress.h"
// copybara:import_next_line:gemma_cpp
#include "gemma.h" // Gemma
#include "gemma/gemma.h" // Gemma
#include "util/app.h"
#include "util/args.h" // HasHelp
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
#include "hwy/per_target.h"
#include "hwy/profiler.h"
#include "hwy/timer.h"
// copybara:import_next_line:gemma_cpp
#include "util/app.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // HasHelp
static constexpr bool kVerboseLogTokens = false;
@ -98,11 +94,10 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
void ReplGemma(gcpp::Gemma& model, ModelTraining training,
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const InferenceArgs& args,
int verbosity, const gcpp::AcceptFunc& accept_token,
std::string& eot_line) {
const InferenceArgs& args, int verbosity,
const gcpp::AcceptFunc& accept_token, std::string& eot_line) {
PROFILER_ZONE("Gen.misc");
int abs_pos = 0; // absolute token index over all turns
size_t abs_pos = 0; // absolute token index over all turns
int current_pos = 0; // token index within the current turn
int prompt_size{};
@ -185,7 +180,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
// For instruction-tuned models: add control tokens.
prompt_string = "<start_of_turn>user\n" + prompt_string +
"<end_of_turn>\n<start_of_turn>model\n";
if (abs_pos > 0) {
if (abs_pos != 0) {
// Prepend "<end_of_turn>" token if this is a multi-turn dialogue
// continuation.
prompt_string = "<end_of_turn>\n" + prompt_string;
@ -213,7 +208,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
const double time_start = hwy::platform::Now();
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool,
args.temperature, prompt, abs_pos, kv_cache, pool,
stream_token, accept_token, gen, verbosity);
const double time_end = hwy::platform::Now();
const double tok_sec = current_pos / (time_end - time_start);
@ -233,7 +228,6 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
PROFILER_ZONE("Run.misc");
hwy::ThreadPool inner_pool(0);
hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning threads to cores helps.
if (app.num_threads > 10) {
@ -275,8 +269,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
}
ReplGemma(
model, loader.ModelTraining(), kv_cache, pool, inner_pool, inference,
app.verbosity,
model, loader.ModelTraining(), kv_cache, pool, inference, app.verbosity,
/*accept_token=*/[](int) { return true; }, app.eot_line);
}

View File

@ -18,7 +18,6 @@
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
#include <iterator>
#if HWY_OS_LINUX
#include <sched.h>
@ -32,13 +31,11 @@
#include <algorithm> // std::clamp
#include <thread> // NOLINT>
// copybara:import_next_line:gemma_cpp
#include "configs.h"
// copybara:import_next_line:gemma_cpp
#include "gemma.h"
#include "hwy/base.h" // HWY_ASSERT
// copybara:import_next_line:gemma_cpp
#include "compression/io.h" // Path
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "util/args.h"
#include "hwy/base.h" // HWY_ASSERT
namespace gcpp {
@ -49,6 +46,10 @@ static inline const char* CompiledConfig() {
return "msan";
} else if (HWY_IS_TSAN) {
return "tsan";
#if defined(HWY_IS_HWASAN)
} else if (HWY_IS_HWASAN) {
return "hwasan";
#endif
#if defined(HWY_IS_UBSAN)
} else if (HWY_IS_UBSAN) {
return "ubsan";
@ -84,8 +85,7 @@ class AppArgs : public ArgsBase<AppArgs> {
void ChooseNumThreads() {
if (num_threads == kDefaultNumThreads) {
// This is a rough heuristic, replace with something better in the future.
num_threads = static_cast<size_t>(std::clamp(
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
num_threads = GetSupportedThreadCount();
}
}
@ -95,6 +95,12 @@ class AppArgs : public ArgsBase<AppArgs> {
ChooseNumThreads();
}
static inline size_t GetSupportedThreadCount() {
return static_cast<size_t>(
std::clamp(static_cast<int>(std::thread::hardware_concurrency()) - 2, 1,
HWY_MIN(static_cast<int>(kMaxThreads), 18)));
}
Path log; // output
int verbosity;
size_t num_threads;
@ -137,7 +143,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
if (tokenizer.path.empty()) {
return "Missing --tokenizer flag, a file for the tokenizer is required.";
}
if (!tokenizer.exists()) {
if (!tokenizer.Exists()) {
return "Can't open file specified with --tokenizer flag.";
}
if (!compressed_weights.path.empty()) {
@ -152,7 +158,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
if (weights.path.empty()) {
return "Missing --weights flag, a file for the model weights.";
}
if (!weights.exists()) {
if (!weights.Exists()) {
return "Can't open file specified with --weights flag.";
}
return nullptr;

View File

@ -23,48 +23,11 @@
#include <algorithm> // std::transform
#include <string>
#include "compression/io.h"
#include "hwy/base.h" // HWY_ABORT
#if defined(_WIN32)
#include <io.h>
#define F_OK 0
#define access _access
#else
#include <unistd.h>
#endif
namespace gcpp {
// Wrapper for strings representing a path name. Differentiates vs. arbitrary
// strings and supports shortening for display purposes.
struct Path {
Path() {}
explicit Path(const char* p) : path(p) {}
Path& operator=(const char* other) {
path = other;
return *this;
}
std::string Shortened() const {
constexpr size_t kMaxLen = 48;
constexpr size_t kCutPoint = kMaxLen / 2 - 5;
if (path.size() > kMaxLen) {
return std::string(begin(path), begin(path) + kCutPoint) + " ... " +
std::string(end(path) - kCutPoint, end(path));
}
if (path.empty()) return "[no path specified]";
return path;
}
// Beware, TOCTOU.
bool exists() const {
return (access(path.c_str(), F_OK) == 0);
}
std::string path;
};
// Args is a class that provides a ForEach member function which visits each of
// its member variables. ArgsBase provides functions called by Args to
// initialize values to their defaults (passed as an argument to the visitor),