initial commit

This commit is contained in:
Austin Huang 2024-02-13 06:30:41 +00:00 committed by austinvhuang
commit e29cd566cf
28 changed files with 7125 additions and 0 deletions

79
CMakeLists.txt Normal file
View File

@ -0,0 +1,79 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
cmake_minimum_required(VERSION 3.11)
include(FetchContent)
project(gemma)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f)
FetchContent_MakeAvailable(highway)
## Note: absl meeds tp be installed by sentencepiece. This will only happen if
## cmake is invoked with -DSPM_ENABLE_SHARED=OFF and -DSPM_ABSL_PROVIDER=module
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
FetchContent_MakeAvailable(sentencepiece)
set(SOURCES
gemma.cc
compression/blob_store.cc
compression/blob_store.h
compression/compress.h
compression/compress-inl.h
compression/nuq.h
compression/nuq-inl.h
compression/sfp.h
compression/sfp-inl.h
util/app.h
util/args.h
)
add_compile_options($<$<CONFIG:Release>:-O2>)
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release")
endif()
# Allowable types for WEIGHT_TYPE:
# float - slow, not recommended
# hwy::bfloat16_t - bfloat16 as impemented by https://github.com/google/highway
# SfpStream - 8-bit switched floating point (recommended)
# NuqStream - experimental, work-in-progress
option(WEIGHT_TYPE "Set weight type" "")
if (WEIGHT_TYPE)
add_definitions(-DGEMMA_WEIGHT_T=${WEIGHT_TYPE})
endif()
# Executable Target
add_executable(gemma run.cc)
target_sources(gemma PRIVATE ${SOURCES})
set_property(TARGET gemma PROPERTY CXX_STANDARD 17)
target_link_libraries(gemma hwy hwy_contrib sentencepiece)
target_include_directories(gemma PRIVATE ./)
FetchContent_GetProperties(sentencepiece)
target_include_directories(gemma PRIVATE ${sentencepiece_SOURCE_DIR})
## Library Target
add_library(libgemma ${SOURCES})
set_property(TARGET libgemma PROPERTY CXX_STANDARD 17)
set_target_properties(libgemma PROPERTIES PREFIX "")
target_include_directories(libgemma PUBLIC ./)
target_link_libraries(libgemma hwy hwy_contrib sentencepiece)
target_include_directories(libgemma PRIVATE ${sentencepiece_SOURCE_DIR})

72
DEVELOPERS.md Normal file
View File

@ -0,0 +1,72 @@
# Developer Notes
## Motivation: A Minimalist C++ LLM Runtime for Research and Experimentation
In the past, neural network inference has been similar to a simple, opaque,
stateless function function with a single input and output. By contrast,
foundation model runtimes are better considered as systems with multiple forms
of state, subsystems, and heterogeneous inputs and outputs. They are often
integrated with a wide variety of other systems that have their own resources
(e.g. RAG and tools) and potentially interact with an external environment. They
have become compute engines to embed proximal tasks and goals within expansively
broad, general-purpose world models.
With this in mind, we believe that developing an experimental runtime that is
flexible and approachable will allow us to explore the design space of co-design
between high level model concerns and low-level runtime computation.
## Design Priorities
Given these motivations, we propose the following priorities for
making decisions regarding the direction and design of the codebase.
**Maximize Leverage with a Narrow Scope.** We focus on direct implementations of
foundation models like Gemma. This allows us to focus effort on bottlenecks of
specific models. We are willing to trade off generality to keep implementation
code relatively simple and readable at all layers of the stack, achieve good
performance, and maintain the velocity of a small team.
**Data Oriented Design.** Follow data oriented design principles where possible
to minimize unnecessary performance pessimization. It's best to apply these
optimizations during the initial design, or when refactoring a subcomponent. The
first step is to think in terms of batches or tuples of plain old data (POD)
types: separate arrays, instead of an array of structs. The second is to
de-emphasize control flow (if statements, virtual functions and class
hierarchies). The third step is to know intrinsic properties of data and bake
that into the layout and algorithm.
**Prioritize Small Batch Latency** Since production serving solutions are
available for large-scale serving powered by accelerators and optimizing for
throughput, this project focuses on the possibilities of local, interactive use
of foundation models. Although throughput remains important, low latency and
small batch sizes are prioritized, other things being equal.
**Maintain a Portable Baseline** Our starting point is a portable CPU SIMD (via
[highway](https://github.com/google/highway)). We expect to add accelerator and
hybrid CPU/GPU support in the future, but the project should continue to allow
builds using this portable baseline. This ensures that research-oriented and
experimental runtimes and hardware platforms will have a minimum viable option
to run Gemma even if specialized production-ready deployment paths are not
available.
## Code Organization
The implementation code is roughly split into 4 layers, from high to low level:
1. Frontends (`run.cc`) - Either interactive interfaces or automation
orchestration that interacts. Frontend code implements a use case objective
in terms of invocations to model inference and generation (2). Projects that
use gemma.cpp as a library are considered alternative frontends to `run.cc`.
We will add examples of additional frontends in the future.
2. Models (`gemma.cc`, `gemma.h`, `configs.h`) - Implements the compute graph
of the model including supporting functions such as loading and compressing
weights using transformer operations provided by layer (3).
3. Operations (`ops.h`) - A minimal set of transformer and supporting
mathematical operations implementations using compute backends (4). This
code should be agnostic to the specifics of the compute graph of the model
implementation (2).
4. Backend (`highway`) - Low-level hardware interface (SIMD in the case of
highway) supporting the implementations in (3).

202
LICENSE Normal file
View File

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

26
LICENSE-BSD3 Normal file
View File

@ -0,0 +1,26 @@
Copyright (c) The gemma.cpp Project Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

335
README.md Normal file
View File

@ -0,0 +1,335 @@
# gemma.cpp
gemma.cpp is a lightweight, standalone C++ inference engine for the Gemma
foundation models from Google.
For additional information about Gemma, see
[ai.google.dev/gemma](https://ai.google.dev/gemma). Model weights, including gemma.cpp
specific artifacts, are [available on
kaggle](https://www.kaggle.com/models/google/gemma).
## Who is this project for?
Modern LLM inference engines are sophisticated systems, often with bespoke
capabilities extending beyond traditional neural network runtimes. With this
comes opportunities for research and innovation through co-design of high level
algorithms and low-level computation. However, there is a gap between
deployment-oriented C++ inference runtimes, which are not designed for
experimentation, and Python-centric ML research frameworks, which abstract away
low-level computation through compilation.
gemma.cpp provides a minimalist implementation of Gemma 2B and 7B models,
focusing on simplicity and directness rather than full generality. This is
inspired by vertically-integrated model implementations such as
[ggml](https://github.com/ggerganov/ggml),
[llama.c](https://github.com/karpathy/llama2.c), and
[llama.rs](https://github.com/srush/llama2.rs).
gemma.cpp targets experimentation and research use cases. It is intended to be
straightforward to embed in other projects with minimal dependencies and also
easily modifiable with a small ~2K LoC core implementation (along with ~4K LoC
of supporting utilities). We use the [Google
Highway](https://github.com/google/highway) Library to take advantage of
portable SIMD for CPU inference.
For production-oriented edge deployments we recommend standard deployment
pathways using Python frameworks like JAX, Keras, PyTorch, and Transformers
([all model variations here](https://www.kaggle.com/models/google/gemma)).
Community contributions large and small are welcome. This project follows
[Google's Open Source Community
Guidelines](https://opensource.google.com/conduct/).
## Quick Start
### System requirements
Before starting, you should have installed:
- [CMake](https://cmake.org/)
- [Clang C++ compiler](https://clang.llvm.org/get_started.html), supporting at
least C++17.
- `tar` for extracting archives from Kaggle.
### Step 1: Obtain model weights and tokenizer from Kaggle
Visit [the Gemma model page on
Kaggle](https://www.kaggle.com/models/google/gemma) and select `Model Variations
|> Gemma C++`. On this tab, the `Variation` dropdown includes the options below.
Note bfloat16 weights are higher fidelity, while 8-bit switched floating point
weights enable faster inference.
2B instruction-tuned (`it`) and pre-trained (`pt`) models:
| Model name | Description |
| ----------- | ----------- |
| `2b-it` | 2 billion parameter instruction-tuned model, bfloat16 |
| `2b-it-sfp` | 2 billion parameter instruction-tuned model, 8-bit switched floating point |
| `2b-pt` | 2 billion parameter pre-trained model, bfloat16 |
| `2b-pt-sfp` | 2 billion parameter pre-trained model, 8-bit switched floating point |
7B instruction-tuned (`it`) and pre-trained (`pt`) models:
| Model name | Description |
| ----------- | ----------- |
| `7b-it` | 7 billion parameter instruction-tuned model, bfloat16 |
| `7b-it-sfp` | 7 billion parameter instruction-tuned model, 8-bit switched floating point |
| `7b-pt` | 7 billion parameter pre-trained model, bfloat16 |
| `7b-pt-sfp` | 7 billion parameter pre-trained model, 8-bit switched floating point |
> [!NOTE]
> We *recommend starting with `2b-it-sfp`* to get up and running.
### Step 2: Extract Files
After filling out the consent form, the download should proceed to retrieve a
tar archive file `archive.tar.gz`. Extract files from `archive.tar.gz` (this can
take a few minutes):
```
tar -xf archive.tar.gz
```
This should produce a file containing model weights such as `2b-it-sfp.sbs` and
a tokenizer file (`tokenizer.spm`). You may want to move these files to a
convenient directory location (e.g. the `build/` directory in this repo).
### Step 3: Build
The build system uses [CMake](https://cmake.org/). To build the gemma inference
runtime, create a build directory and generate the build files using `cmake`
from the top-level project directory:
```sh
(cd build && cmake ..)
```
Then run `make` to build the `./gemma` executable:
```sh
cd build
make -j [number of parallel threads to use] gemma
```
For example, `make -j 8 gemma`. If this is successful, you should now have a
`gemma` executable in the `build/` directory.
> [!NOTE]
> On Windows Subsystem for Linux (WSL) users should set the number of
> parallel threads to 1. Using a larger number may result in errors.
### Step 4: Run
You can now run `gemma` from inside the `build/` directory.
`gemma` has the following required arguments:
| Argument | Description | Example value |
| -------- | ----------- | ------------- |
| `--model` | The model type. | `2b-it`, `2b-pt`, `7b-it`, `7b-pt`, ... (see above) |
| `--compressed_weights` | The compressed weights file. | `2b-it-sfp.sbs`, ... (see above) |
| `--tokenizer` | The tokenizer file. | `tokenizer.spm` |
`gemma` is invoked as:
```sh
./gemma \
--tokenizer [tokenizer file] \
--compressed_weights [compressed weights file] \
--model [2b-it or 2b-pt or 7b-it or 7b-pt or ...]
```
Example invocation for the following configuration:
- Compressed weights file `2b-it-sfp.sbs` (2B instruction-tuned model, 8-bit
switched floating point).
- Tokenizer file `tokenizer.spm`.
```sh
./gemma \
--tokenizer tokenizer.spm \
--compressed_weights 2b-it-sfp.sbs \
--model 2b-it
```
## Usage
`gemma` has different usage modes, controlled by the verbosity flag.
All usage modes are currently interactive, triggering text generation upon
newline input.
| Verbosity | Usage mode | Details |
| --------------- | ---------- | --------------------------------------------- |
| `--verbosity 0` | Minimal | Only prints generation output. Suitable as a CLI tool. |
| `--verbosity 1` | Default | Standard user-facing terminal UI. |
| `--verbosity 2` | Detailed | Shows additional developer and debug info. |
### Interactive Terminal App
By default, verbosity is set to 1, bringing up a terminal-based interactive
interface when `gemma` is invoked:
```console
$ ./gemma [...]
__ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __
/ _` |/ _ \ '_ ` _ \| '_ ` _ \ / _` | / __| '_ \| '_ \
| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |
\__, |\___|_| |_| |_|_| |_| |_|\__,_(_)___| .__/| .__/
__/ | | | | |
|___/ |_| |_|
tokenizer : tokenizer.spm
compressed_weights : 2b-it-sfp.sbs
model : 2b-it
weights : [no path specified]
max_tokens : 3072
max_generated_tokens : 2048
*Usage*
Enter an instruction and press enter (%Q quits).
*Examples*
- Write an email to grandma thanking her for the cookies.
- What are some historical attractions to visit around Massachusetts?
- Compute the nth fibonacci number in javascript.
- Write a standup comedy bit about WebGPU programming.
> What are some outdoorsy places to visit around Boston?
[ Reading prompt ] .....................
**Boston Harbor and Islands:**
* **Boston Harbor Islands National and State Park:** Explore pristine beaches, wildlife, and maritime history.
* **Charles River Esplanade:** Enjoy scenic views of the harbor and city skyline.
* **Boston Harbor Cruise Company:** Take a relaxing harbor cruise and admire the city from a different perspective.
* **Seaport Village:** Visit a charming waterfront area with shops, restaurants, and a seaport museum.
**Forest and Nature:**
* **Forest Park:** Hike through a scenic forest with diverse wildlife.
* **Quabbin Reservoir:** Enjoy boating, fishing, and hiking in a scenic setting.
* **Mount Forest:** Explore a mountain with breathtaking views of the city and surrounding landscape.
...
```
### Usage as a Command Line Tool
For using the `gemma` executable as a command line tool, it may be useful to
create an alias for gemma.cpp with arguments fully specified:
```sh
alias gemma2b="~/gemma.cpp/build/gemma -- --tokenizer ~/gemma.cpp/build/tokenizer.spm --compressed_weights ~/gemma.cpp/build/2b-it-sfp.sbs --model 2b-it --verbosity 0"
```
Replace the above paths with your own paths to the model and tokenizer paths
from the download.
Here is an example of prompting `gemma` with a truncated input
file (using a `gemma2b` alias like defined above):
```sh
cat configs.h | tail -35 | tr '\n' ' ' | xargs -0 echo "What does this C++ code do: " | gemma2b
```
> [!NOTE]
> CLI usage of gemma.cpp is experimental and should take context length
> limitations into account.
The output of the above command should look like:
```console
$ cat configs.h | tail -35 | tr '\n' ' ' | xargs -0 echo "What does this C++ code do: " | gemma2b
[ Reading prompt ] ......................................................................................................................................................................................................................................................................................................................................................................................................................................................................................
The code defines two C++ structs, `ConfigGemma7B` and `ConfigGemma2B`, which are used for configuring a deep learning model.
**ConfigGemma7B**:
* `seq_len`: Stores the length of the sequence to be processed. It's set to 7168.
* `vocab_size`: Stores the size of the vocabulary, which is 256128.
* `n_layers`: Number of layers in the deep learning model. It's set to 28.
* `dim_model`: Dimension of the model's internal representation. It's set to 3072.
* `dim_ffw_hidden`: Dimension of the feedforward and recurrent layers' hidden representations. It's set to 16 * 3072 / 2.
**ConfigGemma2B**:
* `seq_len`: Stores the length of the sequence to be processed. It's also set to 7168.
* `vocab_size`: Size of the vocabulary, which is 256128.
* `n_layers`: Number of layers in the deep learning model. It's set to 18.
* `dim_model`: Dimension of the model's internal representation. It's set to 2048.
* `dim_ffw_hidden`: Dimension of the feedforward and recurrent layers' hidden representations. It's set to 16 * 2048 / 2.
These structs are used to configure a deep learning model with specific parameters for either Gemma7B or Gemma2B architecture.
```
### Incorporating gemma.cpp as a Library in your Project
The easiest way to incorporate gemma.cpp in your own project is to pull in
gemma.cpp and dependencies using `FetchContent`. You can add the following to your
CMakeLists.txt:
```
include(FetchContent)
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
FetchContent_MakeAvailable(sentencepiece)
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main)
FetchContent_MakeAvailable(gemma)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f)
FetchContent_MakeAvailable(highway)
```
Note for the gemma.cpp `GIT_TAG`, you may replace `origin/main` for a specific
commit hash if you would like to pin the library version.
After your executable is defined (substitute your executable name for
`[Executable Name]` below):
```
target_link_libraries([Executable Name] libgemma hwy hwy_contrib sentencepiece)
FetchContent_GetProperties(gemma)
FetchContent_GetProperties(sentencepiece)
target_include_directories([Executable Name] PRIVATE ${gemma_SOURCE_DIR})
target_include_directories([Executable Name] PRIVATE ${sentencepiece_SOURCE_DIR})
```
### Building gemma.cpp as a Library
gemma.cpp can also be used as a library dependency in your own project. The
shared library artifact can be built by modifying the make invocation to build
the `libgemma` target instead of `gemma`.
> [!NOTE]
> If you are using gemma.cpp in your own project with the `FetchContent` steps
> in the previous section, building the library is done automatically by `cmake`
> and this section can be skipped.
First, run `cmake`:
```sh
(cd build && cmake ..)
```
Then, run `make` with the `libgemma` target:
```sh
cd build
make -j [number of parallel threads to use] libgemma
```
If this is successful, you should now have a
`libgemma` library file in the `build/` directory. On linux the filename is `libgemma.a`.
## Acknowledgements and Contacts
gemma.cpp was started in fall 2023 by [Austin Huang](austinvhuang@google.com)
and [Jan Wassenberg](janwas@google.com), and subsequently released February 2024
thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng.
This is not an officially supported Google product.

3
build/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
*
!.gitignore
!.hgignore

244
compression/analyze.h Normal file
View File

@ -0,0 +1,244 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Normal include guard to placate lint.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h> // memcpy
#include <cmath> // std::signbit
#include <cstdlib> // std::abs
#include <vector>
// copybara:import_next_line:gemma_cpp
#include "compression/distortion.h"
// copybara:import_next_line:gemma_cpp
#include "compression/nuq.h"
// copybara:import_next_line:gemma_cpp
#include "compression/stats.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/timer.h"
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
// Actual per-target include guard.
#if defined(THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
#endif
// copybara:import_next_line:gemma_cpp
#include "compression/nuq-inl.h"
// copybara:import_next_line:gemma_cpp
#include "compression/sfp-inl.h"
#include "hwy/contrib/sort/vqsort-inl.h"
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
class PerThread {
public:
void NotifyGroup(const float* group) {
Stats s_group;
for (size_t i = 0; i < kGroupSize; ++i) {
// Skip zero so we can see the lowest actual magnitude
if (group[i] == 0.0f || group[i] == -0.0f) continue;
s_all_.Notify(group[i]);
s_group.Notify(group[i]);
num_tiny_ += std::abs(group[i]) < 1e-3f;
// b_magn100_.Notify(group[i] * 40.0f + 20.0f);
const uint32_t binary32 =
hwy::BitCastScalar<uint32_t>(std::abs(group[i]));
// const int32_t exp = (binary32 >> 23) - 127;
b_exp256_.Notify(binary32 >> 23);
const uint32_t m4 = (binary32 & 0x7FFFFF) >> (23 - 4);
b_m4_.Notify(m4);
}
s_group_ranges_.Notify(s_group.Max() - s_group.Min());
s_group_mins_.Notify(s_group.Min());
s_group_maxs_.Notify(s_group.Max());
float desc[kGroupSize];
memcpy(desc, group, kGroupSize * sizeof(group[0]));
hn::VQSortStatic(desc, kGroupSize, hwy::SortDescending());
// Find largest |max/min| (dynamic range)
float max_ratio = 0.0f;
for (size_t i = 0; i < kGroupSize; ++i) {
if (desc[i] != 0.0f && desc[i] != -0.0f) {
max_ratio = std::max(max_ratio, std::abs(desc[0] / desc[i]));
}
}
s_group_max_vs_min_.Notify(max_ratio);
// Relative errors
float diffs[kGroupSize];
for (size_t i = 0; i < kGroupSize - 1; ++i) {
// was in descending order. Avoid div by 0. Ignore sign changes.
diffs[i] = std::abs(desc[i]) < 1e-5
? 0
: std::abs((desc[i] - desc[i + 1]) / desc[i]);
}
hn::VQSortStatic(diffs, kGroupSize, hwy::SortDescending());
s_cut15_.Notify(diffs[15]);
}
void Assimilate(const PerThread& other) {
num_tiny_ += other.num_tiny_;
s_all_.Assimilate(other.s_all_);
s_group_ranges_.Assimilate(other.s_group_ranges_);
s_group_mins_.Assimilate(other.s_group_mins_);
s_group_maxs_.Assimilate(other.s_group_maxs_);
s_group_max_vs_min_.Assimilate(other.s_group_max_vs_min_);
s_erange_.Assimilate(other.s_erange_);
s_km_1_.Assimilate(other.s_km_1_);
s_km_2_.Assimilate(other.s_km_2_);
s_cut15_.Assimilate(other.s_cut15_);
b_magn100_.Assimilate(other.b_magn100_);
b_exp256_.Assimilate(other.b_exp256_);
b_m4_.Assimilate(other.b_m4_);
}
void PrintAll() {
const int skip = Stats::kNoGeomean;
fprintf(stderr, "num tiny %zu\n", num_tiny_);
fprintf(stderr, "weights %s\n", s_all_.ToString(skip).c_str());
fprintf(stderr, " ranges %s\n", s_group_ranges_.ToString(skip).c_str());
fprintf(stderr, " mins %s\n", s_group_mins_.ToString(skip).c_str());
fprintf(stderr, " maxs %s\n", s_group_maxs_.ToString(skip).c_str());
fprintf(stderr, " Mvm %s\n", s_group_max_vs_min_.ToString(skip).c_str());
fprintf(stderr, " cut15 %s\n", s_cut15_.ToString(skip).c_str());
fprintf(stderr, " erange %s\n", s_erange_.ToString(skip).c_str());
fprintf(stderr, " km1 %s\n", s_km_1_.ToString(skip).c_str());
fprintf(stderr, " km2 %s\n", s_km_2_.ToString(skip).c_str());
// b_magn100_.Print("magn100");
// b_exp256_.Print("exp");
// b_m4_.Print("mantissa bits4");
fprintf(stderr, "\n");
}
private:
size_t num_tiny_ = 0;
Stats s_all_;
Stats s_group_ranges_;
Stats s_group_mins_;
Stats s_group_maxs_;
Stats s_group_max_vs_min_;
Stats s_erange_;
Stats s_km_1_;
Stats s_km_2_;
Stats s_cut15_;
Bins<100> b_magn100_;
Bins<256> b_exp256_;
Bins<16> b_m4_;
uint8_t padding_[64]; // prevent false sharing
};
class PerLayer {
public:
void NotifyGroup(const float* group) {
for (size_t i = 0; i < kGroupSize; ++i) {
s_layer_.Notify(group[i]);
}
}
void UpdateOutliers(const float* layer, size_t weights_per_layer) {
const float layer_mean = s_layer_.Mean();
const float layer_sd = s_layer_.StandardDeviation();
for (size_t i = 0; i < weights_per_layer; ++i) {
num_outliers_ +=
std::abs(std::abs(layer[i]) - layer_mean) >= 3.0f * layer_sd;
}
}
const Stats& GetStats() const { return s_layer_; }
size_t Outliers() const { return num_outliers_; }
private:
Stats s_layer_;
size_t num_outliers_ = 0;
uint8_t padding[64]; // prevent false sharing
};
static HWY_NOINLINE void Analyze(const char* caption, float* mat, size_t layers,
size_t weights_per_layer,
hwy::ThreadPool& pool) {
std::vector<PerThread> tls;
std::vector<PerLayer> per_layer(layers);
const auto init = [&](size_t num_threads) {
tls.resize(num_threads);
return true;
};
pool.Run(0, static_cast<uint32_t>(layers), init,
[&](uint32_t idx_layer, size_t idx_thread) {
PerThread& self = tls[idx_thread];
const float* layer = &mat[idx_layer * weights_per_layer];
// For each whole group in the layer
for (size_t group_start = 0;
group_start + kGroupSize <= weights_per_layer;
group_start += kGroupSize) {
const float* group = layer + group_start;
per_layer[idx_layer].NotifyGroup(group);
self.NotifyGroup(group);
}
per_layer[idx_layer].UpdateOutliers(layer, weights_per_layer);
});
const int skip = Stats::kNoGeomean;
fprintf(stderr, "\n------------%s\n", caption);
for (size_t i = 1; i < pool.NumThreads(); ++i) {
tls[0].Assimilate(tls[i]);
}
tls[0].PrintAll();
Stats s_layer_ranges;
Stats s_layer_outliers;
for (size_t i = 0; i < layers; ++i) {
fprintf(stderr, " %02zu %s\n", i,
per_layer[i].GetStats().ToString(skip).c_str());
const float range =
per_layer[i].GetStats().Max() - per_layer[i].GetStats().Min();
s_layer_ranges.Notify(range);
s_layer_outliers.Notify((100.0 * per_layer[i].Outliers()) /
weights_per_layer);
}
fprintf(stderr, "layer outliers%% %s\n",
s_layer_outliers.ToString(skip).c_str());
fprintf(stderr, "layer ranges %s\n", s_layer_ranges.ToString(skip).c_str());
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_

348
compression/blob_store.cc Normal file
View File

@ -0,0 +1,348 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// copybara:import_next_line:gemma_cpp
#include "compression/blob_store.h"
#include <fcntl.h> // open
#include <stdint.h>
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
#include <sys/stat.h> // O_RDONLY
#include <unistd.h> // read, close
#include <atomic>
#include <vector>
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/detect_compiler_arch.h"
namespace gcpp {
hwy::uint128_t MakeKey(const char* string) {
size_t length = 0;
for (size_t i = 0; string[i] != '\0'; ++i) {
++length;
}
if (length > 16) {
HWY_ABORT("Key %s is too long, please truncate to 16 chars.", string);
}
hwy::uint128_t ret;
hwy::ZeroBytes<sizeof(ret)>(&ret);
hwy::CopyBytes(string, &ret, length);
return ret;
}
static void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data,
std::vector<BlobIO>& requests) {
// Split into chunks for load-balancing even if blob sizes vary.
constexpr size_t kChunkSize = 4 * 1024 * 1024;
// Split into whole chunks and possibly one remainder.
uint64_t pos = 0;
if (size >= kChunkSize) {
for (; pos <= size - kChunkSize; pos += kChunkSize) {
requests.emplace_back(offset + pos, kChunkSize, data + pos, 0);
}
}
if (pos != size) {
requests.emplace_back(offset + pos, size - pos, data + pos, 0);
}
}
struct IO {
// Returns size in bytes or 0.
static uint64_t FileSize(const char* filename) {
int fd = open(filename, O_RDONLY);
if (fd >= 0) {
const off_t size = lseek(fd, 0, SEEK_END);
HWY_ASSERT(close(fd) != -1);
if (size != static_cast<off_t>(-1)) {
return static_cast<uint64_t>(size);
}
}
return 0;
}
static bool Read(int fd, uint64_t offset, uint64_t size, void* to) {
uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
uint64_t pos = 0;
for (;;) {
// pread seems to be faster than lseek + read when parallelized.
const auto bytes_read = pread(fd, bytes + pos, size - pos, offset + pos);
if (bytes_read <= 0) break;
pos += bytes_read;
HWY_ASSERT(pos <= size);
if (pos == size) break;
}
return pos == size; // success if managed to read desired size
}
static bool Write(const void* from, uint64_t size, uint64_t offset, int fd) {
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(from);
uint64_t pos = 0;
for (;;) {
const auto bytes_written =
pwrite(fd, bytes + pos, size - pos, offset + pos);
if (bytes_written <= 0) break;
pos += bytes_written;
HWY_ASSERT(pos <= size);
if (pos == size) break;
}
return pos == size; // success if managed to write desired size
}
}; // IO
static_assert(HWY_IS_LITTLE_ENDIAN, "Assumes little endian");
// On-disk representation (little-endian).
//
// Deliberately omits a version number because this file format is unchanging.
// Additional data may be added only inside new blobs. Changes to the blob
// contents or type should be handled by renaming keys.
#pragma pack(push, 1)
class BlobStore {
static constexpr uint32_t kMagic = 0x0A534253; // SBS\n
// Blob offsets on disk and memory addresses are a multiple of this, because
// we pad the header and each blob's size. This matches CUDA alignment and the
// maximum SVE vector size, and exceeds typical x86 cache line sizes (64 or
// 128), which can help performance.
static constexpr size_t kAlign = 256;
public:
// NOT including padding, so that we can also use ZeroFillPadding after
// copying the header.
static constexpr size_t HeaderSize(size_t num_blobs) {
// 16-byte fixed fields plus per-blob: 16-byte key, 16-byte offset/size.
return 16 + 32 * num_blobs;
}
// Returns how many bytes to allocate for the header without the subsequent
// blobs. Requires num_blobs_ to already be set, typically by reading
// sizeof(BlobStore) bytes from disk.
size_t PaddedHeaderSize() const {
return hwy::RoundUpTo(HeaderSize(num_blobs_), kAlign);
}
// Returns aligned offset and zero-fills between that and `offset`.
uint64_t ZeroFillPadding(uint64_t offset) {
uint8_t* const bytes = reinterpret_cast<uint8_t*>(this);
const uint64_t padded = hwy::RoundUpTo(offset, kAlign);
hwy::ZeroBytes(bytes + offset, padded - offset);
return padded;
}
BlobError CheckValidity(const uint64_t file_size) {
if (magic_ != kMagic) return __LINE__;
if (num_blobs_ == 0) return __LINE__;
if (file_size_ != file_size) return __LINE__;
// Ensure blobs are back to back, and zero-pad.
uint64_t offset = ZeroFillPadding(HeaderSize(num_blobs_));
for (size_t i = 0; i < num_blobs_; ++i) {
const hwy::uint128_t val = keys_[num_blobs_ + i];
if (val.lo != offset) return __LINE__;
offset = ZeroFillPadding(offset + val.hi);
}
if (offset != file_size_) return __LINE__;
return 0; // all OK
}
static BlobStorePtr Allocate(uint64_t total_size) {
uint8_t* bytes =
static_cast<uint8_t*>(hwy::AllocateAlignedBytes(total_size));
if (!bytes) return BlobStorePtr();
return BlobStorePtr(new (bytes) BlobStore(), hwy::AlignedFreer());
}
static std::vector<BlobIO> PrepareWriteRequests(
const hwy::uint128_t keys[], const hwy::Span<uint8_t> blobs[],
size_t num_blobs) {
// Sanity check and ensure the cast below is safe.
HWY_ASSERT(num_blobs < (1ULL << 20));
// Allocate var-length header.
const size_t header_size = HeaderSize(num_blobs);
const size_t padded_header_size = hwy::RoundUpTo(header_size, kAlign);
BlobStorePtr bs = Allocate(padded_header_size);
const uint64_t padded_header_end = bs->ZeroFillPadding(header_size);
HWY_ASSERT(padded_header_end == padded_header_size);
// All-zero buffer used to write padding to the file without copying the
// input blobs.
static uint8_t zeros[kAlign] = {0};
// Total file size will be the header plus all padded blobs.
uint64_t payload = 0;
for (size_t i = 0; i < num_blobs; ++i) {
payload += hwy::RoundUpTo(blobs[i].size(), kAlign);
}
const size_t total_size = padded_header_size + payload;
// Fill header.
bs->magic_ = kMagic;
bs->num_blobs_ = static_cast<uint32_t>(num_blobs);
bs->file_size_ = total_size;
hwy::CopyBytes(keys, bs->keys_, num_blobs * sizeof(keys[0]));
// First IO request is for the header (not yet filled!).
std::vector<BlobIO> requests;
requests.reserve(1 + 2 * num_blobs);
requests.emplace_back(/*offset=*/0, padded_header_size,
reinterpret_cast<uint8_t*>(bs.get()), 0);
// Fill second half of keys_ with offset/size and prepare IO requests.
uint64_t offset = padded_header_end;
for (size_t i = 0; i < num_blobs; ++i) {
bs->keys_[num_blobs + i].lo = offset;
bs->keys_[num_blobs + i].hi = blobs[i].size();
EnqueueChunkRequests(offset, blobs[i].size(), blobs[i].data(), requests);
offset += blobs[i].size();
const size_t padded_size = hwy::RoundUpTo(blobs[i].size(), kAlign);
if (padded_size != blobs[i].size()) {
const size_t padding = padded_size - blobs[i].size();
HWY_ASSERT(padding <= kAlign);
requests.emplace_back(offset, padding, zeros, 0);
offset += padding;
}
}
HWY_ASSERT(offset == total_size);
return requests;
}
bool FindKey(const hwy::uint128_t key, uint64_t& offset, size_t& size) const {
for (size_t i = 0; i < num_blobs_; ++i) {
if (keys_[i] == key) {
const hwy::uint128_t val = keys_[num_blobs_ + i];
offset = val.lo;
size = val.hi;
return true;
}
}
return false;
}
private:
uint32_t magic_;
uint32_t num_blobs_; // never 0
uint64_t file_size_; // must match actual size of file
hwy::uint128_t keys_[1]; // length: 2 * num_blobs
// Padding, then the blob identified by keys[0], then padding etc.
};
#pragma pack(pop)
BlobError BlobReader::Open(const char* filename) {
fd_ = open(filename, O_RDONLY);
if (fd_ < 0) return __LINE__;
#if _POSIX_C_SOURCE >= 200112L
// Doubles the readahead window, which seems slightly faster when cached.
(void)posix_fadvise(fd_, 0, 0, POSIX_FADV_SEQUENTIAL);
#endif
// Read first part of header to get actual size.
BlobStore bs;
if (!IO::Read(fd_, 0, sizeof(bs), &bs)) return __LINE__;
const size_t padded_size = bs.PaddedHeaderSize();
HWY_ASSERT(padded_size >= sizeof(bs));
// Allocate full header.
blob_store_ = BlobStore::Allocate(padded_size);
if (!blob_store_) return __LINE__;
// Copy what we already read (more efficient than seek + re-read).
hwy::CopySameSize(&bs, blob_store_.get());
// Read the rest of the header, but not the full file.
uint8_t* bytes = reinterpret_cast<uint8_t*>(blob_store_.get());
if (!IO::Read(fd_, sizeof(bs), padded_size - sizeof(bs),
bytes + sizeof(bs))) {
return __LINE__;
}
return blob_store_->CheckValidity(IO::FileSize(filename));
}
BlobReader::~BlobReader() {
if (fd_ >= 0) {
HWY_ASSERT(close(fd_) != -1);
}
}
BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) {
uint64_t offset;
size_t actual_size;
if (!blob_store_->FindKey(key, offset, actual_size)) return __LINE__;
if (actual_size != size) return __LINE__;
EnqueueChunkRequests(offset, actual_size, reinterpret_cast<uint8_t*>(data),
requests_);
return 0;
}
// Parallel synchronous I/O. Alternatives considered:
// - readv is limited to 0x7FFFF000 bytes on Linux (even 64-bit). Note that
// pread calls preadv with a single iovec.
// - O_DIRECT seems undesirable because we do want to use the OS cache
// between consecutive runs.
// - memory-mapped I/O is less predictable and adds noise to measurements.
BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
const int fd = fd_;
const auto& requests = requests_;
std::atomic_flag err = ATOMIC_FLAG_INIT;
// >5x speedup from parallel reads when cached.
pool.Run(0, requests.size(),
[fd, &requests, &err](uint64_t i, size_t /*thread*/) {
if (!IO::Read(fd, requests[i].offset, requests[i].size,
requests[i].data)) {
err.test_and_set();
}
});
if (err.test_and_set()) return __LINE__;
return 0;
}
BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool,
const char* filename) const {
HWY_ASSERT(keys_.size() == blobs_.size());
// Concatenate blobs in memory.
std::vector<BlobIO> requests = BlobStore::PrepareWriteRequests(
keys_.data(), blobs_.data(), keys_.size());
// Create/replace existing file.
const int fd = open(filename, O_CREAT | O_RDWR | O_TRUNC, 0644);
if (fd < 0) return __LINE__;
std::atomic_flag err = ATOMIC_FLAG_INIT;
pool.Run(0, requests.size(),
[fd, &requests, &err](uint64_t i, size_t /*thread*/) {
if (!IO::Write(requests[i].data, requests[i].size,
requests[i].offset, fd)) {
err.test_and_set();
}
});
if (err.test_and_set()) return __LINE__;
return 0;
}
} // namespace gcpp

90
compression/blob_store.h Normal file
View File

@ -0,0 +1,90 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_BLOB_STORE_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_BLOB_STORE_H_
#include <stddef.h>
#include <stdint.h>
#include <vector>
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::uint128_t
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
// Convenient way to construct a key from a string (<= 16 chars).
hwy::uint128_t MakeKey(const char* string);
// Ordered list of opaque blobs (~hundreds), identified by unique opaque
// 128-bit keys.
class BlobStore;
// Incomplete type, so dtor will not be called.
using BlobStorePtr = hwy::AlignedFreeUniquePtr<BlobStore>;
// 0 if successful, otherwise the line number of the failing check.
using BlobError = int;
struct BlobIO {
BlobIO(uint64_t offset, size_t size, void* data, uint64_t padding)
: offset(offset), size(size), data(data), padding(padding) {}
uint64_t offset;
size_t size;
void* data;
uint64_t padding;
};
class BlobReader {
public:
BlobReader() { requests_.reserve(500); }
~BlobReader();
// Opens `filename` and reads its header.
BlobError Open(const char* filename);
// Enqueues read requests if `key` is found and its size matches `size`.
BlobError Enqueue(hwy::uint128_t key, void* data, size_t size);
// Reads all enqueued requests.
BlobError ReadAll(hwy::ThreadPool& pool);
private:
BlobStorePtr blob_store_; // holds header, not the entire file
std::vector<BlobIO> requests_;
int fd_ = 0;
};
class BlobWriter {
public:
void Add(hwy::uint128_t key, void* data, size_t size) {
keys_.push_back(key);
blobs_.emplace_back(static_cast<uint8_t*>(data), size);
}
// Stores all blobs to disk in the given order with padding for alignment.
BlobError WriteAll(hwy::ThreadPool& pool, const char* filename) const;
private:
std::vector<hwy::uint128_t> keys_;
std::vector<hwy::Span<uint8_t>> blobs_;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_BLOB_STORE_H_

467
compression/compress-inl.h Normal file
View File

@ -0,0 +1,467 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Include guard for headers.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_INL_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_INL_H_
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <array>
// copybara:import_next_line:gemma_cpp
#include "compression/blob_store.h"
// copybara:import_next_line:gemma_cpp
#include "compression/compress.h"
// copybara:import_next_line:gemma_cpp
#include "compression/distortion.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/timer.h"
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_INL_H_
// Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE
#endif
// copybara:import_next_line:gemma_cpp
#include "compression/nuq-inl.h"
// copybara:import_next_line:gemma_cpp
#include "compression/sfp-inl.h"
#include "hwy/contrib/dot/dot-inl.h"
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
// Enables generic code independent of compression type.
template <typename T> // primary, must specialize
struct CompressTraits {};
template <>
struct CompressTraits<float> {
using MatT = float;
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
size_t num, CompressPerThread& tls,
size_t /*out_capacity*/,
MatT* HWY_RESTRICT out, size_t out_ofs) {
using VF = hn::Vec<decltype(df)>;
const size_t N = hn::Lanes(df);
HWY_DASSERT(num >= 2 * N && num % (2 * N) == 0);
for (size_t i = 0; i < num; i += 2 * N) {
const VF in0 = hn::LoadU(df, in + i);
const VF in1 = hn::LoadU(df, in + i + N);
hn::StoreU(in0, df, out + out_ofs + i);
hn::StoreU(in1, df, out + out_ofs + i + N);
}
}
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Decompress(DF df, size_t /*in_capacity*/,
const MatT* HWY_RESTRICT in, size_t in_ofs,
float* HWY_RESTRICT out, size_t num) {
using VF = hn::Vec<decltype(df)>;
const size_t N = hn::Lanes(df);
HWY_DASSERT(num >= 2 * N && num % (2 * N) == 0);
for (size_t i = 0; i < num; i += 2 * N) {
const VF in0 = hn::LoadU(df, in + in_ofs + i);
const VF in1 = hn::LoadU(df, in + in_ofs + i + N);
hn::StoreU(in0, df, out + i);
hn::StoreU(in1, df, out + i + N);
}
}
// VecT can be float or hwy::bfloat16_t.
template <class DF, typename VecT, HWY_IF_F32_D(DF)>
static HWY_INLINE float Dot(DF df, size_t /*in_capacity*/,
const MatT* HWY_RESTRICT in, size_t in_ofs,
const VecT* HWY_RESTRICT vec_aligned,
size_t num) {
HWY_DASSERT(num >= hn::Lanes(df) && (num % hn::Lanes(df)) == 0);
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
constexpr int kAssumptions =
hn::Dot::kAtLeastOneVector | hn::Dot::kMultipleOfVector;
// vec_aligned must be the second argument because hn::Dot supports f32*bf16
// and f32*f32.
return hn::Dot::Compute<kAssumptions>(df, in + in_ofs, vec_aligned, num);
}
};
template <>
struct CompressTraits<hwy::bfloat16_t> {
using MatT = hwy::bfloat16_t;
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
size_t num, CompressPerThread& tls,
size_t /*out_capacity*/,
MatT* HWY_RESTRICT out, size_t out_ofs) {
const hn::RebindToUnsigned<decltype(df)> du;
const hn::Repartition<hwy::bfloat16_t, decltype(df)> dbf;
using VF = hn::Vec<decltype(df)>;
const size_t N = hn::Lanes(df);
hn::Vec<decltype(du)> or_sum = hn::Zero(du);
size_t i = 0;
if (num >= 2 * N) {
for (; i <= num - 2 * N; i += 2 * N) {
const VF in0 = hn::LoadU(df, in + i);
const VF in1 = hn::LoadU(df, in + i + N);
// Sticky bits so we can warn if any lower bits were set.
or_sum = hn::Or3(or_sum, hn::BitCast(du, in0), hn::BitCast(du, in1));
hn::StoreU(hn::OrderedDemote2To(dbf, in0, in1), dbf, out + out_ofs + i);
if (COMPRESS_STATS) {
DistortionStats stats;
for (size_t j = 0; j < 2 * N; ++j) {
stats.Notify(in[i + j], hwy::F32FromBF16(out[out_ofs + i + j]));
}
tls.stats.Notify(stats);
}
}
}
size_t remaining = num - i;
if (remaining != 0) {
const VF in0 = hn::LoadN(df, in + i, remaining);
const size_t remaining1 = remaining - HWY_MIN(remaining, N / 2);
const VF in1 = hn::LoadN(df, in + i + N, remaining1);
// Sticky bits so we can warn if any lower bits were set.
or_sum = hn::Or3(or_sum, hn::BitCast(du, in0), hn::BitCast(du, in1));
hn::StoreU(hn::OrderedDemote2To(dbf, in0, in1), dbf, out + out_ofs + i);
if (COMPRESS_STATS) {
DistortionStats stats;
for (size_t j = 0; j < remaining; ++j) {
stats.Notify(in[i + j], hwy::F32FromBF16(out[out_ofs + i + j]));
}
tls.stats.Notify(stats);
}
}
// If the lower 16 bits are not zero, we should implement rounding.
or_sum = hn::And(or_sum, hn::Set(du, 0xFFFF));
if (!hn::AllTrue(du, hn::Eq(or_sum, hn::Zero(du)))) {
// fprintf(stderr, "Warning: Lossy truncation.");
}
}
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Decompress(DF df, size_t /*in_capacity*/,
const MatT* HWY_RESTRICT in, size_t in_ofs,
float* HWY_RESTRICT out, size_t num) {
const hn::Repartition<hwy::bfloat16_t, decltype(df)> dbf;
using VBF = hn::Vec<decltype(dbf)>;
using VF = hn::Vec<decltype(df)>;
const size_t N16 = hn::Lanes(dbf);
size_t i = 0;
if (num >= N16) {
for (i = 0; i <= num - N16; i += N16) {
const VBF in16 = hn::LoadU(dbf, in + in_ofs + i);
const VF in0 = hn::PromoteLowerTo(df, in16);
const VF in1 = hn::PromoteUpperTo(df, in16);
hn::StoreU(in0, df, out + i);
hn::StoreU(in1, df, out + i + N16 / 2);
}
}
size_t remaining = num - i;
if (remaining != 0) {
const VBF in16 = hn::LoadN(dbf, in + in_ofs + i, remaining);
const VF in0 = hn::PromoteLowerTo(df, in16);
const VF in1 = hn::PromoteUpperTo(df, in16);
hn::StoreN(in0, df, out + i, remaining);
// Avoid wraparound, potentially store nothing.
const size_t remaining1 = remaining - HWY_MIN(remaining, N16 / 2);
hn::StoreN(in1, df, out + i + N16 / 2, remaining1);
}
}
// VecT can be float or hwy::bfloat16_t.
template <class DF, typename VecT, HWY_IF_F32_D(DF)>
static HWY_INLINE float Dot(DF df, size_t /*in_capacity*/,
const MatT* HWY_RESTRICT in, size_t in_ofs,
const VecT* HWY_RESTRICT vec_aligned,
size_t num) {
HWY_DASSERT(num >= hn::Lanes(df) && (num % hn::Lanes(df)) == 0);
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
const hn::Repartition<VecT, decltype(df)> d_vec;
constexpr int kAssumptions =
hn::Dot::kAtLeastOneVector | hn::Dot::kMultipleOfVector;
// vec_aligned must be first argument because hn::Dot supports f32*bf16 and
// bf16*bf16.
return hn::Dot::Compute<kAssumptions>(d_vec, vec_aligned, in + in_ofs, num);
}
};
template <>
struct CompressTraits<SfpStream> {
using MatT = SfpStream;
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* in, size_t num,
CompressPerThread& tls,
size_t /*out_capacity*/, MatT* out,
size_t out_ofs) {
SfpCodec::Enc(df, in, num, out + out_ofs);
if (COMPRESS_STATS) {
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
auto distorted = hwy::AllocateAligned<hwy::bfloat16_t>(num);
SfpCodec::Dec(dbf, out + out_ofs, num, distorted.get());
DistortionStats stats;
for (size_t i = 0; i < num; ++i) {
stats.Notify(in[i], hwy::F32FromBF16(distorted[i]));
}
tls.stats.Notify(stats);
}
}
template <class D, typename OutT>
static HWY_INLINE void Decompress(D d, size_t /*in_capacity*/, const MatT* in,
size_t in_ofs, OutT* out, size_t num) {
SfpCodec::Dec(d, in + in_ofs, num, out);
}
template <class DF, typename VecT, HWY_IF_F32_D(DF)>
static HWY_INLINE float Dot(DF df, size_t /*in_capacity*/, const MatT* in,
size_t in_ofs, const VecT* vec_aligned,
size_t num) {
using VF = hn::Vec<decltype(df)>;
VF sum0 = hn::Zero(df);
VF sum1 = hn::Zero(df);
VF sum2 = hn::Zero(df);
VF sum3 = hn::Zero(df);
SfpCodec::Dot(df, in + in_ofs, num, vec_aligned, sum0, sum1, sum2, sum3);
// Reduction tree: sum of all accumulators, then their lanes
sum0 = hn::Add(sum0, sum1);
sum2 = hn::Add(sum2, sum3);
sum0 = hn::Add(sum0, sum2);
return hn::ReduceSum(df, sum0);
}
};
template <>
struct CompressTraits<NuqStream> {
using MatT = NuqStream;
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* in, size_t num,
CompressPerThread& tls, size_t out_capacity,
MatT* out, size_t out_ofs) {
NuqCodec::Enc(df, in, num, tls.buf, out_capacity, out, out_ofs);
if (COMPRESS_STATS) {
for (size_t i = 0; i < num; ++i) {
tls.stats.NotifyIn(in[i] * 100 + 500);
}
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
auto distorted = hwy::AllocateAligned<hwy::bfloat16_t>(num);
NuqCodec::Dec(dbf, out_capacity, out, out_ofs, distorted.get(), num);
DistortionStats stats;
for (size_t i = 0; i < num; ++i) {
stats.Notify(in[i], hwy::F32FromBF16(distorted[i]));
}
tls.stats.Notify(stats);
}
}
template <class D, typename OutT>
static HWY_INLINE void Decompress(D d, size_t in_capacity, const MatT* in,
size_t in_ofs, OutT* out, size_t num) {
NuqCodec::Dec(d, in_capacity, in, in_ofs, out, num);
}
template <class DF, typename VecT, HWY_IF_F32_D(DF)>
static HWY_INLINE float Dot(DF df, size_t in_capacity, const MatT* in,
size_t in_ofs,
const VecT* HWY_RESTRICT vec_aligned,
size_t num) {
using VF = hn::Vec<decltype(df)>;
VF sum0 = hn::Zero(df);
VF sum1 = hn::Zero(df);
VF sum2 = hn::Zero(df);
VF sum3 = hn::Zero(df);
NuqCodec::Dot(df, in_capacity, in, in_ofs, vec_aligned, num, sum0, sum1,
sum2, sum3);
// Reduction tree: sum of all accumulators, then their lanes
sum0 = hn::Add(hn::Add(sum0, sum1), hn::Add(sum2, sum3));
return hn::ReduceSum(df, sum0);
}
};
// Compresses `num` inputs to `out` starting at `out_ofs`. This can be used for
// compressing sub-regions of an array.
template <typename MatT>
HWY_NOINLINE void Compress(const float* in, size_t num,
CompressWorkingSet& work, size_t out_capacity,
MatT* out, size_t out_ofs, hwy::ThreadPool& pool) {
HWY_DASSERT(out_ofs + num <= out_capacity);
work.tls.resize(pool.NumThreads());
if (COMPRESS_STATS) {
for (auto& tls : work.tls) {
tls.stats.Reset();
}
}
const double t0 = hwy::platform::Now();
using Traits = CompressTraits<MatT>;
constexpr size_t kBatch = 8192;
const size_t num_batches = hwy::DivCeil(num, kBatch);
pool.Run(0, num_batches,
[&](const uint32_t idx_batch, size_t thread) HWY_ATTR {
const hn::ScalableTag<float> df;
const size_t in_ofs = idx_batch * kBatch;
const size_t my_num =
idx_batch == num_batches - 1 ? (num - in_ofs) : kBatch;
Traits::Compress(df, in + in_ofs, my_num, work.tls[thread],
out_capacity, out, out_ofs + in_ofs);
});
const double t1 = hwy::platform::Now();
const double mb = num * sizeof(in[0]) * 1E-6;
const double mbps = mb / (t1 - t0);
fprintf(stderr, "Compress %.1f MB/s\n", mbps);
if (COMPRESS_STATS) {
for (size_t i = 1; i < work.tls.size(); ++i) {
work.tls[0].stats.Assimilate(work.tls[i].stats);
}
work.tls[0].stats.PrintAll();
}
}
// Compresses an entire std::array into `out`, which is assumed to have exactly
// that much capacity.
template <size_t kCapacity, typename MatT>
HWY_INLINE void Compress(const std::array<float, kCapacity>& in,
CompressWorkingSet& work,
CompressedArray<MatT, kCapacity>& compressed,
hwy::ThreadPool& pool) {
Compress(in.data(), kCapacity, work, kCapacity, compressed.data(), 0, pool);
}
// Decompresses `num` values from `compressed` starting at `compressed_ofs`.
template <typename MatT, size_t kCapacity, typename OutT>
HWY_NOINLINE void Decompress(const CompressedArray<MatT, kCapacity>& compressed,
size_t compressed_ofs, OutT* out, size_t num) {
HWY_DASSERT(compressed_ofs + num <= compressed.NumElements());
const hn::ScalableTag<OutT> d;
using Traits = CompressTraits<MatT>;
Traits::Decompress(d, kCapacity, compressed.data(), compressed_ofs, out, num);
}
// As above, but with threading and benchmarking.
template <typename MatT, size_t kCapacity, typename OutT>
HWY_INLINE void Decompress(const CompressedArray<MatT, kCapacity>& compressed,
size_t compressed_ofs, OutT* out, size_t num,
hwy::ThreadPool& pool) {
HWY_DASSERT(compressed_ofs + num <= compressed.NumElements());
const double t0 = hwy::platform::Now();
using Traits = CompressTraits<MatT>;
constexpr size_t kBatch = 8192;
const size_t num_batches = hwy::DivCeil(num, kBatch);
pool.Run(
0, num_batches, [&](const uint32_t idx_batch, size_t thread) HWY_ATTR {
const hn::ScalableTag<OutT> d;
const size_t ofs = idx_batch * kBatch;
const size_t num = idx_batch == num_batches - 1 ? (num - ofs) : kBatch;
Traits::Decompress(d, compressed.NumElements(), compressed.data(),
compressed_ofs + ofs, out + ofs, num);
});
const double t1 = hwy::platform::Now();
const double mb = num * sizeof(MatT) * 1E-6;
const double mbps = mb / (t1 - t0);
fprintf(stderr, "Decompress %.1f MB/s\n", mbps);
}
// Returns dot product with `vec_aligned` of length `num`.
template <class DF, typename MatT, size_t kCapacity, typename VecT>
HWY_INLINE float Dot(DF df, const CompressedArray<MatT, kCapacity>& compressed,
size_t compressed_ofs, const VecT* vec_aligned,
size_t num) {
HWY_DASSERT(compressed_ofs + num <= compressed.NumElements());
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
using Traits = CompressTraits<MatT>;
return Traits::Dot(df, kCapacity, compressed.data(), compressed_ofs,
vec_aligned, num);
}
// Callback used by ForeachTensor.
class Compressor {
public:
explicit Compressor(hwy::ThreadPool& pool) : pool_(pool) {}
// Called for each tensor; compresses it and stores to the cache.
template <typename MatT, size_t kCapacity>
void operator()(const char* name, const float* weights,
CompressedArray<MatT, kCapacity>& compressed) {
fprintf(stderr, "Regenerating %s (%zuM), please wait\n", name,
kCapacity / (1000 * 1000));
Compress(weights, kCapacity, work_, kCapacity, compressed.data(), 0, pool_);
writer_.Add(CacheKey<MatT>(name), compressed.data(),
compressed.CompressedSize());
}
void WriteAll(hwy::ThreadPool& pool, const char* blob_filename) {
const BlobError err = writer_.WriteAll(pool, blob_filename);
if (err != 0) {
fprintf(stderr, "Failed to write blobs to %s (error %d)\n", blob_filename,
err);
}
}
private:
CompressWorkingSet work_;
hwy::ThreadPool& pool_;
BlobWriter writer_;
};
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // NOLINT

215
compression/compress.h Normal file
View File

@ -0,0 +1,215 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Target-independent definitions.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
#define COMPRESS_STATS 0
#include <stddef.h>
#include <stdio.h>
#include <array>
#include <string>
#include <vector>
// IWYU pragma: begin_exports
// copybara:import_next_line:gemma_cpp
#include "compression/blob_store.h"
// copybara:import_next_line:gemma_cpp
#include "compression/nuq.h"
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h"
// IWYU pragma: end_exports
// copybara:import_next_line:gemma_cpp
#include "compression/distortion.h"
#include "hwy/base.h" // hwy::bfloat16_t
#include "hwy/contrib/thread_pool/thread_pool.h"
#if COMPRESS_STATS
// copybara:import_next_line:gemma_cpp
#include "compression/stats.h"
#endif
namespace gcpp {
static inline const char* TypeName(float) { return "f32"; }
static inline const char* TypeName(hwy::bfloat16_t) { return "b16"; }
namespace detail {
// How many MatT are required to store `capacity` weights. For all but
// NuqStream, this is the same as `capacity`. For use by CompressedArray.
template <typename MatT>
constexpr size_t CompressedArrayLen(size_t capacity) {
return capacity;
}
template <>
constexpr size_t CompressedArrayLen<NuqStream>(size_t capacity) {
return NuqStream::PackedEnd(capacity);
}
} // namespace detail
// Compressed representation of floating-point elements. The array length may
// differ from the number of elements. Associated operations such as Dot are
// implemented in SIMD code and are thus non-member functions.
template <typename MatT, size_t kCapacity>
class CompressedArray {
static constexpr size_t NumCompressed() {
return detail::CompressedArrayLen<MatT>(kCapacity);
}
public:
MatT* data() { return data_.data(); }
const MatT* data() const { return data_.data(); }
constexpr size_t NumElements() const { return kCapacity; }
constexpr size_t CompressedSize() const {
return NumCompressed() * sizeof(MatT);
}
private:
std::array<MatT, NumCompressed()> data_;
};
#if COMPRESS_STATS
class CompressStats {
public:
void Notify(const DistortionStats& stats) {
const float pnorm = stats.PNorm();
const float snr = stats.GeomeanValueDivL1();
num_exact_ += stats.NumExact();
s_pnorm_.Notify(pnorm);
// No loss - skip to avoid dragging down the average.
if (snr != 0.0f) {
s_snr_.Notify(snr);
}
}
void NotifyIn(int sfp) { hist_weights_.Notify(sfp); }
void Assimilate(const CompressStats& other) {
s_pnorm_.Assimilate(other.s_pnorm_);
s_snr_.Assimilate(other.s_snr_);
num_exact_ += other.num_exact_;
hist_weights_.Assimilate(other.hist_weights_);
}
void PrintAll() {
const int skip = Stats::kNoGeomean;
fprintf(stderr, " pnorm %s\n", s_pnorm_.ToString(skip).c_str());
fprintf(stderr, " SNR %s\n", s_snr_.ToString(skip).c_str());
fprintf(stderr, " #exact %.3E\n", static_cast<double>(num_exact_));
// hist_weights_.Print("indices");
}
void Reset() {
s_pnorm_.Reset();
s_snr_.Reset();
num_exact_ = 0;
hist_weights_.Reset();
}
private:
Stats s_pnorm_;
Stats s_snr_;
size_t num_exact_ = 0;
Bins<1000> hist_weights_;
char padding_[64]; // prevent false sharing
};
#else
struct CompressStats {
void Notify(const DistortionStats&) {}
void NotifyIn(int) {}
void Assimilate(const CompressStats&) {}
void PrintAll() {}
void Reset() {}
};
#endif // COMPRESS_STATS
struct CompressPerThread {
CompressStats stats;
ClusterBuf buf;
};
struct CompressWorkingSet {
std::vector<CompressPerThread> tls;
};
// Returns key for the given tensor name. Also encodes the type, so that
// changing the representation automatically invalidates prior cached files
// (the new blob name will not be found).
template <typename MatT>
hwy::uint128_t CacheKey(const char* name) {
// Already used/retired: s, S, n, 1
const char prefix = hwy::IsSame<MatT, float>() ? 'F'
: hwy::IsSame<MatT, hwy::bfloat16_t>() ? 'B'
: hwy::IsSame<MatT, SfpStream>() ? '$'
: hwy::IsSame<MatT, NuqStream>() ? '2'
: '?';
return MakeKey((std::string(1, prefix) + name).c_str());
}
class CacheLoader {
public:
explicit CacheLoader(const char* blob_filename) {
err_ = reader_.Open(blob_filename);
if (err_ != 0) {
fprintf(stderr,
"Cached compressed weights does not exist yet (code %d), "
"compressing weights and creating file: %s.\n",
err_, blob_filename);
}
}
// Called for each tensor, enqueues read requests.
template <typename MatT, size_t kCapacity>
void operator()(const char* name, const float* null,
CompressedArray<MatT, kCapacity>& compressed) {
HWY_DASSERT(null == nullptr);
// Skip if reader_ is invalid or any load failed: we will regenerate
// everything because it's rare to update only a few tensors.
if (err_ != 0) return;
err_ = reader_.Enqueue(CacheKey<MatT>(name), compressed.data(),
compressed.CompressedSize());
if (err_ != 0) {
fprintf(stderr, "Failed to read cache %s (error %d)\n", name, err_);
}
}
// Returns whether all tensors are successfully loaded from cache.
bool ReadAll(hwy::ThreadPool& pool) {
// reader_ invalid or any Enqueue failed
if (err_ != 0) return false;
err_ = reader_.ReadAll(pool);
if (err_ != 0) {
fprintf(stderr, "Failed to read all tensors (error %d)\n", err_);
return false;
}
return true;
}
private:
BlobReader reader_;
BlobError err_ = 0;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_

99
compression/distortion.h Normal file
View File

@ -0,0 +1,99 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_
#include <math.h> // pow
#include <stddef.h>
#include "hwy/base.h" // ScalarAbs
namespace gcpp {
class DistortionStats {
public:
void Notify(float original, float distorted) {
const double l1 = hwy::ScalarAbs(original - distorted);
if (l1 > max_l1_) {
max_l1_ = l1;
max_idx_ = n_;
}
const double pow3 = l1 * l1 * l1;
sum_pow3_ += pow3;
sum_pow6_ += pow3 * pow3;
n_ += 1;
// Avoid division by zero, which happens when there is no error. NumExact()
// reports the number of times this happens.
if (l1 != 0.0) {
const double rel = 1.0 + hwy::ScalarAbs(original) / l1;
// Logarithm is required to prevent overflow. A hierarchical geomean
// could also work, but that is more complex and not necessarily better.
sum_log_rel_ += log(rel);
num_rel_ += 1;
}
}
void Assimilate(const DistortionStats& other) {
if (other.max_l1_ > max_l1_) {
max_l1_ = other.max_l1_;
max_idx_ = other.max_idx_;
}
sum_pow3_ += other.sum_pow3_;
sum_pow6_ += other.sum_pow6_;
n_ += other.n_;
sum_log_rel_ += other.sum_log_rel_;
num_rel_ += other.num_rel_;
}
size_t NumExact() const { return n_ - num_rel_; }
double GeomeanValueDivL1() const {
if (num_rel_ == 0) return 0.0;
return exp(sum_log_rel_ / num_rel_);
}
double PNorm() const {
// p-norms are a compromise between max-norm (penalizes the largest error
// without dilution, but does not notice any other errors) and L1 (all
// errors contribute, but large errors are diluted by smaller ones).
const double norm3 = pow(sum_pow3_ / n_, 1.0 / 3);
const double norm6 = pow(sum_pow6_ / n_, 1.0 / 6);
return 0.5 * (norm3 + norm6);
}
size_t MaxIndex() const { return max_idx_; }
double MaxL1() const { return max_l1_; }
private:
size_t n_ = 0;
size_t max_idx_ = 0; // index that had l1 = max_l1_.
double max_l1_ = -1.0;
double sum_pow3_ = 0.0;
double sum_pow6_ = 0.0;
double sum_log_rel_ = 0.0;
size_t num_rel_ = 0;
double padding_; // prevents false sharing
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_

730
compression/nuq-inl.h Normal file
View File

@ -0,0 +1,730 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Normal include guard.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_
#include <stddef.h>
#include <stdint.h>
// copybara:import_next_line:gemma_cpp
#include "compression/nuq.h"
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h"
#include "hwy/base.h"
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_
// Actual per-target include guard.
#if defined(THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE) == \
defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE
#endif
// copybara:import_next_line:gemma_cpp
#include "compression/sfp-inl.h"
#include "hwy/contrib/sort/vqsort-inl.h"
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
// For internal use by NuqCodec.
class NuqClustering {
// To go from sorted order back to the original order in O(1), we store the
// original index in the lower bits of the float32 mantissa, which means they
// are sorted alongside the value.
struct FloatPayload {
// Resets payload to zero; useful for displaying the actual value.
static HWY_INLINE float Clear(float f) {
const uint32_t binary32 = hwy::BitCastScalar<uint32_t>(f);
return hwy::BitCastScalar<float>(binary32 &
~static_cast<uint32_t>(kGroupSize - 1));
}
// Sets payload to `bits`.
static HWY_INLINE float Set(float f, size_t bits) {
HWY_DASSERT(bits < kGroupSize);
const uint32_t binary32 = hwy::BitCastScalar<uint32_t>(Clear(f));
return hwy::BitCastScalar<float>(static_cast<uint32_t>(binary32 | bits));
}
// Obtains the payload (index) previously set by `Set`.
static HWY_INLINE size_t Get(float f) {
return hwy::BitCastScalar<uint32_t>(f) &
static_cast<uint32_t>(kGroupSize - 1);
}
};
// Cumulative sums for O(1) mean and interval sums.
class ClusterCost {
public:
explicit ClusterCost(const float* sorted) {
cumsum_[0] = cumsum2_[0] = 0.0;
for (size_t i = 0; i < kGroupSize; ++i) {
const float x = FloatPayload::Clear(sorted[i]);
cumsum_[1 + i] = x + cumsum_[i];
cumsum2_[1 + i] = x * x + cumsum2_[i];
}
inv_len_[0] = 0.0f; // unused
for (size_t i = 0; i <= kGroupSize; ++i) {
inv_len_[i] = 1.0f / i;
}
}
float SumOfSorted(size_t first, size_t last) const {
return cumsum_[last + 1] - cumsum_[first];
}
// Returns cost of clustering first..last with their mean, for a vector of
// last. O(1) thanks to cumulative sums, which works for Lp-norms with p >
// 1; we choose p=2 for simplicity (fewer terms).
template <class DF>
hn::Vec<DF> operator()(DF df, size_t first, size_t last) const {
// Callers are responsible for ignoring lanes where last < first.
HWY_DASSERT(first < kGroupSize);
HWY_DASSERT(last < kGroupSize);
const size_t len = last - first + 1;
const hn::Vec<DF> vlen =
hn::Iota(df, static_cast<float>(static_cast<int>(len)));
const hn::Vec<DF> u_lo = hn::Set(df, cumsum_[first]);
const hn::Vec<DF> u_lo2 = hn::Set(df, cumsum2_[first]);
const hn::Vec<DF> hi = hn::LoadU(df, cumsum_ + last + 1);
const hn::Vec<DF> hi2 = hn::LoadU(df, cumsum2_ + last + 1);
const hn::Vec<DF> sum = hn::Sub(hi, u_lo);
const hn::Vec<DF> sum2 = hn::Sub(hi2, u_lo2);
// Compute mean: table lookup is faster than division.
const hn::Vec<DF> mu = hn::Mul(sum, hn::LoadU(df, inv_len_ + len));
// (x - mu)^2 = sum2 - 2mu*sum + mu^2
const hn::Vec<DF> mu2 = hn::Mul(mu, mu);
const hn::Vec<DF> two_mu = hn::Add(mu, mu);
return hn::NegMulAdd(two_mu, sum, hn::MulAdd(vlen, mu2, sum2));
}
private:
// Float has enough precision for our relatively small kGroupSize (128).
float cumsum_[kGroupSize + 1];
float cumsum2_[kGroupSize + 1];
float inv_len_[kGroupSize + 1];
};
// Cost of clustering 0..last, where the rightmost cluster is j..last. This is
// called in a loop over j, and we return the vector of costs for a batch of
// last = [last, last + N).
template <class DF>
static HWY_INLINE hn::Vec<DF> ClusterDynProg(
DF df, const AlignedMatrix<float>& D, const ClusterCost& cc,
const size_t num_clusters, const size_t last, const size_t j) {
HWY_DASSERT(last < kGroupSize);
HWY_DASSERT(0 != j && j < kGroupSize);
const hn::RebindToSigned<decltype(df)> di;
using VF = hn::Vec<decltype(df)>;
using VI = hn::Vec<decltype(di)>;
using MI = hn::Mask<decltype(di)>;
const VI vlast = hn::Iota(di, static_cast<int32_t>(last));
// We have a non-empty rightmost cluster if j <= last <==> j-1 < last.
const MI valid = hn::Lt(hn::Set(di, static_cast<int32_t>(j) - 1), vlast);
// If not valid, return an arbitrary high cost, which will not be the min.
const VF max = hn::Set(df, 1E38f);
// Cost of clustering 0..j-1 with one fewer cluster than now.
const VF vd = hn::Set(df, D(num_clusters - 1, j - 1));
// Eq2: add to that the cost of another cluster from j..last.
return hn::MaskedAddOr(max, RebindMask(df, valid), vd, cc(df, j, last));
}
public:
// Clusters `kGroupSize` values in `x`, which need not be sorted already nor
// aligned, by choosing and filling `centers` (size `kClusters`, ascending
// order, not necessarily equal to one of the `x`). Fills `indices` with the
// index of the cluster to which each `x` belongs (16-bit for bit-packing).
// `buf` is per-thread.
//
// Returns the number of unused clusters, i.e., the starting index within
// `centers`; prior centers are zero-initialized.
//
// O(kClusters * kGroupSize * kGroupSize), but the constant factors are so low
// that this is about 10 times as fast as the O(kClusters * kGroupSize) SMAWK
// as implemented in FAISS, for our kGroupSize <= 128.
template <class DF>
static HWY_NOINLINE size_t ClusterExactL2(DF df, const float* x,
ClusterBuf& buf,
float* HWY_RESTRICT centers,
uint16_t* HWY_RESTRICT indices) {
const hn::RebindToSigned<decltype(df)> di;
using VF = hn::Vec<decltype(df)>;
using VI = hn::Vec<decltype(di)>;
const VI k1 = hn::Set(di, 1);
const size_t N = hn::Lanes(df);
HWY_ALIGN float sorted_and_i[kGroupSize];
for (size_t i = 0; i < kGroupSize; ++i) {
sorted_and_i[i] = FloatPayload::Set(x[i], i);
}
hn::VQSortStatic(sorted_and_i, kGroupSize, hwy::SortAscending());
ClusterCost cc(sorted_and_i);
// Reference: https://arxiv.org/abs/1701.07204
// D[k-1][m] is the lowest cost of clustering x1..m into k clusters.
AlignedMatrix<float>& D = buf.d;
// T[k][m] is the starting index within sorted_and_i[] of the k-th cluster.
AlignedMatrix<int32_t>& T = buf.t;
// Initialize the first rows for a single cluster.
for (size_t last = 0; last < kGroupSize; last += N) {
hn::Store(cc(df, 0, last), df, &D(0, last)); // Cost of 0..last
hn::Store(Zero(di), di, &T(0, last)); // Cluster index = 0
}
for (size_t num_clusters = 1; num_clusters < kClusters; ++num_clusters) {
// For each batch starting at `last`, one per lane:
for (size_t last = 0; last < kGroupSize; last += N) {
VF min = cc(df, 0, last);
VI arg = hn::Zero(di);
// For each j (start of rightmost cluster):
VI vj = k1;
for (size_t j = 1; j < last + N; ++j, vj = Add(vj, k1)) {
const VF c = ClusterDynProg(df, D, cc, num_clusters, last, j);
// Retain the min cost and the j index that caused it.
const auto less = hn::Lt(c, min);
min = hn::IfThenElse(less, c, min);
arg = hn::IfThenElse(RebindMask(di, less), vj, arg);
}
hn::Store(min, df, &D(num_clusters, last));
hn::Store(arg, di, &T(num_clusters, last));
}
}
// Backtrack to find centers. Clusters are [T(k, last), last].
size_t last = kGroupSize - 1;
size_t unused_clusters = 0;
for (size_t k = kClusters - 1; k < kClusters; --k) {
const size_t start = static_cast<size_t>(T(k, last));
// Center = mean, O(1) thanks to cumulative sums.
const float sum = cc.SumOfSorted(start, last);
const int size = static_cast<int>(last) - static_cast<int>(start) + 1;
HWY_DASSERT(0 < size && size <= kGroupSize);
centers[k] = sum / size;
// We know the range inside sorted_and_i[]; translate to original indices,
// which are stored inside each of the sorted_and_i mantissas.
for (size_t i = start; i <= last; ++i) {
const size_t idx_x = FloatPayload::Get(sorted_and_i[i]);
HWY_DASSERT(idx_x < kGroupSize);
indices[idx_x] = static_cast<uint16_t>(k);
}
// Not using all clusters. Avoid out of bounds accesses by stopping early.
if (start == 0) {
unused_clusters = k;
for (size_t cluster = 0; cluster < unused_clusters; ++cluster) {
centers[cluster] = 0.0f;
}
break;
}
last = start - 1;
HWY_DASSERT(last < kGroupSize);
}
if (HWY_IS_DEBUG_BUILD) {
// Centers are in ascending order.
for (size_t i = unused_clusters + 1; i < kClusters; ++i) {
HWY_DASSERT(centers[i] >= centers[i - 1]);
}
}
return unused_clusters;
}
}; // NuqClustering
// Bit-packing 4-bit values is trivial if we have 2 or 4 independent vectors:
// simply shift+OR them together into a full vector of 8 or 16-bit lanes.
// However, the order then depends on the vector length, which is unacceptable
// because we may store the encoding to disk and decode on another CPU.
//
// The dependency on vector length could be removed by introducing fixed-size
// packets and loading the next vector from a fixed offset known to be at
// least the vector length. However, this may require packets that are larger
// than the seek granularity of the application (e.g. matrix rows).
//
// We instead choose a continuous stream layout, which seems to entail the
// nibbles being stored and decoded in-order. This involves nontrivial shuffle
// operations which benefit from special-casing for target and vector length.
class NibbleCodec {
public:
// Packs four u16 vectors' lanes to nibbles within one vector, in order, and
// stores that vector to `out`.
template <class D16, class V16 = hn::Vec<D16>>
static HWY_INLINE void OrderedPackU16(D16 d16, V16 in0, V16 in1, V16 in2,
V16 in3, uint8_t* HWY_RESTRICT out) {
const hn::Repartition<uint8_t, D16> d8;
const hn::Repartition<uint32_t, D16> d32;
const hn::Repartition<uint64_t, D16> d64;
using V8 = hn::Vec<decltype(d8)>;
// Pairwise compaction of a single vector so nibbles are packed in-order.
// v16 lanes hold a 4-bit value; OR together adjacent pairs into the lower
// byte of *even* u16.
const auto combine_u16_pair_to_8 = [d16, d32](V16 v16) HWY_ATTR {
return hn::Xor(
v16, hn::BitCast(d16, hn::ShiftRight<12>(hn::BitCast(d32, v16))));
};
const V16 u8_0 = combine_u16_pair_to_8(in0);
const V16 u8_1 = combine_u16_pair_to_8(in1);
const V16 u8_2 = combine_u16_pair_to_8(in2);
const V16 u8_3 = combine_u16_pair_to_8(in3);
V8 packed;
if (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) {
// 8-bit ConcatEven is efficient. Let digits denote eight u8 lanes
// of u8_1/0: ?d?3 ?c?2 / ?b?1 ?a?0. 8-bit ConcatEven = d3c2 b1a0, and
// again with the second x2_1 gives 7654 3210.
const V8 x2_0 = hn::ConcatEven(d8, BitCast(d8, u8_1), BitCast(d8, u8_0));
const V8 x2_1 = hn::ConcatEven(d8, BitCast(d8, u8_3), BitCast(d8, u8_2));
packed = hn::ConcatEven(d8, x2_1, x2_0);
} else {
// To avoid expensive 8-bit ConcatEven, compact pairs of u32 into the
// lower 16 bits in each u64, with other bits undefined.
const auto combine_u32_pair_to_16 = [d16, d64](V16 v16) HWY_ATTR {
return hn::Xor(
v16, hn::BitCast(d16, hn::ShiftRight<24>(hn::BitCast(d64, v16))));
};
const V16 u16_0 = combine_u32_pair_to_16(u8_0);
const V16 u16_1 = combine_u32_pair_to_16(u8_1);
const V16 u16_2 = combine_u32_pair_to_16(u8_2);
const V16 u16_3 = combine_u32_pair_to_16(u8_3);
// In-order compaction of four vectors into one, keeping only the low
// u16 of every u64. This is the same as above but with 16-bit Concat.
const V16 x2_0 = hn::ConcatEven(d16, u16_1, u16_0);
const V16 x2_1 = hn::ConcatEven(d16, u16_3, u16_2);
packed = hn::BitCast(d8, hn::ConcatEven(d16, x2_1, x2_0));
}
hn::StoreU(packed, d8, out);
}
// Unpacks `Lanes(d16)` nibbles to u16 lanes. The first comes from the low
// nibble of packed[0], then its high nibble, then the next low nibble, etc.
template <class D16, class V16 = hn::Vec<D16>>
static HWY_INLINE V16 OrderedUnpackU16(D16 d16, const uint8_t* packed) {
const hn::Repartition<uint8_t, D16> d8;
using V8 = hn::Vec<decltype(d8)>;
const hn::CappedTag<uint8_t, d16.MaxBytes() / 4> d_load;
// We replicate each byte 4x, so that its two nibbles propagate to both
// u16 lanes that they will initialize. The only performance-portable op to
// replicate bytes is TableLookupBytes, which shuffles 128-bit blocks
// independently. Thus each block receives 4 packed bytes, replicates them
// 4x, shifts/masks, and casts to 8 u16 lanes.
//
// Loading 16 bytes via LoadDup128 only works on AVX3; for smaller vectors,
// it may trigger asan errors from overrunning the end. We thus special-case
// vector lengths, handling any non-constexpr, and constexpr <= 512 bit.
V8 rep4;
if (HWY_HAVE_SCALABLE) {
// Non constexpr length: 4 per whole block equals size/4.
const size_t num_bytes = HWY_MAX(1, hn::Lanes(d8) / 4);
const V8 bytes = hn::LoadN(d8, packed, num_bytes);
// Replicate bytes 4x: lowest 4 = 0, next 4 = 1 etc.
const V8 idx = hn::And(hn::Iota(d8, 0), hn::Set(d8, 0xFCu));
rep4 = hn::TableLookupBytes(bytes, idx);
} else if (hn::MaxLanes(d16) <= 8) { // <= 128-bit
const V8 bytes = hn::ResizeBitCast(d8, hn::LoadU(d_load, packed));
alignas(16) static constexpr uint8_t kRep4[16] = {
HWY_REP4(0), HWY_REP4(1), HWY_REP4(2), HWY_REP4(3)};
rep4 = hn::TableLookupBytes(bytes, hn::Load(d8, kRep4));
} else if (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) {
// Plain load, can do 256..512-bit permute across blocks.
const V8 bytes = hn::ResizeBitCast(d8, hn::LoadU(d_load, packed));
alignas(64) static constexpr uint8_t kRep4[64] = {
HWY_REP4(0), HWY_REP4(1), HWY_REP4(2), HWY_REP4(3),
HWY_REP4(4), HWY_REP4(5), HWY_REP4(6), HWY_REP4(7),
HWY_REP4(8), HWY_REP4(9), HWY_REP4(10), HWY_REP4(11),
HWY_REP4(12), HWY_REP4(13), HWY_REP4(14), HWY_REP4(15)};
rep4 = hn::TableLookupLanes(bytes, hn::SetTableIndices(d8, kRep4));
} else if (hn::MaxLanes(d16) == 16) { // 256-bit
const V8 bytes = hn::ResizeBitCast(d8, hn::LoadU(d_load, packed));
// First copy to upper block for TableLookupBytes. This is slightly
// faster than 64-bit BroadcastLane.
const V8 bcast = hn::ConcatLowerLower(d8, bytes, bytes);
alignas(32) static constexpr uint8_t kRep4[32] = {
HWY_REP4(0), HWY_REP4(1), HWY_REP4(2), HWY_REP4(3),
HWY_REP4(4), HWY_REP4(5), HWY_REP4(6), HWY_REP4(7)};
rep4 = hn::TableLookupBytes(bcast, hn::Load(d8, kRep4));
} else if (hn::MaxLanes(d16) == 32) { // 512-bit
const V8 bytes = hn::LoadDup128(d8, packed);
alignas(64) static constexpr uint8_t kRep4[64] = {
HWY_REP4(0), HWY_REP4(1), HWY_REP4(2), HWY_REP4(3),
HWY_REP4(4), HWY_REP4(5), HWY_REP4(6), HWY_REP4(7),
HWY_REP4(8), HWY_REP4(9), HWY_REP4(10), HWY_REP4(11),
HWY_REP4(12), HWY_REP4(13), HWY_REP4(14), HWY_REP4(15)};
rep4 = hn::TableLookupBytes(bytes, hn::Load(d8, kRep4));
} else {
HWY_DASSERT(false);
}
const V16 mask4 = hn::Set(d16, 0xF);
const V16 u16 = BitCast(d16, rep4);
// In-order unpack. Right-shift odd u16 by 4. Example with two packed
// bytes, one digit representing a nibble:
// 32 32 32 32 | 10 10 10 10 u16
// z3 23 32 32 | z1 01 10 10 OddEven+ShiftRight
// zz z3 zz z2 | zz z1 zz z0 And (unpacked result)
return hn::And(mask4, hn::OddEven(hn::ShiftRight<4>(u16), u16));
}
};
// Encode/decode functions.
class NuqCodec {
// 256-bit vectors can hold 16 bf16, otherwise we require 2x128-bit.
template <class DU>
static constexpr size_t NumTables(DU du) {
return (!HWY_HAVE_SCALABLE && du.MaxBytes() >= 32) ? 1 : 2;
}
// Unpacks `centers` from SFP into bf16 and loads them into one or two vectors
// for use by [Two]TableLookups. Returns as u16 because TableLookupLanes might
// not be available for bf16.
template <class DU, HWY_IF_U16_D(DU)>
static HWY_INLINE hn::Vec<DU> LoadTable(DU du, const uint8_t* centers,
hn::Vec<DU>* HWY_RESTRICT tbl1) {
// Cap to the table size (kClusters) for decoding SFP - sufficient, and may
// be faster than a large vector.
const hn::CappedTag<hwy::bfloat16_t, kClusters> d_table;
// We ResizeCast tables to DU: if DU is bigger, table lookups will only
// access lanes < kClusters. If DU is smaller (128-bit), we have 2 tables.
HWY_DASSERT(hn::Lanes(du) >= hn::Lanes(d_table) || NumTables(du) == 2);
HWY_ALIGN hwy::bfloat16_t table[kClusters];
SfpCodec::Dec(d_table, reinterpret_cast<const SfpStream*>(centers),
kClusters, table);
// If we assume >= 128-bit vectors, we can use [Two]TableLookupLanes
// instead of TableLookupBytes, which requires extra interleaving of lo/hi.
HWY_DASSERT(hn::Lanes(du) >= 8);
if (NumTables(du) == 2) {
// Reduce cap for second half to avoid loading past the end of the table.
const hn::CappedTag<hwy::bfloat16_t, kClusters / 2> d_table2;
*tbl1 = hn::ResizeBitCast(du, hn::LoadU(d_table2, table + kClusters / 2));
}
return hn::ResizeBitCast(du, hn::Load(d_table, table));
}
// Unpacks per-weight indices and sets c0/c1 to the corresponding centers.
template <class DU>
static HWY_INLINE void TableLookups(DU du, hn::Vec<DU> tbl0, hn::Vec<DU> tbl1,
const uint8_t* packed, hn::Vec<DU>& c0,
hn::Vec<DU>& c1) {
using V16 = hn::Vec<decltype(du)>;
const size_t N16 = hn::Lanes(du);
const V16 idx0 = NibbleCodec::OrderedUnpackU16(du, packed);
const V16 idx1 = NibbleCodec::OrderedUnpackU16(du, packed + N16 / 2);
const auto indices0 = hn::IndicesFromVec(du, idx0);
const auto indices1 = hn::IndicesFromVec(du, idx1);
if (NumTables(du) == 1) {
(void)tbl1;
c0 = hn::TableLookupLanes(tbl0, indices0);
c1 = hn::TableLookupLanes(tbl0, indices1);
} else {
c0 = hn::TwoTablesLookupLanes(du, tbl0, tbl1, indices0);
c1 = hn::TwoTablesLookupLanes(du, tbl0, tbl1, indices1);
}
}
public:
// Encodes `num` floats starting from `in`. `out` points to compressed
// storage for `out_capacity` values and `out_ofs` indicates the destination
// offset within it, in units of float values, for parallel encoding by
// multiple threads. `num`, `out_capacity`, and `out_ofs` must all be
// multiples of `kGroupSize`. Returns the total number of unused clusters,
// which is expected to be zero.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE size_t Enc(DF df, const float* const in, const size_t num,
ClusterBuf& buf, const size_t out_capacity,
NuqStream* const out, const size_t out_ofs) {
const hn::Repartition<uint8_t, DF> d8;
const hn::Repartition<uint16_t, DF> d16;
using V8 = hn::Vec<decltype(d8)>;
using V16 = hn::Vec<decltype(d16)>;
const size_t N16 = hn::Lanes(d16);
HWY_ASSERT(kGroupSize >= 4 * N16);
HWY_ASSERT(out_ofs + num <= out_capacity);
buf.Resize(num);
HWY_ASSERT(num % kGroupSize == 0);
HWY_ASSERT(out_capacity % kGroupSize == 0);
HWY_ASSERT(out_ofs % kGroupSize == 0);
const size_t num_groups = num / kGroupSize;
const size_t ofs_groups = out_ofs / kGroupSize;
size_t unused_clusters = 0;
for (size_t g = 0; g < num_groups; ++g) {
const float* HWY_RESTRICT g_in = in + g * kGroupSize;
float* HWY_RESTRICT g_centers = buf.centers.get() + g * kClusters;
uint16_t* HWY_RESTRICT g_idx = buf.idx.get() + g * kGroupSize;
unused_clusters +=
NuqClustering::ClusterExactL2(df, g_in, buf, g_centers, g_idx);
}
uint8_t* centers = &out->byte + ofs_groups * kClusters;
SfpCodec::Enc(df, buf.centers.get(), num_groups * kClusters,
reinterpret_cast<SfpStream*>(centers));
uint8_t* packed_start = &out->byte + NuqStream::PackedStart(out_capacity) +
ofs_groups * kGroupSize / 2;
HWY_UNROLL(1)
for (size_t g = 0; g < num_groups; ++g) {
const uint16_t* HWY_RESTRICT g_idx = buf.idx.get() + g * kGroupSize;
uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2;
HWY_UNROLL(1)
for (size_t i = 0; i < kGroupSize; i += 4 * N16) {
const V16 idx0 = hn::LoadU(d16, g_idx + i + N16 * 0);
const V16 idx1 = hn::LoadU(d16, g_idx + i + N16 * 1);
const V16 idx2 = hn::LoadU(d16, g_idx + i + N16 * 2);
const V16 idx3 = hn::LoadU(d16, g_idx + i + N16 * 3);
NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3,
g_packed + i / 2);
}
}
return unused_clusters;
}
// Decodes `num` values from the stream `in`, starting at the offset `in_ofs`
// (in units of values), to bf16 in `out`. `in_capacity`, `in_ofs` and `num`
// must all be multiples of `kGroupSize`.
template <class DF, HWY_IF_BF16_D(DF)>
static HWY_INLINE void Dec(DF dbf, const size_t in_capacity,
const NuqStream* const in, const size_t in_ofs,
hwy::bfloat16_t* const out, const size_t num) {
const hn::RebindToUnsigned<decltype(dbf)> d16;
using V16 = hn::Vec<decltype(d16)>;
const size_t N16 = hn::Lanes(d16);
HWY_DASSERT(kGroupSize >= 4 * N16);
HWY_DASSERT(in_ofs + num <= in_capacity);
HWY_DASSERT(in_capacity % kGroupSize == 0);
HWY_DASSERT(in_ofs % kGroupSize == 0);
HWY_DASSERT(num % kGroupSize == 0);
const size_t num_groups = num / kGroupSize;
const size_t ofs_groups = in_ofs / kGroupSize;
const uint8_t* tables = &in->byte + ofs_groups * kClusters;
const uint8_t* packed_start = &in->byte +
NuqStream::PackedStart(in_capacity) +
ofs_groups * kGroupSize / 2;
HWY_UNROLL(1)
for (size_t g = 0; g < num_groups; ++g) {
const uint8_t* g_centers = tables + g * kClusters;
const uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2;
hwy::bfloat16_t* HWY_RESTRICT g_out = out + g * kGroupSize;
V16 tbl1 = Zero(d16);
const V16 tbl0 = LoadTable(d16, g_centers, &tbl1);
HWY_UNROLL(1)
for (size_t i = 0; i < kGroupSize; i += 2 * N16) {
V16 c0, c1;
TableLookups(d16, tbl0, tbl1, g_packed + i / 2, c0, c1);
hn::StoreU(BitCast(dbf, c0), dbf, g_out + i + N16 * 0);
hn::StoreU(BitCast(dbf, c1), dbf, g_out + i + N16 * 1);
}
}
}
// Decodes `num` values from the stream `in`, starting at the offset
// `in_ofs` (in units of values), to f32 in `out`. `in_capacity`,
// `in_ofs` and `num` must all be multiples of `kGroupSize`.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Dec(DF df, const size_t in_capacity,
const NuqStream* const in, const size_t in_ofs,
float* const out, const size_t num) {
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
const hn::RebindToUnsigned<decltype(dbf)> d16;
using V16 = hn::Vec<decltype(d16)>;
using VF = hn::Vec<DF>;
const size_t NF = hn::Lanes(df);
HWY_DASSERT(kGroupSize >= 4 * NF);
HWY_DASSERT(in_ofs + num <= in_capacity);
HWY_DASSERT(in_capacity % kGroupSize == 0);
HWY_DASSERT(in_ofs % kGroupSize == 0);
HWY_DASSERT(num % kGroupSize == 0);
const size_t ofs_groups = in_ofs / kGroupSize;
const size_t num_groups = num / kGroupSize;
const uint8_t* tables = &in->byte + ofs_groups * kClusters;
const uint8_t* packed_start = &in->byte +
NuqStream::PackedStart(in_capacity) +
ofs_groups * kGroupSize / 2;
HWY_UNROLL(1)
for (size_t g = 0; g < num_groups; ++g) {
const uint8_t* g_centers = tables + g * kClusters;
const uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2;
float* HWY_RESTRICT g_out = out + g * kGroupSize;
V16 tbl1 = Zero(d16);
const V16 tbl0 = LoadTable(d16, g_centers, &tbl1);
HWY_UNROLL(1)
for (size_t i = 0; i < kGroupSize; i += 4 * NF) {
V16 c0, c1;
TableLookups(d16, tbl0, tbl1, g_packed + i / 2, c0, c1);
const VF f0 = hn::PromoteLowerTo(df, BitCast(dbf, c0));
const VF f1 = hn::PromoteUpperTo(df, BitCast(dbf, c0));
const VF f2 = hn::PromoteLowerTo(df, BitCast(dbf, c1));
const VF f3 = hn::PromoteUpperTo(df, BitCast(dbf, c1));
hn::StoreU(f0, df, g_out + i + NF * 0);
hn::StoreU(f1, df, g_out + i + NF * 1);
hn::StoreU(f2, df, g_out + i + NF * 2);
hn::StoreU(f3, df, g_out + i + NF * 3);
}
}
}
// Accumulates into `sum0..3` dot products of decoded values with `num` bf16
// from `vec_aligned`. DF is f32 because sum0..3 are also f32. `in_capacity`,
// `in_ofs` and `num` must all be multiples of `kGroupSize`.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Dot(DF df, const size_t in_capacity,
const NuqStream* const in, const size_t in_ofs,
const hwy::bfloat16_t* const vec_aligned,
const size_t num, hn::Vec<DF>& sum0,
hn::Vec<DF>& sum1, hn::Vec<DF>& sum2,
hn::Vec<DF>& sum3) {
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
const hn::RebindToUnsigned<decltype(dbf)> d16;
using VBF = hn::Vec<decltype(dbf)>;
using V16 = hn::Vec<decltype(d16)>;
const size_t N16 = hn::Lanes(d16);
HWY_DASSERT(kGroupSize >= 4 * N16);
HWY_DASSERT(in_ofs + num <= in_capacity);
HWY_DASSERT(in_capacity % kGroupSize == 0);
HWY_DASSERT(in_ofs % kGroupSize == 0);
HWY_DASSERT(num % kGroupSize == 0);
const size_t ofs_groups = in_ofs / kGroupSize;
const size_t num_groups = num / kGroupSize;
const uint8_t* tables = &in->byte + ofs_groups * kClusters;
const uint8_t* packed_start = &in->byte +
NuqStream::PackedStart(in_capacity) +
ofs_groups * kGroupSize / 2;
HWY_UNROLL(1)
for (size_t g = 0; g < num_groups; ++g) {
const uint8_t* g_centers = tables + g * kClusters;
const uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2;
const hwy::bfloat16_t* HWY_RESTRICT g_in = vec_aligned + g * kGroupSize;
V16 tbl1 = Zero(d16);
const V16 tbl0 = LoadTable(d16, g_centers, &tbl1);
HWY_UNROLL(1)
for (size_t i = 0; i < kGroupSize; i += 2 * N16) {
V16 c0, c1;
TableLookups(d16, tbl0, tbl1, g_packed + i / 2, c0, c1);
const VBF in0 = hn::Load(dbf, g_in + i + N16 * 0);
const VBF in1 = hn::Load(dbf, g_in + i + N16 * 1);
sum0 = hn::ReorderWidenMulAccumulate(df, in0, BitCast(dbf, c0), sum0,
sum1);
sum2 = hn::ReorderWidenMulAccumulate(df, in1, BitCast(dbf, c1), sum2,
sum3);
}
}
}
// Accumulates into `sum0..3` dot products of decoded values with `num` f32
// from `vec_aligned`. `in_capacity`, `in_ofs` and `num` must all be
// multiples of `kGroupSize`.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Dot(DF df, const size_t in_capacity,
const NuqStream* const in, const size_t in_ofs,
const float* const vec_aligned, const size_t num,
hn::Vec<DF>& sum0, hn::Vec<DF>& sum1,
hn::Vec<DF>& sum2, hn::Vec<DF>& sum3) {
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
const hn::RebindToUnsigned<decltype(dbf)> d16;
using VF = hn::Vec<decltype(df)>;
using V16 = hn::Vec<decltype(d16)>;
const size_t NF = hn::Lanes(df);
HWY_DASSERT(kGroupSize >= 4 * NF);
HWY_DASSERT(in_ofs + num <= in_capacity);
HWY_DASSERT(in_capacity % kGroupSize == 0);
HWY_DASSERT(in_ofs % kGroupSize == 0);
HWY_DASSERT(num % kGroupSize == 0);
const size_t ofs_groups = in_ofs / kGroupSize;
const size_t num_groups = num / kGroupSize;
const uint8_t* tables = &in->byte + ofs_groups * kClusters;
const uint8_t* packed_start = &in->byte +
NuqStream::PackedStart(in_capacity) +
ofs_groups * kGroupSize / 2;
HWY_UNROLL(1)
for (size_t g = 0; g < num_groups; ++g) {
const uint8_t* g_centers = tables + g * kClusters;
const uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2;
const float* HWY_RESTRICT g_in = vec_aligned + g * kGroupSize;
V16 tbl1 = Zero(d16);
const V16 tbl0 = LoadTable(d16, g_centers, &tbl1);
HWY_UNROLL(1)
for (size_t i = 0; i < kGroupSize; i += 4 * NF) {
V16 c0, c1;
TableLookups(d16, tbl0, tbl1, g_packed + i / 2, c0, c1);
const VF in0 = hn::LoadU(df, g_in + i + NF * 0);
const VF in1 = hn::LoadU(df, g_in + i + NF * 1);
const VF in2 = hn::LoadU(df, g_in + i + NF * 2);
const VF in3 = hn::LoadU(df, g_in + i + NF * 3);
const VF f0 = hn::PromoteLowerTo(df, BitCast(dbf, c0));
const VF f1 = hn::PromoteUpperTo(df, BitCast(dbf, c0));
const VF f2 = hn::PromoteLowerTo(df, BitCast(dbf, c1));
const VF f3 = hn::PromoteUpperTo(df, BitCast(dbf, c1));
sum0 = hn::MulAdd(in0, f0, sum0);
sum1 = hn::MulAdd(in1, f1, sum1);
sum2 = hn::MulAdd(in2, f2, sum2);
sum3 = hn::MulAdd(in3, f3, sum3);
}
}
}
}; // NuqCodec
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_

116
compression/nuq.h Normal file
View File

@ -0,0 +1,116 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_H_
// Non-uniform quantization: a compressed representation of f32 inputs that
// supports seeking at a granularity of kGroupSize, decoding to bf16/f32, and a
// fused decode/dot product with bf16/f32 vectors.
#include <stddef.h>
#include <stdint.h>
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // HWY_INLINE
namespace gcpp {
// 4-bit indices are a sweet spot in terms of quality per size.
static constexpr size_t kClusters = 16;
// Number of weights that share a table. Larger = slower encode, higher error,
// smaller size (table amortized over more weights). This is the minimum
// granularity for seeking/decoding in the stream, and must be at least four
// times the number of bf16 elements per vector.
static constexpr size_t kGroupSize = 256;
// Points to the *start* of a NUQ stream. Aligning the allocation (see
// aligned_allocator.h) may be speed up decoding but is not required.
//
// See go/streaming-weight-decode for background and design. Layout: first one
// table of kClusters entries per group, in ascending order of group index,
// then two packed indices per byte.
//
// Indices are stored in-order to enable vector-length agnostic decode, because
// streams may be persisted to disk and used by other CPUs.
//
// To enable parallel encoding and decoding, Enc/Dec have `offset` parameters
// which refer to the stream, NOT the raw from/to pointers, which point directly
// to the source/destination. Offsets are in units of values, NOT compressed
// bytes within the stream.
#pragma pack(push, 1)
struct NuqStream {
// Returns offset of packed indices from the start of the stream. This matches
// the (padded) total table size because table entries are bytes. `capacity`
// is already a multiple of `kGroupSize`.
static constexpr size_t PackedStart(size_t capacity) {
// Round up to avoid cache-line splits when loading indices. No effect on
// size as long as capacity / kGroupSize is a multiple of 4.
return hwy::RoundUpTo((capacity / kGroupSize) * kClusters, 64);
}
// Returns number of NuqStream to allocate for the stream, which matches its
// size in bytes. `capacity` is already a multiple of `kGroupSize`.
static constexpr size_t PackedEnd(size_t capacity) {
return PackedStart(capacity) + capacity / 2; // two 4-bit indices per byte.
}
uint8_t byte;
};
#pragma pack(pop)
static inline const char* TypeName(NuqStream) { return "NUQ"; }
// Storage for dynamic programming. There are two matrices; we use separate
// allocations to avoid type punning.
template <class T>
class AlignedMatrix {
public:
AlignedMatrix() : mem_(hwy::AllocateAligned<T>(kClusters * kGroupSize)) {}
HWY_INLINE const T& operator()(size_t row, size_t col) const {
return mem_[row * kGroupSize + col];
}
HWY_INLINE T& operator()(size_t row, size_t col) {
return mem_[row * kGroupSize + col];
}
private:
hwy::AlignedFreeUniquePtr<T[]> mem_;
};
// Reuse memory across calls to Enc to avoid per-call allocations.
struct ClusterBuf {
void Resize(size_t new_num) {
if (new_num < num) return;
num = new_num;
const size_t num_groups = hwy::DivCeil(num, kGroupSize);
centers = hwy::AllocateAligned<float>(num_groups * kClusters);
idx = hwy::AllocateAligned<uint16_t>(num);
}
AlignedMatrix<float> d;
AlignedMatrix<int32_t> t;
size_t num = 0;
hwy::AlignedFreeUniquePtr<float[]> centers;
hwy::AlignedFreeUniquePtr<uint16_t[]> idx;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_H_

428
compression/nuq_test.cc Normal file
View File

@ -0,0 +1,428 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <algorithm> // std::shuffle
#include <random>
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"third_party/gemma_cpp/compression/nuq_test.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep
// Other headers that include Highway must come after foreach_target.h
// copybara:import_next_line:gemma_cpp
#include "compression/distortion.h"
// copybara:import_next_line:gemma_cpp
#include "compression/nuq-inl.h"
// copybara:import_next_line:gemma_cpp
#include "compression/nuq.h"
#include "hwy/highway.h"
#include "hwy/tests/hwy_gtest.h"
#include "hwy/tests/test_util-inl.h"
#include "hwy/timer.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
// All-equal inputs: only one cluster
struct TestFlat {
template <typename T, class DF>
HWY_INLINE void operator()(T /*unused*/, DF df) {
// Run this simple test only once to save time/debug output.
if (!(HWY_ONCE && hn::Lanes(df) == hn::Lanes(hn::ScalableTag<float>()))) {
return;
}
auto in = hwy::AllocateAligned<float>(kGroupSize);
HWY_ASSERT(in);
for (size_t i = 0; i < kGroupSize; ++i) {
in[i] = 0.5f;
}
ClusterBuf buf;
float centers[kClusters];
uint16_t indices[kGroupSize];
const size_t unused_clusters =
NuqClustering::ClusterExactL2(df, in.get(), buf, centers, indices);
HWY_ASSERT(unused_clusters == kClusters - 1);
for (size_t i = 0; i < unused_clusters; ++i) {
HWY_ASSERT(centers[i] == 0.0f);
}
HWY_ASSERT(centers[unused_clusters] == 0.5f);
for (size_t i = 0; i < kGroupSize; ++i) {
HWY_ASSERT(indices[i] == unused_clusters);
}
}
};
void TestAllFlat() { hn::ForGEVectors<64, TestFlat>()(float()); }
// Generate shuffled plateaus, one per cluster
struct TestPlateaus {
template <typename T, class DF>
HWY_INLINE void operator()(T /*unused*/, DF df) {
// Run this simple test only once to save time/debug output.
if (!(HWY_ONCE && hn::Lanes(df) == hn::Lanes(hn::ScalableTag<float>()))) {
return;
}
auto in = hwy::AllocateAligned<float>(kGroupSize);
HWY_ASSERT(in);
for (size_t i = 0; i < kGroupSize; ++i) {
const size_t idx_cluster = i / (kGroupSize / kClusters);
HWY_ASSERT(idx_cluster < kClusters);
in[i] = (1.0f * idx_cluster / kClusters) - 0.5f;
HWY_ASSERT(-0.5f <= in[i] && in[i] < 0.5f);
}
std::random_device rd;
std::mt19937 rng(rd());
std::shuffle(in.get(), in.get() + kGroupSize, rng);
ClusterBuf buf;
float centers[kClusters];
uint16_t indices[kGroupSize];
const size_t unused_clusters =
NuqClustering::ClusterExactL2(df, in.get(), buf, centers, indices);
HWY_ASSERT(unused_clusters == 0);
DistortionStats stats;
for (size_t i = 0; i < kGroupSize; ++i) {
HWY_ASSERT(indices[i] < kClusters);
stats.Notify(in[i], centers[indices[i]]);
}
const float pnorm = stats.PNorm();
const float snr = stats.GeomeanValueDivL1();
fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr,
stats.MaxIndex(), stats.MaxL1());
HWY_ASSERT(pnorm == 0.0f);
HWY_ASSERT(snr == 0.0f);
}
};
void TestAllPlateaus() { hn::ForGEVectors<64, TestPlateaus>()(float()); }
struct TestRamp {
template <typename T, class DF>
HWY_INLINE void operator()(T /*unused*/, DF df) {
// Run this simple test only once to save time/debug output.
if (!(HWY_ONCE && hn::Lanes(df) == hn::Lanes(hn::ScalableTag<float>()))) {
return;
}
auto in = hwy::AllocateAligned<float>(kGroupSize);
HWY_ASSERT(in);
for (size_t i = 0; i < kGroupSize; ++i) {
in[i] = (1.0f * i / kGroupSize) - 0.45f; // slightly asymmetric
HWY_ASSERT(-0.45f <= in[i] && in[i] < 0.55f);
}
std::random_device rd;
std::mt19937 rng(rd());
std::shuffle(in.get(), in.get() + kGroupSize, rng);
ClusterBuf buf;
float centers[kClusters];
uint16_t indices[kGroupSize];
const size_t unused_clusters =
NuqClustering::ClusterExactL2(df, in.get(), buf, centers, indices);
HWY_ASSERT(unused_clusters == 0);
DistortionStats stats;
for (size_t i = 0; i < kGroupSize; ++i) {
HWY_ASSERT(indices[i] < kClusters);
stats.Notify(in[i], centers[indices[i]]);
}
const float pnorm = stats.PNorm();
const float snr = stats.GeomeanValueDivL1();
fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr,
stats.MaxIndex(), stats.MaxL1());
static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected");
const float expected_pnorm = kGroupSize == 128 ? 2.08E-2f : 2.1E-2f;
const float expected_snr = kGroupSize == 128 ? 16.9f : 17.6f;
HWY_ASSERT(expected_pnorm <= pnorm && pnorm < 1.02f * expected_pnorm);
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr);
}
};
void TestAllRamp() { hn::ForGEVectors<64, TestRamp>()(float()); }
struct TestNormal {
template <typename T, class DF>
HWY_INLINE void operator()(T /*unused*/, DF df) {
auto in = hwy::AllocateAligned<float>(kGroupSize);
HWY_ASSERT(in);
std::mt19937 rng(123);
std::normal_distribution<float> dist{0.001f, 0.3f};
for (size_t i = 0; i < kGroupSize; ++i) {
in[i] = dist(rng);
}
std::shuffle(in.get(), in.get() + kGroupSize, rng);
ClusterBuf buf;
float centers[kClusters];
uint16_t indices[kGroupSize];
double elapsed = hwy::HighestValue<double>();
for (size_t rep = 0; rep < 100; ++rep) {
const double t0 = hwy::platform::Now();
const size_t unused_clusters =
NuqClustering::ClusterExactL2(df, in.get(), buf, centers, indices);
HWY_ASSERT(unused_clusters == 0);
const double t1 = hwy::platform::Now();
elapsed = HWY_MIN(elapsed, t1 - t0);
}
fprintf(stderr, "Vec %zu Enc %.2f MB/s\n", Lanes(df) * 4,
kGroupSize * sizeof(float) * 1E-6 / elapsed);
DistortionStats stats;
for (size_t i = 0; i < kGroupSize; ++i) {
HWY_ASSERT(indices[i] < kClusters);
stats.Notify(in[i], centers[indices[i]]);
}
const float pnorm = stats.PNorm();
const float snr = stats.GeomeanValueDivL1();
fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr,
stats.MaxIndex(), stats.MaxL1());
static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected");
const float expected_pnorm = kGroupSize == 128 ? 3E-2f : 3.4E-2f;
const float expected_snr = kGroupSize == 128 ? 17.4f : 13.1f;
HWY_ASSERT(expected_pnorm <= pnorm && pnorm < 1.02f * expected_pnorm);
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr);
}
};
void TestAllNormal() { hn::ForGEVectors<64, TestNormal>()(float()); }
// Can encode and decode sub-regions.
struct TestOffset {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::Repartition<float, D> df;
const size_t total = 10 * kGroupSize;
const size_t kMidLen = 2 * kGroupSize; // length of middle piece
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
auto dec1 = hwy::AllocateAligned<T>(total);
auto dec2 = hwy::AllocateAligned<T>(kMidLen);
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
HWY_ASSERT(in && dec1 && dec2 && nuq);
std::mt19937 rng(123);
std::normal_distribution<float> dist{0.001f, 0.3f};
for (size_t i = 0; i < total; ++i) {
in[i] = dist(rng);
}
// Encode + decode everything
ClusterBuf buf;
(void)NuqCodec::Enc(df, in.get(), total, buf, total, nuq.get(), 0);
NuqCodec::Dec(d, total, nuq.get(), 0, dec1.get(), total);
// Overwrite middle with first inputs
const size_t offset = 5 * kGroupSize;
(void)NuqCodec::Enc(df, in.get(), kMidLen, buf, total, nuq.get(), offset);
// Decoded middle now matches previously decoded first
NuqCodec::Dec(d, total, nuq.get(), offset, dec2.get(), kMidLen);
for (size_t i = 0; i < kMidLen; ++i) {
HWY_ASSERT(dec1[i] == dec2[i]);
}
}
};
void TestAllOffsetF32() {
const hn::ForGEVectors<128, TestOffset> test;
test(float());
}
void TestAllOffsetBF16() {
const hn::ForGEVectors<128, TestOffset> test;
test(hwy::bfloat16_t());
}
struct TestStream {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::Repartition<float, D> df;
const size_t num = 4 * kGroupSize;
auto in = hwy::AllocateAligned<float>(num); // Enc() requires f32
auto out = hwy::AllocateAligned<T>(num);
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(num));
HWY_ASSERT(in && out && nuq);
std::mt19937 rng(123);
std::normal_distribution<float> dist{0.001f, 0.3f};
for (size_t i = 0; i < num; ++i) {
in[i] = dist(rng);
}
ClusterBuf buf;
double elapsed = hwy::HighestValue<double>();
for (size_t rep = 0; rep < 100; ++rep) {
const double t0 = hwy::platform::Now();
const size_t unused_clusters =
NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0);
HWY_ASSERT(unused_clusters == 0);
const double t1 = hwy::platform::Now();
elapsed = HWY_MIN(elapsed, t1 - t0);
}
fprintf(stderr, "Vec %zu Enc %.2f MB/s\n", Lanes(d) * sizeof(T),
num * sizeof(float) * 1E-6 / elapsed);
elapsed = hwy::HighestValue<double>();
for (size_t rep = 0; rep < 100; ++rep) {
const double t0 = hwy::platform::Now();
NuqCodec::Dec(d, num, nuq.get(), 0, out.get(), num);
const double t1 = hwy::platform::Now();
elapsed = HWY_MIN(elapsed, t1 - t0);
}
fprintf(stderr, "Vec %zu Dec %.2f MB/s\n", Lanes(d) * sizeof(T),
num * sizeof(T) * 1E-6 / elapsed);
DistortionStats stats;
for (size_t i = 0; i < num; ++i) {
stats.Notify(in[i], hwy::ConvertScalarTo<float>(out[i]));
}
const float pnorm = stats.PNorm();
const float snr = stats.GeomeanValueDivL1();
fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr,
stats.MaxIndex(), stats.MaxL1());
static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected");
const float expected_pnorm = kGroupSize == 128 ? 3.44E-2f : 3.88E-2f;
const float expected_snr = kGroupSize == 128 ? 15.0f : 13.3f;
HWY_ASSERT(expected_pnorm <= pnorm && pnorm < 1.02f * expected_pnorm);
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr);
}
};
void TestAllStreamF32() {
const hn::ForGEVectors<128, TestStream> test;
test(float());
}
void TestAllStreamBF16() {
const hn::ForGEVectors<128, TestStream> test;
test(hwy::bfloat16_t());
}
struct TestDot {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::Repartition<float, D> df;
const size_t num = 4 * kGroupSize;
auto in = hwy::AllocateAligned<float>(num);
auto dec = hwy::AllocateAligned<float>(num);
auto vec = hwy::AllocateAligned<T>(num);
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(num));
HWY_ASSERT(in && dec && vec && nuq);
std::mt19937 rng(123);
std::normal_distribution<float> dist{0.001f, 0.3f};
for (size_t i = 0; i < num; ++i) {
in[i] = dist(rng);
vec[i] = hwy::ConvertScalarTo<T>(dist(rng));
}
// This changes the correlation between in and vec, which considerably
// affects the error of the result.
std::shuffle(in.get(), in.get() + num, rng);
ClusterBuf buf;
const size_t unused_clusters =
NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0);
HWY_ASSERT(unused_clusters == 0);
double actual = 0.0;
double elapsed = hwy::HighestValue<double>();
for (size_t rep = 0; rep < 20; ++rep) {
hn::Vec<decltype(df)> sum0 = hn::Zero(df);
hn::Vec<decltype(df)> sum1 = hn::Zero(df);
hn::Vec<decltype(df)> sum2 = hn::Zero(df);
hn::Vec<decltype(df)> sum3 = hn::Zero(df);
const double t0 = hwy::platform::Now();
NuqCodec::Dot(df, num, nuq.get(), 0, vec.get(), num, sum0, sum1, sum2,
sum3);
const double t1 = hwy::platform::Now();
elapsed = HWY_MIN(elapsed, t1 - t0);
sum0 = hn::Add(hn::Add(sum0, sum1), hn::Add(sum2, sum3));
actual = hn::ReduceSum(df, sum0);
}
NuqCodec::Dec(df, num, nuq.get(), 0, dec.get(), num);
fprintf(stderr, "Vec %zu Dec %.2f MB/s\n", Lanes(d) * sizeof(T),
num * sizeof(in[0]) * 1E-6 / elapsed);
double expected = 0.0; // using original input
double expected2 = 0.0; // using decoded NUQ
for (size_t i = 0; i < num; ++i) {
expected += in[i] * hwy::ConvertScalarTo<double>(vec[i]);
expected2 += dec[i] * hwy::ConvertScalarTo<double>(vec[i]);
}
const double l1 = hwy::ScalarAbs(expected - actual);
const double snr = 1.0 + hwy::ScalarAbs(expected) / l1;
fprintf(stderr, "expected %.3f e2 %.4f actual %.4f l1 %E snr %.2f\n",
expected, expected2, actual, l1, snr);
HWY_ASSERT(hwy::ScalarAbs(expected2 - actual) < 1E-4);
static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected");
const double expected_l1 = kGroupSize == 128 ? 7.3E-2 : 4.34E-2;
const double expected_snr = kGroupSize == 128 ? 9.7f
: sizeof(T) == 2 ? 14.5f
: 14.9f;
HWY_ASSERT(expected_l1 <= l1 && l1 < 1.02f * expected_l1);
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr);
}
};
void TestAllDotF32() {
const hn::ForGEVectors<128, TestDot> test;
test(float());
}
void TestAllDotBF16() {
const hn::ForGEVectors<128, TestDot> test;
test(hwy::bfloat16_t());
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_BEFORE_TEST(NuqTest);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllFlat);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllPlateaus);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllRamp);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNormal);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllOffsetF32);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllOffsetBF16);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamF32);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamBF16);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllDotF32);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllDotBF16);
} // namespace gcpp
#endif

515
compression/sfp-inl.h Normal file
View File

@ -0,0 +1,515 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Normal include guard to placate lint.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_
#include <stddef.h>
#include <stdint.h>
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h"
#include "hwy/base.h"
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_
// Actual per-target include guard.
#if defined(THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE
#endif
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
// For unsigned numbers with MSB zero, signed comparison is faster on x86.
template <class DU>
HWY_INLINE hn::Mask<DU> SignedGt(DU du, hn::Vec<DU> a, hn::Vec<DU> b) {
const hn::RebindToSigned<DU> di;
return hn::RebindMask(du, hn::Gt(BitCast(di, a), hn::BitCast(di, b)));
}
template <class DU>
HWY_INLINE hn::Mask<DU> SignedLt(DU du, hn::Vec<DU> a, hn::Vec<DU> b) {
return SignedGt(du, b, a);
}
// Encode/decode functions.
class SfpCodec {
public:
// Returns 8-bit packed representation of `lo` and `hi` bytes of bf16. 31 ops.
// Implementation detail, public because called by test.
template <class D, HWY_IF_U8_D(D)>
static HWY_INLINE hn::Vec<D> EncBytes(D d, const hn::Vec<D> lo,
const hn::Vec<D> hi) {
const hn::Vec<D> k1 = hn::Set(d, 1u);
const hn::Vec<D> k80 = hn::Set(d, 0x80u);
// Copy sign for later insertion.
const hn::Vec<D> sign_in_msb = hi;
// Biased exponent = lower 7 bits of hi and MSB of lo. Modified below.
hn::Vec<D> biased_e = hn::Or(hn::Add(hi, hi), hn::ShiftRight<7>(lo));
HWY_ASSERT(hn::AllTrue(d, hn::Lt(biased_e, k80))); // <= 2^0
// Clear MSB to isolate the mantissa and enable signed comparisons, then
// shift right by *one* (plus 1 to undo the prior add/left-shift) to leave
// headroom for overflow during rounding.
const hn::Vec<D> m6 = hn::ShiftRight<2>(hn::Add(lo, lo));
// The place to round depends on whether the exponent is large (>= -7) - if
// so, we retain three mantissa bits, otherwise two. However, rounding can
// also cause the exponent to increase. We first choose a threshold that
// rounds up to 1.0*2^-7 for both two and three bit mantissas:
// >= 1.1111 * 2^-8 (0.007568359375). This entails the exponent being
// greater, or equal and the mantissa > (1111000 >> 1) - 1 = 0x3B.
const hn::Vec<D> kMinLargeE = hn::Set(d, 127 - 8);
const hn::Mask<D> is_large_before_round = hn::Or(
SignedGt(d, biased_e, kMinLargeE),
hn::And(hn::Eq(biased_e, kMinLargeE), SignedGt(d, m6, Set(d, 0x3B))));
// To retain the most-significant 3 or 2 mantissa bits, we will right-shift
// by is_large_before_round ? 3 : 4. Variable Shr is expensive for 8-bit
// elements, so (<< 1) if is_large_before_round, then always (>> 4).
const hn::Vec<D> m_shl4 =
hn::MaskedAddOr(m6, is_large_before_round, m6, m6);
// Before shifting (truncation), round to nearest even to reduce bias. If
// the lowest remaining mantissa bit is odd, increase the offset. Example
// with the lowest remaining bit (left) and next lower two bits; the
// latter, plus two more, will be truncated.
// 0[00] + 1 = 0[01]
// 0[01] + 1 = 0[10]
// 0[10] + 1 = 0[11] (round down toward even)
// 0[11] + 1 = 1[00] (round up)
// 1[00] + 10 = 1[10]
// 1[01] + 10 = 1[11]
// 1[10] + 10 = C0[00] (round up toward even with C=1 carry out)
// 1[11] + 10 = C0[01] (round up toward even with C=1 carry out)
const hn::Vec<D> odd_bit = hn::And(hn::ShiftRight<4>(m_shl4), k1);
const hn::Vec<D> rounded = hn::Add(m_shl4, hn::Add(odd_bit, Set(d, 7)));
// Update the exponent if rounding overflowed.
const hn::Vec<D> carry_bit =
hn::IfThenElse(is_large_before_round, k80, hn::Set(d, 0x40u));
const hn::Vec<D> carry_clear = hn::AndNot(carry_bit, rounded);
HWY_DASSERT(hn::AllTrue(d, hn::Lt(carry_clear, carry_bit)));
const hn::Mask<D> is_overflow = hn::Ne(carry_clear, rounded);
biased_e = hn::MaskedAddOr(biased_e, is_overflow, biased_e, k1);
HWY_DASSERT(hn::AllTrue(d, hn::Lt(biased_e, Set(d, 128))));
// Detect if zero or the min exponent.
const hn::Vec<D> kMinNormal = hn::Set(d, 127 - 23);
const hn::Mask<D> is_zero = SignedLt(d, biased_e, kMinNormal);
const hn::Mask<D> is_min = hn::Eq(biased_e, kMinNormal);
// 1.1110xxx * 2^-8 was considered small above, and thus rounded up to 2^-7,
// which the decoder will consider large, and expect 3 mantissa bits. If we
// set the threshold above to 1.111, then it does NOT round up. Thus we
// check exponent >= -7 *after* rounding.
const hn::Mask<D> is_large = SignedGt(d, biased_e, hn::Set(d, 127 - 8));
// To extract and pack the mantissa, only is_large matters. Either it
// matches is_large_before_round, or the rounding resulted in mantissa=0, so
// we either extract two or three bits by shifting out the lower 5..6 bits.
// is_large_before is_large rounded want
// 0 0 0Cmm???? mm
// 0 1 0100???? 000
// 1 0 impossible -
// 1 1 Cmmm???0 mmm
hn::Vec<D> m = hn::ShiftRight<4>(carry_clear);
HWY_DASSERT(hn::AllTrue(
d, SignedLt(d, m,
hn::IfThenElse(is_large, hn::Set(d, 8), hn::Set(d, 4)))));
// 1.0 * 2^-23 has the same encoding as zero, so round it up to 1.01.
m = hn::MaskedMaxOr(m, is_min, m, k1);
const hn::Vec<D> e_bias = hn::IfThenElse(
is_large,
hn::Set(d, hwy::BitCastScalar<uint8_t>(static_cast<int8_t>(15 - 127))),
hn::Set(d, hwy::BitCastScalar<uint8_t>(static_cast<int8_t>(23 - 127))));
const hn::Vec<D> e = hn::Add(biased_e, e_bias);
HWY_DASSERT(
hn::AllTrue(d, hn::Lt(hn::IfThenZeroElse(is_zero, e), hn::Set(d, 16))));
// Shift exponent left 2 or 3 bits to make space for `m`.
const hn::Vec<D> em =
hn::Or(m, hn::ShiftLeft<2>(hn::MaskedAddOr(e, is_large, e, e)));
HWY_DASSERT(hn::AllTrue(d, hn::Lt(hn::IfThenZeroElse(is_zero, em), k80)));
const hn::Vec<D> encoded = hn::BitwiseIfThenElse(k80, sign_in_msb, em);
// Doing this last ensures -0 is replaced with 0.
return hn::IfThenZeroElse(is_zero, encoded);
}
// Decodes u8 `encoded` into `lo` and `hi` bytes of bf16. 12 ops.
// Implementation detail, public because called by test.
template <class D, HWY_IF_U8_D(D)>
static HWY_INLINE void DecBytes(D d, hn::Vec<D> encoded, hn::Vec<D>& lo,
hn::Vec<D>& hi) {
const hn::Vec<D> k0 = hn::Zero(d);
const hn::Vec<D> k80 = hn::Set(d, 0x80u);
HWY_DASSERT(hn::AllTrue(d, hn::Ne(encoded, k80))); // -0 is reserved
// Copy sign for later insertion via BitwiseIfThenElse.
const hn::Vec<D> sign_in_msb = encoded;
encoded = hn::AndNot(k80, encoded);
// Special-case zero, negated so we can use MaskedAddOr. Signed comparison
// is fine because we have cleared the sign bit.
const hn::Mask<D> is_nonzero = SignedGt(d, encoded, k0);
// If MSB is clear, we have two mantissa bits, otherwise three.
const hn::Mask<D> is_small_e = SignedLt(d, encoded, hn::Set(d, 64));
// If is_small_e, add/left-shift 0xxxx.mm to 0xxxx.mm0; else keep 1xxx.mmm.
const hn::Vec<D> e4m3 =
hn::MaskedAddOr(encoded, is_small_e, encoded, encoded);
HWY_DASSERT(hn::AllTrue(d, hn::Lt(e4m3, k80)));
const hn::Vec<D> e = hn::ShiftRight<3>(e4m3); // 4-bit exponent only
HWY_DASSERT(hn::AllTrue(d, hn::Lt(e, Set(d, 16u))));
// The encoded exponent for 2^0 is 15, so subtract 15. Add 127 for the
// binary32/bf16 bias. Subtract another 8 if is_small_e because its lowest
// encoded value (0) should be less than the lowest 'large' exponent 2^-7.
const hn::Vec<D> e_bias = hn::IfThenElse(
is_small_e, hn::Set(d, 127u - 15u - 8u), hn::Set(d, 127u - 15u));
// Special-case zero or add e_bias. If encoded=0, e and e4m3 are zero, but
// we must zero e_bias to get the desired all-zero bf16.
const hn::Vec<D> biased_e = hn::MaskedAddOr(k0, is_nonzero, e_bias, e);
// The decoded binary32 exponent should be at most 2^0.
HWY_DASSERT(hn::AllTrue(d, hn::Lt(biased_e, k80)));
// Shift the MSB of e4m3's mantissa into the MSB of the bf16 mantissa.
const hn::Vec<D> m7 = hn::ShiftLeft<4>(e4m3);
// Lower byte of bf16 = exponent LSB || mantissa.
lo = hn::BitwiseIfThenElse(k80, hn::ShiftLeft<7>(biased_e), m7);
// Upper byte of bf16 = sign || lower 7 bits of exponent.
hi = hn::BitwiseIfThenElse(k80, sign_in_msb, hn::ShiftRight<1>(biased_e));
}
// Encodes `num` bf16 values from `in_bf` to `out_packed`.
template <class DBF, HWY_IF_BF16_D(DBF)>
static HWY_INLINE void Enc(DBF dbf, const hwy::bfloat16_t* HWY_RESTRICT in_bf,
size_t num, SfpStream* HWY_RESTRICT out_packed) {
const hn::Repartition<uint8_t, DBF> d8;
using V8 = hn::Vec<decltype(d8)>;
const size_t N16 = hn::Lanes(dbf);
size_t i = 0;
if (num >= 2 * N16) {
HWY_UNROLL(1)
for (; i <= num - 2 * N16; i += 2 * N16) {
const V8 packed = Enc2B(dbf, in_bf + i);
hn::StoreU(packed, d8, &out_packed->byte + i);
}
}
const size_t remaining = num - i;
HWY_DASSERT(remaining < 2 * N16);
if (remaining != 0) {
HWY_ALIGN hwy::bfloat16_t padded[2 * hn::MaxLanes(dbf)];
hwy::ZeroBytes(padded, sizeof(padded));
hwy::CopyBytes(in_bf + i, padded, remaining * sizeof(padded[0]));
const V8 packed = Enc2B(dbf, padded);
hn::StoreN(packed, d8, &out_packed->byte + i, remaining);
}
}
// Encodes `num` f32 values from `in_f` to `packed`.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Enc(DF df, const float* HWY_RESTRICT in_f, size_t num,
SfpStream* HWY_RESTRICT out_packed) {
const hn::Repartition<uint8_t, DF> d8;
using V8 = hn::Vec<decltype(d8)>;
const size_t NF = hn::Lanes(df);
size_t i = 0;
if (num >= 4 * NF) {
HWY_UNROLL(1)
for (; i <= num - 4 * NF; i += 4 * NF) {
const V8 packed = Enc4F(df, in_f + i);
hn::StoreU(packed, d8, &out_packed->byte + i);
}
}
const size_t remaining = num - i;
HWY_DASSERT(remaining < 4 * NF);
if (remaining != 0) {
HWY_ALIGN float padded[4 * hn::MaxLanes(df)];
hwy::ZeroBytes(padded, sizeof(padded));
hwy::CopyBytes(in_f + i, padded, remaining * sizeof(padded[0]));
const V8 packed = Enc4F(df, padded);
hn::StoreN(packed, d8, &out_packed->byte + i, remaining);
}
}
// Decodes `num` values from `in_packed` to `out_bf`.
template <class DBF, HWY_IF_BF16_D(DBF)>
static HWY_INLINE void Dec(DBF dbf, const SfpStream* HWY_RESTRICT in_packed,
size_t num, hwy::bfloat16_t* HWY_RESTRICT out_bf) {
const hn::Repartition<uint8_t, DBF> d8;
using V8 = hn::Vec<decltype(d8)>;
using VBF = hn::Vec<decltype(dbf)>;
const size_t N16 = hn::Lanes(dbf);
size_t i = 0;
if (num >= 2 * N16) {
HWY_UNROLL(1)
for (; i <= num - 2 * N16; i += 2 * N16) {
const V8 packed = hn::LoadU(d8, &in_packed->byte + i);
VBF bf0, bf1;
Dec2B(dbf, packed, bf0, bf1);
hn::StoreU(bf0, dbf, out_bf + i);
hn::StoreU(bf1, dbf, out_bf + i + N16);
}
}
const size_t remaining = num - i;
HWY_DASSERT(remaining < 2 * N16);
if (remaining != 0) {
const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining);
HWY_ALIGN hwy::bfloat16_t padded[2 * hn::MaxLanes(dbf)];
VBF bf0, bf1;
Dec2B(dbf, packed, bf0, bf1);
hn::StoreU(bf0, dbf, padded);
hn::StoreU(bf1, dbf, padded + N16);
hwy::CopyBytes(padded, out_bf + i, remaining * sizeof(padded[0]));
}
}
// Decodes `num` values from `in_packed` to `out_f`.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Dec(DF df, const SfpStream* HWY_RESTRICT in_packed,
size_t num, float* HWY_RESTRICT out_f) {
const hn::Repartition<uint8_t, DF> d8;
using V8 = hn::Vec<decltype(d8)>;
using VF = hn::Vec<decltype(df)>;
const size_t NF = hn::Lanes(df);
size_t i = 0;
if (num >= 4 * NF) {
HWY_UNROLL(1)
for (; i <= num - 4 * NF; i += 4 * NF) {
const V8 packed = hn::LoadU(d8, &in_packed->byte + i);
VF f0, f1, f2, f3;
Dec4F(df, packed, f0, f1, f2, f3);
hn::StoreU(f0, df, out_f + i + NF * 0);
hn::StoreU(f1, df, out_f + i + NF * 1);
hn::StoreU(f2, df, out_f + i + NF * 2);
hn::StoreU(f3, df, out_f + i + NF * 3);
}
}
const size_t remaining = num - i;
HWY_DASSERT(remaining < 4 * NF);
if (remaining != 0) {
const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining);
HWY_ALIGN float padded[4 * hn::MaxLanes(df)];
VF f0, f1, f2, f3;
Dec4F(df, packed, f0, f1, f2, f3);
hn::StoreU(f0, df, padded + NF * 0);
hn::StoreU(f1, df, padded + NF * 1);
hn::StoreU(f2, df, padded + NF * 2);
hn::StoreU(f3, df, padded + NF * 3);
hwy::CopyBytes(padded, out_f + i, remaining * sizeof(padded[0]));
}
}
// Fused decode and dot product with bf16 into four output accumulators.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Dot(DF df, const SfpStream* HWY_RESTRICT in_packed,
size_t num,
const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
hn::Vec<DF>& sum0, hn::Vec<DF>& sum1,
hn::Vec<DF>& sum2, hn::Vec<DF>& sum3) {
const hn::Repartition<uint8_t, DF> d8;
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
using V8 = hn::Vec<decltype(d8)>;
using VBF = hn::Vec<decltype(dbf)>;
const size_t N16 = hn::Lanes(dbf);
size_t i = 0;
if (num >= 2 * N16) {
HWY_UNROLL(1)
for (; i <= num - 2 * N16; i += 2 * N16) {
const V8 packed = hn::LoadU(d8, &in_packed->byte + i);
const VBF v0 = hn::LoadU(dbf, vec_aligned + i);
const VBF v1 = hn::LoadU(dbf, vec_aligned + i + N16);
VBF bf0, bf1;
Dec2B(dbf, packed, bf0, bf1);
sum0 = hn::ReorderWidenMulAccumulate(df, bf0, v0, sum0, sum1);
sum2 = hn::ReorderWidenMulAccumulate(df, bf1, v1, sum2, sum3);
}
}
const size_t remaining = num - i;
if (remaining != 0) {
const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining);
HWY_ALIGN hwy::bfloat16_t padded[2 * hn::MaxLanes(dbf)];
hwy::ZeroBytes(padded, sizeof(padded));
hwy::CopyBytes(vec_aligned + i, padded, remaining * sizeof(padded[0]));
const VBF v0 = hn::LoadU(dbf, padded);
const VBF v1 = hn::LoadU(dbf, padded + N16);
VBF bf0, bf1;
Dec2B(dbf, packed, bf0, bf1);
sum0 = hn::ReorderWidenMulAccumulate(df, bf0, v0, sum0, sum1);
sum2 = hn::ReorderWidenMulAccumulate(df, bf1, v1, sum2, sum3);
}
}
// Fused decode and dot product with f32 into four output accumulators.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Dot(DF df, const SfpStream* HWY_RESTRICT in_packed,
size_t num, const float* HWY_RESTRICT vec_aligned,
hn::Vec<DF>& sum0, hn::Vec<DF>& sum1,
hn::Vec<DF>& sum2, hn::Vec<DF>& sum3) {
const hn::Repartition<uint8_t, DF> d8;
using V8 = hn::Vec<decltype(d8)>;
using VF = hn::Vec<decltype(df)>;
const size_t NF = hn::Lanes(df);
size_t i = 0;
if (num >= 4 * NF) {
HWY_UNROLL(1)
for (; i <= num - 4 * NF; i += 4 * NF) {
const V8 packed = hn::LoadU(d8, &in_packed->byte + i);
const VF v0 = hn::LoadU(df, vec_aligned + i + NF * 0);
const VF v1 = hn::LoadU(df, vec_aligned + i + NF * 1);
const VF v2 = hn::LoadU(df, vec_aligned + i + NF * 2);
const VF v3 = hn::LoadU(df, vec_aligned + i + NF * 3);
VF f0, f1, f2, f3;
Dec4F(df, packed, f0, f1, f2, f3);
sum0 = hn::MulAdd(f0, v0, sum0);
sum1 = hn::MulAdd(f1, v1, sum1);
sum2 = hn::MulAdd(f2, v2, sum2);
sum3 = hn::MulAdd(f3, v3, sum3);
}
}
const size_t remaining = num - i;
if (remaining != 0) {
const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining);
HWY_ALIGN float padded[4 * hn::MaxLanes(df)];
hwy::ZeroBytes(padded, sizeof(padded));
hwy::CopyBytes(vec_aligned + i, padded, remaining * sizeof(padded[0]));
const VF v0 = hn::LoadU(df, padded + NF * 0);
const VF v1 = hn::LoadU(df, padded + NF * 1);
const VF v2 = hn::LoadU(df, padded + NF * 2);
const VF v3 = hn::LoadU(df, padded + NF * 3);
VF f0, f1, f2, f3;
Dec4F(df, packed, f0, f1, f2, f3);
sum0 = hn::MulAdd(f0, v0, sum0);
sum1 = hn::MulAdd(f1, v1, sum1);
sum2 = hn::MulAdd(f2, v2, sum2);
sum3 = hn::MulAdd(f3, v3, sum3);
}
}
private:
// Wrappers to avoid code duplication across float/bf16 input types and
// the main loop/remainder.
// Returns vector of packed bytes for callers to StoreU or StoreN.
template <class D16, HWY_IF_U16_D(D16),
class V8 = hn::Vec<hn::Repartition<uint8_t, D16>>>
static HWY_INLINE V8 Enc2U(D16 d16, const hn::Vec<D16> w0,
const hn::Vec<D16> w1) {
const hn::Repartition<uint8_t, D16> d8;
// Although more expensive on AVX3, in-order packing enables streaming
// decompression without fixed-size packets.
const V8 lo = hn::ConcatEven(d8, hn::BitCast(d8, w1), hn::BitCast(d8, w0));
const V8 hi = hn::ConcatOdd(d8, hn::BitCast(d8, w1), hn::BitCast(d8, w0));
return EncBytes(d8, lo, hi);
}
template <class DBF, HWY_IF_BF16_D(DBF),
class V8 = hn::Vec<hn::Repartition<uint8_t, DBF>>>
static HWY_INLINE V8 Enc2B(DBF dbf, const hwy::bfloat16_t* HWY_RESTRICT in) {
const hn::Repartition<uint16_t, DBF> d16;
const size_t N16 = hn::Lanes(d16);
using V16 = hn::Vec<decltype(d16)>;
const V16 w0 = hn::BitCast(d16, hn::LoadU(dbf, in));
const V16 w1 = hn::BitCast(d16, hn::LoadU(dbf, in + N16));
return Enc2U(d16, w0, w1);
}
template <class DF, HWY_IF_F32_D(DF),
class V8 = hn::Vec<hn::Repartition<uint8_t, DF>>>
static HWY_INLINE V8 Enc4F(DF df, const float* HWY_RESTRICT in) {
const hn::Repartition<uint16_t, DF> d16;
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
using VF = hn::Vec<decltype(df)>;
using V16 = hn::Vec<decltype(d16)>;
const size_t NF = hn::Lanes(df);
const VF f0 = hn::LoadU(df, in + NF * 0);
const VF f1 = hn::LoadU(df, in + NF * 1);
const VF f2 = hn::LoadU(df, in + NF * 2);
const VF f3 = hn::LoadU(df, in + NF * 3);
// Chop off the lower 16 bits; EncBytes still rounds properly.
const V16 w0 = hn::BitCast(d16, hn::OrderedDemote2To(dbf, f0, f1));
const V16 w1 = hn::BitCast(d16, hn::OrderedDemote2To(dbf, f2, f3));
return Enc2U(d16, w0, w1);
}
template <class D16, HWY_IF_U16_D(D16),
class V8 = hn::Vec<hn::Repartition<uint8_t, D16>>>
static HWY_INLINE void Dec2U(D16 d16, V8 packed, hn::Vec<D16>& w0,
hn::Vec<D16>& w1) {
const hn::Repartition<uint8_t, D16> d8;
V8 lo, hi;
DecBytes(d8, packed, lo, hi);
w0 = hn::BitCast(d16, hn::InterleaveWholeLower(d8, lo, hi));
w1 = hn::BitCast(d16, hn::InterleaveWholeUpper(d8, lo, hi));
}
template <class DBF, HWY_IF_BF16_D(DBF),
class V8 = hn::Vec<hn::Repartition<uint8_t, DBF>>>
static HWY_INLINE void Dec2B(DBF dbf, V8 packed, hn::Vec<DBF>& bf0,
hn::Vec<DBF>& bf1) {
const hn::Repartition<uint16_t, DBF> d16;
using V16 = hn::Vec<decltype(d16)>;
V16 w0, w1;
Dec2U(d16, packed, w0, w1);
bf0 = hn::BitCast(dbf, w0);
bf1 = hn::BitCast(dbf, w1);
}
template <class DF, HWY_IF_F32_D(DF),
class V8 = hn::Vec<hn::Repartition<uint8_t, DF>>>
static HWY_INLINE void Dec4F(DF df, V8 packed, hn::Vec<DF>& f0,
hn::Vec<DF>& f1, hn::Vec<DF>& f2,
hn::Vec<DF>& f3) {
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
using VBF = hn::Vec<decltype(dbf)>;
VBF bf0, bf1;
Dec2B(dbf, packed, bf0, bf1);
f0 = hn::PromoteLowerTo(df, bf0);
f1 = hn::PromoteUpperTo(df, bf0);
f2 = hn::PromoteLowerTo(df, bf1);
f3 = hn::PromoteUpperTo(df, bf1);
}
}; // SfpCodec
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_

51
compression/sfp.h Normal file
View File

@ -0,0 +1,51 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_H_
// Switching Floating Point: a hybrid 8-bit float representation of bf16/f32
// inputs that combines the advantages of e4m3 and e5m2 into a single format.
// It supports seeking at a granularity of 1, decoding to bf16/f32, and a
// fused decode/dot product with bf16/f32 vectors.
#include <stdint.h>
namespace gcpp {
// Points to the *start* of an SFP stream. Values are stored in-order to enable
// vector-length agnostic seeking, because streams may be written to disk for
// loading on other CPUs.
//
// Characteristics:
// - 24-bit dynamic range, with max exponent 2^0.
// - 3 bit mantissa for values >= 2^-7, otherwise 2.
//
// This is faster to decode than a straightforward implementation of eXmY, in
// part because SFP does not require subnormals. Unlike OCP MX, it also does not
// require side information (shared exponents).
//
// Although the representation could probably be shrunk to 6-7 bits, more
// savings can be had by non-uniform clustering - see nuq.h.
#pragma pack(push, 1)
struct SfpStream {
uint8_t byte;
};
#pragma pack(pop)
static inline const char* TypeName(SfpStream) { return "SFP"; }
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_H_

440
compression/sfp_test.cc Normal file
View File

@ -0,0 +1,440 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h"
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <algorithm>
#include <random>
#include <set>
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"third_party/gemma_cpp/compression/sfp_test.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep
// Any highway.h must come after foreach_target.h
// copybara:import_next_line:gemma_cpp
#include "compression/distortion.h"
// copybara:import_next_line:gemma_cpp
#include "compression/sfp-inl.h"
#include "hwy/highway.h"
#include "hwy/tests/hwy_gtest.h"
#include "hwy/tests/test_util-inl.h"
#include "hwy/timer.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
// Decode
float F32FromSFP8(uint32_t sfp) {
HWY_ASSERT(sfp < 256);
HWY_ASSERT(sfp != 0x80); // -0 is reserved
const uint32_t sign32 = (sfp & 0x80) << 24;
sfp &= 0x7F;
const bool large_e = sfp >= 64;
const size_t m_bits = large_e ? 3 : 2;
uint32_t m = sfp & ((1u << m_bits) - 1u);
size_t e = sfp >> m_bits;
if (sfp == 0) return 0.0f;
const uint32_t e_bias = large_e ? 15 : 23;
const uint32_t exp32 = static_cast<uint32_t>(127 + e - e_bias) << 23;
const uint32_t mnt32 = m << (23 - m_bits);
const uint32_t binary32 = sign32 | exp32 | mnt32;
float result;
hwy::CopySameSize(&binary32, &result);
return result;
}
void TestAllUnique() {
std::set<float> unique;
for (uint32_t sfp = 0; sfp < 256; ++sfp) {
if (sfp == 0x80) continue; // -0 is reserved
unique.insert(F32FromSFP8(sfp));
}
HWY_ASSERT_EQ(size_t{255}, unique.size());
if (false) {
for (float f : unique) {
fprintf(stderr, "%e\n", f);
}
}
}
// ------------------------------ Foreach compressed representation
// Encode
HWY_INLINE uint32_t SFP8FromF32(float f) {
HWY_ASSERT(-1.875f <= f && f <= 1.875f);
constexpr uint32_t kMaskM = hwy::MantissaMask<float>();
uint32_t binary32;
hwy::CopySameSize(&f, &binary32);
const uint32_t s = (binary32 & hwy::SignMask<float>()) >> 24;
binary32 &= ~hwy::SignMask<float>();
f = hwy::ScalarAbs(f);
// >= 1.1111 * 2^-8 rounds up to 1.0*2^-7.
bool large_e = (f >= 0.007568359375f);
const uint32_t org_binary32 = binary32;
const uint32_t m32 = binary32 & kMaskM;
binary32 = (binary32 & ~kMaskM) | m32;
size_t m_bits = large_e ? 3 : 2;
const uint32_t is_odd = (m32 >> (23 - m_bits)) & 1;
const uint32_t round = is_odd + (1u << (23 - m_bits - 1)) - 1;
const uint32_t rounded = binary32 + round;
// >= 1.111 also rounds up, but only if it was considered !large_e before.
if (f >= 0.00732421875f) {
large_e = true;
m_bits = 3;
}
uint32_t m = (kMaskM & rounded) >> (23 - m_bits);
int32_t e = (rounded >> 23) - 127;
if (e <= -23) {
// 2^-23 is the smallest normal exponent. Zero has e = -127. Do not set the
// SFP sign bit because the encoding for -0 is reserved.
if (e < -23) return 0;
// e = 2^-23: round up mantissa because m=0 encodes 0.0f.
if (m == 0) m = 1;
}
if (false) {
fprintf(stderr, "in %x round %x rounded %x e %d m %x large_e %d\n",
org_binary32, round, rounded, e, m, large_e);
}
uint32_t e_sfp = e + (large_e ? 15 : 23);
HWY_ASSERT(e_sfp < 16);
const uint32_t encoded = (e_sfp << m_bits) | m | s;
HWY_ASSERT(encoded < 256);
return encoded;
}
// For every possible encoding: ensure re-encoding the decoded value matches it.
struct TestDecEnc {
template <class T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::RepartitionToWide<D> d16;
const hn::Rebind<hwy::bfloat16_t, decltype(d16)> dbf;
const hn::Repartition<float, D> df;
for (uint32_t encoded = 0; encoded < 256; ++encoded) {
if (encoded == 0x80) continue; // -0 is reserved
const float decoded = F32FromSFP8(encoded);
const uint32_t encoded2 = SFP8FromF32(decoded);
hn::Vec<D> dec_lo, dec_hi;
SfpCodec::DecBytes(d, hn::Set(d, encoded), dec_lo, dec_hi);
const hn::Vec<decltype(dbf)> dec =
hn::BitCast(dbf, hn::ZipLower(d16, dec_lo, dec_hi));
const float vdecoded = hn::GetLane(hn::PromoteLowerTo(df, dec));
const uint32_t vencoded2 =
hn::GetLane(SfpCodec::EncBytes(d, dec_lo, dec_hi));
if (decoded != vdecoded || encoded2 != vencoded2 || encoded != encoded2) {
HWY_ABORT("enc %u -> dec %E=%x=%E -> enc %u %u\n", encoded, decoded,
hwy::BitCastScalar<uint32_t>(decoded), vdecoded, encoded2,
vencoded2);
}
}
}
};
void TestAllDecEnc() { hn::ForGEVectors<32, TestDecEnc>()(uint8_t()); }
// ------------------------------ Golden (known values)
// Generate values, encode, decode back to that value.
struct TestGolden {
template <class T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::Repartition<float, D> df;
const hn::Repartition<hwy::bfloat16_t, D> dbf;
const hn::RebindToUnsigned<decltype(dbf)> d16;
struct Golden {
float in;
float out;
};
const Golden golden[] = {
// All mantissa bits set, all discarded zero (no rounding)
{0.46875f, 0.46875f},
{0.9375f, 0.9375f},
// All mantissa bits set, one below it set (round up to pow2)
{0.484375f, 0.5f},
{0.96875f, 1.0f},
// Lowest mantissa bit set, all discarded zero (no rounding)
{0.28125f, 0.28125f},
{0.5625f, 0.5625f},
// Lowest mantissa bit set, one below it set (round up to even)
{0.296875f, 0.3125f},
{0.59375f, 0.625f},
// All mantissa zero, all discarded set (round up)
{0.279296875f, 0.28125f},
{0.55859375f, 0.5625f},
// All mantissa zero, one below it set (round DOWN to pow2)
{0.265625f, 0.25f},
{0.53125f, 0.5f},
// At inflection point: 1.max*2^-8 rounds up to 1.0*2^-7
{0.0068359375f, 0.0068359375f}, // 1.11 -> 1.11
{0.00732421875f, 0.0078125f}, // 1.111 -> 1.11[1] -> 1.0
{0.007568359375f, 0.0078125f}, // 1.1111 -> 1.0
// Above 1.0: no longer special-cased.
{1.0f, 1.0f},
{1.0625f, 1.0f}, // 1.000100
// Smallest normal exponents - we no longer use subnormals.
{2.384185791015625E-7f, 2.384185791015625E-7f}, // 1.00p-22
{1.49011611938E-07f, 1.49011611938E-07f}, // 1.01p-23
{1.19209289551E-07f, 1.49011611938E-07f}, // 1.00p-23 -> 1.01p-23
{5.96046447754E-08f, 0.0f}, // 1.00p-24 -> 0
{8.94069671631E-08f, 0.0f}, // 1.10p-24 -> 0
{1.11758708954E-07f, 1.49011611938E-07f}, // 1.111p-24-> 1.01p-23
// 1100_010 * 2^-7 rounds down to 110
{0.013841f, 0.013671875f},
};
constexpr size_t kNumGolden = sizeof(golden) / sizeof(Golden);
for (uint32_t s : {0, 1}) {
for (size_t i = 0; i < kNumGolden; ++i) {
const float in = s ? -golden[i].in : golden[i].in;
const float out = s ? -golden[i].out : golden[i].out;
const hn::Vec<decltype(dbf)> in_bf =
hn::OrderedDemote2To(dbf, hn::Set(df, in), hn::Set(df, in));
const uint32_t encoded = SFP8FromF32(in);
const uint32_t vencoded = hn::GetLane(SfpCodec::EncBytes(
d, hn::BitCast(d, in_bf),
hn::BitCast(d, hn::ShiftRight<8>(hn::BitCast(d16, in_bf)))));
const float decoded = F32FromSFP8(encoded);
hn::Vec<D> dec_lo, dec_hi;
SfpCodec::DecBytes(d, hn::Set(d, encoded), dec_lo, dec_hi);
const hn::Vec<decltype(dbf)> dec =
hn::BitCast(dbf, hn::ZipLower(d16, dec_lo, dec_hi));
const float vdecoded = hn::GetLane(hn::PromoteLowerTo(df, dec));
if (decoded != vdecoded || decoded != out || encoded != vencoded) {
HWY_ABORT("@%zu in %E dec %E %E golden %E\n", i, in, decoded,
vdecoded, golden[i].out);
}
} // i
} // s
}
};
void TestAllGolden() {
// Full vectors only, other tests cover partial vectors.
TestGolden()(uint8_t(), hn::ScalableTag<uint8_t>());
}
// ------------------------------ Foreach bf16 input
// Generate all values, encode, decode back.
struct TestEncDec {
template <class T, class DBF>
HWY_INLINE void operator()(T /*unused*/, DBF dbf) {
const hn::Repartition<uint8_t, DBF> du8;
// We only use the upper 4 of 7 bf16 mantissa bits, so force the lower three
// bits to zero to reduce the number of inputs.
constexpr size_t kStep = 8;
const size_t max = 0x8000 / 8;
auto in = hwy::AllocateAligned<T>(max);
auto packed = hwy::AllocateAligned<SfpStream>(max);
auto dec = hwy::AllocateAligned<T>(max);
HWY_ASSERT(in && packed && dec);
size_t num = 0;
for (size_t i = 0; i < max; ++i) {
const uint16_t bits = i * kStep;
const float f = hwy::F32FromBF16(hwy::BitCastScalar<T>(bits));
// Keep if within range
if (hwy::ScalarIsFinite(f) && f <= 1.875f) {
in[num] = hwy::BF16FromF32(f);
in[num + 1] = hwy::BF16FromF32(-f);
num += 2;
}
}
double enc_elapsed = hwy::HighestValue<double>();
double dec_elapsed = hwy::HighestValue<double>();
for (size_t rep = 0; rep < 100; ++rep) {
const double t0 = hwy::platform::Now();
SfpCodec::Enc(dbf, in.get(), num, packed.get());
const double t1 = hwy::platform::Now();
SfpCodec::Dec(dbf, packed.get(), num, dec.get());
const double t2 = hwy::platform::Now();
enc_elapsed = HWY_MIN(enc_elapsed, t1 - t0);
dec_elapsed = HWY_MIN(dec_elapsed, t2 - t1);
}
const double enc_mbs = num * sizeof(T) * 1E-6 / enc_elapsed;
const double dec_mbs = num * sizeof(T) * 1E-6 / dec_elapsed;
fprintf(stderr, "Vec size %zu Enc %.2f MB/s Dec %.2f MB/s\n", Lanes(du8),
enc_mbs, dec_mbs);
{
double sum = 0.0;
DistortionStats stats;
for (size_t i = 0; i < num; ++i) {
const float out = hwy::F32FromBF16(dec[i]);
sum += hwy::ConvertScalarTo<double>(hwy::ScalarAbs(in[i]));
stats.Notify(in[i], out);
}
const double avg = sum / num;
fprintf(stderr, "Avg magnitude %.3E, p-norm %.3E snr %.2f @%zu = %.4E\n",
avg, stats.PNorm(), stats.GeomeanValueDivL1(), stats.MaxIndex(),
stats.MaxL1());
}
}
};
void TestAllEncDec() { hn::ForGEVectors<32, TestEncDec>()(hwy::bfloat16_t()); }
// ------------------------------ Order
// Store 8-bit iota, decode, encode, check iota == packed. This ensures
// Enc/Dec are preserving the order independent of vector length.
struct TestOrder {
template <class T, class DBF>
HWY_INLINE void operator()(T /*unused*/, DBF dbf) {
const hn::Repartition<uint8_t, DBF> du8;
const size_t num = 10 * hn::Lanes(du8) / 3;
auto iota = hwy::AllocateAligned<SfpStream>(num);
auto packed = hwy::AllocateAligned<SfpStream>(num);
auto bf = hwy::AllocateAligned<hwy::bfloat16_t>(num);
HWY_ASSERT(iota && packed && bf);
for (size_t i = 0; i < num; ++i) {
// Clear sign bit so we can also check that bf is in ascending order.
iota[i].byte = i & 127;
}
SfpCodec::Dec(dbf, iota.get(), num, bf.get());
SfpCodec::Enc(dbf, bf.get(), num, packed.get());
for (size_t i = 0; i < num; ++i) {
if (iota[i].byte != packed[i].byte) {
HWY_ABORT("@%zu: %d %d\n", i, iota[i].byte, packed[i].byte);
}
}
}
};
void TestAllOrder() { hn::ForGEVectors<32, TestOrder>()(hwy::bfloat16_t()); }
// ------------------------------ Dot
struct TestDot {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::Repartition<float, D> df;
const size_t num = 384;
auto in = hwy::AllocateAligned<T>(num);
auto dec = hwy::AllocateAligned<T>(num);
auto vec = hwy::AllocateAligned<T>(num);
auto sfp = hwy::AllocateAligned<SfpStream>(num);
HWY_ASSERT(in && dec && vec && sfp);
std::mt19937 rng(123);
std::normal_distribution<float> dist{0.001f, 0.3f};
for (size_t i = 0; i < num; ++i) {
in[i] = hwy::ConvertScalarTo<T>(dist(rng));
vec[i] = hwy::ConvertScalarTo<T>(dist(rng));
}
// This changes the correlation between in and vec, which considerably
// affects the error of the result.
std::shuffle(in.get(), in.get() + num, rng);
SfpCodec::Enc(d, in.get(), num, sfp.get());
double actual = 0.0;
double elapsed = hwy::HighestValue<double>();
for (size_t rep = 0; rep < 200; ++rep) {
hn::Vec<decltype(df)> sum0 = hn::Zero(df);
hn::Vec<decltype(df)> sum1 = hn::Zero(df);
hn::Vec<decltype(df)> sum2 = hn::Zero(df);
hn::Vec<decltype(df)> sum3 = hn::Zero(df);
const double t0 = hwy::platform::Now();
SfpCodec::Dot(df, sfp.get(), num, vec.get(), sum0, sum1, sum2, sum3);
const double t1 = hwy::platform::Now();
elapsed = HWY_MIN(elapsed, t1 - t0);
sum0 = hn::Add(hn::Add(sum0, sum1), hn::Add(sum2, sum3));
actual = hn::ReduceSum(df, sum0);
}
SfpCodec::Dec(d, sfp.get(), num, dec.get());
fprintf(stderr, "Vec %zu Dot %.2f MB/s\n", Lanes(d) * sizeof(T),
num * sizeof(T) * 1E-6 / elapsed);
double expected = 0.0; // using original input
double expected2 = 0.0; // using decoded SFP
for (size_t i = 0; i < num; ++i) {
expected += hwy::ConvertScalarTo<double>(in[i]) *
hwy::ConvertScalarTo<double>(vec[i]);
expected2 += hwy::ConvertScalarTo<double>(dec[i]) *
hwy::ConvertScalarTo<double>(vec[i]);
}
const double l1 = hwy::ScalarAbs(expected - actual);
const double snr = 1.0 + hwy::ScalarAbs(expected) / l1;
fprintf(stderr, "expected %.3f e2 %.4f actual %.4f l1 %E snr %.2f\n",
expected, expected2, actual, l1, snr);
HWY_ASSERT(hwy::ScalarAbs(expected2 - actual) < 1E-4);
const double expected_l1 = sizeof(T) == 2 ? 1.52E-2 : 1.15E-2;
const double expected_snr = sizeof(T) == 2 ? 80.1f : 104.9f;
HWY_ASSERT(expected_l1 <= l1 && l1 < 1.02f * expected_l1);
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr);
}
};
void TestAllDotF32() {
const hn::ForGEVectors<128, TestDot> test;
test(float());
}
void TestAllDotBF16() {
const hn::ForGEVectors<128, TestDot> test;
test(hwy::bfloat16_t());
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_BEFORE_TEST(SfpTest);
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllUnique);
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDecEnc);
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllGolden);
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllEncDec);
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllOrder);
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDotF32);
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDotBF16);
} // namespace gcpp
#endif

117
compression/stats.cc Normal file
View File

@ -0,0 +1,117 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// copybara:import_next_line:gemma_cpp
#include "compression/stats.h"
#include <stdio.h>
#include <algorithm> // std::min
#include <string>
#include "hwy/base.h" // HWY_ASSERT
void Stats::Assimilate(const Stats& other) {
const int64_t total_n = n_ + other.n_;
if (total_n == 0) return; // Nothing to do; prevents div by zero.
min_ = std::min(min_, other.min_);
max_ = std::max(max_, other.max_);
product_ *= other.product_;
const double product_n = n_ * other.n_;
const double n2 = n_ * n_;
const double other_n2 = other.n_ * other.n_;
const int64_t total_n2 = total_n * total_n;
const double total_n3 = static_cast<double>(total_n2) * total_n;
// Precompute reciprocal for speed - used at least twice.
const double inv_total_n = 1.0 / total_n;
const double inv_total_n2 = 1.0 / total_n2;
const double delta = other.m1_ - m1_;
const double delta2 = delta * delta;
const double delta3 = delta * delta2;
const double delta4 = delta2 * delta2;
m1_ = (n_ * m1_ + other.n_ * other.m1_) * inv_total_n;
const double new_m2 = m2_ + other.m2_ + delta2 * product_n * inv_total_n;
const double new_m3 =
m3_ + other.m3_ + delta3 * product_n * (n_ - other.n_) * inv_total_n2 +
3.0 * delta * (n_ * other.m2_ - other.n_ * m2_) * inv_total_n;
m4_ += other.m4_ +
delta4 * product_n * (n2 - product_n + other_n2) / total_n3 +
6.0 * delta2 * (n2 * other.m2_ + other_n2 * m2_) * inv_total_n2 +
4.0 * delta * (n_ * other.m3_ - other.n_ * m3_) * inv_total_n;
m2_ = new_m2;
m3_ = new_m3;
n_ = total_n;
}
std::string Stats::ToString(int exclude) const {
if (Count() == 0) return std::string("(none)");
char buf[300];
int pos = 0;
int ret; // snprintf - bytes written or negative for error.
if ((exclude & kNoCount) == 0) {
ret = snprintf(buf + pos, sizeof(buf) - pos, "Count=%9zu ",
static_cast<size_t>(Count()));
HWY_ASSERT(ret > 0);
pos += ret;
}
if ((exclude & kNoMeanSD) == 0) {
const float sd = StandardDeviation();
if (sd > 100) {
ret = snprintf(buf + pos, sizeof(buf) - pos, "Mean=%8.2E SD=%7.1E ",
Mean(), sd);
} else {
ret = snprintf(buf + pos, sizeof(buf) - pos, "Mean=%8.6f SD=%7.5f ",
Mean(), sd);
}
HWY_ASSERT(ret > 0);
pos += ret;
}
if ((exclude & kNoMinMax) == 0) {
ret = snprintf(buf + pos, sizeof(buf) - pos, "Min=%8.5e Max=%8.5e ", Min(),
Max());
HWY_ASSERT(ret > 0);
pos += ret;
}
if ((exclude & kNoSkewKurt) == 0) {
ret = snprintf(buf + pos, sizeof(buf) - pos, "Skew=%5.2f Kurt=%7.2f ",
Skewness(), Kurtosis());
HWY_ASSERT(ret > 0);
pos += ret;
}
if ((exclude & kNoGeomean) == 0) {
ret = snprintf(buf + pos, sizeof(buf) - pos, "GeoMean=%9.6f ",
GeometricMean());
HWY_ASSERT(ret > 0);
pos += ret;
}
HWY_ASSERT(pos < sizeof(buf));
return buf;
}

190
compression/stats.h Normal file
View File

@ -0,0 +1,190 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_STATS_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_STATS_H_
#include <stdint.h>
#include <stdio.h>
#include <algorithm>
#include <cmath>
#include <string>
#include "hwy/base.h" // HWY_ASSERT
// Thread-compatible.
template <size_t N>
class Bins {
public:
Bins() { Reset(); }
template <typename T>
void Notify(T bin) {
HWY_ASSERT(T{0} <= bin && bin < static_cast<T>(N));
counts_[static_cast<int32_t>(bin)]++;
}
void Assimilate(const Bins<N>& other) {
for (size_t i = 0; i < N; ++i) {
counts_[i] += other.counts_[i];
}
}
void Print(const char* caption) {
fprintf(stderr, "\n%s [%zu]\n", caption, N);
size_t last_nonzero = 0;
for (size_t i = N - 1; i < N; --i) {
if (counts_[i] != 0) {
last_nonzero = i;
break;
}
}
for (size_t i = 0; i <= last_nonzero; ++i) {
fprintf(stderr, " %zu\n", counts_[i]);
}
}
void Reset() {
for (size_t i = 0; i < N; ++i) {
counts_[i] = 0;
}
}
private:
size_t counts_[N];
};
// Descriptive statistics of a variable (4 moments). Thread-compatible.
class Stats {
public:
Stats() { Reset(); }
void Notify(const float x) {
++n_;
min_ = std::min(min_, x);
max_ = std::max(max_, x);
product_ *= x;
// Online moments. Reference: https://goo.gl/9ha694
const double d = x - m1_;
const double d_div_n = d / n_;
const double d2n1_div_n = d * (n_ - 1) * d_div_n;
const int64_t n_poly = n_ * n_ - 3 * n_ + 3;
m1_ += d_div_n;
m4_ += d_div_n * (d_div_n * (d2n1_div_n * n_poly + 6.0 * m2_) - 4.0 * m3_);
m3_ += d_div_n * (d2n1_div_n * (n_ - 2) - 3.0 * m2_);
m2_ += d2n1_div_n;
}
void Assimilate(const Stats& other);
int64_t Count() const { return n_; }
float Min() const { return min_; }
float Max() const { return max_; }
double GeometricMean() const {
return n_ == 0 ? 0.0 : pow(product_, 1.0 / n_);
}
double Mean() const { return m1_; }
// Same as Mu2. Assumes n_ is large.
double SampleVariance() const {
return n_ == 0 ? 0.0 : m2_ / static_cast<int>(n_);
}
// Unbiased estimator for population variance even for smaller n_.
double Variance() const {
if (n_ == 0) return 0.0;
if (n_ == 1) return m2_;
return m2_ / static_cast<int>(n_ - 1);
}
double StandardDeviation() const { return std::sqrt(Variance()); }
// Near zero for normal distributions; if positive on a unimodal distribution,
// the right tail is fatter. Assumes n_ is large.
double SampleSkewness() const {
if (std::abs(m2_) < 1E-7) return 0.0;
return m3_ * std::sqrt(static_cast<double>(n_)) / std::pow(m2_, 1.5);
}
// Corrected for bias (same as Wikipedia and Minitab but not Excel).
double Skewness() const {
if (n_ == 0) return 0.0;
const double biased = SampleSkewness();
const double r = (n_ - 1.0) / n_;
return biased * std::pow(r, 1.5);
}
// Near zero for normal distributions; smaller values indicate fewer/smaller
// outliers and larger indicates more/larger outliers. Assumes n_ is large.
double SampleKurtosis() const {
if (std::abs(m2_) < 1E-7) return 0.0;
return m4_ * n_ / (m2_ * m2_);
}
// Corrected for bias (same as Wikipedia and Minitab but not Excel).
double Kurtosis() const {
if (n_ == 0) return 0.0;
const double biased = SampleKurtosis();
const double r = (n_ - 1.0) / n_;
return biased * r * r;
}
// Central moments, useful for "method of moments"-based parameter estimation
// of a mixture of two Gaussians. Assumes Count() != 0.
double Mu1() const { return m1_; }
double Mu2() const { return m2_ / static_cast<int>(n_); }
double Mu3() const { return m3_ / static_cast<int>(n_); }
double Mu4() const { return m4_ / static_cast<int>(n_); }
// Which statistics to EXCLUDE in ToString
enum {
kNoCount = 1,
kNoMeanSD = 2,
kNoMinMax = 4,
kNoSkewKurt = 8,
kNoGeomean = 16
};
std::string ToString(int exclude = 0) const;
void Reset() {
n_ = 0;
min_ = hwy::HighestValue<float>();
max_ = hwy::LowestValue<float>();
product_ = 1.0;
m1_ = 0.0;
m2_ = 0.0;
m3_ = 0.0;
m4_ = 0.0;
}
private:
int64_t n_; // signed for faster conversion + safe subtraction
float min_;
float max_;
double product_; // for geomean
// Moments
double m1_;
double m2_;
double m3_;
double m4_;
};
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_STATS_H_

57
configs.h Normal file
View File

@ -0,0 +1,57 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Model configurations
#ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
#define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
#include <cstddef>
namespace gcpp {
static constexpr size_t kSeqLen = 7168;
struct ConfigGemma7B {
// NOLINTBEGIN(google3-readability-class-member-naming)
static constexpr int seq_len = kSeqLen;
static constexpr int vocab_size = 256128;
static constexpr int n_layers = 28;
static constexpr int dim_model = 3072;
static constexpr int dim_ffw_hidden = 16 * 3072 / 2; // = 24576
static constexpr int n_heads = 16;
static constexpr int n_kv_heads = 16; // standard MHA, no GQA or MQA
static constexpr int dim_qkv = 256; // query size == key size == value size
static constexpr int top_k = 1;
// NOLINTEND(google3-readability-class-member-naming)
};
struct ConfigGemma2B {
// NOLINTBEGIN(google3-readability-class-member-naming)
static constexpr int seq_len = kSeqLen;
static constexpr int vocab_size = 256128;
static constexpr int n_layers = 18;
static constexpr int dim_model = 2048;
static constexpr int dim_ffw_hidden = 16 * 2048 / 2; // = 16384
static constexpr int n_heads = 8;
static constexpr int n_kv_heads = 8; // TODO(austinvhuang): add MQA support
static constexpr int dim_qkv = 256; // query size == key size == value size
static constexpr int top_k = 1;
// NOLINTEND(google3-readability-class-member-naming)
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_

32
docs/CONTRIBUTING.md Normal file
View File

@ -0,0 +1,32 @@
# How to Contribute
We would love to accept your patches and contributions to this project.
## Before you begin
### Sign our Contributor License Agreement
Contributions to this project must be accompanied by a
[Contributor License Agreement](https://cla.developers.google.com/about) (CLA).
You (or your employer) retain the copyright to your contribution; this simply
gives us permission to use and redistribute your contributions as part of the
project.
If you or your current employer have already signed the Google CLA (even if it
was for a different project), you probably don't need to do it again.
Visit <https://cla.developers.google.com/> to see your current agreements or to
sign a new one.
### Review our Community Guidelines
This project follows [Google's Open Source Community
Guidelines](https://opensource.google/conduct/).
## Contribution process
### Code Reviews
All submissions, including submissions by project members, require review. We
use [GitHub pull requests](https://docs.github.com/articles/about-pull-requests)
for this purpose.

811
gemma.cc Normal file
View File

@ -0,0 +1,811 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Lightweight C++ implementation of the gemma model.
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep
// Must come after foreach_target.h to avoid redefinition errors.
// copybara:import_next_line:gemma_cpp
#include "compression/compress-inl.h"
// copybara:import_next_line:gemma_cpp
#include "ops.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // Path
#include "hwy/contrib/matvec/matvec-inl.h"
#include "hwy/highway.h"
#include "hwy/profiler.h"
#include "hwy/timer.h"
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
// compile pass, whereas we want this defined in the first.
#ifndef GEMMA_ONCE
#define GEMMA_ONCE
#include <stddef.h>
#include <stdio.h>
#include <algorithm>
#include <array>
#include <cmath>
#include <cstdlib>
#include <filesystem> // NOLINT
#include <iostream>
#include <memory>
#include <random>
#include <string>
#include <vector>
// copybara:import_next_line:gemma_cpp
#include "compression/compress.h"
// copybara:import_next_line:gemma_cpp
#include "configs.h"
// copybara:import_next_line:gemma_cpp
#include "gemma.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h"
// #include "third_party/sentencepiece/src/util.h"
namespace gcpp {
template <class TConfig>
struct Layer {
Layer() = default;
// NOLINTBEGIN(google3-readability-class-member-naming)
static constexpr size_t n_heads = TConfig::n_heads;
static constexpr size_t dim_model = TConfig::dim_model;
static constexpr size_t dim_qkv = TConfig::dim_qkv;
static constexpr size_t dim_ffw_hidden = TConfig::dim_ffw_hidden;
static constexpr size_t size_attn_vec_einsum_w =
n_heads * dim_qkv * dim_model;
// 3x for (query, key, value)
static constexpr size_t size_qkv_einsum_w = 3 * n_heads * dim_qkv * dim_model;
// 2x for (gelu gating vector, gated vector)
static constexpr size_t size_gating_einsum_w = 2 * dim_ffw_hidden * dim_model;
static constexpr size_t size_linear_w = dim_model * dim_ffw_hidden;
std::array<float, size_attn_vec_einsum_w> attn_vec_einsum_w;
std::array<float, size_qkv_einsum_w> qkv_einsum_w;
std::array<float, size_gating_einsum_w> gating_einsum_w;
std::array<float, size_linear_w> linear_w;
std::array<float, dim_model> pre_attention_norm_scale;
std::array<float, dim_model> pre_ffw_norm_scale;
// NOLINTEND(google3-readability-class-member-naming)
};
template <class TConfig>
struct Weights {
Weights() = default;
hwy::AlignedUniquePtr<Layer<TConfig>[]> layers; // n_layers
std::array<float, TConfig::vocab_size * TConfig::dim_model>
embedder_input_embedding;
std::array<float, TConfig::dim_model> final_norm_scale;
};
// Only called if cached loading fails.
template <typename TConfig>
hwy::AlignedUniquePtr<Weights<TConfig>> LoadWeights(const Path& checkpoint) {
PROFILER_ZONE("Startup.LoadWeights");
using TWeights = Weights<TConfig>;
hwy::AlignedUniquePtr<TWeights> weights = hwy::MakeUniqueAligned<TWeights>();
weights->layers =
hwy::MakeUniqueAlignedArray<Layer<TConfig>>(TConfig::n_layers);
FILE* fptr;
fptr = fopen(checkpoint.path.c_str(), "rb");
if (fptr == nullptr) {
HWY_ABORT("Failed to open model file %s - does it exist?",
checkpoint.path.c_str());
}
bool ok = true;
ok &= 1 == fread(&(weights->embedder_input_embedding),
sizeof(weights->embedder_input_embedding), 1, fptr);
ok &= 1 == fread(&(weights->final_norm_scale),
sizeof(weights->final_norm_scale), 1, fptr);
for (size_t layer = 0; layer < TConfig::n_layers; ++layer) {
Layer<TConfig>* layer_view = &weights->layers[layer];
ok &= 1 == fread(&layer_view->attn_vec_einsum_w,
sizeof(layer_view->attn_vec_einsum_w), 1, fptr);
ok &= 1 == fread(&layer_view->qkv_einsum_w,
sizeof(layer_view->qkv_einsum_w), 1, fptr);
ok &= 1 == fread(&layer_view->gating_einsum_w,
sizeof(layer_view->gating_einsum_w), 1, fptr);
ok &= 1 ==
fread(&layer_view->linear_w, sizeof(layer_view->linear_w), 1, fptr);
ok &= 1 == fread(&layer_view->pre_attention_norm_scale,
sizeof(layer_view->pre_attention_norm_scale), 1, fptr);
ok &= 1 == fread(&layer_view->pre_ffw_norm_scale,
sizeof(layer_view->pre_ffw_norm_scale), 1, fptr);
}
if (!ok) {
HWY_ABORT("Failed to read from %s - might be a directory, or too small?",
checkpoint.path.c_str());
}
HWY_ASSERT(0 == fclose(fptr));
return weights;
}
template <class TConfig>
struct CompressedLayer {
// No ctor/dtor, allocated via AllocateAligned.
using TLayer = gcpp::Layer<TConfig>;
// # NOLINTBEGIN(google3-readability-class-member-naming)
static constexpr size_t dim_model = TConfig::dim_model;
static constexpr size_t dim_ffw_hidden = TConfig::dim_ffw_hidden;
// NOLINTEND(google3-readability-class-member-naming)
// Compressed Parameters
// We don't yet have an RMSNorm that accepts all WeightT.
CompressedArray<hwy::bfloat16_t, dim_model> c_pre_attention_norm_scale;
CompressedArray<hwy::bfloat16_t, dim_model> c_pre_ffw_norm_scale;
CompressedArray<WeightT, TLayer::size_gating_einsum_w> c_gating_einsum_w;
CompressedArray<WeightT, dim_model * dim_ffw_hidden> c_linear_w;
CompressedArray<WeightT, TLayer::size_qkv_einsum_w> c_qkv_einsum_w;
CompressedArray<WeightT, TLayer::size_attn_vec_einsum_w> c_attn_vec_einsum_w;
};
// Array instead of single large allocation for parallel mem init. Split out of
// CompressedWeights so that only these pointers are initialized, not the
// CompressedArray.
template <class TConfig>
struct CompressedLayerPointers {
explicit CompressedLayerPointers(hwy::ThreadPool& pool) {
pool.Run(0, TConfig::n_layers, [this](uint64_t task, size_t /*thread*/) {
this->c_layers[task] = hwy::AllocateAligned<CompressedLayer<TConfig>>(1);
});
}
using CLayer = CompressedLayer<TConfig>;
std::array<hwy::AlignedFreeUniquePtr<CLayer[]>, TConfig::n_layers> c_layers;
};
template <class TConfig>
struct CompressedWeights {
// No ctor/dtor, allocated via AllocateAligned.
CompressedArray<EmbedderInputT, TConfig::vocab_size * TConfig::dim_model>
c_embedder_input_embedding;
CompressedArray<hwy::bfloat16_t, TConfig::dim_model> c_final_norm_scale;
// Must be last so that the other arrays remain aligned.
CompressedLayerPointers<TConfig> c_layer_ptrs;
const CompressedLayer<TConfig>* CLayer(size_t layer) const {
return c_layer_ptrs.c_layers[layer].get();
}
CompressedLayer<TConfig>* CLayer(size_t layer) {
return c_layer_ptrs.c_layers[layer].get();
}
};
// Aligned.
template <class TConfig, size_t BatchSize>
struct Activations {
// # NOLINTBEGIN(google3-readability-class-member-naming)
static constexpr size_t batch_size = BatchSize;
using LayerConfig = Layer<TConfig>;
static constexpr size_t dim_model = TConfig::dim_model;
static constexpr size_t dim_qkv = TConfig::dim_qkv;
static constexpr size_t n_heads = TConfig::n_heads;
static constexpr size_t n_kv_heads = TConfig::n_kv_heads;
static constexpr size_t size_cache_pos =
TConfig::n_layers * n_kv_heads * dim_qkv;
static constexpr size_t size_cache_layer = n_kv_heads * dim_qkv;
// NOLINTEND(google3-readability-class-member-naming)
std::array<float, batch_size * dim_model> x; // input
std::array<float, batch_size * dim_model> pre_att_rms_out;
std::array<float, batch_size * n_heads * dim_qkv> q; // query vector
std::array<float, batch_size * n_heads * TConfig::seq_len>
att; // attention vector
std::array<float, batch_size * n_heads * dim_qkv>
att_out; // attention output
std::array<float, n_heads * batch_size * dim_model>
att_post1; // attention output after linear transformation, per head
std::array<float, batch_size * dim_model>
att_post2; // accumulation of attention outputs over heads
std::array<hwy::bfloat16_t, batch_size * dim_model> bf_pre_ffw_rms_out;
std::array<float, batch_size * TConfig::dim_ffw_hidden * 2> ffw_hidden;
// bf_ version can't be used until GeluMulToBF16 issue in FFW() is resolved.
// std::array<hwy::bfloat16_t, batch_size * 2 * TConfig::dim_ffw_hidden>
// bf_ffw_hidden;
std::array<float, batch_size * dim_model> ffw_out;
std::array<float, batch_size * TConfig::vocab_size> logits;
};
// GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we
// define an abstract base class.
struct GemmaInterface {
virtual ~GemmaInterface() = default;
virtual const sentencepiece::SentencePieceProcessor& Tokenizer() const = 0;
// TODO: group pool/callbacks into struct
virtual void Generate(const InferenceArgs& args,
const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) = 0;
};
template <class Config>
struct GemmaImpl : public GemmaInterface {
GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool);
~GemmaImpl() {
using CWeights = CompressedWeights<Config>;
CWeights* c_weights = reinterpret_cast<CWeights*>(compressed_weights.get());
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
}
const sentencepiece::SentencePieceProcessor& Tokenizer() const {
return tokenizer;
}
void Generate(const InferenceArgs& args, const std::vector<int>& prompt,
size_t start_pos, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937&, int verbosity);
sentencepiece::SentencePieceProcessor tokenizer;
// CompressedWeights<Config>
hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights;
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
hwy::AlignedUniquePtr<Activations<Config, 1>> state;
KVCache kv_cache;
};
} // namespace gcpp
#endif // GEMMA_ONCE
// SIMD code, compiled once per target.
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
template <class TConfig, size_t batch_size>
HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
Activations<TConfig, batch_size>& activations,
const CompressedLayer<TConfig>* c_layer,
KVCache& kv_cache, hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Attention");
const size_t pos = batch_start + batch_idx;
HWY_DASSERT(batch_idx < batch_size);
static constexpr size_t dim_qkv = gcpp::Activations<TConfig, 1>::dim_qkv;
static constexpr size_t size_cache_pos =
gcpp::Activations<TConfig, batch_size>::size_cache_pos;
static constexpr size_t size_cache_layer =
gcpp::Activations<TConfig, batch_size>::size_cache_layer;
static constexpr size_t dim_model =
gcpp::Activations<TConfig, batch_size>::dim_model;
static constexpr size_t n_heads = TConfig::n_heads;
const float kQueryScale = 1.0 / sqrtf(static_cast<float>(dim_qkv));
pool.Run(0, n_heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
// linear projections to QKV
const size_t head_offset =
3 * dim_qkv * dim_model; // 3x for QKV dimensions
const size_t q_offset = head * head_offset + 0 * dim_qkv * dim_model;
const size_t k_offset = head * head_offset + 1 * dim_qkv * dim_model;
const size_t v_offset = head * head_offset + 2 * dim_qkv * dim_model;
float* HWY_RESTRICT q =
activations.q.data() + head * dim_qkv + batch_idx * n_heads * dim_qkv;
const size_t batch_offset = batch_idx * dim_model;
MatVecLoop<dim_qkv, dim_model>(
c_layer->c_qkv_einsum_w, q_offset,
activations.pre_att_rms_out.data() + batch_offset, q);
const size_t kv_offset =
pos * size_cache_pos + layer * size_cache_layer + head * dim_qkv;
TwoOfsMatVecLoop<dim_qkv, dim_model>(
c_layer->c_qkv_einsum_w, k_offset, v_offset,
activations.pre_att_rms_out.data() + batch_offset,
kv_cache.key_cache.get() + kv_offset,
kv_cache.value_cache.get() + kv_offset);
// Calculate scores
float* HWY_RESTRICT head_att = activations.att.data() +
head * TConfig::seq_len +
batch_idx * n_heads * dim_qkv;
Rope(q, dim_qkv, pos);
Rope(kv_cache.key_cache.get() + kv_offset, dim_qkv, pos);
MulByConst(kQueryScale, q, dim_qkv);
// Compute Q dot K scores
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t cache_offset =
pos2 * size_cache_pos + layer * size_cache_layer + head * dim_qkv;
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset;
const float score = Dot(q, k2, dim_qkv);
head_att[pos2] = score;
}
Softmax(head_att, pos + 1);
// Weighted summation
float* HWY_RESTRICT att_out = activations.att_out.data() + head * dim_qkv +
batch_idx * n_heads * dim_qkv;
hwy::ZeroBytes(att_out, dim_qkv * sizeof(*att_out));
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t cache_offset =
pos2 * size_cache_pos + layer * size_cache_layer + head * dim_qkv;
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset;
MulByConstAndAdd(head_att[pos2], v2, att_out, dim_qkv);
}
// linear projection from dim_qkv back to dim_model, sum projections
// across heads
float* HWY_RESTRICT head_out =
head == 0
? activations.att_post2.data() + batch_idx * dim_model
: activations.att_post1.data() + head * batch_size * dim_model;
MatVecLoop<dim_model, dim_qkv>(c_layer->c_attn_vec_einsum_w,
head * dim_model * dim_qkv, att_out,
head_out);
});
// accumulate output across all heads into att_post2. head 0 already wrote
// directly to att_post2.
for (size_t head = 1; head < n_heads; ++head) {
AddFrom(activations.att_post1.data() + head * batch_size * dim_model,
activations.att_post2.data() + batch_idx * dim_model, dim_model);
}
}
template <typename TConfig, size_t batch_size>
HWY_NOINLINE void FFW(Activations<TConfig, batch_size>& activations,
size_t batch_idx, const CompressedLayer<TConfig>* c_layer,
hwy::ThreadPool& pool) {
HWY_DASSERT(batch_idx < batch_size);
static constexpr size_t dim_model = TConfig::dim_model;
static constexpr size_t dim_ffw_hidden = TConfig::dim_ffw_hidden;
const size_t hidden_offset = batch_idx * dim_ffw_hidden * 2;
{
PROFILER_ZONE("Gen.FFW.GatedGELU");
const hwy::bfloat16_t* HWY_RESTRICT vec =
activations.bf_pre_ffw_rms_out.data() + batch_idx * dim_model;
float* HWY_RESTRICT out = activations.ffw_hidden.data() + hidden_offset;
float* HWY_RESTRICT out_mul = out + dim_ffw_hidden;
// Same matrix, first and second half of rows. Could fuse into one MatVec,
// but separating them could help on NUMA e.g. multiple sockets.
MatVec<dim_ffw_hidden, dim_model>(c_layer->c_gating_einsum_w,
dim_ffw_hidden * dim_model, vec, out_mul,
pool);
// Gate, will go through the nonlinearity.
MatVec<dim_ffw_hidden, dim_model>(c_layer->c_gating_einsum_w, 0, vec, out,
pool);
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
hn::Transform1(DF(), out, dim_ffw_hidden, out_mul,
[](DF df, VF v, VF mul)
HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); });
}
PROFILER_ZONE("Gen.FFW\\GatedGELU");
MatVec<dim_model, dim_ffw_hidden>(
c_layer->c_linear_w, 0, activations.ffw_hidden.data() + hidden_offset,
activations.ffw_out.data() + batch_idx * dim_model, pool);
}
template <typename TConfig, size_t batch_size>
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
const CompressedWeights<TConfig>& c_weights,
Activations<TConfig, batch_size>& activations,
KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool) {
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
static constexpr size_t dim_model = TConfig::dim_model;
static const float kEmbScaling = sqrtf(static_cast<float>(dim_model));
pool.Run(
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
const int token = tokens[token_idx];
Decompress(c_weights.c_embedder_input_embedding, token * dim_model,
activations.x.data() + token_idx * dim_model, dim_model);
MulByConst(kEmbScaling, activations.x.data() + token_idx * dim_model,
dim_model);
});
for (size_t layer = 0; layer < TConfig::n_layers; ++layer) {
const CompressedLayer<TConfig>* c_layer = c_weights.CLayer(layer);
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
RMSNorm(activations.x.data() + token_idx * dim_model,
c_layer->c_pre_attention_norm_scale.data(),
activations.pre_att_rms_out.data() + token_idx * dim_model,
dim_model);
Attention<TConfig, batch_size>(pos, token_idx, layer, activations,
c_layer, kv_cache, pool);
}
// TODO: sink the loop into these functions, i.e. make them matmuls.
pool.Run(
0, num_tokens,
[&](const uint64_t token_idx, size_t thread_id) HWY_ATTR {
AddFrom(activations.att_post2.data() + token_idx * dim_model,
activations.x.data() + token_idx * dim_model, dim_model);
RMSNorm(activations.x.data() + token_idx * dim_model,
c_layer->c_pre_ffw_norm_scale.data(),
activations.bf_pre_ffw_rms_out.data() + token_idx * dim_model,
dim_model);
FFW<TConfig, batch_size>(activations, token_idx, c_layer, inner_pool);
AddFrom(activations.ffw_out.data() + token_idx * dim_model,
activations.x.data() + token_idx * dim_model, dim_model);
});
} // foreach layer
pool.Run(
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
RMSNormInplace(c_weights.c_final_norm_scale.data(),
activations.x.data() + token_idx * dim_model, dim_model);
});
}
// n = 1 specialization
template <class TConfig>
void Transformer(int token, size_t pos,
const CompressedWeights<TConfig>& c_weights,
Activations<TConfig, 1>& activations, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool) {
static constexpr size_t n_layers = TConfig::n_layers;
static constexpr size_t dim_model = TConfig::dim_model;
static const float kEmbScaling = sqrtf(static_cast<float>(dim_model));
Decompress(c_weights.c_embedder_input_embedding, token * dim_model,
activations.x.data(), dim_model);
MulByConst(kEmbScaling, activations.x.data(), dim_model);
for (size_t layer = 0; layer < n_layers; ++layer) {
const CompressedLayer<TConfig>* c_layer = c_weights.CLayer(layer);
RMSNorm(activations.x.data(), c_layer->c_pre_attention_norm_scale.data(),
activations.pre_att_rms_out.data(), dim_model);
Attention<TConfig, 1>(pos, 0, layer, activations, c_layer, kv_cache, pool);
AddFrom(activations.att_post2.data(), activations.x.data(), dim_model);
RMSNorm(activations.x.data(), c_layer->c_pre_ffw_norm_scale.data(),
activations.bf_pre_ffw_rms_out.data(), dim_model);
FFW<TConfig, 1>(activations, /* batch_idx = */ 0, c_layer, pool);
AddFrom(activations.ffw_out.data(), activations.x.data(), dim_model);
}
RMSNormInplace(c_weights.c_final_norm_scale.data(), activations.x.data(),
dim_model);
}
template <class TConfig>
void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
const std::vector<int>& prompt, size_t pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) {
static constexpr size_t dim_model = TConfig::dim_model;
static constexpr size_t vocab_size = TConfig::vocab_size;
static constexpr size_t top_k = TConfig::top_k;
Activations<TConfig, 1>& activations = *gemma.state.get();
Activations<TConfig, kPrefillBatchSize>& prefill_activations =
*gemma.prefill.get();
const CompressedWeights<TConfig>& c_weights =
*reinterpret_cast<CompressedWeights<TConfig>*>(
gemma.compressed_weights.get());
KVCache& kv_cache = gemma.kv_cache;
int token;
// pos indexes the KV cache. In the first turn of a chat, pos = 0.
//
// After the first turn, pos gets passed in with > 0 corresponding to the
// current token position in the KV cache.
//
// pos_offset keeps track of the relative position within the turn, starting
// at 0 each turn. During prefill, pos_offset corresponds to the index into
// the prompt vector.
//
// In single-turn (non-chat) usage, pos and pos_offset start at 0 and are
// always equal.
size_t pos_offset = 0; // offset relative to pos
double prefill_start = hwy::platform::Now();
// Prefill stops before prompt.size() - 1 since the last prompt token is the
// first input token for generation.
while (pos_offset < prompt.size() - 1) {
const size_t end_offset =
std::min(kPrefillBatchSize, prompt.size() - 1 - pos_offset);
HWY_DASSERT(end_offset < prompt.size());
const int* batch_tokens = prompt.data() + pos_offset;
Prefill<TConfig, kPrefillBatchSize>(batch_tokens, end_offset, pos,
c_weights, prefill_activations,
kv_cache, pool, inner_pool);
for (size_t idx = 0; idx < end_offset; ++idx) {
stream_token(batch_tokens[idx], 0.0);
}
pos += end_offset;
pos_offset += end_offset;
}
if (verbosity >= 2) {
// in the future this output should not occur in GenerateImpl but instead
// should be available as observable state for frontend code to handle I/O.
double prefill_end = hwy::platform::Now();
const double prefill_tok_sec = pos_offset / (prefill_end - prefill_start);
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]\n";
}
double gen_start = hwy::platform::Now();
HWY_DASSERT(pos_offset == prompt.size() - 1);
if (verbosity >= 2) {
// Provide usage warnings if max_new_tokens is out of range.
if (args.max_generated_tokens > args.max_tokens) {
std::cout << "Warning: max_new_tokens should be <= max_tokens"
<< std::endl;
} else if ((prompt.size() + args.max_generated_tokens) > args.max_tokens) {
std::cout << "Warning: Prompt size + max_new_tokens exceeds max_tokens."
<< std::endl;
}
}
auto pos_gen_start = pos_offset;
token = prompt.at(pos_offset);
size_t generate_pos = 0;
for (; pos < args.max_tokens && generate_pos < args.max_generated_tokens;
++pos, ++pos_offset, ++generate_pos) {
Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool);
float* final_activation = activations.x.data();
if (pos_offset >= prompt.size()) {
PROFILER_ZONE("Gen.Embedding");
// Generation phase
MatVec<vocab_size, dim_model>(c_weights.c_embedder_input_embedding, 0,
final_activation, activations.logits.data(),
pool);
// Barrier: must have all logits so we can subtract max.
Softmax(activations.logits.data(), vocab_size);
token = SampleTopK<top_k>(activations.logits.data(), vocab_size, gen,
args.temperature, accept_token);
}
if (!stream_token(token, activations.logits[token])) {
token = EOS_ID;
}
if (token == EOS_ID) {
if (verbosity >= 2) {
double gen_end = hwy::platform::Now();
const double gen_tok_sec =
(pos_offset - pos_gen_start) / (gen_end - gen_start);
std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
}
break;
}
}
}
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, const InferenceArgs& args,
const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) {
GenerateImpl(gemma, args, prompt, start_pos, pool, inner_pool, stream_token,
accept_token, gen, verbosity);
}
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, const InferenceArgs& args,
const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) {
GenerateImpl(gemma, args, prompt, start_pos, pool, inner_pool, stream_token,
accept_token, gen, verbosity);
}
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
// if weights = null, which happens during the first call where we attempt to
// load from cache.
//
// This avoids repeating the list of tensors between loading and compressing.
template <class TConfig, class Func>
void ForEachTensor(const Weights<TConfig>* weights,
CompressedWeights<TConfig>& c_weights, Func& func) {
func("c_embedding",
weights ? weights->embedder_input_embedding.data() : nullptr,
c_weights.c_embedder_input_embedding);
func("c_final_norm", weights ? weights->final_norm_scale.data() : nullptr,
c_weights.c_final_norm_scale);
char name[16];
for (size_t layer_idx = 0; layer_idx < TConfig::n_layers; ++layer_idx) {
Layer<TConfig>* layer = weights ? &weights->layers[layer_idx] : nullptr;
CompressedLayer<TConfig>* c_layer = c_weights.CLayer(layer_idx);
snprintf(name, sizeof(name), "pre_ff_ns_%lu", layer_idx);
func(name, layer ? layer->pre_ffw_norm_scale.data() : nullptr,
c_layer->c_pre_ffw_norm_scale);
snprintf(name, sizeof(name), "gating_ein_%lu", layer_idx);
func(name, layer ? layer->gating_einsum_w.data() : nullptr,
c_layer->c_gating_einsum_w);
snprintf(name, sizeof(name), "linear_w_%lu", layer_idx);
func(name, layer ? layer->linear_w.data() : nullptr, c_layer->c_linear_w);
snprintf(name, sizeof(name), "qkv_ein_%lu", layer_idx);
func(name, layer ? layer->qkv_einsum_w.data() : nullptr,
c_layer->c_qkv_einsum_w);
snprintf(name, sizeof(name), "att_ein_%lu", layer_idx);
func(name, layer ? layer->attn_vec_einsum_w.data() : nullptr,
c_layer->c_attn_vec_einsum_w);
snprintf(name, sizeof(name), "pre_att_ns_%lu", layer_idx);
func(name, layer ? layer->pre_attention_norm_scale.data() : nullptr,
c_layer->c_pre_attention_norm_scale);
}
}
template <class TConfig>
hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
const Path& model, const Path& cache, hwy::ThreadPool& pool) {
PROFILER_ZONE("Startup.LoadCache");
if (!std::filesystem::exists(model.path) &&
!std::filesystem::exists(cache.path)) {
HWY_ABORT(
"Either the model weights (--weights) or cached compressed weights "
"(--compressed_weights) must exist.");
}
// Allocate compressed weights.
using CWeights = CompressedWeights<TConfig>;
hwy::AlignedFreeUniquePtr<uint8_t[]> c_weights_u8 =
hwy::AllocateAligned<uint8_t>(sizeof(CWeights));
CWeights* c_weights = reinterpret_cast<CWeights*>(c_weights_u8.get());
new (&c_weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
// First attempt to load them from cache, without requiring weights.
CacheLoader loader(cache.path.c_str());
ForEachTensor<TConfig>(nullptr, *c_weights, loader);
if (loader.ReadAll(pool)) return c_weights_u8;
// Get weights, compress, and store in cache.
hwy::AlignedUniquePtr<Weights<TConfig>> weights = LoadWeights<TConfig>(model);
Compressor compressor(pool);
ForEachTensor<TConfig>(weights.get(), *c_weights, compressor);
compressor.WriteAll(pool, cache.path.c_str());
return c_weights_u8;
}
// Type-erased because this function is called via a function pointer.
hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeightsT(
const LoaderArgs& args, hwy::ThreadPool& pool) {
switch (args.ModelType()) {
case Model::GEMMA_2B:
return GetCompressedWeights<ConfigGemma2B>(args.model, args.cache, pool);
case Model::GEMMA_7B:
return GetCompressedWeights<ConfigGemma7B>(args.model, args.cache, pool);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(args.ModelType()));
}
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(GetCompressedWeightsT);
HWY_EXPORT(Generate2B);
HWY_EXPORT(Generate7B);
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) {
KVCache kv_cache = {};
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
kv_cache.value_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
return kv_cache;
}
template <class Config>
GemmaImpl<Config>::GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool)
: compressed_weights(
HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)),
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()),
kv_cache(
CreateKVCache(Config::n_layers * Config::n_kv_heads * Config::dim_qkv,
Config::seq_len)) {
PROFILER_ZONE("Startup.tokenizer");
HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok());
}
template <>
void GemmaImpl<ConfigGemma2B>::Generate(const InferenceArgs& args,
const std::vector<int>& prompt,
size_t start_pos, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) {
HWY_DYNAMIC_DISPATCH(Generate2B)
(*this, args, prompt, start_pos, pool, inner_pool, stream_token, accept_token,
gen, verbosity);
}
template <>
void GemmaImpl<ConfigGemma7B>::Generate(const InferenceArgs& args,
const std::vector<int>& prompt,
size_t start_pos, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) {
HWY_DYNAMIC_DISPATCH(Generate7B)
(*this, args, prompt, start_pos, pool, inner_pool, stream_token, accept_token,
gen, verbosity);
}
Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
const Model model_type = args.ModelType();
model_training = args.ModelTraining();
switch (model_type) {
case Model::GEMMA_2B:
impl_.reset(new GemmaImpl<ConfigGemma2B>(args, pool));
break;
case Model::GEMMA_7B:
impl_.reset(new GemmaImpl<ConfigGemma7B>(args, pool));
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));
}
}
Gemma::~Gemma() = default; // after GemmaInterface is defined
const sentencepiece::SentencePieceProcessor& Gemma::Tokenizer() const {
return impl_->Tokenizer();
}
void GenerateGemma(Gemma& gemma, const InferenceArgs& args,
const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) {
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
gemma.impl_->Generate(args, prompt, start_pos, pool, inner_pool, stream_token,
accept_token, gen, verbosity);
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
}
} // namespace gcpp
#endif // HWY_ONCE

207
gemma.h Normal file
View File

@ -0,0 +1,207 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_H_
#include <algorithm>
#include <cctype>
#include <functional>
#include <memory>
#include <random>
#include <string>
#include <vector>
// copybara:import_next_line:gemma_cpp
#include "configs.h" // kSeqLen
// copybara:import_next_line:gemma_cpp
#include "compression/compress.h" // SfpStream/NuqStream
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // ArgsBase
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::bfloat16_t
#include "hwy/contrib/thread_pool/thread_pool.h"
// copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h"
namespace gcpp {
// Allowable types for GEMMA_WEIGHT_T (can be specified at compilation time):
// float, hwy::bfloat16_t, SfpStream, NuqStream
#ifndef GEMMA_WEIGHT_T
#define GEMMA_WEIGHT_T SfpStream
#endif // !GEMMA_WEIGHT_T
using WeightT = GEMMA_WEIGHT_T;
using EmbedderInputT = hwy::bfloat16_t;
constexpr size_t kPrefillBatchSize = 16;
constexpr bool kSystemPrompt = false;
struct KVCache {
hwy::AlignedFreeUniquePtr<float[]>
key_cache; // batch_size * seq_len * n_layers * n_kv_heads * dim_qkv
hwy::AlignedFreeUniquePtr<float[]>
value_cache; // batch_size * seq_len * n_layers * n_kv_heads * dim_qkv
};
// Model variants: see configs.h for details.
enum class Model { GEMMA_2B, GEMMA_7B };
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
static std::string ToLower(const std::string& text) {
std::string result = text;
std::transform(begin(result), end(result), begin(result),
[](unsigned char c) { return std::tolower(c); });
return result;
}
gcpp::Model ModelType() const {
const std::string model_type_lc = ToLower(model_type);
if (model_type_lc == "2b-pt" || model_type_lc == "2b-it") {
return gcpp::Model::GEMMA_2B;
} else {
return gcpp::Model::GEMMA_7B;
}
}
gcpp::ModelTraining ModelTraining() const {
const std::string model_type_lc = ToLower(model_type);
if (model_type_lc == "7b-pt" || model_type_lc == "2b-pt") {
return gcpp::ModelTraining::GEMMA_PT;
} else {
return gcpp::ModelTraining::GEMMA_IT;
}
}
// Returns error string or nullptr if OK.
const char* Validate() const {
const std::string model_type_lc = ToLower(model_type);
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
return "Model type must be 2b-pt, 7b-pt, 2b-it, or "
"7b-it.";
}
if (tokenizer.path.empty()) {
return "Missing --tokenizer flag, a file for the tokenizer is required.";
}
if (model_type.empty()) {
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
"2b-it, or 7b-it.";
}
if (cache.path.empty()) {
return "Missing --compressed_weights flag, a file for the compressed "
"model.";
}
return nullptr;
}
Path tokenizer;
Path model; // uncompressed weights OR
Path cache; // compressed weights
std::string model_type;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(tokenizer, "tokenizer", Path(),
"Path name of tokenizer model file. (required)");
visitor(
cache, "compressed_weights", Path(),
"Path name of compressed weights file, regenerated from `--weights` "
"file if "
"the compressed weights file does not exist. (required)");
visitor(model_type, "model", std::string(),
"Model type - can be 2b-it (2B parameters, instruction-tuned), "
"2b-pt (2B parameters, pretrained), 7b-it (7B parameters, "
"instruction-tuned), or 7b-pt (7B parameters, pretrained). "
"(required)");
visitor(model, "weights", Path(),
"Path name of model weights (.sbs) file. Only required if "
"compressed_weights file is not present and needs to be "
"regenerated. Otherwise, not needed");
}
};
struct GemmaInterface;
struct Gemma {
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
const sentencepiece::SentencePieceProcessor& Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_;
gcpp::ModelTraining model_training;
};
// StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f.
using StreamFunc = std::function<bool(int, float)>;
using AcceptFunc = std::function<bool(int)>;
struct InferenceArgs : public ArgsBase<InferenceArgs> {
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
size_t max_tokens;
size_t max_generated_tokens;
float temperature;
bool deterministic;
bool multiturn;
// Returns error string or nullptr if OK.
const char* Validate() const {
if (max_tokens > gcpp::kSeqLen) {
return "max_tokens is larger than the maximum sequence length (see "
"configs.h).";
}
if (max_generated_tokens > max_tokens) {
return "Maximum number of generated tokens is larger than the maximum "
"total tokens.";
}
return nullptr;
}
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(max_tokens, "max_tokens", size_t{3072},
"Maximum number of tokens in prompt + generation.");
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
"Maximum number of tokens to generate.");
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
visitor(deterministic, "deterministic", false,
"Make top-k sampling deterministic", 2);
visitor(multiturn, "multiturn", true,
"Multiturn mode (if 0, this clears the KV cache after every "
"interaction without quitting)",
2);
}
};
void GenerateGemma(Gemma& gemma, const InferenceArgs& args,
const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& g,
int verbosity);
constexpr int EOS_ID = 1;
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_

682
ops.h Normal file
View File

@ -0,0 +1,682 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Include guard for non-SIMD code.
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_H_
#define THIRD_PARTY_GEMMA_CPP_OPS_H_
#include <stddef.h>
#include <stdint.h>
#include <array>
#include <cmath>
#include <random>
// copybara:import_next_line:gemma_cpp
#include "compression/compress.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
#endif // THIRD_PARTY_GEMMA_CPP_OPS_H_
// Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE
#endif
// copybara:import_next_line:gemma_cpp
#include "compression/compress-inl.h"
#include "hwy/cache_control.h" // FlushStream
#include "hwy/contrib/algo/transform-inl.h"
#include "hwy/contrib/dot/dot-inl.h"
#include "hwy/contrib/math/math-inl.h"
#include "hwy/contrib/matvec/matvec-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
HWY_INLINE constexpr size_t MaxCols() {
// Vec + mat rows should fit into 32 KiB L1.
return 2048;
}
template <size_t kOuter>
HWY_INLINE constexpr size_t RowsPerStrip() {
// Aim for 128 work items to reduce pool overhead. Must be at least one
// vector; prefer a power of two for faster division.
constexpr size_t kRowsPerStrip =
HWY_MAX(hn::ScalableTag<float>().MaxLanes(),
1ULL << hwy::FloorLog2(kOuter / 128));
return kRowsPerStrip;
}
// Simple version without tiling nor threading.
template <size_t kOuter, size_t kInner, typename MatT, size_t kCapacity,
typename VecT>
HWY_INLINE void MatVecLoop(const CompressedArray<MatT, kCapacity>& mat,
const size_t mat_ofs,
const VecT* HWY_RESTRICT vec_aligned,
float* HWY_RESTRICT out) {
PROFILER_ZONE("MatVecLoop");
const hn::ScalableTag<float> df;
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) {
const size_t row_ofs = mat_ofs + idx_row * kInner;
out[idx_row] = Dot(df, mat, row_ofs, vec_aligned, kInner);
}
}
// Simple version without tiling nor threading, but two offsets/outputs.
template <size_t kOuter, size_t kInner, typename MatT, size_t kCapacity,
typename VecT>
HWY_INLINE void TwoOfsMatVecLoop(const CompressedArray<MatT, kCapacity>& mat,
const size_t mat_ofs0, const size_t mat_ofs1,
const VecT* HWY_RESTRICT vec_aligned,
float* HWY_RESTRICT out0,
float* HWY_RESTRICT out1) {
PROFILER_ZONE("MatVecLoop");
const hn::ScalableTag<float> df;
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) {
const size_t row_ofs0 = mat_ofs0 + (idx_row)*kInner;
const size_t row_ofs1 = mat_ofs1 + (idx_row)*kInner;
out0[idx_row] = Dot(df, mat, row_ofs0, vec_aligned, kInner);
out1[idx_row] = Dot(df, mat, row_ofs1, vec_aligned, kInner);
}
}
namespace detail {
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product
// of row i with `vec_aligned` and add into `out[i]`. The upper-left coordinate
// of the tile is r0, c0.
template <class DF, typename MatT, size_t kCapacity, typename VecT>
HWY_INLINE void AccumulatePartialDotProducts(
DF df, const CompressedArray<MatT, kCapacity>& mat, size_t mat_ofs,
size_t mat_stride, size_t r0, size_t c0, size_t num_rows, size_t num_cols,
const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) {
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
out[idx_row] += Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
}
}
// Same as above, but sets out[i] to the first partial dot product, which
// avoids having to zero-initialize and accumulate.
template <class DF, typename MatT, size_t kCapacity, typename VecT>
HWY_INLINE void SetFirstPartialDotProducts(
DF df, const CompressedArray<MatT, kCapacity>& mat, size_t mat_ofs,
size_t mat_stride, size_t r0, size_t c0, size_t num_rows, size_t num_cols,
const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) {
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
out[idx_row] = Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
}
}
// Adds together partial dot products for all tiles with the same r0 (a
// horizontal strip of the entire matrix); the result is the full dot product
// for rows r in [r0, r0 + num_rows), which we store into in out[r - r0].
template <class DF, typename MatT, size_t kCapacity, typename VecT>
HWY_INLINE void FullDotProductsForStrip(
DF df, const CompressedArray<MatT, kCapacity>& mat, size_t mat_ofs,
size_t mat_stride, size_t r0, size_t num_rows,
const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) {
// Tall and skinny: set `out` to the single dot product.
if (mat_stride < MaxCols()) {
SetFirstPartialDotProducts(df, mat, mat_ofs, mat_stride, r0, 0, num_rows,
mat_stride, vec_aligned, out);
return;
}
// We have at least MaxCols, so start by setting `out` to that:
SetFirstPartialDotProducts(df, mat, mat_ofs, mat_stride, r0, 0, num_rows,
MaxCols(), vec_aligned, out);
// For further multiples of MaxCols, accumulate. Remainders handled below.
size_t c0 = MaxCols();
HWY_UNROLL(1)
for (; c0 <= mat_stride - MaxCols(); c0 += MaxCols()) {
AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows,
MaxCols(), vec_aligned, out);
}
if (c0 < mat_stride) { // Final cols
AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows,
mat_stride - c0, vec_aligned, out);
}
}
} // namespace detail
// Stores dot products of rows with `vec_aligned` to a buffer, then stores them
// to `out`.
template <size_t kOuter, size_t kInner, typename MatT, size_t kCapacity,
typename VecT>
HWY_INLINE void MatVec(const CompressedArray<MatT, kCapacity>& mat,
const size_t mat_ofs,
const VecT* HWY_RESTRICT const vec_aligned,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
PROFILER_ZONE("MatVec");
const hn::ScalableTag<float> df;
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
// For each entire strip.
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
PROFILER_ZONE("MatVec.lambda");
const size_t r0 = strip * kRowsPerStrip;
detail::FullDotProductsForStrip(df, mat, mat_ofs, kInner, r0, kRowsPerStrip,
vec_aligned, out + r0);
});
// Remaining rows
const size_t r0 = kNumStrips * kRowsPerStrip;
if (r0 < kOuter) {
PROFILER_ZONE("MatVec remainder");
const size_t num_rows = kOuter - r0;
detail::FullDotProductsForStrip(df, mat, mat_ofs, kInner, r0, num_rows,
vec_aligned, out + r0);
}
}
template <class D, HWY_IF_F32_D(D)>
static HWY_INLINE hn::Vec<D> Gelu(D d, hn::Vec<D> v) {
const hn::Vec<D> kMul = Set(d, 0.044715f);
const hn::Vec<D> kSqrt2OverPi = hn::Set(d, 0.797884560804236f);
const hn::Vec<D> kHalf = Set(d, 0.5f);
// tanh approximation matches training.
const hn::Vec<D> v3 = hn::Mul(hn::Mul(v, v), v);
const hn::Vec<D> arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v));
// 0.5 * (1 + tan) = MulAdd(0.5, tan, 0.5).
const hn::Vec<D> cdf = hn::MulAdd(kHalf, hn::Tanh(d, arg), kHalf);
return Mul(v, cdf);
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x,
size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
hn::Transform(D(), x, size, [](D d, hn::Vec<D> v) { return Gelu(d, v); });
}
// out[i] = BF(mul[i] * Gelu(gelu_in[i]))
static HWY_NOINLINE HWY_MAYBE_UNUSED void GeluMulToBF16(
const float* HWY_RESTRICT gelu_in, const float* HWY_RESTRICT mul,
hwy::bfloat16_t* HWY_RESTRICT out, size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df;
const hn::Repartition<hwy::bfloat16_t, decltype(df)> dbf;
const size_t NF = hn::Lanes(df);
using VF = hn::Vec<decltype(df)>;
size_t i = 0;
if (size >= 2 * NF) {
for (; i < size - 2 * NF; i += 2 * NF) {
const VF mul0 = LoadU(df, mul + i);
const VF mul1 = LoadU(df, mul + i + NF);
const VF g0 = Mul(mul0, Gelu(df, LoadU(df, gelu_in + i)));
const VF g1 = Mul(mul1, Gelu(df, LoadU(df, gelu_in + i + NF)));
const hn::Vec<decltype(dbf)> bf = hn::OrderedDemote2To(dbf, g0, g1);
StoreU(bf, dbf, out + i);
}
}
if (i != size) {
const size_t remaining = size - i;
const VF mul0 = LoadN(df, mul + i, remaining);
const VF g0 = Mul(mul0, Gelu(df, LoadN(df, gelu_in + i, remaining)));
const hn::Half<decltype(dbf)> dbfh;
const hn::Vec<decltype(dbfh)> bfh = hn::DemoteTo(dbfh, g0);
StoreN(bfh, dbfh, out + i, remaining);
}
}
// Two matrices, same vector
// TODO(janwas): apply optimizations from MatVec/replace with above overload
template <size_t kOuter, size_t kInner, typename MatT, size_t kCapacity,
typename VecT>
HWY_NOINLINE void TwoMatVec(const CompressedArray<MatT, kCapacity>& mat0,
const CompressedArray<MatT, kCapacity>& mat1,
const size_t mat_ofs,
const VecT* HWY_RESTRICT vec_aligned,
float* HWY_RESTRICT out0, float* HWY_RESTRICT out1,
hwy::ThreadPool& pool) {
const hn::ScalableTag<float> df;
const size_t NF = hn::Lanes(df);
// Process multiple rows at a time so that we write multiples of a cache line
// to avoid false sharing (>= 64).
constexpr size_t kRowsPerStrip = 128 / sizeof(float);
const uint32_t num_strips = kOuter / kRowsPerStrip;
// No remainder handling after ThreadPool.
static_assert(kOuter % kRowsPerStrip == 0, "Add remainder handling");
// Required for Stream loop, otherwise we might have partial vectors.
HWY_DASSERT(kRowsPerStrip >= NF);
pool.Run(0, num_strips,
[&](const uint32_t strip, size_t /*thread*/) HWY_ATTR {
// MSVC workaround: duplicate to ensure constexpr.
constexpr size_t kRowsPerStrip = 128 / sizeof(float);
// Software write-combining to avoid cache pollution from out.
// Although `out` may be used later, keeping it out of the cache
// now and avoiding RFOs is a consistent 5% overall win.
HWY_ALIGN float buf0[kRowsPerStrip];
HWY_ALIGN float buf1[kRowsPerStrip];
// Only handle entire strips here because the Stream is not masked.
const size_t begin = strip * kRowsPerStrip;
for (size_t idx_row = 0; idx_row < kRowsPerStrip; ++idx_row) {
const size_t row_ofs = mat_ofs + (begin + idx_row) * kInner;
buf0[idx_row] = Dot(df, mat0, row_ofs, vec_aligned, kInner);
buf1[idx_row] = Dot(df, mat1, row_ofs, vec_aligned, kInner);
}
HWY_UNROLL(4)
for (size_t i = 0; i != kRowsPerStrip; i += NF) {
hn::Stream(hn::Load(df, buf0 + i), df, out0 + begin + i);
}
HWY_UNROLL(4)
for (size_t i = 0; i != kRowsPerStrip; i += NF) {
hn::Stream(hn::Load(df, buf1 + i), df, out1 + begin + i);
}
});
hwy::FlushStream();
}
// Baseline Naive MatMul
template <size_t kOuter, size_t kInner, size_t kBatchSize, typename MatT,
size_t kCapacity, typename VecT>
HWY_NOINLINE void MatMul(const CompressedArray<MatT, kCapacity>& mat,
const size_t mat_ofs, const VecT* HWY_RESTRICT vec,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
for (size_t i = 0; i < kBatchSize; ++i) {
MatVec<kOuter, kInner, MatT, kCapacity, VecT>(
mat, mat_ofs, vec + i * kInner, out + i * kOuter, pool);
}
}
static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a,
const float* HWY_RESTRICT b,
size_t size) {
const hn::ScalableTag<float> d;
HWY_DASSERT(size >= hn::Lanes(d));
HWY_DASSERT(size % hn::Lanes(d) == 0);
constexpr int kAssumptions =
hn::Dot::kAtLeastOneVector | hn::Dot::kMultipleOfVector;
return hn::Dot::Compute<kAssumptions>(d, a, b, size);
}
// = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT.
static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2(
const float* HWY_RESTRICT a, size_t size) {
float total = 0.f;
for (size_t i = 0; i < size; ++i) {
total += a[i] * a[i];
}
return total;
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight,
float* HWY_RESTRICT out, size_t size) {
constexpr float eps = 1e-6f;
float ss = SquaredL2(x, size);
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps);
for (size_t j = 0; j < size; j++) {
// Note 1.0f centering here
out[j] = (1.0f + weight[j]) * (ss * x[j]);
}
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
float* HWY_RESTRICT out, size_t size) {
constexpr float eps = 1e-6f;
float ss = SquaredL2(x, size);
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps);
for (size_t j = 0; j < size; j++) {
// Note 1.0f centering here
out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]);
}
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) {
constexpr float eps = 1e-6f;
float ss = SquaredL2(inout, size);
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps);
for (size_t j = 0; j < size; j++) {
// Note 1.0f centering here
inout[j] = (1.0f + weight[j]) * (ss * inout[j]);
}
}
// w=bf16 -> f
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
const hwy::bfloat16_t* HWY_RESTRICT weight, float* HWY_RESTRICT inout,
const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<hwy::bfloat16_t> dbf;
const hn::Repartition<float, decltype(dbf)> df32;
using VF = hn::Vec<decltype(df32)>;
const size_t N32 = hn::Lanes(df32);
constexpr float eps = 1e-6f;
const float ss = SquaredL2(inout, size);
const VF vss = Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) {
const hn::Vec<decltype(dbf)> w16 = hn::LoadU(dbf, weight + i);
const VF w0 = hn::PromoteLowerTo(df32, w16);
const VF w1 = hn::PromoteUpperTo(df32, w16);
const VF m0 = hn::Mul(vss, hn::LoadU(df32, inout + i));
const VF m1 = hn::Mul(vss, hn::LoadU(df32, inout + i + N32));
// (1+weight) * m = m + weight*m = one FMA.
hn::StoreU(hn::MulAdd(m0, w0, m0), df32, inout + i);
hn::StoreU(hn::MulAdd(m1, w1, m1), df32, inout + i + N32);
}
}
// f, f -> bf
// TODO(janwas): consider generic function with adapter for loading bf16/f32
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight,
hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<hwy::bfloat16_t> dbf;
const hn::Repartition<float, decltype(dbf)> df32;
using VF = hn::Vec<decltype(df32)>;
const size_t N32 = hn::Lanes(df32);
constexpr float eps = 1e-6f;
const float ss = SquaredL2(x, size);
const VF vss = Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) {
const VF w0 = hn::LoadU(df32, weight + i);
const VF w1 = hn::LoadU(df32, weight + i + N32);
const VF m0 = hn::Mul(vss, hn::LoadU(df32, x + i));
const VF m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32));
// (1+weight) * m = m + weight*m = one FMA.
const VF out0 = hn::MulAdd(m0, w0, m0);
const VF out1 = hn::MulAdd(m1, w1, m1);
hn::StoreU(hn::OrderedDemote2To(dbf, out0, out1), dbf, out + i);
}
}
// x=f, w=bf16 -> bf16 to enable W16A16 MatVec.
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<hwy::bfloat16_t> dbf;
const hn::Repartition<float, decltype(dbf)> df32;
using VF = hn::Vec<decltype(df32)>;
const size_t N32 = hn::Lanes(df32);
constexpr float eps = 1e-6f;
const float ss = SquaredL2(x, size);
const VF vss = Set(df32, 1.0f / sqrtf(ss / size + eps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) {
const hn::Vec<decltype(dbf)> w16 = hn::LoadU(dbf, weight + i);
const VF w0 = hn::PromoteLowerTo(df32, w16);
const VF w1 = hn::PromoteUpperTo(df32, w16);
const VF m0 = hn::Mul(vss, hn::LoadU(df32, x + i));
const VF m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32));
// (1+weight) * m = m + weight*m = one FMA.
const VF out0 = hn::MulAdd(m0, w0, m0);
const VF out1 = hn::MulAdd(m1, w1, m1);
hn::StoreU(hn::OrderedDemote2To(dbf, out0, out1), dbf, out + i);
}
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
float* HWY_RESTRICT x, size_t dim_model, size_t pos) {
const size_t num_timescales = dim_model / 2;
const float log_timescale_increment =
logf(10000.0f) /
(num_timescales != 0
? static_cast<float>(static_cast<int>(num_timescales) - 1)
: 1.0f);
for (size_t dim = 0; dim < num_timescales; ++dim) {
const float inv_timescale =
expf(static_cast<int>(dim) * -log_timescale_increment);
x[dim] += sinf(pos * inv_timescale);
x[num_timescales + dim] += cosf(pos * inv_timescale);
}
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
size_t dim_qkv, size_t pos) {
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
const float freq_exponents = static_cast<float>(2 * static_cast<int>(dim)) /
static_cast<float>(dim_qkv);
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
const float timescale = powf(10000.0f, freq_exponents);
const float theta = pos / timescale;
const float cos_val = cosf(theta);
const float sin_val = sinf(theta);
const float x0 = x[dim];
const float x1 = x[dim + half_dim_qkv];
x[dim] = x0 * cos_val - x1 * sin_val;
x[dim + half_dim_qkv] = x0 * sin_val + x1 * cos_val;
}
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul,
float* HWY_RESTRICT x,
size_t dim_qkv,
size_t pos) {
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
const float freq_exponents = static_cast<float>(2 * static_cast<int>(dim)) /
static_cast<float>(dim_qkv);
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
const float timescale = powf(10000.0f, freq_exponents);
const float theta = pos / timescale;
const float cos_val = cosf(theta);
const float sin_val = sinf(theta);
const float x0 = x[dim];
const float x1 = x[dim + half_dim_qkv];
x[dim] = mul * (x0 * cos_val - x1 * sin_val);
x[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val);
}
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
const float* HWY_RESTRICT other, float* HWY_RESTRICT x, size_t size) {
for (size_t i = 0; i < size; ++i) {
x[i] += other[i];
}
}
static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other,
float* HWY_RESTRICT x, size_t size,
size_t max_pos) {
HWY_DASSERT(max_pos <= size);
for (size_t i = 0; i < max_pos; ++i) {
x[i] *= other[i];
}
}
static HWY_INLINE HWY_MAYBE_UNUSED void MulBy(const float* HWY_RESTRICT other,
float* HWY_RESTRICT x,
size_t size) {
return MulBy(other, x, size, size);
}
static HWY_NOINLINE void MulByConst(float c, float* HWY_RESTRICT x, size_t size,
size_t max_pos) {
HWY_DASSERT(max_pos <= size);
for (size_t i = 0; i < max_pos; ++i) {
x[i] *= c;
}
}
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(float c,
float* HWY_RESTRICT x,
size_t size) {
MulByConst(c, x, size, size);
}
static HWY_NOINLINE void MulByConstAndAdd(float c, const float* HWY_RESTRICT x,
float* HWY_RESTRICT out, size_t size,
size_t max_pos) {
for (size_t i = 0; i < max_pos; ++i) {
out[i] += x[i] * c;
}
}
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
float c, const float* HWY_RESTRICT x, float* HWY_RESTRICT out,
size_t size) {
MulByConstAndAdd(c, x, out, size, size);
}
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, size_t size,
size_t mask_pos) {
HWY_DASSERT(size != 0);
HWY_DASSERT(mask_pos <= size);
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
using V = hn::Vec<D>;
// Find max so we can subtract it below.
const V vmin = hn::Set(d, hwy::LowestValue<float>());
V max = vmin;
hn::Foreach(d, x, mask_pos, vmin,
[&max](D d, V v) { max = hn::Max(max, v); });
max = hn::MaxOfLanes(d, max); // broadcast
// Subtract max (avoid precision loss for large exponents) and exponentiate.
V sum = hn::Zero(d);
hn::Transform(d, x, mask_pos, [&sum, max](D d, V v) {
const V out = hn::Exp(d, hn::Sub(v, max));
sum = hn::Add(sum, out);
return out;
});
// Normalize to probability distribution
const float mul = 1.0f / hn::ReduceSum(d, sum);
MulByConst(mul, x, size, mask_pos);
}
static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x,
size_t size) {
Softmax(x, size, size);
}
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
size_t size, size_t max_pos) {
HWY_DASSERT(max_pos <= size);
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
using V = hn::Vec<D>;
const V inv_cap = hn::Set(d, 1.0f / cap);
const V vcap = hn::Set(d, cap);
hn::Transform(d, x, size, [vcap, inv_cap](D d, hn::Vec<D> v) {
return hn::Mul(vcap, hn::Tanh(d, hn::Mul(inv_cap, v)));
});
}
static HWY_INLINE HWY_MAYBE_UNUSED void LogitsSoftCap(const float cap,
float* HWY_RESTRICT x,
size_t size) {
LogitsSoftCap(cap, x, size, size);
}
static HWY_NOINLINE HWY_MAYBE_UNUSED size_t
SampleArgmax(const float* probabilities, size_t vocab_size) {
size_t max_index = 0;
float max_prob = probabilities[0];
for (size_t i = 1; i < vocab_size; ++i) {
if (probabilities[i] > max_prob) {
max_index = i;
max_prob = probabilities[i];
}
}
return max_index;
}
template <size_t k>
static HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution<int>
create_distribution(std::array<float, k>& top_k, float temperature) {
// re-normalize distribution
for (size_t i = 0; i < k; ++i) {
top_k[i] = exp(log(top_k[i]) / temperature);
}
float denominator = 0.0f;
for (size_t i = 0; i < k; ++i) {
denominator += top_k[i];
}
denominator = 1.0f / denominator;
MulByConst(denominator, top_k.data(), k);
return std::discrete_distribution<int>(std::begin(top_k), std::end(top_k));
}
template <size_t k, typename TAcceptToken>
static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
const float* HWY_RESTRICT probabilities, size_t vocab_size,
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
static_assert(k != 0, "");
// TODO(austinvhuang): Optimize this implementation.
std::array<float, k> top_k{}; // sorted from highest [0], to lowest [k-1]
std::array<int, k> indices{};
for (size_t i = 0; i < vocab_size; ++i) {
if (probabilities[i] < top_k[k - 1] && accept_token(static_cast<int>(i))) {
continue;
}
for (size_t j = 0; j < k; ++j) {
if (probabilities[i] > top_k[j] && accept_token(static_cast<int>(i))) {
// shift elements by 1, insert the new value, move on to next value
for (size_t idx = k - 1; idx > j; --idx) {
top_k[idx] = top_k[idx - 1];
indices[idx] = indices[idx - 1];
}
top_k[j] = probabilities[i];
indices[j] = static_cast<int>(i);
break;
}
}
}
return indices[create_distribution<k>(top_k, temperature)(gen)];
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // NOLINT

261
run.cc Normal file
View File

@ -0,0 +1,261 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Command line text interface to gemma.
#include <ctime>
#include <iostream>
#include <random>
#include <string>
#include <thread> // NOLINT
#include <vector>
// copybara:import_next_line:gemma_cpp
#include "compression/compress.h"
// copybara:import_next_line:gemma_cpp
#include "gemma.h" // Gemma
// copybara:import_next_line:gemma_cpp
#include "util/app.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // HasHelp
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
#include "hwy/per_target.h"
#include "hwy/profiler.h"
#include "hwy/timer.h"
namespace gcpp {
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
gcpp::AppArgs& app) {
fprintf(stderr,
"\ngemma.cpp\n---------\n\nTo run gemma.cpp, you need to "
"specify 3 required model loading arguments: --tokenizer, "
"--compressed_weights, "
"and --model.\n\nModel Loading Arguments\n\n");
loader.Help();
fprintf(stderr, "\nInference Arguments\n\n");
inference.Help();
fprintf(stderr, "\nApplication Arguments\n\n");
app.Help();
fprintf(stderr, "\n\n");
}
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
loader.Print(app.verbosity);
inference.Print(app.verbosity);
app.Print(app.verbosity);
if (app.verbosity >= 2) {
time_t now = time(nullptr);
char* dt = ctime(&now); // NOLINT
std::cout << "Date & Time : " << dt
<< "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize
<< "\n"
<< "Hardware concurrency : "
<< std::thread::hardware_concurrency() << std::endl
<< "Instruction set : "
<< hwy::TargetName(hwy::DispatchedTarget()) << " ("
<< hwy::VectorBytes() * 8 << " bits)" << "\n"
<< "Weight Type : "
<< gcpp::TypeName(gcpp::WeightT()) << "\n"
<< "EmbedderInput Type : "
<< gcpp::TypeName(gcpp::EmbedderInputT()) << "\n";
}
}
void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const InferenceArgs& args,
int verbosity, const gcpp::AcceptFunc& accept_token) {
PROFILER_ZONE("Gen.misc");
int abs_pos = 0; // absolute token index over all turns
int current_pos = 0; // token index within the current turn
int prompt_size{};
std::mt19937 gen;
if (args.deterministic) {
gen.seed(42);
} else {
std::random_device rd;
gen.seed(rd());
}
// callback function invoked for each generated token.
auto stream_token = [&abs_pos, &current_pos, &args, &gen, &prompt_size,
tokenizer = &model.Tokenizer(),
verbosity](int token, float) {
++abs_pos;
++current_pos;
if (current_pos < prompt_size) {
std::cerr << "." << std::flush;
} else if (token == gcpp::EOS_ID) {
if (!args.multiturn) {
abs_pos = 0;
if (args.deterministic) {
gen.seed(42);
}
}
if (verbosity >= 2) {
std::cout << "\n[ End ]" << std::endl;
}
} else {
std::string token_text;
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text).ok());
// +1 since position is incremented above
if (current_pos == prompt_size + 1) {
// first token of response
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
if (verbosity >= 1) {
std::cout << std::endl << std::endl;
}
}
// TODO(austinvhuang): is explicit space necessary?
std::cout << token_text << std::flush;
}
return true;
};
while (abs_pos < args.max_tokens) {
std::string prompt_string;
std::vector<int> prompt;
current_pos = 0;
{
PROFILER_ZONE("Gen.input");
if (verbosity >= 1) {
std::cout << "> " << std::flush;
}
std::getline(std::cin, prompt_string);
}
if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") {
return;
}
if (model.model_training == ModelTraining::GEMMA_IT) {
// For instruction-tuned models: add control tokens.
prompt_string = "<start_of_turn>user\n" + prompt_string +
"<end_of_turn>\n<start_of_turn>model\n";
if (abs_pos > 0) {
// Prepend "<end_of_turn>" token if this is a multi-turn dialogue
// continuation.
prompt_string = "<end_of_turn>\n" + prompt_string;
}
}
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok());
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
if (abs_pos == 0) {
prompt.insert(prompt.begin(), 2);
}
prompt_size = prompt.size();
std::cerr << std::endl << "[ Reading prompt ] " << std::flush;
const double time_start = hwy::platform::Now();
GenerateGemma(model, args, prompt, abs_pos, pool, inner_pool, stream_token,
accept_token, gen, verbosity);
const double time_end = hwy::platform::Now();
const double tok_sec = current_pos / (time_end - time_start);
if (verbosity >= 2) {
std::cout << current_pos << " tokens (" << abs_pos << " total tokens)"
<< std::endl
<< tok_sec << " tokens / sec" << std::endl;
}
std::cout << std::endl << std::endl;
}
std::cout
<< "max_tokens (" << args.max_tokens
<< ") exceeded. Use a larger value if desired using the --max_tokens "
<< "command line flag.\n";
}
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
PROFILER_ZONE("Run.misc");
hwy::ThreadPool inner_pool(0);
hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning threads to cores helps.
if (app.num_threads > 10) {
PinThreadToCore(app.num_threads - 1); // Main thread
pool.Run(0, pool.NumThreads(),
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
}
gcpp::Gemma model(loader, pool);
if (const char* error = inference.Validate()) {
ShowHelp(loader, inference, app);
HWY_ABORT("\nInvalid args: %s", error);
}
if (app.verbosity >= 1) {
static const std::string banner_ascii_art =
" __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __\n"
" / _` |/ _ \\ '_ ` _ \\| '_ ` _ \\ / _` | / __| '_ \\| '_ \\\n"
"| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |\n"
" \\__, |\\___|_| |_| |_|_| |_| |_|\\__,_(_)___| .__/| .__/\n"
" __/ | | | | |\n"
" |___/ |_| |_|";
const std::string instructions =
"*Usage*\n"
" Enter an instruction and press enter (%Q quits).\n\n"
"*Examples*\n"
" - Write an email to grandma thanking her for the cookies.\n"
" - What are some historical attractions to visit around "
"Massachusetts?\n"
" - Compute the nth fibonacci number in javascript.\n"
" - Write a standup comedy bit about GPU programming.\n";
std::cout << "\033[2J\033[1;1H" // clear screen
<< banner_ascii_art << "\n\n";
ShowConfig(loader, inference, app);
std::cout << "\n" << instructions << "\n";
}
ReplGemma(model, pool, inner_pool, inference, app.verbosity,
/*accept_token=*/[](int) { return true; });
}
} // namespace gcpp
int main(int argc, char** argv) {
{
PROFILER_ZONE("Startup.misc");
gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs inference(argc, argv);
gcpp::AppArgs app(argc, argv);
if (gcpp::HasHelp(argc, argv)) {
ShowHelp(loader, inference, app);
return 0;
}
if (const char* error = loader.Validate()) {
ShowHelp(loader, inference, app);
HWY_ABORT("\nInvalid args: %s", error);
}
gcpp::Run(loader, inference, app);
}
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
return 0;
}

85
util/app.h Normal file
View File

@ -0,0 +1,85 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Shared between various frontends.
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
#include <sched.h>
#include <stddef.h>
#include <algorithm> // std::clamp
#include <thread> // NOLINT>
// copybara:import_next_line:gemma_cpp
#include "util/args.h"
#include "hwy/base.h" // HWY_ASSERT
namespace gcpp {
static inline void PinThreadToCore(size_t cpu_index) {
#if HWY_OS_LINUX
// Forces the thread to run on the logical processor with the same number.
cpu_set_t cset; // bit array
CPU_ZERO(&cset); // clear all
CPU_SET(cpu_index, &cset); // set bit indicating which processor to run on.
HWY_ASSERT(0 == sched_setaffinity(0, sizeof(cset), &cset));
#else
(void)cpu_index;
#endif
}
class AppArgs : public ArgsBase<AppArgs> {
static constexpr size_t kDefaultNumThreads = ~size_t{0};
void ChooseNumThreads() {
if (num_threads == kDefaultNumThreads) {
// This is a rough heuristic, replace with something better in the future.
num_threads = static_cast<size_t>(std::clamp(
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
}
}
public:
AppArgs(int argc, char* argv[]) {
InitAndParse(argc, argv);
ChooseNumThreads();
}
Path log; // output
int verbosity;
size_t num_threads;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(log, "log", Path{"/tmp/log.txt"}, "Logging file", 2);
visitor(verbosity, "verbosity", 1,
"Show verbose developer information\n 0 = only print generation "
"output\n 1 = standard user-facing terminal ui\n 2 = show "
"developer/debug info).\n Default = 1.",
2);
visitor(num_threads, "num_threads",
kDefaultNumThreads, // see ChooseNumThreads
"Number of threads to use. Default value is set based on an "
"estimate of "
"how many concurrent threads are supported.",
2);
}
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_

223
util/args.h Normal file
View File

@ -0,0 +1,223 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Command line arguments.
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ARGS_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_ARGS_H_
#include <stdio.h>
#include <algorithm> // std::transform
#include <string>
#include "hwy/base.h" // HWY_ABORT
namespace gcpp {
// Wrapper for strings representing a path name. Differentiates vs. arbitrary
// strings and supports shortening for display purposes.
struct Path {
Path& operator=(const char* other) {
path = other;
return *this;
}
std::string Shortened() const {
constexpr size_t max_len = 48;
constexpr size_t cut_point = max_len / 2 - 5;
if (path.size() > max_len) {
return std::string(begin(path), begin(path) + cut_point) + " ... " +
std::string(end(path) - cut_point, end(path));
}
if (path.empty()) return "[no path specified]";
return path;
}
std::string path;
};
// Args is a class that provides a ForEach member function which visits each of
// its member variables. ArgsBase provides functions called by Args to
// initialize values to their defaults (passed as an argument to the visitor),
// print and parse, without having to repeat the args for each usage.
template <class Args>
class ArgsBase {
struct InitVisitor {
template <typename T>
void operator()(T& t, const char* /*name*/, const T& init,
const char* /*help*/, int /*print_verbosity*/ = 0) const {
t = init;
}
};
struct HelpVisitor {
template <typename T>
void operator()(T&, const char* name, T /*init*/, const char* help,
int /*print_verbosity*/ = 0) const {
fprintf(stderr, " --%s : %s\n", name, help);
}
};
class PrintVisitor {
public:
explicit PrintVisitor(int verbosity) : verbosity_(verbosity) {}
template <typename T>
void operator()(const T& t, const char* name, const T& /*init*/,
const char* /*help*/, int print_verbosity = 0) const {
if (verbosity_ >= print_verbosity) {
fprintf(stderr, "%-30s: %s\n", name, std::to_string(t).c_str());
}
}
void operator()(const std::string& t, const char* name,
const std::string& /*init*/, const char* /*help*/,
int print_verbosity = 0) const {
if (verbosity_ >= print_verbosity) {
fprintf(stderr, "%-30s: %s\n", name, t.c_str());
}
}
void operator()(const Path& t, const char* name, const Path& /*init*/,
const char* /*help*/, int print_verbosity = 0) const {
if (verbosity_ >= print_verbosity) {
fprintf(stderr, "%-30s: %s\n", name, t.Shortened().c_str());
}
}
private:
int verbosity_;
};
// Supported types: integer, float, std::string, bool, Path. This is O(N^2):
// for each arg, we search through argv. If there are more than a dozen args,
// consider adding a hash-map to speed this up.
class ParseVisitor {
public:
ParseVisitor(int argc, char* argv[]) : argc_(argc), argv_(argv) {}
template <typename T>
void operator()(T& t, const char* name, const T& /*init*/,
const char* /*help*/, int /*print_verbosity*/ = 0) const {
const std::string prefixed = std::string("--") + name;
for (int i = 1; i < argc_; ++i) {
if (std::string(argv_[i]) == prefixed) {
if (i + 1 >= argc_) {
HWY_ABORT("Missing value for %s\n", name);
}
if (!SetValue(argv_[i + 1], t)) {
HWY_ABORT("Invalid value for %s, got %s\n", name, argv_[i + 1]);
}
return;
}
}
}
private:
// Returns false if an invalid value is detected.
template <typename T, HWY_IF_NOT_FLOAT(T)>
static bool SetValue(const char* string, T& t) {
t = std::stoi(string);
return true;
}
template <typename T, HWY_IF_FLOAT(T)>
static bool SetValue(const char* string, T& t) {
t = std::stof(string);
return true;
}
static bool SetValue(const char* string, std::string& t) {
t = string;
return true;
}
static bool SetValue(const char* string, Path& t) {
t.path = string;
return true;
}
static bool SetValue(const char* string, bool& t) {
std::string value(string);
// Lower-case. Arg names are expected to be ASCII-only.
std::transform(value.begin(), value.end(), value.begin(), [](char c) {
return 'A' <= c && c <= 'Z' ? c - ('Z' - 'z') : c;
});
if (value == "true" || value == "on" || value == "1") {
t = true;
return true;
} else if (value == "false" || value == "off" || value == "0") {
t = false;
return true;
} else {
return false;
}
}
int argc_;
char** argv_;
}; // ParseVisitor
template <class Visitor>
void ForEach(Visitor& visitor) {
static_cast<Args*>(this)->ForEach(visitor);
}
public:
// WARNING: cannot call from ctor because the derived ctor has not yet run.
void Init() {
InitVisitor visitor;
ForEach(visitor);
}
void Help() {
HelpVisitor visitor;
ForEach(visitor);
}
void Print(int verbosity = 0) {
PrintVisitor visitor(verbosity);
ForEach(visitor);
}
void Parse(int argc, char* argv[]) {
ParseVisitor visitor(argc, argv);
ForEach(visitor);
}
// For convenience, enables single-line constructor.
void InitAndParse(int argc, char* argv[]) {
Init();
Parse(argc, argv);
}
};
static bool HasHelp(int argc, char* argv[]) {
// TODO(austinvhuang): handle case insensitivity
if (argc == 1) {
// no arguments - print help
return true;
}
for (int i = 1; i < argc; ++i) {
if (std::string(argv[i]) == "--help") {
return true;
}
}
return false;
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_ARGS_H_