mirror of https://github.com/google/gemma.cpp.git
initial commit
This commit is contained in:
commit
e29cd566cf
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
cmake_minimum_required(VERSION 3.11)
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
project(gemma)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
|
||||
## Note: absl meeds tp be installed by sentencepiece. This will only happen if
|
||||
## cmake is invoked with -DSPM_ENABLE_SHARED=OFF and -DSPM_ABSL_PROVIDER=module
|
||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
||||
FetchContent_MakeAvailable(sentencepiece)
|
||||
|
||||
set(SOURCES
|
||||
gemma.cc
|
||||
compression/blob_store.cc
|
||||
compression/blob_store.h
|
||||
compression/compress.h
|
||||
compression/compress-inl.h
|
||||
compression/nuq.h
|
||||
compression/nuq-inl.h
|
||||
compression/sfp.h
|
||||
compression/sfp-inl.h
|
||||
util/app.h
|
||||
util/args.h
|
||||
)
|
||||
|
||||
add_compile_options($<$<CONFIG:Release>:-O2>)
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE "Release")
|
||||
endif()
|
||||
|
||||
# Allowable types for WEIGHT_TYPE:
|
||||
# float - slow, not recommended
|
||||
# hwy::bfloat16_t - bfloat16 as impemented by https://github.com/google/highway
|
||||
# SfpStream - 8-bit switched floating point (recommended)
|
||||
# NuqStream - experimental, work-in-progress
|
||||
option(WEIGHT_TYPE "Set weight type" "")
|
||||
|
||||
if (WEIGHT_TYPE)
|
||||
add_definitions(-DGEMMA_WEIGHT_T=${WEIGHT_TYPE})
|
||||
endif()
|
||||
|
||||
# Executable Target
|
||||
|
||||
add_executable(gemma run.cc)
|
||||
target_sources(gemma PRIVATE ${SOURCES})
|
||||
set_property(TARGET gemma PROPERTY CXX_STANDARD 17)
|
||||
target_link_libraries(gemma hwy hwy_contrib sentencepiece)
|
||||
target_include_directories(gemma PRIVATE ./)
|
||||
FetchContent_GetProperties(sentencepiece)
|
||||
target_include_directories(gemma PRIVATE ${sentencepiece_SOURCE_DIR})
|
||||
|
||||
## Library Target
|
||||
|
||||
add_library(libgemma ${SOURCES})
|
||||
set_property(TARGET libgemma PROPERTY CXX_STANDARD 17)
|
||||
set_target_properties(libgemma PROPERTIES PREFIX "")
|
||||
target_include_directories(libgemma PUBLIC ./)
|
||||
target_link_libraries(libgemma hwy hwy_contrib sentencepiece)
|
||||
target_include_directories(libgemma PRIVATE ${sentencepiece_SOURCE_DIR})
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
# Developer Notes
|
||||
|
||||
## Motivation: A Minimalist C++ LLM Runtime for Research and Experimentation
|
||||
|
||||
In the past, neural network inference has been similar to a simple, opaque,
|
||||
stateless function function with a single input and output. By contrast,
|
||||
foundation model runtimes are better considered as systems with multiple forms
|
||||
of state, subsystems, and heterogeneous inputs and outputs. They are often
|
||||
integrated with a wide variety of other systems that have their own resources
|
||||
(e.g. RAG and tools) and potentially interact with an external environment. They
|
||||
have become compute engines to embed proximal tasks and goals within expansively
|
||||
broad, general-purpose world models.
|
||||
|
||||
With this in mind, we believe that developing an experimental runtime that is
|
||||
flexible and approachable will allow us to explore the design space of co-design
|
||||
between high level model concerns and low-level runtime computation.
|
||||
|
||||
## Design Priorities
|
||||
|
||||
Given these motivations, we propose the following priorities for
|
||||
making decisions regarding the direction and design of the codebase.
|
||||
|
||||
**Maximize Leverage with a Narrow Scope.** We focus on direct implementations of
|
||||
foundation models like Gemma. This allows us to focus effort on bottlenecks of
|
||||
specific models. We are willing to trade off generality to keep implementation
|
||||
code relatively simple and readable at all layers of the stack, achieve good
|
||||
performance, and maintain the velocity of a small team.
|
||||
|
||||
**Data Oriented Design.** Follow data oriented design principles where possible
|
||||
to minimize unnecessary performance pessimization. It's best to apply these
|
||||
optimizations during the initial design, or when refactoring a subcomponent. The
|
||||
first step is to think in terms of batches or tuples of plain old data (POD)
|
||||
types: separate arrays, instead of an array of structs. The second is to
|
||||
de-emphasize control flow (if statements, virtual functions and class
|
||||
hierarchies). The third step is to know intrinsic properties of data and bake
|
||||
that into the layout and algorithm.
|
||||
|
||||
**Prioritize Small Batch Latency** Since production serving solutions are
|
||||
available for large-scale serving powered by accelerators and optimizing for
|
||||
throughput, this project focuses on the possibilities of local, interactive use
|
||||
of foundation models. Although throughput remains important, low latency and
|
||||
small batch sizes are prioritized, other things being equal.
|
||||
|
||||
**Maintain a Portable Baseline** Our starting point is a portable CPU SIMD (via
|
||||
[highway](https://github.com/google/highway)). We expect to add accelerator and
|
||||
hybrid CPU/GPU support in the future, but the project should continue to allow
|
||||
builds using this portable baseline. This ensures that research-oriented and
|
||||
experimental runtimes and hardware platforms will have a minimum viable option
|
||||
to run Gemma even if specialized production-ready deployment paths are not
|
||||
available.
|
||||
|
||||
## Code Organization
|
||||
|
||||
The implementation code is roughly split into 4 layers, from high to low level:
|
||||
|
||||
1. Frontends (`run.cc`) - Either interactive interfaces or automation
|
||||
orchestration that interacts. Frontend code implements a use case objective
|
||||
in terms of invocations to model inference and generation (2). Projects that
|
||||
use gemma.cpp as a library are considered alternative frontends to `run.cc`.
|
||||
We will add examples of additional frontends in the future.
|
||||
|
||||
2. Models (`gemma.cc`, `gemma.h`, `configs.h`) - Implements the compute graph
|
||||
of the model including supporting functions such as loading and compressing
|
||||
weights using transformer operations provided by layer (3).
|
||||
|
||||
3. Operations (`ops.h`) - A minimal set of transformer and supporting
|
||||
mathematical operations implementations using compute backends (4). This
|
||||
code should be agnostic to the specifics of the compute graph of the model
|
||||
implementation (2).
|
||||
|
||||
4. Backend (`highway`) - Low-level hardware interface (SIMD in the case of
|
||||
highway) supporting the implementations in (3).
|
||||
|
|
@ -0,0 +1,202 @@
|
|||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
Copyright (c) The gemma.cpp Project Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification,
|
||||
are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
|
@ -0,0 +1,335 @@
|
|||
# gemma.cpp
|
||||
|
||||
gemma.cpp is a lightweight, standalone C++ inference engine for the Gemma
|
||||
foundation models from Google.
|
||||
|
||||
For additional information about Gemma, see
|
||||
[ai.google.dev/gemma](https://ai.google.dev/gemma). Model weights, including gemma.cpp
|
||||
specific artifacts, are [available on
|
||||
kaggle](https://www.kaggle.com/models/google/gemma).
|
||||
|
||||
## Who is this project for?
|
||||
|
||||
Modern LLM inference engines are sophisticated systems, often with bespoke
|
||||
capabilities extending beyond traditional neural network runtimes. With this
|
||||
comes opportunities for research and innovation through co-design of high level
|
||||
algorithms and low-level computation. However, there is a gap between
|
||||
deployment-oriented C++ inference runtimes, which are not designed for
|
||||
experimentation, and Python-centric ML research frameworks, which abstract away
|
||||
low-level computation through compilation.
|
||||
|
||||
gemma.cpp provides a minimalist implementation of Gemma 2B and 7B models,
|
||||
focusing on simplicity and directness rather than full generality. This is
|
||||
inspired by vertically-integrated model implementations such as
|
||||
[ggml](https://github.com/ggerganov/ggml),
|
||||
[llama.c](https://github.com/karpathy/llama2.c), and
|
||||
[llama.rs](https://github.com/srush/llama2.rs).
|
||||
|
||||
gemma.cpp targets experimentation and research use cases. It is intended to be
|
||||
straightforward to embed in other projects with minimal dependencies and also
|
||||
easily modifiable with a small ~2K LoC core implementation (along with ~4K LoC
|
||||
of supporting utilities). We use the [Google
|
||||
Highway](https://github.com/google/highway) Library to take advantage of
|
||||
portable SIMD for CPU inference.
|
||||
|
||||
For production-oriented edge deployments we recommend standard deployment
|
||||
pathways using Python frameworks like JAX, Keras, PyTorch, and Transformers
|
||||
([all model variations here](https://www.kaggle.com/models/google/gemma)).
|
||||
|
||||
Community contributions large and small are welcome. This project follows
|
||||
[Google's Open Source Community
|
||||
Guidelines](https://opensource.google.com/conduct/).
|
||||
|
||||
## Quick Start
|
||||
|
||||
### System requirements
|
||||
|
||||
Before starting, you should have installed:
|
||||
|
||||
- [CMake](https://cmake.org/)
|
||||
- [Clang C++ compiler](https://clang.llvm.org/get_started.html), supporting at
|
||||
least C++17.
|
||||
- `tar` for extracting archives from Kaggle.
|
||||
|
||||
### Step 1: Obtain model weights and tokenizer from Kaggle
|
||||
|
||||
Visit [the Gemma model page on
|
||||
Kaggle](https://www.kaggle.com/models/google/gemma) and select `Model Variations
|
||||
|> Gemma C++`. On this tab, the `Variation` dropdown includes the options below.
|
||||
Note bfloat16 weights are higher fidelity, while 8-bit switched floating point
|
||||
weights enable faster inference.
|
||||
|
||||
2B instruction-tuned (`it`) and pre-trained (`pt`) models:
|
||||
|
||||
| Model name | Description |
|
||||
| ----------- | ----------- |
|
||||
| `2b-it` | 2 billion parameter instruction-tuned model, bfloat16 |
|
||||
| `2b-it-sfp` | 2 billion parameter instruction-tuned model, 8-bit switched floating point |
|
||||
| `2b-pt` | 2 billion parameter pre-trained model, bfloat16 |
|
||||
| `2b-pt-sfp` | 2 billion parameter pre-trained model, 8-bit switched floating point |
|
||||
|
||||
7B instruction-tuned (`it`) and pre-trained (`pt`) models:
|
||||
|
||||
| Model name | Description |
|
||||
| ----------- | ----------- |
|
||||
| `7b-it` | 7 billion parameter instruction-tuned model, bfloat16 |
|
||||
| `7b-it-sfp` | 7 billion parameter instruction-tuned model, 8-bit switched floating point |
|
||||
| `7b-pt` | 7 billion parameter pre-trained model, bfloat16 |
|
||||
| `7b-pt-sfp` | 7 billion parameter pre-trained model, 8-bit switched floating point |
|
||||
|
||||
> [!NOTE]
|
||||
> We *recommend starting with `2b-it-sfp`* to get up and running.
|
||||
|
||||
### Step 2: Extract Files
|
||||
|
||||
After filling out the consent form, the download should proceed to retrieve a
|
||||
tar archive file `archive.tar.gz`. Extract files from `archive.tar.gz` (this can
|
||||
take a few minutes):
|
||||
|
||||
```
|
||||
tar -xf archive.tar.gz
|
||||
```
|
||||
|
||||
This should produce a file containing model weights such as `2b-it-sfp.sbs` and
|
||||
a tokenizer file (`tokenizer.spm`). You may want to move these files to a
|
||||
convenient directory location (e.g. the `build/` directory in this repo).
|
||||
|
||||
### Step 3: Build
|
||||
|
||||
The build system uses [CMake](https://cmake.org/). To build the gemma inference
|
||||
runtime, create a build directory and generate the build files using `cmake`
|
||||
from the top-level project directory:
|
||||
|
||||
```sh
|
||||
(cd build && cmake ..)
|
||||
```
|
||||
|
||||
Then run `make` to build the `./gemma` executable:
|
||||
|
||||
```sh
|
||||
cd build
|
||||
make -j [number of parallel threads to use] gemma
|
||||
```
|
||||
|
||||
For example, `make -j 8 gemma`. If this is successful, you should now have a
|
||||
`gemma` executable in the `build/` directory.
|
||||
|
||||
> [!NOTE]
|
||||
> On Windows Subsystem for Linux (WSL) users should set the number of
|
||||
> parallel threads to 1. Using a larger number may result in errors.
|
||||
|
||||
### Step 4: Run
|
||||
|
||||
You can now run `gemma` from inside the `build/` directory.
|
||||
|
||||
`gemma` has the following required arguments:
|
||||
|
||||
| Argument | Description | Example value |
|
||||
| -------- | ----------- | ------------- |
|
||||
| `--model` | The model type. | `2b-it`, `2b-pt`, `7b-it`, `7b-pt`, ... (see above) |
|
||||
| `--compressed_weights` | The compressed weights file. | `2b-it-sfp.sbs`, ... (see above) |
|
||||
| `--tokenizer` | The tokenizer file. | `tokenizer.spm` |
|
||||
|
||||
|
||||
`gemma` is invoked as:
|
||||
|
||||
```sh
|
||||
./gemma \
|
||||
--tokenizer [tokenizer file] \
|
||||
--compressed_weights [compressed weights file] \
|
||||
--model [2b-it or 2b-pt or 7b-it or 7b-pt or ...]
|
||||
```
|
||||
|
||||
Example invocation for the following configuration:
|
||||
|
||||
- Compressed weights file `2b-it-sfp.sbs` (2B instruction-tuned model, 8-bit
|
||||
switched floating point).
|
||||
- Tokenizer file `tokenizer.spm`.
|
||||
|
||||
```sh
|
||||
./gemma \
|
||||
--tokenizer tokenizer.spm \
|
||||
--compressed_weights 2b-it-sfp.sbs \
|
||||
--model 2b-it
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
`gemma` has different usage modes, controlled by the verbosity flag.
|
||||
|
||||
All usage modes are currently interactive, triggering text generation upon
|
||||
newline input.
|
||||
|
||||
| Verbosity | Usage mode | Details |
|
||||
| --------------- | ---------- | --------------------------------------------- |
|
||||
| `--verbosity 0` | Minimal | Only prints generation output. Suitable as a CLI tool. |
|
||||
| `--verbosity 1` | Default | Standard user-facing terminal UI. |
|
||||
| `--verbosity 2` | Detailed | Shows additional developer and debug info. |
|
||||
|
||||
### Interactive Terminal App
|
||||
|
||||
By default, verbosity is set to 1, bringing up a terminal-based interactive
|
||||
interface when `gemma` is invoked:
|
||||
|
||||
```console
|
||||
$ ./gemma [...]
|
||||
__ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __
|
||||
/ _` |/ _ \ '_ ` _ \| '_ ` _ \ / _` | / __| '_ \| '_ \
|
||||
| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |
|
||||
\__, |\___|_| |_| |_|_| |_| |_|\__,_(_)___| .__/| .__/
|
||||
__/ | | | | |
|
||||
|___/ |_| |_|
|
||||
|
||||
tokenizer : tokenizer.spm
|
||||
compressed_weights : 2b-it-sfp.sbs
|
||||
model : 2b-it
|
||||
weights : [no path specified]
|
||||
max_tokens : 3072
|
||||
max_generated_tokens : 2048
|
||||
|
||||
*Usage*
|
||||
Enter an instruction and press enter (%Q quits).
|
||||
|
||||
*Examples*
|
||||
- Write an email to grandma thanking her for the cookies.
|
||||
- What are some historical attractions to visit around Massachusetts?
|
||||
- Compute the nth fibonacci number in javascript.
|
||||
- Write a standup comedy bit about WebGPU programming.
|
||||
|
||||
> What are some outdoorsy places to visit around Boston?
|
||||
|
||||
[ Reading prompt ] .....................
|
||||
|
||||
|
||||
**Boston Harbor and Islands:**
|
||||
|
||||
* **Boston Harbor Islands National and State Park:** Explore pristine beaches, wildlife, and maritime history.
|
||||
* **Charles River Esplanade:** Enjoy scenic views of the harbor and city skyline.
|
||||
* **Boston Harbor Cruise Company:** Take a relaxing harbor cruise and admire the city from a different perspective.
|
||||
* **Seaport Village:** Visit a charming waterfront area with shops, restaurants, and a seaport museum.
|
||||
|
||||
**Forest and Nature:**
|
||||
|
||||
* **Forest Park:** Hike through a scenic forest with diverse wildlife.
|
||||
* **Quabbin Reservoir:** Enjoy boating, fishing, and hiking in a scenic setting.
|
||||
* **Mount Forest:** Explore a mountain with breathtaking views of the city and surrounding landscape.
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
### Usage as a Command Line Tool
|
||||
|
||||
For using the `gemma` executable as a command line tool, it may be useful to
|
||||
create an alias for gemma.cpp with arguments fully specified:
|
||||
|
||||
```sh
|
||||
alias gemma2b="~/gemma.cpp/build/gemma -- --tokenizer ~/gemma.cpp/build/tokenizer.spm --compressed_weights ~/gemma.cpp/build/2b-it-sfp.sbs --model 2b-it --verbosity 0"
|
||||
```
|
||||
|
||||
Replace the above paths with your own paths to the model and tokenizer paths
|
||||
from the download.
|
||||
|
||||
Here is an example of prompting `gemma` with a truncated input
|
||||
file (using a `gemma2b` alias like defined above):
|
||||
|
||||
```sh
|
||||
cat configs.h | tail -35 | tr '\n' ' ' | xargs -0 echo "What does this C++ code do: " | gemma2b
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> CLI usage of gemma.cpp is experimental and should take context length
|
||||
> limitations into account.
|
||||
|
||||
The output of the above command should look like:
|
||||
|
||||
```console
|
||||
$ cat configs.h | tail -35 | tr '\n' ' ' | xargs -0 echo "What does this C++ code do: " | gemma2b
|
||||
[ Reading prompt ] ......................................................................................................................................................................................................................................................................................................................................................................................................................................................................................
|
||||
The code defines two C++ structs, `ConfigGemma7B` and `ConfigGemma2B`, which are used for configuring a deep learning model.
|
||||
|
||||
**ConfigGemma7B**:
|
||||
|
||||
* `seq_len`: Stores the length of the sequence to be processed. It's set to 7168.
|
||||
* `vocab_size`: Stores the size of the vocabulary, which is 256128.
|
||||
* `n_layers`: Number of layers in the deep learning model. It's set to 28.
|
||||
* `dim_model`: Dimension of the model's internal representation. It's set to 3072.
|
||||
* `dim_ffw_hidden`: Dimension of the feedforward and recurrent layers' hidden representations. It's set to 16 * 3072 / 2.
|
||||
|
||||
**ConfigGemma2B**:
|
||||
|
||||
* `seq_len`: Stores the length of the sequence to be processed. It's also set to 7168.
|
||||
* `vocab_size`: Size of the vocabulary, which is 256128.
|
||||
* `n_layers`: Number of layers in the deep learning model. It's set to 18.
|
||||
* `dim_model`: Dimension of the model's internal representation. It's set to 2048.
|
||||
* `dim_ffw_hidden`: Dimension of the feedforward and recurrent layers' hidden representations. It's set to 16 * 2048 / 2.
|
||||
|
||||
These structs are used to configure a deep learning model with specific parameters for either Gemma7B or Gemma2B architecture.
|
||||
```
|
||||
|
||||
### Incorporating gemma.cpp as a Library in your Project
|
||||
|
||||
The easiest way to incorporate gemma.cpp in your own project is to pull in
|
||||
gemma.cpp and dependencies using `FetchContent`. You can add the following to your
|
||||
CMakeLists.txt:
|
||||
|
||||
```
|
||||
include(FetchContent)
|
||||
|
||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
||||
FetchContent_MakeAvailable(sentencepiece)
|
||||
|
||||
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main)
|
||||
FetchContent_MakeAvailable(gemma)
|
||||
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
```
|
||||
|
||||
Note for the gemma.cpp `GIT_TAG`, you may replace `origin/main` for a specific
|
||||
commit hash if you would like to pin the library version.
|
||||
|
||||
After your executable is defined (substitute your executable name for
|
||||
`[Executable Name]` below):
|
||||
|
||||
```
|
||||
target_link_libraries([Executable Name] libgemma hwy hwy_contrib sentencepiece)
|
||||
FetchContent_GetProperties(gemma)
|
||||
FetchContent_GetProperties(sentencepiece)
|
||||
target_include_directories([Executable Name] PRIVATE ${gemma_SOURCE_DIR})
|
||||
target_include_directories([Executable Name] PRIVATE ${sentencepiece_SOURCE_DIR})
|
||||
```
|
||||
|
||||
### Building gemma.cpp as a Library
|
||||
|
||||
gemma.cpp can also be used as a library dependency in your own project. The
|
||||
shared library artifact can be built by modifying the make invocation to build
|
||||
the `libgemma` target instead of `gemma`.
|
||||
|
||||
> [!NOTE]
|
||||
> If you are using gemma.cpp in your own project with the `FetchContent` steps
|
||||
> in the previous section, building the library is done automatically by `cmake`
|
||||
> and this section can be skipped.
|
||||
|
||||
First, run `cmake`:
|
||||
|
||||
```sh
|
||||
(cd build && cmake ..)
|
||||
```
|
||||
|
||||
Then, run `make` with the `libgemma` target:
|
||||
|
||||
```sh
|
||||
cd build
|
||||
make -j [number of parallel threads to use] libgemma
|
||||
```
|
||||
|
||||
If this is successful, you should now have a
|
||||
`libgemma` library file in the `build/` directory. On linux the filename is `libgemma.a`.
|
||||
|
||||
## Acknowledgements and Contacts
|
||||
|
||||
gemma.cpp was started in fall 2023 by [Austin Huang](austinvhuang@google.com)
|
||||
and [Jan Wassenberg](janwas@google.com), and subsequently released February 2024
|
||||
thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng.
|
||||
|
||||
This is not an officially supported Google product.
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
*
|
||||
!.gitignore
|
||||
!.hgignore
|
||||
|
|
@ -0,0 +1,244 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Normal include guard to placate lint.
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h> // memcpy
|
||||
|
||||
#include <cmath> // std::signbit
|
||||
#include <cstdlib> // std::abs
|
||||
#include <vector>
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/distortion.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/nuq.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/stats.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
|
||||
|
||||
// Actual per-target include guard.
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||
#ifdef THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
|
||||
#undef THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
|
||||
#else
|
||||
#define THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
|
||||
#endif
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/nuq-inl.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/sfp-inl.h"
|
||||
#include "hwy/contrib/sort/vqsort-inl.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
class PerThread {
|
||||
public:
|
||||
void NotifyGroup(const float* group) {
|
||||
Stats s_group;
|
||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
// Skip zero so we can see the lowest actual magnitude
|
||||
if (group[i] == 0.0f || group[i] == -0.0f) continue;
|
||||
s_all_.Notify(group[i]);
|
||||
s_group.Notify(group[i]);
|
||||
|
||||
num_tiny_ += std::abs(group[i]) < 1e-3f;
|
||||
|
||||
// b_magn100_.Notify(group[i] * 40.0f + 20.0f);
|
||||
const uint32_t binary32 =
|
||||
hwy::BitCastScalar<uint32_t>(std::abs(group[i]));
|
||||
|
||||
// const int32_t exp = (binary32 >> 23) - 127;
|
||||
b_exp256_.Notify(binary32 >> 23);
|
||||
const uint32_t m4 = (binary32 & 0x7FFFFF) >> (23 - 4);
|
||||
b_m4_.Notify(m4);
|
||||
}
|
||||
s_group_ranges_.Notify(s_group.Max() - s_group.Min());
|
||||
s_group_mins_.Notify(s_group.Min());
|
||||
s_group_maxs_.Notify(s_group.Max());
|
||||
|
||||
float desc[kGroupSize];
|
||||
memcpy(desc, group, kGroupSize * sizeof(group[0]));
|
||||
hn::VQSortStatic(desc, kGroupSize, hwy::SortDescending());
|
||||
|
||||
// Find largest |max/min| (dynamic range)
|
||||
float max_ratio = 0.0f;
|
||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
if (desc[i] != 0.0f && desc[i] != -0.0f) {
|
||||
max_ratio = std::max(max_ratio, std::abs(desc[0] / desc[i]));
|
||||
}
|
||||
}
|
||||
s_group_max_vs_min_.Notify(max_ratio);
|
||||
|
||||
// Relative errors
|
||||
float diffs[kGroupSize];
|
||||
for (size_t i = 0; i < kGroupSize - 1; ++i) {
|
||||
// was in descending order. Avoid div by 0. Ignore sign changes.
|
||||
diffs[i] = std::abs(desc[i]) < 1e-5
|
||||
? 0
|
||||
: std::abs((desc[i] - desc[i + 1]) / desc[i]);
|
||||
}
|
||||
hn::VQSortStatic(diffs, kGroupSize, hwy::SortDescending());
|
||||
s_cut15_.Notify(diffs[15]);
|
||||
}
|
||||
|
||||
void Assimilate(const PerThread& other) {
|
||||
num_tiny_ += other.num_tiny_;
|
||||
s_all_.Assimilate(other.s_all_);
|
||||
s_group_ranges_.Assimilate(other.s_group_ranges_);
|
||||
s_group_mins_.Assimilate(other.s_group_mins_);
|
||||
s_group_maxs_.Assimilate(other.s_group_maxs_);
|
||||
s_group_max_vs_min_.Assimilate(other.s_group_max_vs_min_);
|
||||
s_erange_.Assimilate(other.s_erange_);
|
||||
s_km_1_.Assimilate(other.s_km_1_);
|
||||
s_km_2_.Assimilate(other.s_km_2_);
|
||||
s_cut15_.Assimilate(other.s_cut15_);
|
||||
b_magn100_.Assimilate(other.b_magn100_);
|
||||
b_exp256_.Assimilate(other.b_exp256_);
|
||||
b_m4_.Assimilate(other.b_m4_);
|
||||
}
|
||||
|
||||
void PrintAll() {
|
||||
const int skip = Stats::kNoGeomean;
|
||||
fprintf(stderr, "num tiny %zu\n", num_tiny_);
|
||||
fprintf(stderr, "weights %s\n", s_all_.ToString(skip).c_str());
|
||||
fprintf(stderr, " ranges %s\n", s_group_ranges_.ToString(skip).c_str());
|
||||
fprintf(stderr, " mins %s\n", s_group_mins_.ToString(skip).c_str());
|
||||
fprintf(stderr, " maxs %s\n", s_group_maxs_.ToString(skip).c_str());
|
||||
fprintf(stderr, " Mvm %s\n", s_group_max_vs_min_.ToString(skip).c_str());
|
||||
fprintf(stderr, " cut15 %s\n", s_cut15_.ToString(skip).c_str());
|
||||
fprintf(stderr, " erange %s\n", s_erange_.ToString(skip).c_str());
|
||||
fprintf(stderr, " km1 %s\n", s_km_1_.ToString(skip).c_str());
|
||||
fprintf(stderr, " km2 %s\n", s_km_2_.ToString(skip).c_str());
|
||||
|
||||
// b_magn100_.Print("magn100");
|
||||
// b_exp256_.Print("exp");
|
||||
// b_m4_.Print("mantissa bits4");
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
private:
|
||||
size_t num_tiny_ = 0;
|
||||
Stats s_all_;
|
||||
Stats s_group_ranges_;
|
||||
Stats s_group_mins_;
|
||||
Stats s_group_maxs_;
|
||||
Stats s_group_max_vs_min_;
|
||||
Stats s_erange_;
|
||||
Stats s_km_1_;
|
||||
Stats s_km_2_;
|
||||
Stats s_cut15_;
|
||||
Bins<100> b_magn100_;
|
||||
Bins<256> b_exp256_;
|
||||
Bins<16> b_m4_;
|
||||
uint8_t padding_[64]; // prevent false sharing
|
||||
};
|
||||
|
||||
class PerLayer {
|
||||
public:
|
||||
void NotifyGroup(const float* group) {
|
||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
s_layer_.Notify(group[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateOutliers(const float* layer, size_t weights_per_layer) {
|
||||
const float layer_mean = s_layer_.Mean();
|
||||
const float layer_sd = s_layer_.StandardDeviation();
|
||||
for (size_t i = 0; i < weights_per_layer; ++i) {
|
||||
num_outliers_ +=
|
||||
std::abs(std::abs(layer[i]) - layer_mean) >= 3.0f * layer_sd;
|
||||
}
|
||||
}
|
||||
|
||||
const Stats& GetStats() const { return s_layer_; }
|
||||
size_t Outliers() const { return num_outliers_; }
|
||||
|
||||
private:
|
||||
Stats s_layer_;
|
||||
size_t num_outliers_ = 0;
|
||||
uint8_t padding[64]; // prevent false sharing
|
||||
};
|
||||
|
||||
static HWY_NOINLINE void Analyze(const char* caption, float* mat, size_t layers,
|
||||
size_t weights_per_layer,
|
||||
hwy::ThreadPool& pool) {
|
||||
std::vector<PerThread> tls;
|
||||
std::vector<PerLayer> per_layer(layers);
|
||||
const auto init = [&](size_t num_threads) {
|
||||
tls.resize(num_threads);
|
||||
return true;
|
||||
};
|
||||
|
||||
pool.Run(0, static_cast<uint32_t>(layers), init,
|
||||
[&](uint32_t idx_layer, size_t idx_thread) {
|
||||
PerThread& self = tls[idx_thread];
|
||||
const float* layer = &mat[idx_layer * weights_per_layer];
|
||||
// For each whole group in the layer
|
||||
for (size_t group_start = 0;
|
||||
group_start + kGroupSize <= weights_per_layer;
|
||||
group_start += kGroupSize) {
|
||||
const float* group = layer + group_start;
|
||||
per_layer[idx_layer].NotifyGroup(group);
|
||||
self.NotifyGroup(group);
|
||||
}
|
||||
|
||||
per_layer[idx_layer].UpdateOutliers(layer, weights_per_layer);
|
||||
});
|
||||
|
||||
const int skip = Stats::kNoGeomean;
|
||||
fprintf(stderr, "\n------------%s\n", caption);
|
||||
|
||||
for (size_t i = 1; i < pool.NumThreads(); ++i) {
|
||||
tls[0].Assimilate(tls[i]);
|
||||
}
|
||||
tls[0].PrintAll();
|
||||
|
||||
Stats s_layer_ranges;
|
||||
Stats s_layer_outliers;
|
||||
for (size_t i = 0; i < layers; ++i) {
|
||||
fprintf(stderr, " %02zu %s\n", i,
|
||||
per_layer[i].GetStats().ToString(skip).c_str());
|
||||
const float range =
|
||||
per_layer[i].GetStats().Max() - per_layer[i].GetStats().Min();
|
||||
s_layer_ranges.Notify(range);
|
||||
s_layer_outliers.Notify((100.0 * per_layer[i].Outliers()) /
|
||||
weights_per_layer);
|
||||
}
|
||||
fprintf(stderr, "layer outliers%% %s\n",
|
||||
s_layer_outliers.ToString(skip).c_str());
|
||||
fprintf(stderr, "layer ranges %s\n", s_layer_ranges.ToString(skip).c_str());
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
|
||||
|
|
@ -0,0 +1,348 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/blob_store.h"
|
||||
|
||||
#include <fcntl.h> // open
|
||||
#include <stdint.h>
|
||||
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
|
||||
#include <sys/stat.h> // O_RDONLY
|
||||
#include <unistd.h> // read, close
|
||||
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/detect_compiler_arch.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
hwy::uint128_t MakeKey(const char* string) {
|
||||
size_t length = 0;
|
||||
for (size_t i = 0; string[i] != '\0'; ++i) {
|
||||
++length;
|
||||
}
|
||||
if (length > 16) {
|
||||
HWY_ABORT("Key %s is too long, please truncate to 16 chars.", string);
|
||||
}
|
||||
|
||||
hwy::uint128_t ret;
|
||||
hwy::ZeroBytes<sizeof(ret)>(&ret);
|
||||
hwy::CopyBytes(string, &ret, length);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data,
|
||||
std::vector<BlobIO>& requests) {
|
||||
// Split into chunks for load-balancing even if blob sizes vary.
|
||||
constexpr size_t kChunkSize = 4 * 1024 * 1024;
|
||||
|
||||
// Split into whole chunks and possibly one remainder.
|
||||
uint64_t pos = 0;
|
||||
if (size >= kChunkSize) {
|
||||
for (; pos <= size - kChunkSize; pos += kChunkSize) {
|
||||
requests.emplace_back(offset + pos, kChunkSize, data + pos, 0);
|
||||
}
|
||||
}
|
||||
if (pos != size) {
|
||||
requests.emplace_back(offset + pos, size - pos, data + pos, 0);
|
||||
}
|
||||
}
|
||||
|
||||
struct IO {
|
||||
// Returns size in bytes or 0.
|
||||
static uint64_t FileSize(const char* filename) {
|
||||
int fd = open(filename, O_RDONLY);
|
||||
if (fd >= 0) {
|
||||
const off_t size = lseek(fd, 0, SEEK_END);
|
||||
HWY_ASSERT(close(fd) != -1);
|
||||
if (size != static_cast<off_t>(-1)) {
|
||||
return static_cast<uint64_t>(size);
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static bool Read(int fd, uint64_t offset, uint64_t size, void* to) {
|
||||
uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
|
||||
uint64_t pos = 0;
|
||||
for (;;) {
|
||||
// pread seems to be faster than lseek + read when parallelized.
|
||||
const auto bytes_read = pread(fd, bytes + pos, size - pos, offset + pos);
|
||||
if (bytes_read <= 0) break;
|
||||
pos += bytes_read;
|
||||
HWY_ASSERT(pos <= size);
|
||||
if (pos == size) break;
|
||||
}
|
||||
return pos == size; // success if managed to read desired size
|
||||
}
|
||||
|
||||
static bool Write(const void* from, uint64_t size, uint64_t offset, int fd) {
|
||||
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(from);
|
||||
uint64_t pos = 0;
|
||||
for (;;) {
|
||||
const auto bytes_written =
|
||||
pwrite(fd, bytes + pos, size - pos, offset + pos);
|
||||
if (bytes_written <= 0) break;
|
||||
pos += bytes_written;
|
||||
HWY_ASSERT(pos <= size);
|
||||
if (pos == size) break;
|
||||
}
|
||||
return pos == size; // success if managed to write desired size
|
||||
}
|
||||
}; // IO
|
||||
|
||||
static_assert(HWY_IS_LITTLE_ENDIAN, "Assumes little endian");
|
||||
|
||||
// On-disk representation (little-endian).
|
||||
//
|
||||
// Deliberately omits a version number because this file format is unchanging.
|
||||
// Additional data may be added only inside new blobs. Changes to the blob
|
||||
// contents or type should be handled by renaming keys.
|
||||
#pragma pack(push, 1)
|
||||
class BlobStore {
|
||||
static constexpr uint32_t kMagic = 0x0A534253; // SBS\n
|
||||
|
||||
// Blob offsets on disk and memory addresses are a multiple of this, because
|
||||
// we pad the header and each blob's size. This matches CUDA alignment and the
|
||||
// maximum SVE vector size, and exceeds typical x86 cache line sizes (64 or
|
||||
// 128), which can help performance.
|
||||
static constexpr size_t kAlign = 256;
|
||||
|
||||
public:
|
||||
// NOT including padding, so that we can also use ZeroFillPadding after
|
||||
// copying the header.
|
||||
static constexpr size_t HeaderSize(size_t num_blobs) {
|
||||
// 16-byte fixed fields plus per-blob: 16-byte key, 16-byte offset/size.
|
||||
return 16 + 32 * num_blobs;
|
||||
}
|
||||
|
||||
// Returns how many bytes to allocate for the header without the subsequent
|
||||
// blobs. Requires num_blobs_ to already be set, typically by reading
|
||||
// sizeof(BlobStore) bytes from disk.
|
||||
size_t PaddedHeaderSize() const {
|
||||
return hwy::RoundUpTo(HeaderSize(num_blobs_), kAlign);
|
||||
}
|
||||
|
||||
// Returns aligned offset and zero-fills between that and `offset`.
|
||||
uint64_t ZeroFillPadding(uint64_t offset) {
|
||||
uint8_t* const bytes = reinterpret_cast<uint8_t*>(this);
|
||||
const uint64_t padded = hwy::RoundUpTo(offset, kAlign);
|
||||
hwy::ZeroBytes(bytes + offset, padded - offset);
|
||||
return padded;
|
||||
}
|
||||
|
||||
BlobError CheckValidity(const uint64_t file_size) {
|
||||
if (magic_ != kMagic) return __LINE__;
|
||||
if (num_blobs_ == 0) return __LINE__;
|
||||
if (file_size_ != file_size) return __LINE__;
|
||||
|
||||
// Ensure blobs are back to back, and zero-pad.
|
||||
uint64_t offset = ZeroFillPadding(HeaderSize(num_blobs_));
|
||||
for (size_t i = 0; i < num_blobs_; ++i) {
|
||||
const hwy::uint128_t val = keys_[num_blobs_ + i];
|
||||
if (val.lo != offset) return __LINE__;
|
||||
offset = ZeroFillPadding(offset + val.hi);
|
||||
}
|
||||
|
||||
if (offset != file_size_) return __LINE__;
|
||||
|
||||
return 0; // all OK
|
||||
}
|
||||
|
||||
static BlobStorePtr Allocate(uint64_t total_size) {
|
||||
uint8_t* bytes =
|
||||
static_cast<uint8_t*>(hwy::AllocateAlignedBytes(total_size));
|
||||
if (!bytes) return BlobStorePtr();
|
||||
return BlobStorePtr(new (bytes) BlobStore(), hwy::AlignedFreer());
|
||||
}
|
||||
|
||||
static std::vector<BlobIO> PrepareWriteRequests(
|
||||
const hwy::uint128_t keys[], const hwy::Span<uint8_t> blobs[],
|
||||
size_t num_blobs) {
|
||||
// Sanity check and ensure the cast below is safe.
|
||||
HWY_ASSERT(num_blobs < (1ULL << 20));
|
||||
|
||||
// Allocate var-length header.
|
||||
const size_t header_size = HeaderSize(num_blobs);
|
||||
const size_t padded_header_size = hwy::RoundUpTo(header_size, kAlign);
|
||||
BlobStorePtr bs = Allocate(padded_header_size);
|
||||
const uint64_t padded_header_end = bs->ZeroFillPadding(header_size);
|
||||
HWY_ASSERT(padded_header_end == padded_header_size);
|
||||
|
||||
// All-zero buffer used to write padding to the file without copying the
|
||||
// input blobs.
|
||||
static uint8_t zeros[kAlign] = {0};
|
||||
|
||||
// Total file size will be the header plus all padded blobs.
|
||||
uint64_t payload = 0;
|
||||
for (size_t i = 0; i < num_blobs; ++i) {
|
||||
payload += hwy::RoundUpTo(blobs[i].size(), kAlign);
|
||||
}
|
||||
const size_t total_size = padded_header_size + payload;
|
||||
|
||||
// Fill header.
|
||||
bs->magic_ = kMagic;
|
||||
bs->num_blobs_ = static_cast<uint32_t>(num_blobs);
|
||||
bs->file_size_ = total_size;
|
||||
hwy::CopyBytes(keys, bs->keys_, num_blobs * sizeof(keys[0]));
|
||||
|
||||
// First IO request is for the header (not yet filled!).
|
||||
std::vector<BlobIO> requests;
|
||||
requests.reserve(1 + 2 * num_blobs);
|
||||
requests.emplace_back(/*offset=*/0, padded_header_size,
|
||||
reinterpret_cast<uint8_t*>(bs.get()), 0);
|
||||
|
||||
// Fill second half of keys_ with offset/size and prepare IO requests.
|
||||
uint64_t offset = padded_header_end;
|
||||
for (size_t i = 0; i < num_blobs; ++i) {
|
||||
bs->keys_[num_blobs + i].lo = offset;
|
||||
bs->keys_[num_blobs + i].hi = blobs[i].size();
|
||||
|
||||
EnqueueChunkRequests(offset, blobs[i].size(), blobs[i].data(), requests);
|
||||
offset += blobs[i].size();
|
||||
const size_t padded_size = hwy::RoundUpTo(blobs[i].size(), kAlign);
|
||||
if (padded_size != blobs[i].size()) {
|
||||
const size_t padding = padded_size - blobs[i].size();
|
||||
HWY_ASSERT(padding <= kAlign);
|
||||
requests.emplace_back(offset, padding, zeros, 0);
|
||||
offset += padding;
|
||||
}
|
||||
}
|
||||
|
||||
HWY_ASSERT(offset == total_size);
|
||||
return requests;
|
||||
}
|
||||
|
||||
bool FindKey(const hwy::uint128_t key, uint64_t& offset, size_t& size) const {
|
||||
for (size_t i = 0; i < num_blobs_; ++i) {
|
||||
if (keys_[i] == key) {
|
||||
const hwy::uint128_t val = keys_[num_blobs_ + i];
|
||||
offset = val.lo;
|
||||
size = val.hi;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
uint32_t magic_;
|
||||
uint32_t num_blobs_; // never 0
|
||||
uint64_t file_size_; // must match actual size of file
|
||||
hwy::uint128_t keys_[1]; // length: 2 * num_blobs
|
||||
// Padding, then the blob identified by keys[0], then padding etc.
|
||||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
BlobError BlobReader::Open(const char* filename) {
|
||||
fd_ = open(filename, O_RDONLY);
|
||||
if (fd_ < 0) return __LINE__;
|
||||
|
||||
#if _POSIX_C_SOURCE >= 200112L
|
||||
// Doubles the readahead window, which seems slightly faster when cached.
|
||||
(void)posix_fadvise(fd_, 0, 0, POSIX_FADV_SEQUENTIAL);
|
||||
#endif
|
||||
|
||||
// Read first part of header to get actual size.
|
||||
BlobStore bs;
|
||||
if (!IO::Read(fd_, 0, sizeof(bs), &bs)) return __LINE__;
|
||||
const size_t padded_size = bs.PaddedHeaderSize();
|
||||
HWY_ASSERT(padded_size >= sizeof(bs));
|
||||
|
||||
// Allocate full header.
|
||||
blob_store_ = BlobStore::Allocate(padded_size);
|
||||
if (!blob_store_) return __LINE__;
|
||||
|
||||
// Copy what we already read (more efficient than seek + re-read).
|
||||
hwy::CopySameSize(&bs, blob_store_.get());
|
||||
// Read the rest of the header, but not the full file.
|
||||
uint8_t* bytes = reinterpret_cast<uint8_t*>(blob_store_.get());
|
||||
if (!IO::Read(fd_, sizeof(bs), padded_size - sizeof(bs),
|
||||
bytes + sizeof(bs))) {
|
||||
return __LINE__;
|
||||
}
|
||||
|
||||
return blob_store_->CheckValidity(IO::FileSize(filename));
|
||||
}
|
||||
|
||||
BlobReader::~BlobReader() {
|
||||
if (fd_ >= 0) {
|
||||
HWY_ASSERT(close(fd_) != -1);
|
||||
}
|
||||
}
|
||||
|
||||
BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) {
|
||||
uint64_t offset;
|
||||
size_t actual_size;
|
||||
if (!blob_store_->FindKey(key, offset, actual_size)) return __LINE__;
|
||||
if (actual_size != size) return __LINE__;
|
||||
|
||||
EnqueueChunkRequests(offset, actual_size, reinterpret_cast<uint8_t*>(data),
|
||||
requests_);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Parallel synchronous I/O. Alternatives considered:
|
||||
// - readv is limited to 0x7FFFF000 bytes on Linux (even 64-bit). Note that
|
||||
// pread calls preadv with a single iovec.
|
||||
// - O_DIRECT seems undesirable because we do want to use the OS cache
|
||||
// between consecutive runs.
|
||||
// - memory-mapped I/O is less predictable and adds noise to measurements.
|
||||
BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
|
||||
const int fd = fd_;
|
||||
const auto& requests = requests_;
|
||||
std::atomic_flag err = ATOMIC_FLAG_INIT;
|
||||
// >5x speedup from parallel reads when cached.
|
||||
pool.Run(0, requests.size(),
|
||||
[fd, &requests, &err](uint64_t i, size_t /*thread*/) {
|
||||
if (!IO::Read(fd, requests[i].offset, requests[i].size,
|
||||
requests[i].data)) {
|
||||
err.test_and_set();
|
||||
}
|
||||
});
|
||||
if (err.test_and_set()) return __LINE__;
|
||||
return 0;
|
||||
}
|
||||
|
||||
BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool,
|
||||
const char* filename) const {
|
||||
HWY_ASSERT(keys_.size() == blobs_.size());
|
||||
|
||||
// Concatenate blobs in memory.
|
||||
std::vector<BlobIO> requests = BlobStore::PrepareWriteRequests(
|
||||
keys_.data(), blobs_.data(), keys_.size());
|
||||
|
||||
// Create/replace existing file.
|
||||
const int fd = open(filename, O_CREAT | O_RDWR | O_TRUNC, 0644);
|
||||
if (fd < 0) return __LINE__;
|
||||
|
||||
std::atomic_flag err = ATOMIC_FLAG_INIT;
|
||||
pool.Run(0, requests.size(),
|
||||
[fd, &requests, &err](uint64_t i, size_t /*thread*/) {
|
||||
if (!IO::Write(requests[i].data, requests[i].size,
|
||||
requests[i].offset, fd)) {
|
||||
err.test_and_set();
|
||||
}
|
||||
});
|
||||
if (err.test_and_set()) return __LINE__;
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,90 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_BLOB_STORE_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_BLOB_STORE_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // hwy::uint128_t
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Convenient way to construct a key from a string (<= 16 chars).
|
||||
hwy::uint128_t MakeKey(const char* string);
|
||||
|
||||
// Ordered list of opaque blobs (~hundreds), identified by unique opaque
|
||||
// 128-bit keys.
|
||||
class BlobStore;
|
||||
|
||||
// Incomplete type, so dtor will not be called.
|
||||
using BlobStorePtr = hwy::AlignedFreeUniquePtr<BlobStore>;
|
||||
|
||||
// 0 if successful, otherwise the line number of the failing check.
|
||||
using BlobError = int;
|
||||
|
||||
struct BlobIO {
|
||||
BlobIO(uint64_t offset, size_t size, void* data, uint64_t padding)
|
||||
: offset(offset), size(size), data(data), padding(padding) {}
|
||||
|
||||
uint64_t offset;
|
||||
size_t size;
|
||||
void* data;
|
||||
uint64_t padding;
|
||||
};
|
||||
|
||||
class BlobReader {
|
||||
public:
|
||||
BlobReader() { requests_.reserve(500); }
|
||||
~BlobReader();
|
||||
|
||||
// Opens `filename` and reads its header.
|
||||
BlobError Open(const char* filename);
|
||||
|
||||
// Enqueues read requests if `key` is found and its size matches `size`.
|
||||
BlobError Enqueue(hwy::uint128_t key, void* data, size_t size);
|
||||
|
||||
// Reads all enqueued requests.
|
||||
BlobError ReadAll(hwy::ThreadPool& pool);
|
||||
|
||||
private:
|
||||
BlobStorePtr blob_store_; // holds header, not the entire file
|
||||
std::vector<BlobIO> requests_;
|
||||
int fd_ = 0;
|
||||
};
|
||||
|
||||
class BlobWriter {
|
||||
public:
|
||||
void Add(hwy::uint128_t key, void* data, size_t size) {
|
||||
keys_.push_back(key);
|
||||
blobs_.emplace_back(static_cast<uint8_t*>(data), size);
|
||||
}
|
||||
|
||||
// Stores all blobs to disk in the given order with padding for alignment.
|
||||
BlobError WriteAll(hwy::ThreadPool& pool, const char* filename) const;
|
||||
|
||||
private:
|
||||
std::vector<hwy::uint128_t> keys_;
|
||||
std::vector<hwy::Span<uint8_t>> blobs_;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_BLOB_STORE_H_
|
||||
|
|
@ -0,0 +1,467 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Include guard for headers.
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_INL_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_INL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <array>
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/blob_store.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/compress.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/distortion.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_INL_H_
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||
#ifdef THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE
|
||||
#undef THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE
|
||||
#else
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE
|
||||
#endif
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/nuq-inl.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/sfp-inl.h"
|
||||
#include "hwy/contrib/dot/dot-inl.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
// Enables generic code independent of compression type.
|
||||
template <typename T> // primary, must specialize
|
||||
struct CompressTraits {};
|
||||
|
||||
template <>
|
||||
struct CompressTraits<float> {
|
||||
using MatT = float;
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
|
||||
size_t num, CompressPerThread& tls,
|
||||
size_t /*out_capacity*/,
|
||||
MatT* HWY_RESTRICT out, size_t out_ofs) {
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t N = hn::Lanes(df);
|
||||
HWY_DASSERT(num >= 2 * N && num % (2 * N) == 0);
|
||||
|
||||
for (size_t i = 0; i < num; i += 2 * N) {
|
||||
const VF in0 = hn::LoadU(df, in + i);
|
||||
const VF in1 = hn::LoadU(df, in + i + N);
|
||||
hn::StoreU(in0, df, out + out_ofs + i);
|
||||
hn::StoreU(in1, df, out + out_ofs + i + N);
|
||||
}
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Decompress(DF df, size_t /*in_capacity*/,
|
||||
const MatT* HWY_RESTRICT in, size_t in_ofs,
|
||||
float* HWY_RESTRICT out, size_t num) {
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t N = hn::Lanes(df);
|
||||
HWY_DASSERT(num >= 2 * N && num % (2 * N) == 0);
|
||||
|
||||
for (size_t i = 0; i < num; i += 2 * N) {
|
||||
const VF in0 = hn::LoadU(df, in + in_ofs + i);
|
||||
const VF in1 = hn::LoadU(df, in + in_ofs + i + N);
|
||||
hn::StoreU(in0, df, out + i);
|
||||
hn::StoreU(in1, df, out + i + N);
|
||||
}
|
||||
}
|
||||
|
||||
// VecT can be float or hwy::bfloat16_t.
|
||||
template <class DF, typename VecT, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE float Dot(DF df, size_t /*in_capacity*/,
|
||||
const MatT* HWY_RESTRICT in, size_t in_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
size_t num) {
|
||||
HWY_DASSERT(num >= hn::Lanes(df) && (num % hn::Lanes(df)) == 0);
|
||||
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
|
||||
constexpr int kAssumptions =
|
||||
hn::Dot::kAtLeastOneVector | hn::Dot::kMultipleOfVector;
|
||||
// vec_aligned must be the second argument because hn::Dot supports f32*bf16
|
||||
// and f32*f32.
|
||||
return hn::Dot::Compute<kAssumptions>(df, in + in_ofs, vec_aligned, num);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CompressTraits<hwy::bfloat16_t> {
|
||||
using MatT = hwy::bfloat16_t;
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
|
||||
size_t num, CompressPerThread& tls,
|
||||
size_t /*out_capacity*/,
|
||||
MatT* HWY_RESTRICT out, size_t out_ofs) {
|
||||
const hn::RebindToUnsigned<decltype(df)> du;
|
||||
const hn::Repartition<hwy::bfloat16_t, decltype(df)> dbf;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t N = hn::Lanes(df);
|
||||
|
||||
hn::Vec<decltype(du)> or_sum = hn::Zero(du);
|
||||
|
||||
size_t i = 0;
|
||||
if (num >= 2 * N) {
|
||||
for (; i <= num - 2 * N; i += 2 * N) {
|
||||
const VF in0 = hn::LoadU(df, in + i);
|
||||
const VF in1 = hn::LoadU(df, in + i + N);
|
||||
|
||||
// Sticky bits so we can warn if any lower bits were set.
|
||||
or_sum = hn::Or3(or_sum, hn::BitCast(du, in0), hn::BitCast(du, in1));
|
||||
hn::StoreU(hn::OrderedDemote2To(dbf, in0, in1), dbf, out + out_ofs + i);
|
||||
|
||||
if (COMPRESS_STATS) {
|
||||
DistortionStats stats;
|
||||
for (size_t j = 0; j < 2 * N; ++j) {
|
||||
stats.Notify(in[i + j], hwy::F32FromBF16(out[out_ofs + i + j]));
|
||||
}
|
||||
tls.stats.Notify(stats);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t remaining = num - i;
|
||||
if (remaining != 0) {
|
||||
const VF in0 = hn::LoadN(df, in + i, remaining);
|
||||
const size_t remaining1 = remaining - HWY_MIN(remaining, N / 2);
|
||||
const VF in1 = hn::LoadN(df, in + i + N, remaining1);
|
||||
|
||||
// Sticky bits so we can warn if any lower bits were set.
|
||||
or_sum = hn::Or3(or_sum, hn::BitCast(du, in0), hn::BitCast(du, in1));
|
||||
hn::StoreU(hn::OrderedDemote2To(dbf, in0, in1), dbf, out + out_ofs + i);
|
||||
|
||||
if (COMPRESS_STATS) {
|
||||
DistortionStats stats;
|
||||
for (size_t j = 0; j < remaining; ++j) {
|
||||
stats.Notify(in[i + j], hwy::F32FromBF16(out[out_ofs + i + j]));
|
||||
}
|
||||
tls.stats.Notify(stats);
|
||||
}
|
||||
}
|
||||
|
||||
// If the lower 16 bits are not zero, we should implement rounding.
|
||||
or_sum = hn::And(or_sum, hn::Set(du, 0xFFFF));
|
||||
if (!hn::AllTrue(du, hn::Eq(or_sum, hn::Zero(du)))) {
|
||||
// fprintf(stderr, "Warning: Lossy truncation.");
|
||||
}
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Decompress(DF df, size_t /*in_capacity*/,
|
||||
const MatT* HWY_RESTRICT in, size_t in_ofs,
|
||||
float* HWY_RESTRICT out, size_t num) {
|
||||
const hn::Repartition<hwy::bfloat16_t, decltype(df)> dbf;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t N16 = hn::Lanes(dbf);
|
||||
|
||||
size_t i = 0;
|
||||
if (num >= N16) {
|
||||
for (i = 0; i <= num - N16; i += N16) {
|
||||
const VBF in16 = hn::LoadU(dbf, in + in_ofs + i);
|
||||
const VF in0 = hn::PromoteLowerTo(df, in16);
|
||||
const VF in1 = hn::PromoteUpperTo(df, in16);
|
||||
hn::StoreU(in0, df, out + i);
|
||||
hn::StoreU(in1, df, out + i + N16 / 2);
|
||||
}
|
||||
}
|
||||
|
||||
size_t remaining = num - i;
|
||||
if (remaining != 0) {
|
||||
const VBF in16 = hn::LoadN(dbf, in + in_ofs + i, remaining);
|
||||
const VF in0 = hn::PromoteLowerTo(df, in16);
|
||||
const VF in1 = hn::PromoteUpperTo(df, in16);
|
||||
hn::StoreN(in0, df, out + i, remaining);
|
||||
// Avoid wraparound, potentially store nothing.
|
||||
const size_t remaining1 = remaining - HWY_MIN(remaining, N16 / 2);
|
||||
hn::StoreN(in1, df, out + i + N16 / 2, remaining1);
|
||||
}
|
||||
}
|
||||
|
||||
// VecT can be float or hwy::bfloat16_t.
|
||||
template <class DF, typename VecT, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE float Dot(DF df, size_t /*in_capacity*/,
|
||||
const MatT* HWY_RESTRICT in, size_t in_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
size_t num) {
|
||||
HWY_DASSERT(num >= hn::Lanes(df) && (num % hn::Lanes(df)) == 0);
|
||||
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
|
||||
|
||||
const hn::Repartition<VecT, decltype(df)> d_vec;
|
||||
|
||||
constexpr int kAssumptions =
|
||||
hn::Dot::kAtLeastOneVector | hn::Dot::kMultipleOfVector;
|
||||
// vec_aligned must be first argument because hn::Dot supports f32*bf16 and
|
||||
// bf16*bf16.
|
||||
return hn::Dot::Compute<kAssumptions>(d_vec, vec_aligned, in + in_ofs, num);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CompressTraits<SfpStream> {
|
||||
using MatT = SfpStream;
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Compress(DF df, const float* in, size_t num,
|
||||
CompressPerThread& tls,
|
||||
size_t /*out_capacity*/, MatT* out,
|
||||
size_t out_ofs) {
|
||||
SfpCodec::Enc(df, in, num, out + out_ofs);
|
||||
|
||||
if (COMPRESS_STATS) {
|
||||
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
|
||||
auto distorted = hwy::AllocateAligned<hwy::bfloat16_t>(num);
|
||||
SfpCodec::Dec(dbf, out + out_ofs, num, distorted.get());
|
||||
DistortionStats stats;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
stats.Notify(in[i], hwy::F32FromBF16(distorted[i]));
|
||||
}
|
||||
tls.stats.Notify(stats);
|
||||
}
|
||||
}
|
||||
|
||||
template <class D, typename OutT>
|
||||
static HWY_INLINE void Decompress(D d, size_t /*in_capacity*/, const MatT* in,
|
||||
size_t in_ofs, OutT* out, size_t num) {
|
||||
SfpCodec::Dec(d, in + in_ofs, num, out);
|
||||
}
|
||||
|
||||
template <class DF, typename VecT, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE float Dot(DF df, size_t /*in_capacity*/, const MatT* in,
|
||||
size_t in_ofs, const VecT* vec_aligned,
|
||||
size_t num) {
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
VF sum0 = hn::Zero(df);
|
||||
VF sum1 = hn::Zero(df);
|
||||
VF sum2 = hn::Zero(df);
|
||||
VF sum3 = hn::Zero(df);
|
||||
|
||||
SfpCodec::Dot(df, in + in_ofs, num, vec_aligned, sum0, sum1, sum2, sum3);
|
||||
|
||||
// Reduction tree: sum of all accumulators, then their lanes
|
||||
sum0 = hn::Add(sum0, sum1);
|
||||
sum2 = hn::Add(sum2, sum3);
|
||||
sum0 = hn::Add(sum0, sum2);
|
||||
return hn::ReduceSum(df, sum0);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CompressTraits<NuqStream> {
|
||||
using MatT = NuqStream;
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Compress(DF df, const float* in, size_t num,
|
||||
CompressPerThread& tls, size_t out_capacity,
|
||||
MatT* out, size_t out_ofs) {
|
||||
NuqCodec::Enc(df, in, num, tls.buf, out_capacity, out, out_ofs);
|
||||
|
||||
if (COMPRESS_STATS) {
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
tls.stats.NotifyIn(in[i] * 100 + 500);
|
||||
}
|
||||
|
||||
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
|
||||
auto distorted = hwy::AllocateAligned<hwy::bfloat16_t>(num);
|
||||
NuqCodec::Dec(dbf, out_capacity, out, out_ofs, distorted.get(), num);
|
||||
DistortionStats stats;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
stats.Notify(in[i], hwy::F32FromBF16(distorted[i]));
|
||||
}
|
||||
tls.stats.Notify(stats);
|
||||
}
|
||||
}
|
||||
|
||||
template <class D, typename OutT>
|
||||
static HWY_INLINE void Decompress(D d, size_t in_capacity, const MatT* in,
|
||||
size_t in_ofs, OutT* out, size_t num) {
|
||||
NuqCodec::Dec(d, in_capacity, in, in_ofs, out, num);
|
||||
}
|
||||
|
||||
template <class DF, typename VecT, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE float Dot(DF df, size_t in_capacity, const MatT* in,
|
||||
size_t in_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
size_t num) {
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
VF sum0 = hn::Zero(df);
|
||||
VF sum1 = hn::Zero(df);
|
||||
VF sum2 = hn::Zero(df);
|
||||
VF sum3 = hn::Zero(df);
|
||||
|
||||
NuqCodec::Dot(df, in_capacity, in, in_ofs, vec_aligned, num, sum0, sum1,
|
||||
sum2, sum3);
|
||||
|
||||
// Reduction tree: sum of all accumulators, then their lanes
|
||||
sum0 = hn::Add(hn::Add(sum0, sum1), hn::Add(sum2, sum3));
|
||||
return hn::ReduceSum(df, sum0);
|
||||
}
|
||||
};
|
||||
|
||||
// Compresses `num` inputs to `out` starting at `out_ofs`. This can be used for
|
||||
// compressing sub-regions of an array.
|
||||
template <typename MatT>
|
||||
HWY_NOINLINE void Compress(const float* in, size_t num,
|
||||
CompressWorkingSet& work, size_t out_capacity,
|
||||
MatT* out, size_t out_ofs, hwy::ThreadPool& pool) {
|
||||
HWY_DASSERT(out_ofs + num <= out_capacity);
|
||||
work.tls.resize(pool.NumThreads());
|
||||
if (COMPRESS_STATS) {
|
||||
for (auto& tls : work.tls) {
|
||||
tls.stats.Reset();
|
||||
}
|
||||
}
|
||||
|
||||
const double t0 = hwy::platform::Now();
|
||||
|
||||
using Traits = CompressTraits<MatT>;
|
||||
constexpr size_t kBatch = 8192;
|
||||
const size_t num_batches = hwy::DivCeil(num, kBatch);
|
||||
pool.Run(0, num_batches,
|
||||
[&](const uint32_t idx_batch, size_t thread) HWY_ATTR {
|
||||
const hn::ScalableTag<float> df;
|
||||
|
||||
const size_t in_ofs = idx_batch * kBatch;
|
||||
const size_t my_num =
|
||||
idx_batch == num_batches - 1 ? (num - in_ofs) : kBatch;
|
||||
Traits::Compress(df, in + in_ofs, my_num, work.tls[thread],
|
||||
out_capacity, out, out_ofs + in_ofs);
|
||||
});
|
||||
|
||||
const double t1 = hwy::platform::Now();
|
||||
const double mb = num * sizeof(in[0]) * 1E-6;
|
||||
const double mbps = mb / (t1 - t0);
|
||||
fprintf(stderr, "Compress %.1f MB/s\n", mbps);
|
||||
|
||||
if (COMPRESS_STATS) {
|
||||
for (size_t i = 1; i < work.tls.size(); ++i) {
|
||||
work.tls[0].stats.Assimilate(work.tls[i].stats);
|
||||
}
|
||||
work.tls[0].stats.PrintAll();
|
||||
}
|
||||
}
|
||||
|
||||
// Compresses an entire std::array into `out`, which is assumed to have exactly
|
||||
// that much capacity.
|
||||
template <size_t kCapacity, typename MatT>
|
||||
HWY_INLINE void Compress(const std::array<float, kCapacity>& in,
|
||||
CompressWorkingSet& work,
|
||||
CompressedArray<MatT, kCapacity>& compressed,
|
||||
hwy::ThreadPool& pool) {
|
||||
Compress(in.data(), kCapacity, work, kCapacity, compressed.data(), 0, pool);
|
||||
}
|
||||
|
||||
// Decompresses `num` values from `compressed` starting at `compressed_ofs`.
|
||||
template <typename MatT, size_t kCapacity, typename OutT>
|
||||
HWY_NOINLINE void Decompress(const CompressedArray<MatT, kCapacity>& compressed,
|
||||
size_t compressed_ofs, OutT* out, size_t num) {
|
||||
HWY_DASSERT(compressed_ofs + num <= compressed.NumElements());
|
||||
const hn::ScalableTag<OutT> d;
|
||||
using Traits = CompressTraits<MatT>;
|
||||
Traits::Decompress(d, kCapacity, compressed.data(), compressed_ofs, out, num);
|
||||
}
|
||||
|
||||
// As above, but with threading and benchmarking.
|
||||
template <typename MatT, size_t kCapacity, typename OutT>
|
||||
HWY_INLINE void Decompress(const CompressedArray<MatT, kCapacity>& compressed,
|
||||
size_t compressed_ofs, OutT* out, size_t num,
|
||||
hwy::ThreadPool& pool) {
|
||||
HWY_DASSERT(compressed_ofs + num <= compressed.NumElements());
|
||||
const double t0 = hwy::platform::Now();
|
||||
|
||||
using Traits = CompressTraits<MatT>;
|
||||
constexpr size_t kBatch = 8192;
|
||||
const size_t num_batches = hwy::DivCeil(num, kBatch);
|
||||
pool.Run(
|
||||
0, num_batches, [&](const uint32_t idx_batch, size_t thread) HWY_ATTR {
|
||||
const hn::ScalableTag<OutT> d;
|
||||
|
||||
const size_t ofs = idx_batch * kBatch;
|
||||
const size_t num = idx_batch == num_batches - 1 ? (num - ofs) : kBatch;
|
||||
Traits::Decompress(d, compressed.NumElements(), compressed.data(),
|
||||
compressed_ofs + ofs, out + ofs, num);
|
||||
});
|
||||
|
||||
const double t1 = hwy::platform::Now();
|
||||
const double mb = num * sizeof(MatT) * 1E-6;
|
||||
const double mbps = mb / (t1 - t0);
|
||||
fprintf(stderr, "Decompress %.1f MB/s\n", mbps);
|
||||
}
|
||||
|
||||
// Returns dot product with `vec_aligned` of length `num`.
|
||||
template <class DF, typename MatT, size_t kCapacity, typename VecT>
|
||||
HWY_INLINE float Dot(DF df, const CompressedArray<MatT, kCapacity>& compressed,
|
||||
size_t compressed_ofs, const VecT* vec_aligned,
|
||||
size_t num) {
|
||||
HWY_DASSERT(compressed_ofs + num <= compressed.NumElements());
|
||||
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
|
||||
using Traits = CompressTraits<MatT>;
|
||||
return Traits::Dot(df, kCapacity, compressed.data(), compressed_ofs,
|
||||
vec_aligned, num);
|
||||
}
|
||||
|
||||
// Callback used by ForeachTensor.
|
||||
class Compressor {
|
||||
public:
|
||||
explicit Compressor(hwy::ThreadPool& pool) : pool_(pool) {}
|
||||
|
||||
// Called for each tensor; compresses it and stores to the cache.
|
||||
template <typename MatT, size_t kCapacity>
|
||||
void operator()(const char* name, const float* weights,
|
||||
CompressedArray<MatT, kCapacity>& compressed) {
|
||||
fprintf(stderr, "Regenerating %s (%zuM), please wait\n", name,
|
||||
kCapacity / (1000 * 1000));
|
||||
Compress(weights, kCapacity, work_, kCapacity, compressed.data(), 0, pool_);
|
||||
writer_.Add(CacheKey<MatT>(name), compressed.data(),
|
||||
compressed.CompressedSize());
|
||||
}
|
||||
|
||||
void WriteAll(hwy::ThreadPool& pool, const char* blob_filename) {
|
||||
const BlobError err = writer_.WriteAll(pool, blob_filename);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "Failed to write blobs to %s (error %d)\n", blob_filename,
|
||||
err);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
CompressWorkingSet work_;
|
||||
hwy::ThreadPool& pool_;
|
||||
BlobWriter writer_;
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#endif // NOLINT
|
||||
|
|
@ -0,0 +1,215 @@
|
|||
// Copyright 2023 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Target-independent definitions.
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
|
||||
|
||||
#define COMPRESS_STATS 0
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/blob_store.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/nuq.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/sfp.h"
|
||||
// IWYU pragma: end_exports
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/distortion.h"
|
||||
#include "hwy/base.h" // hwy::bfloat16_t
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#if COMPRESS_STATS
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/stats.h"
|
||||
#endif
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
static inline const char* TypeName(float) { return "f32"; }
|
||||
static inline const char* TypeName(hwy::bfloat16_t) { return "b16"; }
|
||||
|
||||
namespace detail {
|
||||
// How many MatT are required to store `capacity` weights. For all but
|
||||
// NuqStream, this is the same as `capacity`. For use by CompressedArray.
|
||||
template <typename MatT>
|
||||
constexpr size_t CompressedArrayLen(size_t capacity) {
|
||||
return capacity;
|
||||
}
|
||||
template <>
|
||||
constexpr size_t CompressedArrayLen<NuqStream>(size_t capacity) {
|
||||
return NuqStream::PackedEnd(capacity);
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// Compressed representation of floating-point elements. The array length may
|
||||
// differ from the number of elements. Associated operations such as Dot are
|
||||
// implemented in SIMD code and are thus non-member functions.
|
||||
template <typename MatT, size_t kCapacity>
|
||||
class CompressedArray {
|
||||
static constexpr size_t NumCompressed() {
|
||||
return detail::CompressedArrayLen<MatT>(kCapacity);
|
||||
}
|
||||
|
||||
public:
|
||||
MatT* data() { return data_.data(); }
|
||||
const MatT* data() const { return data_.data(); }
|
||||
|
||||
constexpr size_t NumElements() const { return kCapacity; }
|
||||
|
||||
constexpr size_t CompressedSize() const {
|
||||
return NumCompressed() * sizeof(MatT);
|
||||
}
|
||||
|
||||
private:
|
||||
std::array<MatT, NumCompressed()> data_;
|
||||
};
|
||||
|
||||
#if COMPRESS_STATS
|
||||
class CompressStats {
|
||||
public:
|
||||
void Notify(const DistortionStats& stats) {
|
||||
const float pnorm = stats.PNorm();
|
||||
const float snr = stats.GeomeanValueDivL1();
|
||||
num_exact_ += stats.NumExact();
|
||||
s_pnorm_.Notify(pnorm);
|
||||
// No loss - skip to avoid dragging down the average.
|
||||
if (snr != 0.0f) {
|
||||
s_snr_.Notify(snr);
|
||||
}
|
||||
}
|
||||
|
||||
void NotifyIn(int sfp) { hist_weights_.Notify(sfp); }
|
||||
|
||||
void Assimilate(const CompressStats& other) {
|
||||
s_pnorm_.Assimilate(other.s_pnorm_);
|
||||
s_snr_.Assimilate(other.s_snr_);
|
||||
num_exact_ += other.num_exact_;
|
||||
hist_weights_.Assimilate(other.hist_weights_);
|
||||
}
|
||||
|
||||
void PrintAll() {
|
||||
const int skip = Stats::kNoGeomean;
|
||||
fprintf(stderr, " pnorm %s\n", s_pnorm_.ToString(skip).c_str());
|
||||
fprintf(stderr, " SNR %s\n", s_snr_.ToString(skip).c_str());
|
||||
fprintf(stderr, " #exact %.3E\n", static_cast<double>(num_exact_));
|
||||
// hist_weights_.Print("indices");
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
s_pnorm_.Reset();
|
||||
s_snr_.Reset();
|
||||
num_exact_ = 0;
|
||||
hist_weights_.Reset();
|
||||
}
|
||||
|
||||
private:
|
||||
Stats s_pnorm_;
|
||||
Stats s_snr_;
|
||||
size_t num_exact_ = 0;
|
||||
Bins<1000> hist_weights_;
|
||||
char padding_[64]; // prevent false sharing
|
||||
};
|
||||
#else
|
||||
struct CompressStats {
|
||||
void Notify(const DistortionStats&) {}
|
||||
void NotifyIn(int) {}
|
||||
void Assimilate(const CompressStats&) {}
|
||||
void PrintAll() {}
|
||||
void Reset() {}
|
||||
};
|
||||
#endif // COMPRESS_STATS
|
||||
|
||||
struct CompressPerThread {
|
||||
CompressStats stats;
|
||||
ClusterBuf buf;
|
||||
};
|
||||
|
||||
struct CompressWorkingSet {
|
||||
std::vector<CompressPerThread> tls;
|
||||
};
|
||||
|
||||
// Returns key for the given tensor name. Also encodes the type, so that
|
||||
// changing the representation automatically invalidates prior cached files
|
||||
// (the new blob name will not be found).
|
||||
template <typename MatT>
|
||||
hwy::uint128_t CacheKey(const char* name) {
|
||||
// Already used/retired: s, S, n, 1
|
||||
const char prefix = hwy::IsSame<MatT, float>() ? 'F'
|
||||
: hwy::IsSame<MatT, hwy::bfloat16_t>() ? 'B'
|
||||
: hwy::IsSame<MatT, SfpStream>() ? '$'
|
||||
: hwy::IsSame<MatT, NuqStream>() ? '2'
|
||||
: '?';
|
||||
|
||||
return MakeKey((std::string(1, prefix) + name).c_str());
|
||||
}
|
||||
|
||||
class CacheLoader {
|
||||
public:
|
||||
explicit CacheLoader(const char* blob_filename) {
|
||||
err_ = reader_.Open(blob_filename);
|
||||
if (err_ != 0) {
|
||||
fprintf(stderr,
|
||||
"Cached compressed weights does not exist yet (code %d), "
|
||||
"compressing weights and creating file: %s.\n",
|
||||
err_, blob_filename);
|
||||
}
|
||||
}
|
||||
|
||||
// Called for each tensor, enqueues read requests.
|
||||
template <typename MatT, size_t kCapacity>
|
||||
void operator()(const char* name, const float* null,
|
||||
CompressedArray<MatT, kCapacity>& compressed) {
|
||||
HWY_DASSERT(null == nullptr);
|
||||
|
||||
// Skip if reader_ is invalid or any load failed: we will regenerate
|
||||
// everything because it's rare to update only a few tensors.
|
||||
if (err_ != 0) return;
|
||||
|
||||
err_ = reader_.Enqueue(CacheKey<MatT>(name), compressed.data(),
|
||||
compressed.CompressedSize());
|
||||
if (err_ != 0) {
|
||||
fprintf(stderr, "Failed to read cache %s (error %d)\n", name, err_);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns whether all tensors are successfully loaded from cache.
|
||||
bool ReadAll(hwy::ThreadPool& pool) {
|
||||
// reader_ invalid or any Enqueue failed
|
||||
if (err_ != 0) return false;
|
||||
|
||||
err_ = reader_.ReadAll(pool);
|
||||
if (err_ != 0) {
|
||||
fprintf(stderr, "Failed to read all tensors (error %d)\n", err_);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
BlobReader reader_;
|
||||
BlobError err_ = 0;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
// Copyright 2023 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_
|
||||
#include <math.h> // pow
|
||||
#include <stddef.h>
|
||||
|
||||
#include "hwy/base.h" // ScalarAbs
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
class DistortionStats {
|
||||
public:
|
||||
void Notify(float original, float distorted) {
|
||||
const double l1 = hwy::ScalarAbs(original - distorted);
|
||||
|
||||
if (l1 > max_l1_) {
|
||||
max_l1_ = l1;
|
||||
max_idx_ = n_;
|
||||
}
|
||||
|
||||
const double pow3 = l1 * l1 * l1;
|
||||
sum_pow3_ += pow3;
|
||||
sum_pow6_ += pow3 * pow3;
|
||||
n_ += 1;
|
||||
|
||||
// Avoid division by zero, which happens when there is no error. NumExact()
|
||||
// reports the number of times this happens.
|
||||
if (l1 != 0.0) {
|
||||
const double rel = 1.0 + hwy::ScalarAbs(original) / l1;
|
||||
// Logarithm is required to prevent overflow. A hierarchical geomean
|
||||
// could also work, but that is more complex and not necessarily better.
|
||||
sum_log_rel_ += log(rel);
|
||||
num_rel_ += 1;
|
||||
}
|
||||
}
|
||||
|
||||
void Assimilate(const DistortionStats& other) {
|
||||
if (other.max_l1_ > max_l1_) {
|
||||
max_l1_ = other.max_l1_;
|
||||
max_idx_ = other.max_idx_;
|
||||
}
|
||||
|
||||
sum_pow3_ += other.sum_pow3_;
|
||||
sum_pow6_ += other.sum_pow6_;
|
||||
n_ += other.n_;
|
||||
|
||||
sum_log_rel_ += other.sum_log_rel_;
|
||||
num_rel_ += other.num_rel_;
|
||||
}
|
||||
|
||||
size_t NumExact() const { return n_ - num_rel_; }
|
||||
|
||||
double GeomeanValueDivL1() const {
|
||||
if (num_rel_ == 0) return 0.0;
|
||||
return exp(sum_log_rel_ / num_rel_);
|
||||
}
|
||||
|
||||
double PNorm() const {
|
||||
// p-norms are a compromise between max-norm (penalizes the largest error
|
||||
// without dilution, but does not notice any other errors) and L1 (all
|
||||
// errors contribute, but large errors are diluted by smaller ones).
|
||||
const double norm3 = pow(sum_pow3_ / n_, 1.0 / 3);
|
||||
const double norm6 = pow(sum_pow6_ / n_, 1.0 / 6);
|
||||
return 0.5 * (norm3 + norm6);
|
||||
}
|
||||
|
||||
size_t MaxIndex() const { return max_idx_; }
|
||||
double MaxL1() const { return max_l1_; }
|
||||
|
||||
private:
|
||||
size_t n_ = 0;
|
||||
size_t max_idx_ = 0; // index that had l1 = max_l1_.
|
||||
double max_l1_ = -1.0;
|
||||
|
||||
double sum_pow3_ = 0.0;
|
||||
double sum_pow6_ = 0.0;
|
||||
|
||||
double sum_log_rel_ = 0.0;
|
||||
size_t num_rel_ = 0;
|
||||
double padding_; // prevents false sharing
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_
|
||||
|
|
@ -0,0 +1,730 @@
|
|||
// Copyright 2023 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Normal include guard.
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/nuq.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/sfp.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_
|
||||
|
||||
// Actual per-target include guard.
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE) == \
|
||||
defined(HWY_TARGET_TOGGLE)
|
||||
#ifdef THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE
|
||||
#undef THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE
|
||||
#else
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE
|
||||
#endif
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/sfp-inl.h"
|
||||
#include "hwy/contrib/sort/vqsort-inl.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
// For internal use by NuqCodec.
|
||||
class NuqClustering {
|
||||
// To go from sorted order back to the original order in O(1), we store the
|
||||
// original index in the lower bits of the float32 mantissa, which means they
|
||||
// are sorted alongside the value.
|
||||
struct FloatPayload {
|
||||
// Resets payload to zero; useful for displaying the actual value.
|
||||
static HWY_INLINE float Clear(float f) {
|
||||
const uint32_t binary32 = hwy::BitCastScalar<uint32_t>(f);
|
||||
return hwy::BitCastScalar<float>(binary32 &
|
||||
~static_cast<uint32_t>(kGroupSize - 1));
|
||||
}
|
||||
|
||||
// Sets payload to `bits`.
|
||||
static HWY_INLINE float Set(float f, size_t bits) {
|
||||
HWY_DASSERT(bits < kGroupSize);
|
||||
const uint32_t binary32 = hwy::BitCastScalar<uint32_t>(Clear(f));
|
||||
return hwy::BitCastScalar<float>(static_cast<uint32_t>(binary32 | bits));
|
||||
}
|
||||
|
||||
// Obtains the payload (index) previously set by `Set`.
|
||||
static HWY_INLINE size_t Get(float f) {
|
||||
return hwy::BitCastScalar<uint32_t>(f) &
|
||||
static_cast<uint32_t>(kGroupSize - 1);
|
||||
}
|
||||
};
|
||||
|
||||
// Cumulative sums for O(1) mean and interval sums.
|
||||
class ClusterCost {
|
||||
public:
|
||||
explicit ClusterCost(const float* sorted) {
|
||||
cumsum_[0] = cumsum2_[0] = 0.0;
|
||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
const float x = FloatPayload::Clear(sorted[i]);
|
||||
cumsum_[1 + i] = x + cumsum_[i];
|
||||
cumsum2_[1 + i] = x * x + cumsum2_[i];
|
||||
}
|
||||
|
||||
inv_len_[0] = 0.0f; // unused
|
||||
for (size_t i = 0; i <= kGroupSize; ++i) {
|
||||
inv_len_[i] = 1.0f / i;
|
||||
}
|
||||
}
|
||||
|
||||
float SumOfSorted(size_t first, size_t last) const {
|
||||
return cumsum_[last + 1] - cumsum_[first];
|
||||
}
|
||||
|
||||
// Returns cost of clustering first..last with their mean, for a vector of
|
||||
// last. O(1) thanks to cumulative sums, which works for Lp-norms with p >
|
||||
// 1; we choose p=2 for simplicity (fewer terms).
|
||||
template <class DF>
|
||||
hn::Vec<DF> operator()(DF df, size_t first, size_t last) const {
|
||||
// Callers are responsible for ignoring lanes where last < first.
|
||||
HWY_DASSERT(first < kGroupSize);
|
||||
HWY_DASSERT(last < kGroupSize);
|
||||
const size_t len = last - first + 1;
|
||||
const hn::Vec<DF> vlen =
|
||||
hn::Iota(df, static_cast<float>(static_cast<int>(len)));
|
||||
|
||||
const hn::Vec<DF> u_lo = hn::Set(df, cumsum_[first]);
|
||||
const hn::Vec<DF> u_lo2 = hn::Set(df, cumsum2_[first]);
|
||||
const hn::Vec<DF> hi = hn::LoadU(df, cumsum_ + last + 1);
|
||||
const hn::Vec<DF> hi2 = hn::LoadU(df, cumsum2_ + last + 1);
|
||||
const hn::Vec<DF> sum = hn::Sub(hi, u_lo);
|
||||
const hn::Vec<DF> sum2 = hn::Sub(hi2, u_lo2);
|
||||
|
||||
// Compute mean: table lookup is faster than division.
|
||||
const hn::Vec<DF> mu = hn::Mul(sum, hn::LoadU(df, inv_len_ + len));
|
||||
|
||||
// (x - mu)^2 = sum2 - 2mu*sum + mu^2
|
||||
const hn::Vec<DF> mu2 = hn::Mul(mu, mu);
|
||||
const hn::Vec<DF> two_mu = hn::Add(mu, mu);
|
||||
return hn::NegMulAdd(two_mu, sum, hn::MulAdd(vlen, mu2, sum2));
|
||||
}
|
||||
|
||||
private:
|
||||
// Float has enough precision for our relatively small kGroupSize (128).
|
||||
float cumsum_[kGroupSize + 1];
|
||||
float cumsum2_[kGroupSize + 1];
|
||||
float inv_len_[kGroupSize + 1];
|
||||
};
|
||||
|
||||
// Cost of clustering 0..last, where the rightmost cluster is j..last. This is
|
||||
// called in a loop over j, and we return the vector of costs for a batch of
|
||||
// last = [last, last + N).
|
||||
template <class DF>
|
||||
static HWY_INLINE hn::Vec<DF> ClusterDynProg(
|
||||
DF df, const AlignedMatrix<float>& D, const ClusterCost& cc,
|
||||
const size_t num_clusters, const size_t last, const size_t j) {
|
||||
HWY_DASSERT(last < kGroupSize);
|
||||
HWY_DASSERT(0 != j && j < kGroupSize);
|
||||
|
||||
const hn::RebindToSigned<decltype(df)> di;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
using VI = hn::Vec<decltype(di)>;
|
||||
using MI = hn::Mask<decltype(di)>;
|
||||
|
||||
const VI vlast = hn::Iota(di, static_cast<int32_t>(last));
|
||||
|
||||
// We have a non-empty rightmost cluster if j <= last <==> j-1 < last.
|
||||
const MI valid = hn::Lt(hn::Set(di, static_cast<int32_t>(j) - 1), vlast);
|
||||
// If not valid, return an arbitrary high cost, which will not be the min.
|
||||
const VF max = hn::Set(df, 1E38f);
|
||||
// Cost of clustering 0..j-1 with one fewer cluster than now.
|
||||
const VF vd = hn::Set(df, D(num_clusters - 1, j - 1));
|
||||
// Eq2: add to that the cost of another cluster from j..last.
|
||||
return hn::MaskedAddOr(max, RebindMask(df, valid), vd, cc(df, j, last));
|
||||
}
|
||||
|
||||
public:
|
||||
// Clusters `kGroupSize` values in `x`, which need not be sorted already nor
|
||||
// aligned, by choosing and filling `centers` (size `kClusters`, ascending
|
||||
// order, not necessarily equal to one of the `x`). Fills `indices` with the
|
||||
// index of the cluster to which each `x` belongs (16-bit for bit-packing).
|
||||
// `buf` is per-thread.
|
||||
//
|
||||
// Returns the number of unused clusters, i.e., the starting index within
|
||||
// `centers`; prior centers are zero-initialized.
|
||||
//
|
||||
// O(kClusters * kGroupSize * kGroupSize), but the constant factors are so low
|
||||
// that this is about 10 times as fast as the O(kClusters * kGroupSize) SMAWK
|
||||
// as implemented in FAISS, for our kGroupSize <= 128.
|
||||
template <class DF>
|
||||
static HWY_NOINLINE size_t ClusterExactL2(DF df, const float* x,
|
||||
ClusterBuf& buf,
|
||||
float* HWY_RESTRICT centers,
|
||||
uint16_t* HWY_RESTRICT indices) {
|
||||
const hn::RebindToSigned<decltype(df)> di;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
using VI = hn::Vec<decltype(di)>;
|
||||
const VI k1 = hn::Set(di, 1);
|
||||
const size_t N = hn::Lanes(df);
|
||||
|
||||
HWY_ALIGN float sorted_and_i[kGroupSize];
|
||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
sorted_and_i[i] = FloatPayload::Set(x[i], i);
|
||||
}
|
||||
hn::VQSortStatic(sorted_and_i, kGroupSize, hwy::SortAscending());
|
||||
ClusterCost cc(sorted_and_i);
|
||||
|
||||
// Reference: https://arxiv.org/abs/1701.07204
|
||||
// D[k-1][m] is the lowest cost of clustering x1..m into k clusters.
|
||||
AlignedMatrix<float>& D = buf.d;
|
||||
// T[k][m] is the starting index within sorted_and_i[] of the k-th cluster.
|
||||
AlignedMatrix<int32_t>& T = buf.t;
|
||||
|
||||
// Initialize the first rows for a single cluster.
|
||||
for (size_t last = 0; last < kGroupSize; last += N) {
|
||||
hn::Store(cc(df, 0, last), df, &D(0, last)); // Cost of 0..last
|
||||
hn::Store(Zero(di), di, &T(0, last)); // Cluster index = 0
|
||||
}
|
||||
|
||||
for (size_t num_clusters = 1; num_clusters < kClusters; ++num_clusters) {
|
||||
// For each batch starting at `last`, one per lane:
|
||||
for (size_t last = 0; last < kGroupSize; last += N) {
|
||||
VF min = cc(df, 0, last);
|
||||
VI arg = hn::Zero(di);
|
||||
// For each j (start of rightmost cluster):
|
||||
VI vj = k1;
|
||||
for (size_t j = 1; j < last + N; ++j, vj = Add(vj, k1)) {
|
||||
const VF c = ClusterDynProg(df, D, cc, num_clusters, last, j);
|
||||
|
||||
// Retain the min cost and the j index that caused it.
|
||||
const auto less = hn::Lt(c, min);
|
||||
min = hn::IfThenElse(less, c, min);
|
||||
arg = hn::IfThenElse(RebindMask(di, less), vj, arg);
|
||||
}
|
||||
hn::Store(min, df, &D(num_clusters, last));
|
||||
hn::Store(arg, di, &T(num_clusters, last));
|
||||
}
|
||||
}
|
||||
|
||||
// Backtrack to find centers. Clusters are [T(k, last), last].
|
||||
size_t last = kGroupSize - 1;
|
||||
size_t unused_clusters = 0;
|
||||
for (size_t k = kClusters - 1; k < kClusters; --k) {
|
||||
const size_t start = static_cast<size_t>(T(k, last));
|
||||
// Center = mean, O(1) thanks to cumulative sums.
|
||||
const float sum = cc.SumOfSorted(start, last);
|
||||
const int size = static_cast<int>(last) - static_cast<int>(start) + 1;
|
||||
HWY_DASSERT(0 < size && size <= kGroupSize);
|
||||
centers[k] = sum / size;
|
||||
|
||||
// We know the range inside sorted_and_i[]; translate to original indices,
|
||||
// which are stored inside each of the sorted_and_i mantissas.
|
||||
for (size_t i = start; i <= last; ++i) {
|
||||
const size_t idx_x = FloatPayload::Get(sorted_and_i[i]);
|
||||
HWY_DASSERT(idx_x < kGroupSize);
|
||||
indices[idx_x] = static_cast<uint16_t>(k);
|
||||
}
|
||||
|
||||
// Not using all clusters. Avoid out of bounds accesses by stopping early.
|
||||
if (start == 0) {
|
||||
unused_clusters = k;
|
||||
for (size_t cluster = 0; cluster < unused_clusters; ++cluster) {
|
||||
centers[cluster] = 0.0f;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
last = start - 1;
|
||||
HWY_DASSERT(last < kGroupSize);
|
||||
}
|
||||
|
||||
if (HWY_IS_DEBUG_BUILD) {
|
||||
// Centers are in ascending order.
|
||||
for (size_t i = unused_clusters + 1; i < kClusters; ++i) {
|
||||
HWY_DASSERT(centers[i] >= centers[i - 1]);
|
||||
}
|
||||
}
|
||||
return unused_clusters;
|
||||
}
|
||||
}; // NuqClustering
|
||||
|
||||
// Bit-packing 4-bit values is trivial if we have 2 or 4 independent vectors:
|
||||
// simply shift+OR them together into a full vector of 8 or 16-bit lanes.
|
||||
// However, the order then depends on the vector length, which is unacceptable
|
||||
// because we may store the encoding to disk and decode on another CPU.
|
||||
//
|
||||
// The dependency on vector length could be removed by introducing fixed-size
|
||||
// packets and loading the next vector from a fixed offset known to be at
|
||||
// least the vector length. However, this may require packets that are larger
|
||||
// than the seek granularity of the application (e.g. matrix rows).
|
||||
//
|
||||
// We instead choose a continuous stream layout, which seems to entail the
|
||||
// nibbles being stored and decoded in-order. This involves nontrivial shuffle
|
||||
// operations which benefit from special-casing for target and vector length.
|
||||
class NibbleCodec {
|
||||
public:
|
||||
// Packs four u16 vectors' lanes to nibbles within one vector, in order, and
|
||||
// stores that vector to `out`.
|
||||
template <class D16, class V16 = hn::Vec<D16>>
|
||||
static HWY_INLINE void OrderedPackU16(D16 d16, V16 in0, V16 in1, V16 in2,
|
||||
V16 in3, uint8_t* HWY_RESTRICT out) {
|
||||
const hn::Repartition<uint8_t, D16> d8;
|
||||
const hn::Repartition<uint32_t, D16> d32;
|
||||
const hn::Repartition<uint64_t, D16> d64;
|
||||
using V8 = hn::Vec<decltype(d8)>;
|
||||
|
||||
// Pairwise compaction of a single vector so nibbles are packed in-order.
|
||||
// v16 lanes hold a 4-bit value; OR together adjacent pairs into the lower
|
||||
// byte of *even* u16.
|
||||
const auto combine_u16_pair_to_8 = [d16, d32](V16 v16) HWY_ATTR {
|
||||
return hn::Xor(
|
||||
v16, hn::BitCast(d16, hn::ShiftRight<12>(hn::BitCast(d32, v16))));
|
||||
};
|
||||
|
||||
const V16 u8_0 = combine_u16_pair_to_8(in0);
|
||||
const V16 u8_1 = combine_u16_pair_to_8(in1);
|
||||
const V16 u8_2 = combine_u16_pair_to_8(in2);
|
||||
const V16 u8_3 = combine_u16_pair_to_8(in3);
|
||||
V8 packed;
|
||||
if (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) {
|
||||
// 8-bit ConcatEven is efficient. Let digits denote eight u8 lanes
|
||||
// of u8_1/0: ?d?3 ?c?2 / ?b?1 ?a?0. 8-bit ConcatEven = d3c2 b1a0, and
|
||||
// again with the second x2_1 gives 7654 3210.
|
||||
const V8 x2_0 = hn::ConcatEven(d8, BitCast(d8, u8_1), BitCast(d8, u8_0));
|
||||
const V8 x2_1 = hn::ConcatEven(d8, BitCast(d8, u8_3), BitCast(d8, u8_2));
|
||||
packed = hn::ConcatEven(d8, x2_1, x2_0);
|
||||
} else {
|
||||
// To avoid expensive 8-bit ConcatEven, compact pairs of u32 into the
|
||||
// lower 16 bits in each u64, with other bits undefined.
|
||||
const auto combine_u32_pair_to_16 = [d16, d64](V16 v16) HWY_ATTR {
|
||||
return hn::Xor(
|
||||
v16, hn::BitCast(d16, hn::ShiftRight<24>(hn::BitCast(d64, v16))));
|
||||
};
|
||||
const V16 u16_0 = combine_u32_pair_to_16(u8_0);
|
||||
const V16 u16_1 = combine_u32_pair_to_16(u8_1);
|
||||
const V16 u16_2 = combine_u32_pair_to_16(u8_2);
|
||||
const V16 u16_3 = combine_u32_pair_to_16(u8_3);
|
||||
// In-order compaction of four vectors into one, keeping only the low
|
||||
// u16 of every u64. This is the same as above but with 16-bit Concat.
|
||||
const V16 x2_0 = hn::ConcatEven(d16, u16_1, u16_0);
|
||||
const V16 x2_1 = hn::ConcatEven(d16, u16_3, u16_2);
|
||||
packed = hn::BitCast(d8, hn::ConcatEven(d16, x2_1, x2_0));
|
||||
}
|
||||
hn::StoreU(packed, d8, out);
|
||||
}
|
||||
|
||||
// Unpacks `Lanes(d16)` nibbles to u16 lanes. The first comes from the low
|
||||
// nibble of packed[0], then its high nibble, then the next low nibble, etc.
|
||||
template <class D16, class V16 = hn::Vec<D16>>
|
||||
static HWY_INLINE V16 OrderedUnpackU16(D16 d16, const uint8_t* packed) {
|
||||
const hn::Repartition<uint8_t, D16> d8;
|
||||
using V8 = hn::Vec<decltype(d8)>;
|
||||
const hn::CappedTag<uint8_t, d16.MaxBytes() / 4> d_load;
|
||||
|
||||
// We replicate each byte 4x, so that its two nibbles propagate to both
|
||||
// u16 lanes that they will initialize. The only performance-portable op to
|
||||
// replicate bytes is TableLookupBytes, which shuffles 128-bit blocks
|
||||
// independently. Thus each block receives 4 packed bytes, replicates them
|
||||
// 4x, shifts/masks, and casts to 8 u16 lanes.
|
||||
//
|
||||
// Loading 16 bytes via LoadDup128 only works on AVX3; for smaller vectors,
|
||||
// it may trigger asan errors from overrunning the end. We thus special-case
|
||||
// vector lengths, handling any non-constexpr, and constexpr <= 512 bit.
|
||||
V8 rep4;
|
||||
if (HWY_HAVE_SCALABLE) {
|
||||
// Non constexpr length: 4 per whole block equals size/4.
|
||||
const size_t num_bytes = HWY_MAX(1, hn::Lanes(d8) / 4);
|
||||
const V8 bytes = hn::LoadN(d8, packed, num_bytes);
|
||||
// Replicate bytes 4x: lowest 4 = 0, next 4 = 1 etc.
|
||||
const V8 idx = hn::And(hn::Iota(d8, 0), hn::Set(d8, 0xFCu));
|
||||
rep4 = hn::TableLookupBytes(bytes, idx);
|
||||
} else if (hn::MaxLanes(d16) <= 8) { // <= 128-bit
|
||||
const V8 bytes = hn::ResizeBitCast(d8, hn::LoadU(d_load, packed));
|
||||
alignas(16) static constexpr uint8_t kRep4[16] = {
|
||||
HWY_REP4(0), HWY_REP4(1), HWY_REP4(2), HWY_REP4(3)};
|
||||
rep4 = hn::TableLookupBytes(bytes, hn::Load(d8, kRep4));
|
||||
} else if (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) {
|
||||
// Plain load, can do 256..512-bit permute across blocks.
|
||||
const V8 bytes = hn::ResizeBitCast(d8, hn::LoadU(d_load, packed));
|
||||
alignas(64) static constexpr uint8_t kRep4[64] = {
|
||||
HWY_REP4(0), HWY_REP4(1), HWY_REP4(2), HWY_REP4(3),
|
||||
HWY_REP4(4), HWY_REP4(5), HWY_REP4(6), HWY_REP4(7),
|
||||
HWY_REP4(8), HWY_REP4(9), HWY_REP4(10), HWY_REP4(11),
|
||||
HWY_REP4(12), HWY_REP4(13), HWY_REP4(14), HWY_REP4(15)};
|
||||
rep4 = hn::TableLookupLanes(bytes, hn::SetTableIndices(d8, kRep4));
|
||||
} else if (hn::MaxLanes(d16) == 16) { // 256-bit
|
||||
const V8 bytes = hn::ResizeBitCast(d8, hn::LoadU(d_load, packed));
|
||||
// First copy to upper block for TableLookupBytes. This is slightly
|
||||
// faster than 64-bit BroadcastLane.
|
||||
const V8 bcast = hn::ConcatLowerLower(d8, bytes, bytes);
|
||||
alignas(32) static constexpr uint8_t kRep4[32] = {
|
||||
HWY_REP4(0), HWY_REP4(1), HWY_REP4(2), HWY_REP4(3),
|
||||
HWY_REP4(4), HWY_REP4(5), HWY_REP4(6), HWY_REP4(7)};
|
||||
rep4 = hn::TableLookupBytes(bcast, hn::Load(d8, kRep4));
|
||||
} else if (hn::MaxLanes(d16) == 32) { // 512-bit
|
||||
const V8 bytes = hn::LoadDup128(d8, packed);
|
||||
alignas(64) static constexpr uint8_t kRep4[64] = {
|
||||
HWY_REP4(0), HWY_REP4(1), HWY_REP4(2), HWY_REP4(3),
|
||||
HWY_REP4(4), HWY_REP4(5), HWY_REP4(6), HWY_REP4(7),
|
||||
HWY_REP4(8), HWY_REP4(9), HWY_REP4(10), HWY_REP4(11),
|
||||
HWY_REP4(12), HWY_REP4(13), HWY_REP4(14), HWY_REP4(15)};
|
||||
rep4 = hn::TableLookupBytes(bytes, hn::Load(d8, kRep4));
|
||||
} else {
|
||||
HWY_DASSERT(false);
|
||||
}
|
||||
|
||||
const V16 mask4 = hn::Set(d16, 0xF);
|
||||
const V16 u16 = BitCast(d16, rep4);
|
||||
// In-order unpack. Right-shift odd u16 by 4. Example with two packed
|
||||
// bytes, one digit representing a nibble:
|
||||
// 32 32 32 32 | 10 10 10 10 u16
|
||||
// z3 23 32 32 | z1 01 10 10 OddEven+ShiftRight
|
||||
// zz z3 zz z2 | zz z1 zz z0 And (unpacked result)
|
||||
return hn::And(mask4, hn::OddEven(hn::ShiftRight<4>(u16), u16));
|
||||
}
|
||||
};
|
||||
|
||||
// Encode/decode functions.
|
||||
class NuqCodec {
|
||||
// 256-bit vectors can hold 16 bf16, otherwise we require 2x128-bit.
|
||||
template <class DU>
|
||||
static constexpr size_t NumTables(DU du) {
|
||||
return (!HWY_HAVE_SCALABLE && du.MaxBytes() >= 32) ? 1 : 2;
|
||||
}
|
||||
|
||||
// Unpacks `centers` from SFP into bf16 and loads them into one or two vectors
|
||||
// for use by [Two]TableLookups. Returns as u16 because TableLookupLanes might
|
||||
// not be available for bf16.
|
||||
template <class DU, HWY_IF_U16_D(DU)>
|
||||
static HWY_INLINE hn::Vec<DU> LoadTable(DU du, const uint8_t* centers,
|
||||
hn::Vec<DU>* HWY_RESTRICT tbl1) {
|
||||
// Cap to the table size (kClusters) for decoding SFP - sufficient, and may
|
||||
// be faster than a large vector.
|
||||
const hn::CappedTag<hwy::bfloat16_t, kClusters> d_table;
|
||||
// We ResizeCast tables to DU: if DU is bigger, table lookups will only
|
||||
// access lanes < kClusters. If DU is smaller (128-bit), we have 2 tables.
|
||||
HWY_DASSERT(hn::Lanes(du) >= hn::Lanes(d_table) || NumTables(du) == 2);
|
||||
|
||||
HWY_ALIGN hwy::bfloat16_t table[kClusters];
|
||||
SfpCodec::Dec(d_table, reinterpret_cast<const SfpStream*>(centers),
|
||||
kClusters, table);
|
||||
|
||||
// If we assume >= 128-bit vectors, we can use [Two]TableLookupLanes
|
||||
// instead of TableLookupBytes, which requires extra interleaving of lo/hi.
|
||||
HWY_DASSERT(hn::Lanes(du) >= 8);
|
||||
|
||||
if (NumTables(du) == 2) {
|
||||
// Reduce cap for second half to avoid loading past the end of the table.
|
||||
const hn::CappedTag<hwy::bfloat16_t, kClusters / 2> d_table2;
|
||||
*tbl1 = hn::ResizeBitCast(du, hn::LoadU(d_table2, table + kClusters / 2));
|
||||
}
|
||||
return hn::ResizeBitCast(du, hn::Load(d_table, table));
|
||||
}
|
||||
|
||||
// Unpacks per-weight indices and sets c0/c1 to the corresponding centers.
|
||||
template <class DU>
|
||||
static HWY_INLINE void TableLookups(DU du, hn::Vec<DU> tbl0, hn::Vec<DU> tbl1,
|
||||
const uint8_t* packed, hn::Vec<DU>& c0,
|
||||
hn::Vec<DU>& c1) {
|
||||
using V16 = hn::Vec<decltype(du)>;
|
||||
const size_t N16 = hn::Lanes(du);
|
||||
|
||||
const V16 idx0 = NibbleCodec::OrderedUnpackU16(du, packed);
|
||||
const V16 idx1 = NibbleCodec::OrderedUnpackU16(du, packed + N16 / 2);
|
||||
|
||||
const auto indices0 = hn::IndicesFromVec(du, idx0);
|
||||
const auto indices1 = hn::IndicesFromVec(du, idx1);
|
||||
|
||||
if (NumTables(du) == 1) {
|
||||
(void)tbl1;
|
||||
c0 = hn::TableLookupLanes(tbl0, indices0);
|
||||
c1 = hn::TableLookupLanes(tbl0, indices1);
|
||||
} else {
|
||||
c0 = hn::TwoTablesLookupLanes(du, tbl0, tbl1, indices0);
|
||||
c1 = hn::TwoTablesLookupLanes(du, tbl0, tbl1, indices1);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
// Encodes `num` floats starting from `in`. `out` points to compressed
|
||||
// storage for `out_capacity` values and `out_ofs` indicates the destination
|
||||
// offset within it, in units of float values, for parallel encoding by
|
||||
// multiple threads. `num`, `out_capacity`, and `out_ofs` must all be
|
||||
// multiples of `kGroupSize`. Returns the total number of unused clusters,
|
||||
// which is expected to be zero.
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE size_t Enc(DF df, const float* const in, const size_t num,
|
||||
ClusterBuf& buf, const size_t out_capacity,
|
||||
NuqStream* const out, const size_t out_ofs) {
|
||||
const hn::Repartition<uint8_t, DF> d8;
|
||||
const hn::Repartition<uint16_t, DF> d16;
|
||||
using V8 = hn::Vec<decltype(d8)>;
|
||||
using V16 = hn::Vec<decltype(d16)>;
|
||||
|
||||
const size_t N16 = hn::Lanes(d16);
|
||||
HWY_ASSERT(kGroupSize >= 4 * N16);
|
||||
|
||||
HWY_ASSERT(out_ofs + num <= out_capacity);
|
||||
buf.Resize(num);
|
||||
HWY_ASSERT(num % kGroupSize == 0);
|
||||
HWY_ASSERT(out_capacity % kGroupSize == 0);
|
||||
HWY_ASSERT(out_ofs % kGroupSize == 0);
|
||||
const size_t num_groups = num / kGroupSize;
|
||||
const size_t ofs_groups = out_ofs / kGroupSize;
|
||||
|
||||
size_t unused_clusters = 0;
|
||||
for (size_t g = 0; g < num_groups; ++g) {
|
||||
const float* HWY_RESTRICT g_in = in + g * kGroupSize;
|
||||
float* HWY_RESTRICT g_centers = buf.centers.get() + g * kClusters;
|
||||
uint16_t* HWY_RESTRICT g_idx = buf.idx.get() + g * kGroupSize;
|
||||
unused_clusters +=
|
||||
NuqClustering::ClusterExactL2(df, g_in, buf, g_centers, g_idx);
|
||||
}
|
||||
|
||||
uint8_t* centers = &out->byte + ofs_groups * kClusters;
|
||||
SfpCodec::Enc(df, buf.centers.get(), num_groups * kClusters,
|
||||
reinterpret_cast<SfpStream*>(centers));
|
||||
uint8_t* packed_start = &out->byte + NuqStream::PackedStart(out_capacity) +
|
||||
ofs_groups * kGroupSize / 2;
|
||||
|
||||
HWY_UNROLL(1)
|
||||
for (size_t g = 0; g < num_groups; ++g) {
|
||||
const uint16_t* HWY_RESTRICT g_idx = buf.idx.get() + g * kGroupSize;
|
||||
uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2;
|
||||
|
||||
HWY_UNROLL(1)
|
||||
for (size_t i = 0; i < kGroupSize; i += 4 * N16) {
|
||||
const V16 idx0 = hn::LoadU(d16, g_idx + i + N16 * 0);
|
||||
const V16 idx1 = hn::LoadU(d16, g_idx + i + N16 * 1);
|
||||
const V16 idx2 = hn::LoadU(d16, g_idx + i + N16 * 2);
|
||||
const V16 idx3 = hn::LoadU(d16, g_idx + i + N16 * 3);
|
||||
NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3,
|
||||
g_packed + i / 2);
|
||||
}
|
||||
}
|
||||
|
||||
return unused_clusters;
|
||||
}
|
||||
|
||||
// Decodes `num` values from the stream `in`, starting at the offset `in_ofs`
|
||||
// (in units of values), to bf16 in `out`. `in_capacity`, `in_ofs` and `num`
|
||||
// must all be multiples of `kGroupSize`.
|
||||
template <class DF, HWY_IF_BF16_D(DF)>
|
||||
static HWY_INLINE void Dec(DF dbf, const size_t in_capacity,
|
||||
const NuqStream* const in, const size_t in_ofs,
|
||||
hwy::bfloat16_t* const out, const size_t num) {
|
||||
const hn::RebindToUnsigned<decltype(dbf)> d16;
|
||||
using V16 = hn::Vec<decltype(d16)>;
|
||||
|
||||
const size_t N16 = hn::Lanes(d16);
|
||||
HWY_DASSERT(kGroupSize >= 4 * N16);
|
||||
|
||||
HWY_DASSERT(in_ofs + num <= in_capacity);
|
||||
HWY_DASSERT(in_capacity % kGroupSize == 0);
|
||||
HWY_DASSERT(in_ofs % kGroupSize == 0);
|
||||
HWY_DASSERT(num % kGroupSize == 0);
|
||||
const size_t num_groups = num / kGroupSize;
|
||||
const size_t ofs_groups = in_ofs / kGroupSize;
|
||||
const uint8_t* tables = &in->byte + ofs_groups * kClusters;
|
||||
const uint8_t* packed_start = &in->byte +
|
||||
NuqStream::PackedStart(in_capacity) +
|
||||
ofs_groups * kGroupSize / 2;
|
||||
|
||||
HWY_UNROLL(1)
|
||||
for (size_t g = 0; g < num_groups; ++g) {
|
||||
const uint8_t* g_centers = tables + g * kClusters;
|
||||
const uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2;
|
||||
hwy::bfloat16_t* HWY_RESTRICT g_out = out + g * kGroupSize;
|
||||
|
||||
V16 tbl1 = Zero(d16);
|
||||
const V16 tbl0 = LoadTable(d16, g_centers, &tbl1);
|
||||
|
||||
HWY_UNROLL(1)
|
||||
for (size_t i = 0; i < kGroupSize; i += 2 * N16) {
|
||||
V16 c0, c1;
|
||||
TableLookups(d16, tbl0, tbl1, g_packed + i / 2, c0, c1);
|
||||
hn::StoreU(BitCast(dbf, c0), dbf, g_out + i + N16 * 0);
|
||||
hn::StoreU(BitCast(dbf, c1), dbf, g_out + i + N16 * 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Decodes `num` values from the stream `in`, starting at the offset
|
||||
// `in_ofs` (in units of values), to f32 in `out`. `in_capacity`,
|
||||
// `in_ofs` and `num` must all be multiples of `kGroupSize`.
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Dec(DF df, const size_t in_capacity,
|
||||
const NuqStream* const in, const size_t in_ofs,
|
||||
float* const out, const size_t num) {
|
||||
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
|
||||
const hn::RebindToUnsigned<decltype(dbf)> d16;
|
||||
using V16 = hn::Vec<decltype(d16)>;
|
||||
using VF = hn::Vec<DF>;
|
||||
|
||||
const size_t NF = hn::Lanes(df);
|
||||
HWY_DASSERT(kGroupSize >= 4 * NF);
|
||||
|
||||
HWY_DASSERT(in_ofs + num <= in_capacity);
|
||||
HWY_DASSERT(in_capacity % kGroupSize == 0);
|
||||
HWY_DASSERT(in_ofs % kGroupSize == 0);
|
||||
HWY_DASSERT(num % kGroupSize == 0);
|
||||
const size_t ofs_groups = in_ofs / kGroupSize;
|
||||
const size_t num_groups = num / kGroupSize;
|
||||
const uint8_t* tables = &in->byte + ofs_groups * kClusters;
|
||||
const uint8_t* packed_start = &in->byte +
|
||||
NuqStream::PackedStart(in_capacity) +
|
||||
ofs_groups * kGroupSize / 2;
|
||||
|
||||
HWY_UNROLL(1)
|
||||
for (size_t g = 0; g < num_groups; ++g) {
|
||||
const uint8_t* g_centers = tables + g * kClusters;
|
||||
const uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2;
|
||||
float* HWY_RESTRICT g_out = out + g * kGroupSize;
|
||||
|
||||
V16 tbl1 = Zero(d16);
|
||||
const V16 tbl0 = LoadTable(d16, g_centers, &tbl1);
|
||||
|
||||
HWY_UNROLL(1)
|
||||
for (size_t i = 0; i < kGroupSize; i += 4 * NF) {
|
||||
V16 c0, c1;
|
||||
TableLookups(d16, tbl0, tbl1, g_packed + i / 2, c0, c1);
|
||||
const VF f0 = hn::PromoteLowerTo(df, BitCast(dbf, c0));
|
||||
const VF f1 = hn::PromoteUpperTo(df, BitCast(dbf, c0));
|
||||
const VF f2 = hn::PromoteLowerTo(df, BitCast(dbf, c1));
|
||||
const VF f3 = hn::PromoteUpperTo(df, BitCast(dbf, c1));
|
||||
hn::StoreU(f0, df, g_out + i + NF * 0);
|
||||
hn::StoreU(f1, df, g_out + i + NF * 1);
|
||||
hn::StoreU(f2, df, g_out + i + NF * 2);
|
||||
hn::StoreU(f3, df, g_out + i + NF * 3);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulates into `sum0..3` dot products of decoded values with `num` bf16
|
||||
// from `vec_aligned`. DF is f32 because sum0..3 are also f32. `in_capacity`,
|
||||
// `in_ofs` and `num` must all be multiples of `kGroupSize`.
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Dot(DF df, const size_t in_capacity,
|
||||
const NuqStream* const in, const size_t in_ofs,
|
||||
const hwy::bfloat16_t* const vec_aligned,
|
||||
const size_t num, hn::Vec<DF>& sum0,
|
||||
hn::Vec<DF>& sum1, hn::Vec<DF>& sum2,
|
||||
hn::Vec<DF>& sum3) {
|
||||
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
|
||||
const hn::RebindToUnsigned<decltype(dbf)> d16;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
using V16 = hn::Vec<decltype(d16)>;
|
||||
const size_t N16 = hn::Lanes(d16);
|
||||
HWY_DASSERT(kGroupSize >= 4 * N16);
|
||||
|
||||
HWY_DASSERT(in_ofs + num <= in_capacity);
|
||||
HWY_DASSERT(in_capacity % kGroupSize == 0);
|
||||
HWY_DASSERT(in_ofs % kGroupSize == 0);
|
||||
HWY_DASSERT(num % kGroupSize == 0);
|
||||
const size_t ofs_groups = in_ofs / kGroupSize;
|
||||
const size_t num_groups = num / kGroupSize;
|
||||
const uint8_t* tables = &in->byte + ofs_groups * kClusters;
|
||||
const uint8_t* packed_start = &in->byte +
|
||||
NuqStream::PackedStart(in_capacity) +
|
||||
ofs_groups * kGroupSize / 2;
|
||||
|
||||
HWY_UNROLL(1)
|
||||
for (size_t g = 0; g < num_groups; ++g) {
|
||||
const uint8_t* g_centers = tables + g * kClusters;
|
||||
const uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2;
|
||||
const hwy::bfloat16_t* HWY_RESTRICT g_in = vec_aligned + g * kGroupSize;
|
||||
|
||||
V16 tbl1 = Zero(d16);
|
||||
const V16 tbl0 = LoadTable(d16, g_centers, &tbl1);
|
||||
|
||||
HWY_UNROLL(1)
|
||||
for (size_t i = 0; i < kGroupSize; i += 2 * N16) {
|
||||
V16 c0, c1;
|
||||
TableLookups(d16, tbl0, tbl1, g_packed + i / 2, c0, c1);
|
||||
const VBF in0 = hn::Load(dbf, g_in + i + N16 * 0);
|
||||
const VBF in1 = hn::Load(dbf, g_in + i + N16 * 1);
|
||||
sum0 = hn::ReorderWidenMulAccumulate(df, in0, BitCast(dbf, c0), sum0,
|
||||
sum1);
|
||||
sum2 = hn::ReorderWidenMulAccumulate(df, in1, BitCast(dbf, c1), sum2,
|
||||
sum3);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulates into `sum0..3` dot products of decoded values with `num` f32
|
||||
// from `vec_aligned`. `in_capacity`, `in_ofs` and `num` must all be
|
||||
// multiples of `kGroupSize`.
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Dot(DF df, const size_t in_capacity,
|
||||
const NuqStream* const in, const size_t in_ofs,
|
||||
const float* const vec_aligned, const size_t num,
|
||||
hn::Vec<DF>& sum0, hn::Vec<DF>& sum1,
|
||||
hn::Vec<DF>& sum2, hn::Vec<DF>& sum3) {
|
||||
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
|
||||
const hn::RebindToUnsigned<decltype(dbf)> d16;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
using V16 = hn::Vec<decltype(d16)>;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
HWY_DASSERT(kGroupSize >= 4 * NF);
|
||||
|
||||
HWY_DASSERT(in_ofs + num <= in_capacity);
|
||||
HWY_DASSERT(in_capacity % kGroupSize == 0);
|
||||
HWY_DASSERT(in_ofs % kGroupSize == 0);
|
||||
HWY_DASSERT(num % kGroupSize == 0);
|
||||
const size_t ofs_groups = in_ofs / kGroupSize;
|
||||
const size_t num_groups = num / kGroupSize;
|
||||
const uint8_t* tables = &in->byte + ofs_groups * kClusters;
|
||||
const uint8_t* packed_start = &in->byte +
|
||||
NuqStream::PackedStart(in_capacity) +
|
||||
ofs_groups * kGroupSize / 2;
|
||||
|
||||
HWY_UNROLL(1)
|
||||
for (size_t g = 0; g < num_groups; ++g) {
|
||||
const uint8_t* g_centers = tables + g * kClusters;
|
||||
const uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2;
|
||||
const float* HWY_RESTRICT g_in = vec_aligned + g * kGroupSize;
|
||||
|
||||
V16 tbl1 = Zero(d16);
|
||||
const V16 tbl0 = LoadTable(d16, g_centers, &tbl1);
|
||||
|
||||
HWY_UNROLL(1)
|
||||
for (size_t i = 0; i < kGroupSize; i += 4 * NF) {
|
||||
V16 c0, c1;
|
||||
TableLookups(d16, tbl0, tbl1, g_packed + i / 2, c0, c1);
|
||||
const VF in0 = hn::LoadU(df, g_in + i + NF * 0);
|
||||
const VF in1 = hn::LoadU(df, g_in + i + NF * 1);
|
||||
const VF in2 = hn::LoadU(df, g_in + i + NF * 2);
|
||||
const VF in3 = hn::LoadU(df, g_in + i + NF * 3);
|
||||
const VF f0 = hn::PromoteLowerTo(df, BitCast(dbf, c0));
|
||||
const VF f1 = hn::PromoteUpperTo(df, BitCast(dbf, c0));
|
||||
const VF f2 = hn::PromoteLowerTo(df, BitCast(dbf, c1));
|
||||
const VF f3 = hn::PromoteUpperTo(df, BitCast(dbf, c1));
|
||||
sum0 = hn::MulAdd(in0, f0, sum0);
|
||||
sum1 = hn::MulAdd(in1, f1, sum1);
|
||||
sum2 = hn::MulAdd(in2, f2, sum2);
|
||||
sum3 = hn::MulAdd(in3, f3, sum3);
|
||||
}
|
||||
}
|
||||
}
|
||||
}; // NuqCodec
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
// Copyright 2023 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_H_
|
||||
|
||||
// Non-uniform quantization: a compressed representation of f32 inputs that
|
||||
// supports seeking at a granularity of kGroupSize, decoding to bf16/f32, and a
|
||||
// fused decode/dot product with bf16/f32 vectors.
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // HWY_INLINE
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// 4-bit indices are a sweet spot in terms of quality per size.
|
||||
static constexpr size_t kClusters = 16;
|
||||
|
||||
// Number of weights that share a table. Larger = slower encode, higher error,
|
||||
// smaller size (table amortized over more weights). This is the minimum
|
||||
// granularity for seeking/decoding in the stream, and must be at least four
|
||||
// times the number of bf16 elements per vector.
|
||||
static constexpr size_t kGroupSize = 256;
|
||||
|
||||
// Points to the *start* of a NUQ stream. Aligning the allocation (see
|
||||
// aligned_allocator.h) may be speed up decoding but is not required.
|
||||
//
|
||||
// See go/streaming-weight-decode for background and design. Layout: first one
|
||||
// table of kClusters entries per group, in ascending order of group index,
|
||||
// then two packed indices per byte.
|
||||
//
|
||||
// Indices are stored in-order to enable vector-length agnostic decode, because
|
||||
// streams may be persisted to disk and used by other CPUs.
|
||||
//
|
||||
// To enable parallel encoding and decoding, Enc/Dec have `offset` parameters
|
||||
// which refer to the stream, NOT the raw from/to pointers, which point directly
|
||||
// to the source/destination. Offsets are in units of values, NOT compressed
|
||||
// bytes within the stream.
|
||||
#pragma pack(push, 1)
|
||||
struct NuqStream {
|
||||
// Returns offset of packed indices from the start of the stream. This matches
|
||||
// the (padded) total table size because table entries are bytes. `capacity`
|
||||
// is already a multiple of `kGroupSize`.
|
||||
static constexpr size_t PackedStart(size_t capacity) {
|
||||
// Round up to avoid cache-line splits when loading indices. No effect on
|
||||
// size as long as capacity / kGroupSize is a multiple of 4.
|
||||
return hwy::RoundUpTo((capacity / kGroupSize) * kClusters, 64);
|
||||
}
|
||||
|
||||
// Returns number of NuqStream to allocate for the stream, which matches its
|
||||
// size in bytes. `capacity` is already a multiple of `kGroupSize`.
|
||||
static constexpr size_t PackedEnd(size_t capacity) {
|
||||
return PackedStart(capacity) + capacity / 2; // two 4-bit indices per byte.
|
||||
}
|
||||
|
||||
uint8_t byte;
|
||||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
static inline const char* TypeName(NuqStream) { return "NUQ"; }
|
||||
|
||||
// Storage for dynamic programming. There are two matrices; we use separate
|
||||
// allocations to avoid type punning.
|
||||
template <class T>
|
||||
class AlignedMatrix {
|
||||
public:
|
||||
AlignedMatrix() : mem_(hwy::AllocateAligned<T>(kClusters * kGroupSize)) {}
|
||||
|
||||
HWY_INLINE const T& operator()(size_t row, size_t col) const {
|
||||
return mem_[row * kGroupSize + col];
|
||||
}
|
||||
|
||||
HWY_INLINE T& operator()(size_t row, size_t col) {
|
||||
return mem_[row * kGroupSize + col];
|
||||
}
|
||||
|
||||
private:
|
||||
hwy::AlignedFreeUniquePtr<T[]> mem_;
|
||||
};
|
||||
|
||||
// Reuse memory across calls to Enc to avoid per-call allocations.
|
||||
struct ClusterBuf {
|
||||
void Resize(size_t new_num) {
|
||||
if (new_num < num) return;
|
||||
|
||||
num = new_num;
|
||||
const size_t num_groups = hwy::DivCeil(num, kGroupSize);
|
||||
centers = hwy::AllocateAligned<float>(num_groups * kClusters);
|
||||
idx = hwy::AllocateAligned<uint16_t>(num);
|
||||
}
|
||||
|
||||
AlignedMatrix<float> d;
|
||||
AlignedMatrix<int32_t> t;
|
||||
|
||||
size_t num = 0;
|
||||
hwy::AlignedFreeUniquePtr<float[]> centers;
|
||||
hwy::AlignedFreeUniquePtr<uint16_t[]> idx;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_H_
|
||||
|
|
@ -0,0 +1,428 @@
|
|||
// Copyright 2023 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::shuffle
|
||||
#include <random>
|
||||
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"third_party/gemma_cpp/compression/nuq_test.cc" // NOLINT
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
// Other headers that include Highway must come after foreach_target.h
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/distortion.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/nuq-inl.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/nuq.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
// All-equal inputs: only one cluster
|
||||
struct TestFlat {
|
||||
template <typename T, class DF>
|
||||
HWY_INLINE void operator()(T /*unused*/, DF df) {
|
||||
// Run this simple test only once to save time/debug output.
|
||||
if (!(HWY_ONCE && hn::Lanes(df) == hn::Lanes(hn::ScalableTag<float>()))) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto in = hwy::AllocateAligned<float>(kGroupSize);
|
||||
HWY_ASSERT(in);
|
||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
in[i] = 0.5f;
|
||||
}
|
||||
ClusterBuf buf;
|
||||
float centers[kClusters];
|
||||
uint16_t indices[kGroupSize];
|
||||
const size_t unused_clusters =
|
||||
NuqClustering::ClusterExactL2(df, in.get(), buf, centers, indices);
|
||||
HWY_ASSERT(unused_clusters == kClusters - 1);
|
||||
|
||||
for (size_t i = 0; i < unused_clusters; ++i) {
|
||||
HWY_ASSERT(centers[i] == 0.0f);
|
||||
}
|
||||
HWY_ASSERT(centers[unused_clusters] == 0.5f);
|
||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
HWY_ASSERT(indices[i] == unused_clusters);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void TestAllFlat() { hn::ForGEVectors<64, TestFlat>()(float()); }
|
||||
|
||||
// Generate shuffled plateaus, one per cluster
|
||||
struct TestPlateaus {
|
||||
template <typename T, class DF>
|
||||
HWY_INLINE void operator()(T /*unused*/, DF df) {
|
||||
// Run this simple test only once to save time/debug output.
|
||||
if (!(HWY_ONCE && hn::Lanes(df) == hn::Lanes(hn::ScalableTag<float>()))) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto in = hwy::AllocateAligned<float>(kGroupSize);
|
||||
HWY_ASSERT(in);
|
||||
|
||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
const size_t idx_cluster = i / (kGroupSize / kClusters);
|
||||
HWY_ASSERT(idx_cluster < kClusters);
|
||||
in[i] = (1.0f * idx_cluster / kClusters) - 0.5f;
|
||||
HWY_ASSERT(-0.5f <= in[i] && in[i] < 0.5f);
|
||||
}
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 rng(rd());
|
||||
std::shuffle(in.get(), in.get() + kGroupSize, rng);
|
||||
|
||||
ClusterBuf buf;
|
||||
float centers[kClusters];
|
||||
uint16_t indices[kGroupSize];
|
||||
const size_t unused_clusters =
|
||||
NuqClustering::ClusterExactL2(df, in.get(), buf, centers, indices);
|
||||
HWY_ASSERT(unused_clusters == 0);
|
||||
|
||||
DistortionStats stats;
|
||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
HWY_ASSERT(indices[i] < kClusters);
|
||||
stats.Notify(in[i], centers[indices[i]]);
|
||||
}
|
||||
const float pnorm = stats.PNorm();
|
||||
const float snr = stats.GeomeanValueDivL1();
|
||||
fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr,
|
||||
stats.MaxIndex(), stats.MaxL1());
|
||||
HWY_ASSERT(pnorm == 0.0f);
|
||||
HWY_ASSERT(snr == 0.0f);
|
||||
}
|
||||
};
|
||||
|
||||
void TestAllPlateaus() { hn::ForGEVectors<64, TestPlateaus>()(float()); }
|
||||
|
||||
struct TestRamp {
|
||||
template <typename T, class DF>
|
||||
HWY_INLINE void operator()(T /*unused*/, DF df) {
|
||||
// Run this simple test only once to save time/debug output.
|
||||
if (!(HWY_ONCE && hn::Lanes(df) == hn::Lanes(hn::ScalableTag<float>()))) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto in = hwy::AllocateAligned<float>(kGroupSize);
|
||||
HWY_ASSERT(in);
|
||||
|
||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
in[i] = (1.0f * i / kGroupSize) - 0.45f; // slightly asymmetric
|
||||
HWY_ASSERT(-0.45f <= in[i] && in[i] < 0.55f);
|
||||
}
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 rng(rd());
|
||||
std::shuffle(in.get(), in.get() + kGroupSize, rng);
|
||||
|
||||
ClusterBuf buf;
|
||||
float centers[kClusters];
|
||||
uint16_t indices[kGroupSize];
|
||||
const size_t unused_clusters =
|
||||
NuqClustering::ClusterExactL2(df, in.get(), buf, centers, indices);
|
||||
HWY_ASSERT(unused_clusters == 0);
|
||||
|
||||
DistortionStats stats;
|
||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
HWY_ASSERT(indices[i] < kClusters);
|
||||
stats.Notify(in[i], centers[indices[i]]);
|
||||
}
|
||||
const float pnorm = stats.PNorm();
|
||||
const float snr = stats.GeomeanValueDivL1();
|
||||
fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr,
|
||||
stats.MaxIndex(), stats.MaxL1());
|
||||
static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected");
|
||||
|
||||
const float expected_pnorm = kGroupSize == 128 ? 2.08E-2f : 2.1E-2f;
|
||||
const float expected_snr = kGroupSize == 128 ? 16.9f : 17.6f;
|
||||
HWY_ASSERT(expected_pnorm <= pnorm && pnorm < 1.02f * expected_pnorm);
|
||||
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr);
|
||||
}
|
||||
};
|
||||
|
||||
void TestAllRamp() { hn::ForGEVectors<64, TestRamp>()(float()); }
|
||||
|
||||
struct TestNormal {
|
||||
template <typename T, class DF>
|
||||
HWY_INLINE void operator()(T /*unused*/, DF df) {
|
||||
auto in = hwy::AllocateAligned<float>(kGroupSize);
|
||||
HWY_ASSERT(in);
|
||||
|
||||
std::mt19937 rng(123);
|
||||
std::normal_distribution<float> dist{0.001f, 0.3f};
|
||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
in[i] = dist(rng);
|
||||
}
|
||||
std::shuffle(in.get(), in.get() + kGroupSize, rng);
|
||||
|
||||
ClusterBuf buf;
|
||||
float centers[kClusters];
|
||||
uint16_t indices[kGroupSize];
|
||||
double elapsed = hwy::HighestValue<double>();
|
||||
for (size_t rep = 0; rep < 100; ++rep) {
|
||||
const double t0 = hwy::platform::Now();
|
||||
const size_t unused_clusters =
|
||||
NuqClustering::ClusterExactL2(df, in.get(), buf, centers, indices);
|
||||
HWY_ASSERT(unused_clusters == 0);
|
||||
const double t1 = hwy::platform::Now();
|
||||
elapsed = HWY_MIN(elapsed, t1 - t0);
|
||||
}
|
||||
fprintf(stderr, "Vec %zu Enc %.2f MB/s\n", Lanes(df) * 4,
|
||||
kGroupSize * sizeof(float) * 1E-6 / elapsed);
|
||||
|
||||
DistortionStats stats;
|
||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
HWY_ASSERT(indices[i] < kClusters);
|
||||
stats.Notify(in[i], centers[indices[i]]);
|
||||
}
|
||||
const float pnorm = stats.PNorm();
|
||||
const float snr = stats.GeomeanValueDivL1();
|
||||
fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr,
|
||||
stats.MaxIndex(), stats.MaxL1());
|
||||
static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected");
|
||||
const float expected_pnorm = kGroupSize == 128 ? 3E-2f : 3.4E-2f;
|
||||
const float expected_snr = kGroupSize == 128 ? 17.4f : 13.1f;
|
||||
HWY_ASSERT(expected_pnorm <= pnorm && pnorm < 1.02f * expected_pnorm);
|
||||
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr);
|
||||
}
|
||||
};
|
||||
|
||||
void TestAllNormal() { hn::ForGEVectors<64, TestNormal>()(float()); }
|
||||
|
||||
// Can encode and decode sub-regions.
|
||||
struct TestOffset {
|
||||
template <typename T, class D>
|
||||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||
const hn::Repartition<float, D> df;
|
||||
const size_t total = 10 * kGroupSize;
|
||||
const size_t kMidLen = 2 * kGroupSize; // length of middle piece
|
||||
|
||||
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
|
||||
auto dec1 = hwy::AllocateAligned<T>(total);
|
||||
auto dec2 = hwy::AllocateAligned<T>(kMidLen);
|
||||
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
|
||||
HWY_ASSERT(in && dec1 && dec2 && nuq);
|
||||
|
||||
std::mt19937 rng(123);
|
||||
std::normal_distribution<float> dist{0.001f, 0.3f};
|
||||
for (size_t i = 0; i < total; ++i) {
|
||||
in[i] = dist(rng);
|
||||
}
|
||||
|
||||
// Encode + decode everything
|
||||
ClusterBuf buf;
|
||||
(void)NuqCodec::Enc(df, in.get(), total, buf, total, nuq.get(), 0);
|
||||
NuqCodec::Dec(d, total, nuq.get(), 0, dec1.get(), total);
|
||||
|
||||
// Overwrite middle with first inputs
|
||||
const size_t offset = 5 * kGroupSize;
|
||||
(void)NuqCodec::Enc(df, in.get(), kMidLen, buf, total, nuq.get(), offset);
|
||||
|
||||
// Decoded middle now matches previously decoded first
|
||||
NuqCodec::Dec(d, total, nuq.get(), offset, dec2.get(), kMidLen);
|
||||
for (size_t i = 0; i < kMidLen; ++i) {
|
||||
HWY_ASSERT(dec1[i] == dec2[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void TestAllOffsetF32() {
|
||||
const hn::ForGEVectors<128, TestOffset> test;
|
||||
test(float());
|
||||
}
|
||||
|
||||
void TestAllOffsetBF16() {
|
||||
const hn::ForGEVectors<128, TestOffset> test;
|
||||
test(hwy::bfloat16_t());
|
||||
}
|
||||
|
||||
struct TestStream {
|
||||
template <typename T, class D>
|
||||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||
const hn::Repartition<float, D> df;
|
||||
const size_t num = 4 * kGroupSize;
|
||||
auto in = hwy::AllocateAligned<float>(num); // Enc() requires f32
|
||||
auto out = hwy::AllocateAligned<T>(num);
|
||||
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(num));
|
||||
HWY_ASSERT(in && out && nuq);
|
||||
|
||||
std::mt19937 rng(123);
|
||||
std::normal_distribution<float> dist{0.001f, 0.3f};
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
in[i] = dist(rng);
|
||||
}
|
||||
|
||||
ClusterBuf buf;
|
||||
double elapsed = hwy::HighestValue<double>();
|
||||
for (size_t rep = 0; rep < 100; ++rep) {
|
||||
const double t0 = hwy::platform::Now();
|
||||
const size_t unused_clusters =
|
||||
NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0);
|
||||
HWY_ASSERT(unused_clusters == 0);
|
||||
const double t1 = hwy::platform::Now();
|
||||
elapsed = HWY_MIN(elapsed, t1 - t0);
|
||||
}
|
||||
fprintf(stderr, "Vec %zu Enc %.2f MB/s\n", Lanes(d) * sizeof(T),
|
||||
num * sizeof(float) * 1E-6 / elapsed);
|
||||
|
||||
elapsed = hwy::HighestValue<double>();
|
||||
for (size_t rep = 0; rep < 100; ++rep) {
|
||||
const double t0 = hwy::platform::Now();
|
||||
NuqCodec::Dec(d, num, nuq.get(), 0, out.get(), num);
|
||||
const double t1 = hwy::platform::Now();
|
||||
elapsed = HWY_MIN(elapsed, t1 - t0);
|
||||
}
|
||||
fprintf(stderr, "Vec %zu Dec %.2f MB/s\n", Lanes(d) * sizeof(T),
|
||||
num * sizeof(T) * 1E-6 / elapsed);
|
||||
|
||||
DistortionStats stats;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
stats.Notify(in[i], hwy::ConvertScalarTo<float>(out[i]));
|
||||
}
|
||||
const float pnorm = stats.PNorm();
|
||||
const float snr = stats.GeomeanValueDivL1();
|
||||
fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr,
|
||||
stats.MaxIndex(), stats.MaxL1());
|
||||
static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected");
|
||||
const float expected_pnorm = kGroupSize == 128 ? 3.44E-2f : 3.88E-2f;
|
||||
const float expected_snr = kGroupSize == 128 ? 15.0f : 13.3f;
|
||||
HWY_ASSERT(expected_pnorm <= pnorm && pnorm < 1.02f * expected_pnorm);
|
||||
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr);
|
||||
}
|
||||
};
|
||||
|
||||
void TestAllStreamF32() {
|
||||
const hn::ForGEVectors<128, TestStream> test;
|
||||
test(float());
|
||||
}
|
||||
|
||||
void TestAllStreamBF16() {
|
||||
const hn::ForGEVectors<128, TestStream> test;
|
||||
test(hwy::bfloat16_t());
|
||||
}
|
||||
|
||||
struct TestDot {
|
||||
template <typename T, class D>
|
||||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||
const hn::Repartition<float, D> df;
|
||||
const size_t num = 4 * kGroupSize;
|
||||
auto in = hwy::AllocateAligned<float>(num);
|
||||
auto dec = hwy::AllocateAligned<float>(num);
|
||||
auto vec = hwy::AllocateAligned<T>(num);
|
||||
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(num));
|
||||
HWY_ASSERT(in && dec && vec && nuq);
|
||||
|
||||
std::mt19937 rng(123);
|
||||
std::normal_distribution<float> dist{0.001f, 0.3f};
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
in[i] = dist(rng);
|
||||
vec[i] = hwy::ConvertScalarTo<T>(dist(rng));
|
||||
}
|
||||
// This changes the correlation between in and vec, which considerably
|
||||
// affects the error of the result.
|
||||
std::shuffle(in.get(), in.get() + num, rng);
|
||||
|
||||
ClusterBuf buf;
|
||||
const size_t unused_clusters =
|
||||
NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0);
|
||||
HWY_ASSERT(unused_clusters == 0);
|
||||
|
||||
double actual = 0.0;
|
||||
double elapsed = hwy::HighestValue<double>();
|
||||
for (size_t rep = 0; rep < 20; ++rep) {
|
||||
hn::Vec<decltype(df)> sum0 = hn::Zero(df);
|
||||
hn::Vec<decltype(df)> sum1 = hn::Zero(df);
|
||||
hn::Vec<decltype(df)> sum2 = hn::Zero(df);
|
||||
hn::Vec<decltype(df)> sum3 = hn::Zero(df);
|
||||
const double t0 = hwy::platform::Now();
|
||||
NuqCodec::Dot(df, num, nuq.get(), 0, vec.get(), num, sum0, sum1, sum2,
|
||||
sum3);
|
||||
const double t1 = hwy::platform::Now();
|
||||
elapsed = HWY_MIN(elapsed, t1 - t0);
|
||||
sum0 = hn::Add(hn::Add(sum0, sum1), hn::Add(sum2, sum3));
|
||||
actual = hn::ReduceSum(df, sum0);
|
||||
}
|
||||
|
||||
NuqCodec::Dec(df, num, nuq.get(), 0, dec.get(), num);
|
||||
fprintf(stderr, "Vec %zu Dec %.2f MB/s\n", Lanes(d) * sizeof(T),
|
||||
num * sizeof(in[0]) * 1E-6 / elapsed);
|
||||
|
||||
double expected = 0.0; // using original input
|
||||
double expected2 = 0.0; // using decoded NUQ
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
expected += in[i] * hwy::ConvertScalarTo<double>(vec[i]);
|
||||
expected2 += dec[i] * hwy::ConvertScalarTo<double>(vec[i]);
|
||||
}
|
||||
const double l1 = hwy::ScalarAbs(expected - actual);
|
||||
const double snr = 1.0 + hwy::ScalarAbs(expected) / l1;
|
||||
fprintf(stderr, "expected %.3f e2 %.4f actual %.4f l1 %E snr %.2f\n",
|
||||
expected, expected2, actual, l1, snr);
|
||||
HWY_ASSERT(hwy::ScalarAbs(expected2 - actual) < 1E-4);
|
||||
static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected");
|
||||
const double expected_l1 = kGroupSize == 128 ? 7.3E-2 : 4.34E-2;
|
||||
const double expected_snr = kGroupSize == 128 ? 9.7f
|
||||
: sizeof(T) == 2 ? 14.5f
|
||||
: 14.9f;
|
||||
HWY_ASSERT(expected_l1 <= l1 && l1 < 1.02f * expected_l1);
|
||||
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr);
|
||||
}
|
||||
};
|
||||
|
||||
void TestAllDotF32() {
|
||||
const hn::ForGEVectors<128, TestDot> test;
|
||||
test(float());
|
||||
}
|
||||
void TestAllDotBF16() {
|
||||
const hn::ForGEVectors<128, TestDot> test;
|
||||
test(hwy::bfloat16_t());
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
|
||||
namespace gcpp {
|
||||
HWY_BEFORE_TEST(NuqTest);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllFlat);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllPlateaus);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllRamp);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNormal);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllOffsetF32);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllOffsetBF16);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamF32);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamBF16);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllDotF32);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllDotBF16);
|
||||
} // namespace gcpp
|
||||
|
||||
#endif
|
||||
|
|
@ -0,0 +1,515 @@
|
|||
// Copyright 2023 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Normal include guard to placate lint.
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/sfp.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_
|
||||
|
||||
// Actual per-target include guard.
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||
#ifdef THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE
|
||||
#undef THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE
|
||||
#else
|
||||
#define THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE
|
||||
#endif
|
||||
|
||||
#include "hwy/highway.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
// For unsigned numbers with MSB zero, signed comparison is faster on x86.
|
||||
template <class DU>
|
||||
HWY_INLINE hn::Mask<DU> SignedGt(DU du, hn::Vec<DU> a, hn::Vec<DU> b) {
|
||||
const hn::RebindToSigned<DU> di;
|
||||
return hn::RebindMask(du, hn::Gt(BitCast(di, a), hn::BitCast(di, b)));
|
||||
}
|
||||
template <class DU>
|
||||
HWY_INLINE hn::Mask<DU> SignedLt(DU du, hn::Vec<DU> a, hn::Vec<DU> b) {
|
||||
return SignedGt(du, b, a);
|
||||
}
|
||||
|
||||
// Encode/decode functions.
|
||||
class SfpCodec {
|
||||
public:
|
||||
// Returns 8-bit packed representation of `lo` and `hi` bytes of bf16. 31 ops.
|
||||
// Implementation detail, public because called by test.
|
||||
template <class D, HWY_IF_U8_D(D)>
|
||||
static HWY_INLINE hn::Vec<D> EncBytes(D d, const hn::Vec<D> lo,
|
||||
const hn::Vec<D> hi) {
|
||||
const hn::Vec<D> k1 = hn::Set(d, 1u);
|
||||
const hn::Vec<D> k80 = hn::Set(d, 0x80u);
|
||||
|
||||
// Copy sign for later insertion.
|
||||
const hn::Vec<D> sign_in_msb = hi;
|
||||
// Biased exponent = lower 7 bits of hi and MSB of lo. Modified below.
|
||||
hn::Vec<D> biased_e = hn::Or(hn::Add(hi, hi), hn::ShiftRight<7>(lo));
|
||||
HWY_ASSERT(hn::AllTrue(d, hn::Lt(biased_e, k80))); // <= 2^0
|
||||
|
||||
// Clear MSB to isolate the mantissa and enable signed comparisons, then
|
||||
// shift right by *one* (plus 1 to undo the prior add/left-shift) to leave
|
||||
// headroom for overflow during rounding.
|
||||
const hn::Vec<D> m6 = hn::ShiftRight<2>(hn::Add(lo, lo));
|
||||
|
||||
// The place to round depends on whether the exponent is large (>= -7) - if
|
||||
// so, we retain three mantissa bits, otherwise two. However, rounding can
|
||||
// also cause the exponent to increase. We first choose a threshold that
|
||||
// rounds up to 1.0*2^-7 for both two and three bit mantissas:
|
||||
// >= 1.1111 * 2^-8 (0.007568359375). This entails the exponent being
|
||||
// greater, or equal and the mantissa > (1111000 >> 1) - 1 = 0x3B.
|
||||
const hn::Vec<D> kMinLargeE = hn::Set(d, 127 - 8);
|
||||
const hn::Mask<D> is_large_before_round = hn::Or(
|
||||
SignedGt(d, biased_e, kMinLargeE),
|
||||
hn::And(hn::Eq(biased_e, kMinLargeE), SignedGt(d, m6, Set(d, 0x3B))));
|
||||
|
||||
// To retain the most-significant 3 or 2 mantissa bits, we will right-shift
|
||||
// by is_large_before_round ? 3 : 4. Variable Shr is expensive for 8-bit
|
||||
// elements, so (<< 1) if is_large_before_round, then always (>> 4).
|
||||
const hn::Vec<D> m_shl4 =
|
||||
hn::MaskedAddOr(m6, is_large_before_round, m6, m6);
|
||||
|
||||
// Before shifting (truncation), round to nearest even to reduce bias. If
|
||||
// the lowest remaining mantissa bit is odd, increase the offset. Example
|
||||
// with the lowest remaining bit (left) and next lower two bits; the
|
||||
// latter, plus two more, will be truncated.
|
||||
// 0[00] + 1 = 0[01]
|
||||
// 0[01] + 1 = 0[10]
|
||||
// 0[10] + 1 = 0[11] (round down toward even)
|
||||
// 0[11] + 1 = 1[00] (round up)
|
||||
// 1[00] + 10 = 1[10]
|
||||
// 1[01] + 10 = 1[11]
|
||||
// 1[10] + 10 = C0[00] (round up toward even with C=1 carry out)
|
||||
// 1[11] + 10 = C0[01] (round up toward even with C=1 carry out)
|
||||
const hn::Vec<D> odd_bit = hn::And(hn::ShiftRight<4>(m_shl4), k1);
|
||||
const hn::Vec<D> rounded = hn::Add(m_shl4, hn::Add(odd_bit, Set(d, 7)));
|
||||
// Update the exponent if rounding overflowed.
|
||||
const hn::Vec<D> carry_bit =
|
||||
hn::IfThenElse(is_large_before_round, k80, hn::Set(d, 0x40u));
|
||||
const hn::Vec<D> carry_clear = hn::AndNot(carry_bit, rounded);
|
||||
HWY_DASSERT(hn::AllTrue(d, hn::Lt(carry_clear, carry_bit)));
|
||||
const hn::Mask<D> is_overflow = hn::Ne(carry_clear, rounded);
|
||||
biased_e = hn::MaskedAddOr(biased_e, is_overflow, biased_e, k1);
|
||||
HWY_DASSERT(hn::AllTrue(d, hn::Lt(biased_e, Set(d, 128))));
|
||||
|
||||
// Detect if zero or the min exponent.
|
||||
const hn::Vec<D> kMinNormal = hn::Set(d, 127 - 23);
|
||||
const hn::Mask<D> is_zero = SignedLt(d, biased_e, kMinNormal);
|
||||
const hn::Mask<D> is_min = hn::Eq(biased_e, kMinNormal);
|
||||
|
||||
// 1.1110xxx * 2^-8 was considered small above, and thus rounded up to 2^-7,
|
||||
// which the decoder will consider large, and expect 3 mantissa bits. If we
|
||||
// set the threshold above to 1.111, then it does NOT round up. Thus we
|
||||
// check exponent >= -7 *after* rounding.
|
||||
const hn::Mask<D> is_large = SignedGt(d, biased_e, hn::Set(d, 127 - 8));
|
||||
|
||||
// To extract and pack the mantissa, only is_large matters. Either it
|
||||
// matches is_large_before_round, or the rounding resulted in mantissa=0, so
|
||||
// we either extract two or three bits by shifting out the lower 5..6 bits.
|
||||
// is_large_before is_large rounded want
|
||||
// 0 0 0Cmm???? mm
|
||||
// 0 1 0100???? 000
|
||||
// 1 0 impossible -
|
||||
// 1 1 Cmmm???0 mmm
|
||||
hn::Vec<D> m = hn::ShiftRight<4>(carry_clear);
|
||||
HWY_DASSERT(hn::AllTrue(
|
||||
d, SignedLt(d, m,
|
||||
hn::IfThenElse(is_large, hn::Set(d, 8), hn::Set(d, 4)))));
|
||||
|
||||
// 1.0 * 2^-23 has the same encoding as zero, so round it up to 1.01.
|
||||
m = hn::MaskedMaxOr(m, is_min, m, k1);
|
||||
|
||||
const hn::Vec<D> e_bias = hn::IfThenElse(
|
||||
is_large,
|
||||
hn::Set(d, hwy::BitCastScalar<uint8_t>(static_cast<int8_t>(15 - 127))),
|
||||
hn::Set(d, hwy::BitCastScalar<uint8_t>(static_cast<int8_t>(23 - 127))));
|
||||
const hn::Vec<D> e = hn::Add(biased_e, e_bias);
|
||||
HWY_DASSERT(
|
||||
hn::AllTrue(d, hn::Lt(hn::IfThenZeroElse(is_zero, e), hn::Set(d, 16))));
|
||||
|
||||
// Shift exponent left 2 or 3 bits to make space for `m`.
|
||||
const hn::Vec<D> em =
|
||||
hn::Or(m, hn::ShiftLeft<2>(hn::MaskedAddOr(e, is_large, e, e)));
|
||||
HWY_DASSERT(hn::AllTrue(d, hn::Lt(hn::IfThenZeroElse(is_zero, em), k80)));
|
||||
const hn::Vec<D> encoded = hn::BitwiseIfThenElse(k80, sign_in_msb, em);
|
||||
// Doing this last ensures -0 is replaced with 0.
|
||||
return hn::IfThenZeroElse(is_zero, encoded);
|
||||
}
|
||||
|
||||
// Decodes u8 `encoded` into `lo` and `hi` bytes of bf16. 12 ops.
|
||||
// Implementation detail, public because called by test.
|
||||
template <class D, HWY_IF_U8_D(D)>
|
||||
static HWY_INLINE void DecBytes(D d, hn::Vec<D> encoded, hn::Vec<D>& lo,
|
||||
hn::Vec<D>& hi) {
|
||||
const hn::Vec<D> k0 = hn::Zero(d);
|
||||
const hn::Vec<D> k80 = hn::Set(d, 0x80u);
|
||||
|
||||
HWY_DASSERT(hn::AllTrue(d, hn::Ne(encoded, k80))); // -0 is reserved
|
||||
// Copy sign for later insertion via BitwiseIfThenElse.
|
||||
const hn::Vec<D> sign_in_msb = encoded;
|
||||
encoded = hn::AndNot(k80, encoded);
|
||||
|
||||
// Special-case zero, negated so we can use MaskedAddOr. Signed comparison
|
||||
// is fine because we have cleared the sign bit.
|
||||
const hn::Mask<D> is_nonzero = SignedGt(d, encoded, k0);
|
||||
// If MSB is clear, we have two mantissa bits, otherwise three.
|
||||
const hn::Mask<D> is_small_e = SignedLt(d, encoded, hn::Set(d, 64));
|
||||
// If is_small_e, add/left-shift 0xxxx.mm to 0xxxx.mm0; else keep 1xxx.mmm.
|
||||
const hn::Vec<D> e4m3 =
|
||||
hn::MaskedAddOr(encoded, is_small_e, encoded, encoded);
|
||||
HWY_DASSERT(hn::AllTrue(d, hn::Lt(e4m3, k80)));
|
||||
const hn::Vec<D> e = hn::ShiftRight<3>(e4m3); // 4-bit exponent only
|
||||
HWY_DASSERT(hn::AllTrue(d, hn::Lt(e, Set(d, 16u))));
|
||||
// The encoded exponent for 2^0 is 15, so subtract 15. Add 127 for the
|
||||
// binary32/bf16 bias. Subtract another 8 if is_small_e because its lowest
|
||||
// encoded value (0) should be less than the lowest 'large' exponent 2^-7.
|
||||
const hn::Vec<D> e_bias = hn::IfThenElse(
|
||||
is_small_e, hn::Set(d, 127u - 15u - 8u), hn::Set(d, 127u - 15u));
|
||||
// Special-case zero or add e_bias. If encoded=0, e and e4m3 are zero, but
|
||||
// we must zero e_bias to get the desired all-zero bf16.
|
||||
const hn::Vec<D> biased_e = hn::MaskedAddOr(k0, is_nonzero, e_bias, e);
|
||||
// The decoded binary32 exponent should be at most 2^0.
|
||||
HWY_DASSERT(hn::AllTrue(d, hn::Lt(biased_e, k80)));
|
||||
|
||||
// Shift the MSB of e4m3's mantissa into the MSB of the bf16 mantissa.
|
||||
const hn::Vec<D> m7 = hn::ShiftLeft<4>(e4m3);
|
||||
// Lower byte of bf16 = exponent LSB || mantissa.
|
||||
lo = hn::BitwiseIfThenElse(k80, hn::ShiftLeft<7>(biased_e), m7);
|
||||
// Upper byte of bf16 = sign || lower 7 bits of exponent.
|
||||
hi = hn::BitwiseIfThenElse(k80, sign_in_msb, hn::ShiftRight<1>(biased_e));
|
||||
}
|
||||
|
||||
// Encodes `num` bf16 values from `in_bf` to `out_packed`.
|
||||
template <class DBF, HWY_IF_BF16_D(DBF)>
|
||||
static HWY_INLINE void Enc(DBF dbf, const hwy::bfloat16_t* HWY_RESTRICT in_bf,
|
||||
size_t num, SfpStream* HWY_RESTRICT out_packed) {
|
||||
const hn::Repartition<uint8_t, DBF> d8;
|
||||
using V8 = hn::Vec<decltype(d8)>;
|
||||
const size_t N16 = hn::Lanes(dbf);
|
||||
|
||||
size_t i = 0;
|
||||
if (num >= 2 * N16) {
|
||||
HWY_UNROLL(1)
|
||||
for (; i <= num - 2 * N16; i += 2 * N16) {
|
||||
const V8 packed = Enc2B(dbf, in_bf + i);
|
||||
hn::StoreU(packed, d8, &out_packed->byte + i);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t remaining = num - i;
|
||||
HWY_DASSERT(remaining < 2 * N16);
|
||||
if (remaining != 0) {
|
||||
HWY_ALIGN hwy::bfloat16_t padded[2 * hn::MaxLanes(dbf)];
|
||||
hwy::ZeroBytes(padded, sizeof(padded));
|
||||
hwy::CopyBytes(in_bf + i, padded, remaining * sizeof(padded[0]));
|
||||
const V8 packed = Enc2B(dbf, padded);
|
||||
hn::StoreN(packed, d8, &out_packed->byte + i, remaining);
|
||||
}
|
||||
}
|
||||
|
||||
// Encodes `num` f32 values from `in_f` to `packed`.
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Enc(DF df, const float* HWY_RESTRICT in_f, size_t num,
|
||||
SfpStream* HWY_RESTRICT out_packed) {
|
||||
const hn::Repartition<uint8_t, DF> d8;
|
||||
using V8 = hn::Vec<decltype(d8)>;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
||||
size_t i = 0;
|
||||
if (num >= 4 * NF) {
|
||||
HWY_UNROLL(1)
|
||||
for (; i <= num - 4 * NF; i += 4 * NF) {
|
||||
const V8 packed = Enc4F(df, in_f + i);
|
||||
hn::StoreU(packed, d8, &out_packed->byte + i);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t remaining = num - i;
|
||||
HWY_DASSERT(remaining < 4 * NF);
|
||||
if (remaining != 0) {
|
||||
HWY_ALIGN float padded[4 * hn::MaxLanes(df)];
|
||||
hwy::ZeroBytes(padded, sizeof(padded));
|
||||
hwy::CopyBytes(in_f + i, padded, remaining * sizeof(padded[0]));
|
||||
const V8 packed = Enc4F(df, padded);
|
||||
hn::StoreN(packed, d8, &out_packed->byte + i, remaining);
|
||||
}
|
||||
}
|
||||
|
||||
// Decodes `num` values from `in_packed` to `out_bf`.
|
||||
template <class DBF, HWY_IF_BF16_D(DBF)>
|
||||
static HWY_INLINE void Dec(DBF dbf, const SfpStream* HWY_RESTRICT in_packed,
|
||||
size_t num, hwy::bfloat16_t* HWY_RESTRICT out_bf) {
|
||||
const hn::Repartition<uint8_t, DBF> d8;
|
||||
using V8 = hn::Vec<decltype(d8)>;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
const size_t N16 = hn::Lanes(dbf);
|
||||
|
||||
size_t i = 0;
|
||||
if (num >= 2 * N16) {
|
||||
HWY_UNROLL(1)
|
||||
for (; i <= num - 2 * N16; i += 2 * N16) {
|
||||
const V8 packed = hn::LoadU(d8, &in_packed->byte + i);
|
||||
VBF bf0, bf1;
|
||||
Dec2B(dbf, packed, bf0, bf1);
|
||||
hn::StoreU(bf0, dbf, out_bf + i);
|
||||
hn::StoreU(bf1, dbf, out_bf + i + N16);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t remaining = num - i;
|
||||
HWY_DASSERT(remaining < 2 * N16);
|
||||
if (remaining != 0) {
|
||||
const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining);
|
||||
HWY_ALIGN hwy::bfloat16_t padded[2 * hn::MaxLanes(dbf)];
|
||||
VBF bf0, bf1;
|
||||
Dec2B(dbf, packed, bf0, bf1);
|
||||
hn::StoreU(bf0, dbf, padded);
|
||||
hn::StoreU(bf1, dbf, padded + N16);
|
||||
hwy::CopyBytes(padded, out_bf + i, remaining * sizeof(padded[0]));
|
||||
}
|
||||
}
|
||||
|
||||
// Decodes `num` values from `in_packed` to `out_f`.
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Dec(DF df, const SfpStream* HWY_RESTRICT in_packed,
|
||||
size_t num, float* HWY_RESTRICT out_f) {
|
||||
const hn::Repartition<uint8_t, DF> d8;
|
||||
using V8 = hn::Vec<decltype(d8)>;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
||||
size_t i = 0;
|
||||
if (num >= 4 * NF) {
|
||||
HWY_UNROLL(1)
|
||||
for (; i <= num - 4 * NF; i += 4 * NF) {
|
||||
const V8 packed = hn::LoadU(d8, &in_packed->byte + i);
|
||||
VF f0, f1, f2, f3;
|
||||
Dec4F(df, packed, f0, f1, f2, f3);
|
||||
hn::StoreU(f0, df, out_f + i + NF * 0);
|
||||
hn::StoreU(f1, df, out_f + i + NF * 1);
|
||||
hn::StoreU(f2, df, out_f + i + NF * 2);
|
||||
hn::StoreU(f3, df, out_f + i + NF * 3);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t remaining = num - i;
|
||||
HWY_DASSERT(remaining < 4 * NF);
|
||||
if (remaining != 0) {
|
||||
const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining);
|
||||
HWY_ALIGN float padded[4 * hn::MaxLanes(df)];
|
||||
VF f0, f1, f2, f3;
|
||||
Dec4F(df, packed, f0, f1, f2, f3);
|
||||
hn::StoreU(f0, df, padded + NF * 0);
|
||||
hn::StoreU(f1, df, padded + NF * 1);
|
||||
hn::StoreU(f2, df, padded + NF * 2);
|
||||
hn::StoreU(f3, df, padded + NF * 3);
|
||||
hwy::CopyBytes(padded, out_f + i, remaining * sizeof(padded[0]));
|
||||
}
|
||||
}
|
||||
|
||||
// Fused decode and dot product with bf16 into four output accumulators.
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Dot(DF df, const SfpStream* HWY_RESTRICT in_packed,
|
||||
size_t num,
|
||||
const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
|
||||
hn::Vec<DF>& sum0, hn::Vec<DF>& sum1,
|
||||
hn::Vec<DF>& sum2, hn::Vec<DF>& sum3) {
|
||||
const hn::Repartition<uint8_t, DF> d8;
|
||||
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
|
||||
using V8 = hn::Vec<decltype(d8)>;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
const size_t N16 = hn::Lanes(dbf);
|
||||
|
||||
size_t i = 0;
|
||||
if (num >= 2 * N16) {
|
||||
HWY_UNROLL(1)
|
||||
for (; i <= num - 2 * N16; i += 2 * N16) {
|
||||
const V8 packed = hn::LoadU(d8, &in_packed->byte + i);
|
||||
const VBF v0 = hn::LoadU(dbf, vec_aligned + i);
|
||||
const VBF v1 = hn::LoadU(dbf, vec_aligned + i + N16);
|
||||
VBF bf0, bf1;
|
||||
Dec2B(dbf, packed, bf0, bf1);
|
||||
sum0 = hn::ReorderWidenMulAccumulate(df, bf0, v0, sum0, sum1);
|
||||
sum2 = hn::ReorderWidenMulAccumulate(df, bf1, v1, sum2, sum3);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t remaining = num - i;
|
||||
if (remaining != 0) {
|
||||
const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining);
|
||||
HWY_ALIGN hwy::bfloat16_t padded[2 * hn::MaxLanes(dbf)];
|
||||
hwy::ZeroBytes(padded, sizeof(padded));
|
||||
hwy::CopyBytes(vec_aligned + i, padded, remaining * sizeof(padded[0]));
|
||||
const VBF v0 = hn::LoadU(dbf, padded);
|
||||
const VBF v1 = hn::LoadU(dbf, padded + N16);
|
||||
VBF bf0, bf1;
|
||||
Dec2B(dbf, packed, bf0, bf1);
|
||||
sum0 = hn::ReorderWidenMulAccumulate(df, bf0, v0, sum0, sum1);
|
||||
sum2 = hn::ReorderWidenMulAccumulate(df, bf1, v1, sum2, sum3);
|
||||
}
|
||||
}
|
||||
|
||||
// Fused decode and dot product with f32 into four output accumulators.
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Dot(DF df, const SfpStream* HWY_RESTRICT in_packed,
|
||||
size_t num, const float* HWY_RESTRICT vec_aligned,
|
||||
hn::Vec<DF>& sum0, hn::Vec<DF>& sum1,
|
||||
hn::Vec<DF>& sum2, hn::Vec<DF>& sum3) {
|
||||
const hn::Repartition<uint8_t, DF> d8;
|
||||
using V8 = hn::Vec<decltype(d8)>;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
||||
size_t i = 0;
|
||||
if (num >= 4 * NF) {
|
||||
HWY_UNROLL(1)
|
||||
for (; i <= num - 4 * NF; i += 4 * NF) {
|
||||
const V8 packed = hn::LoadU(d8, &in_packed->byte + i);
|
||||
const VF v0 = hn::LoadU(df, vec_aligned + i + NF * 0);
|
||||
const VF v1 = hn::LoadU(df, vec_aligned + i + NF * 1);
|
||||
const VF v2 = hn::LoadU(df, vec_aligned + i + NF * 2);
|
||||
const VF v3 = hn::LoadU(df, vec_aligned + i + NF * 3);
|
||||
VF f0, f1, f2, f3;
|
||||
Dec4F(df, packed, f0, f1, f2, f3);
|
||||
sum0 = hn::MulAdd(f0, v0, sum0);
|
||||
sum1 = hn::MulAdd(f1, v1, sum1);
|
||||
sum2 = hn::MulAdd(f2, v2, sum2);
|
||||
sum3 = hn::MulAdd(f3, v3, sum3);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t remaining = num - i;
|
||||
if (remaining != 0) {
|
||||
const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining);
|
||||
HWY_ALIGN float padded[4 * hn::MaxLanes(df)];
|
||||
hwy::ZeroBytes(padded, sizeof(padded));
|
||||
hwy::CopyBytes(vec_aligned + i, padded, remaining * sizeof(padded[0]));
|
||||
const VF v0 = hn::LoadU(df, padded + NF * 0);
|
||||
const VF v1 = hn::LoadU(df, padded + NF * 1);
|
||||
const VF v2 = hn::LoadU(df, padded + NF * 2);
|
||||
const VF v3 = hn::LoadU(df, padded + NF * 3);
|
||||
VF f0, f1, f2, f3;
|
||||
Dec4F(df, packed, f0, f1, f2, f3);
|
||||
sum0 = hn::MulAdd(f0, v0, sum0);
|
||||
sum1 = hn::MulAdd(f1, v1, sum1);
|
||||
sum2 = hn::MulAdd(f2, v2, sum2);
|
||||
sum3 = hn::MulAdd(f3, v3, sum3);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Wrappers to avoid code duplication across float/bf16 input types and
|
||||
// the main loop/remainder.
|
||||
|
||||
// Returns vector of packed bytes for callers to StoreU or StoreN.
|
||||
template <class D16, HWY_IF_U16_D(D16),
|
||||
class V8 = hn::Vec<hn::Repartition<uint8_t, D16>>>
|
||||
static HWY_INLINE V8 Enc2U(D16 d16, const hn::Vec<D16> w0,
|
||||
const hn::Vec<D16> w1) {
|
||||
const hn::Repartition<uint8_t, D16> d8;
|
||||
|
||||
// Although more expensive on AVX3, in-order packing enables streaming
|
||||
// decompression without fixed-size packets.
|
||||
const V8 lo = hn::ConcatEven(d8, hn::BitCast(d8, w1), hn::BitCast(d8, w0));
|
||||
const V8 hi = hn::ConcatOdd(d8, hn::BitCast(d8, w1), hn::BitCast(d8, w0));
|
||||
return EncBytes(d8, lo, hi);
|
||||
}
|
||||
|
||||
template <class DBF, HWY_IF_BF16_D(DBF),
|
||||
class V8 = hn::Vec<hn::Repartition<uint8_t, DBF>>>
|
||||
static HWY_INLINE V8 Enc2B(DBF dbf, const hwy::bfloat16_t* HWY_RESTRICT in) {
|
||||
const hn::Repartition<uint16_t, DBF> d16;
|
||||
const size_t N16 = hn::Lanes(d16);
|
||||
using V16 = hn::Vec<decltype(d16)>;
|
||||
|
||||
const V16 w0 = hn::BitCast(d16, hn::LoadU(dbf, in));
|
||||
const V16 w1 = hn::BitCast(d16, hn::LoadU(dbf, in + N16));
|
||||
return Enc2U(d16, w0, w1);
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF),
|
||||
class V8 = hn::Vec<hn::Repartition<uint8_t, DF>>>
|
||||
static HWY_INLINE V8 Enc4F(DF df, const float* HWY_RESTRICT in) {
|
||||
const hn::Repartition<uint16_t, DF> d16;
|
||||
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
using V16 = hn::Vec<decltype(d16)>;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
||||
const VF f0 = hn::LoadU(df, in + NF * 0);
|
||||
const VF f1 = hn::LoadU(df, in + NF * 1);
|
||||
const VF f2 = hn::LoadU(df, in + NF * 2);
|
||||
const VF f3 = hn::LoadU(df, in + NF * 3);
|
||||
// Chop off the lower 16 bits; EncBytes still rounds properly.
|
||||
const V16 w0 = hn::BitCast(d16, hn::OrderedDemote2To(dbf, f0, f1));
|
||||
const V16 w1 = hn::BitCast(d16, hn::OrderedDemote2To(dbf, f2, f3));
|
||||
return Enc2U(d16, w0, w1);
|
||||
}
|
||||
|
||||
template <class D16, HWY_IF_U16_D(D16),
|
||||
class V8 = hn::Vec<hn::Repartition<uint8_t, D16>>>
|
||||
static HWY_INLINE void Dec2U(D16 d16, V8 packed, hn::Vec<D16>& w0,
|
||||
hn::Vec<D16>& w1) {
|
||||
const hn::Repartition<uint8_t, D16> d8;
|
||||
V8 lo, hi;
|
||||
DecBytes(d8, packed, lo, hi);
|
||||
w0 = hn::BitCast(d16, hn::InterleaveWholeLower(d8, lo, hi));
|
||||
w1 = hn::BitCast(d16, hn::InterleaveWholeUpper(d8, lo, hi));
|
||||
}
|
||||
|
||||
template <class DBF, HWY_IF_BF16_D(DBF),
|
||||
class V8 = hn::Vec<hn::Repartition<uint8_t, DBF>>>
|
||||
static HWY_INLINE void Dec2B(DBF dbf, V8 packed, hn::Vec<DBF>& bf0,
|
||||
hn::Vec<DBF>& bf1) {
|
||||
const hn::Repartition<uint16_t, DBF> d16;
|
||||
using V16 = hn::Vec<decltype(d16)>;
|
||||
V16 w0, w1;
|
||||
Dec2U(d16, packed, w0, w1);
|
||||
bf0 = hn::BitCast(dbf, w0);
|
||||
bf1 = hn::BitCast(dbf, w1);
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF),
|
||||
class V8 = hn::Vec<hn::Repartition<uint8_t, DF>>>
|
||||
static HWY_INLINE void Dec4F(DF df, V8 packed, hn::Vec<DF>& f0,
|
||||
hn::Vec<DF>& f1, hn::Vec<DF>& f2,
|
||||
hn::Vec<DF>& f3) {
|
||||
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
VBF bf0, bf1;
|
||||
Dec2B(dbf, packed, bf0, bf1);
|
||||
f0 = hn::PromoteLowerTo(df, bf0);
|
||||
f1 = hn::PromoteUpperTo(df, bf0);
|
||||
f2 = hn::PromoteLowerTo(df, bf1);
|
||||
f3 = hn::PromoteUpperTo(df, bf1);
|
||||
}
|
||||
}; // SfpCodec
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
// Copyright 2023 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_H_
|
||||
|
||||
// Switching Floating Point: a hybrid 8-bit float representation of bf16/f32
|
||||
// inputs that combines the advantages of e4m3 and e5m2 into a single format.
|
||||
// It supports seeking at a granularity of 1, decoding to bf16/f32, and a
|
||||
// fused decode/dot product with bf16/f32 vectors.
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Points to the *start* of an SFP stream. Values are stored in-order to enable
|
||||
// vector-length agnostic seeking, because streams may be written to disk for
|
||||
// loading on other CPUs.
|
||||
//
|
||||
// Characteristics:
|
||||
// - 24-bit dynamic range, with max exponent 2^0.
|
||||
// - 3 bit mantissa for values >= 2^-7, otherwise 2.
|
||||
//
|
||||
// This is faster to decode than a straightforward implementation of eXmY, in
|
||||
// part because SFP does not require subnormals. Unlike OCP MX, it also does not
|
||||
// require side information (shared exponents).
|
||||
//
|
||||
// Although the representation could probably be shrunk to 6-7 bits, more
|
||||
// savings can be had by non-uniform clustering - see nuq.h.
|
||||
#pragma pack(push, 1)
|
||||
struct SfpStream {
|
||||
uint8_t byte;
|
||||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
static inline const char* TypeName(SfpStream) { return "SFP"; }
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_H_
|
||||
|
|
@ -0,0 +1,440 @@
|
|||
// Copyright 2023 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/sfp.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <random>
|
||||
#include <set>
|
||||
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"third_party/gemma_cpp/compression/sfp_test.cc" // NOLINT
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
// Any highway.h must come after foreach_target.h
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/distortion.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/sfp-inl.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
// Decode
|
||||
float F32FromSFP8(uint32_t sfp) {
|
||||
HWY_ASSERT(sfp < 256);
|
||||
HWY_ASSERT(sfp != 0x80); // -0 is reserved
|
||||
|
||||
const uint32_t sign32 = (sfp & 0x80) << 24;
|
||||
sfp &= 0x7F;
|
||||
const bool large_e = sfp >= 64;
|
||||
const size_t m_bits = large_e ? 3 : 2;
|
||||
uint32_t m = sfp & ((1u << m_bits) - 1u);
|
||||
size_t e = sfp >> m_bits;
|
||||
if (sfp == 0) return 0.0f;
|
||||
const uint32_t e_bias = large_e ? 15 : 23;
|
||||
const uint32_t exp32 = static_cast<uint32_t>(127 + e - e_bias) << 23;
|
||||
const uint32_t mnt32 = m << (23 - m_bits);
|
||||
const uint32_t binary32 = sign32 | exp32 | mnt32;
|
||||
float result;
|
||||
hwy::CopySameSize(&binary32, &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
void TestAllUnique() {
|
||||
std::set<float> unique;
|
||||
for (uint32_t sfp = 0; sfp < 256; ++sfp) {
|
||||
if (sfp == 0x80) continue; // -0 is reserved
|
||||
unique.insert(F32FromSFP8(sfp));
|
||||
}
|
||||
HWY_ASSERT_EQ(size_t{255}, unique.size());
|
||||
if (false) {
|
||||
for (float f : unique) {
|
||||
fprintf(stderr, "%e\n", f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------ Foreach compressed representation
|
||||
|
||||
// Encode
|
||||
HWY_INLINE uint32_t SFP8FromF32(float f) {
|
||||
HWY_ASSERT(-1.875f <= f && f <= 1.875f);
|
||||
|
||||
constexpr uint32_t kMaskM = hwy::MantissaMask<float>();
|
||||
uint32_t binary32;
|
||||
hwy::CopySameSize(&f, &binary32);
|
||||
const uint32_t s = (binary32 & hwy::SignMask<float>()) >> 24;
|
||||
binary32 &= ~hwy::SignMask<float>();
|
||||
f = hwy::ScalarAbs(f);
|
||||
|
||||
// >= 1.1111 * 2^-8 rounds up to 1.0*2^-7.
|
||||
bool large_e = (f >= 0.007568359375f);
|
||||
|
||||
const uint32_t org_binary32 = binary32;
|
||||
const uint32_t m32 = binary32 & kMaskM;
|
||||
binary32 = (binary32 & ~kMaskM) | m32;
|
||||
size_t m_bits = large_e ? 3 : 2;
|
||||
const uint32_t is_odd = (m32 >> (23 - m_bits)) & 1;
|
||||
const uint32_t round = is_odd + (1u << (23 - m_bits - 1)) - 1;
|
||||
const uint32_t rounded = binary32 + round;
|
||||
|
||||
// >= 1.111 also rounds up, but only if it was considered !large_e before.
|
||||
if (f >= 0.00732421875f) {
|
||||
large_e = true;
|
||||
m_bits = 3;
|
||||
}
|
||||
|
||||
uint32_t m = (kMaskM & rounded) >> (23 - m_bits);
|
||||
int32_t e = (rounded >> 23) - 127;
|
||||
|
||||
if (e <= -23) {
|
||||
// 2^-23 is the smallest normal exponent. Zero has e = -127. Do not set the
|
||||
// SFP sign bit because the encoding for -0 is reserved.
|
||||
if (e < -23) return 0;
|
||||
// e = 2^-23: round up mantissa because m=0 encodes 0.0f.
|
||||
if (m == 0) m = 1;
|
||||
}
|
||||
|
||||
if (false) {
|
||||
fprintf(stderr, "in %x round %x rounded %x e %d m %x large_e %d\n",
|
||||
org_binary32, round, rounded, e, m, large_e);
|
||||
}
|
||||
uint32_t e_sfp = e + (large_e ? 15 : 23);
|
||||
HWY_ASSERT(e_sfp < 16);
|
||||
|
||||
const uint32_t encoded = (e_sfp << m_bits) | m | s;
|
||||
HWY_ASSERT(encoded < 256);
|
||||
return encoded;
|
||||
}
|
||||
|
||||
// For every possible encoding: ensure re-encoding the decoded value matches it.
|
||||
struct TestDecEnc {
|
||||
template <class T, class D>
|
||||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||
const hn::RepartitionToWide<D> d16;
|
||||
const hn::Rebind<hwy::bfloat16_t, decltype(d16)> dbf;
|
||||
const hn::Repartition<float, D> df;
|
||||
for (uint32_t encoded = 0; encoded < 256; ++encoded) {
|
||||
if (encoded == 0x80) continue; // -0 is reserved
|
||||
const float decoded = F32FromSFP8(encoded);
|
||||
const uint32_t encoded2 = SFP8FromF32(decoded);
|
||||
|
||||
hn::Vec<D> dec_lo, dec_hi;
|
||||
SfpCodec::DecBytes(d, hn::Set(d, encoded), dec_lo, dec_hi);
|
||||
const hn::Vec<decltype(dbf)> dec =
|
||||
hn::BitCast(dbf, hn::ZipLower(d16, dec_lo, dec_hi));
|
||||
const float vdecoded = hn::GetLane(hn::PromoteLowerTo(df, dec));
|
||||
const uint32_t vencoded2 =
|
||||
hn::GetLane(SfpCodec::EncBytes(d, dec_lo, dec_hi));
|
||||
|
||||
if (decoded != vdecoded || encoded2 != vencoded2 || encoded != encoded2) {
|
||||
HWY_ABORT("enc %u -> dec %E=%x=%E -> enc %u %u\n", encoded, decoded,
|
||||
hwy::BitCastScalar<uint32_t>(decoded), vdecoded, encoded2,
|
||||
vencoded2);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void TestAllDecEnc() { hn::ForGEVectors<32, TestDecEnc>()(uint8_t()); }
|
||||
|
||||
// ------------------------------ Golden (known values)
|
||||
|
||||
// Generate values, encode, decode back to that value.
|
||||
struct TestGolden {
|
||||
template <class T, class D>
|
||||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||
const hn::Repartition<float, D> df;
|
||||
const hn::Repartition<hwy::bfloat16_t, D> dbf;
|
||||
const hn::RebindToUnsigned<decltype(dbf)> d16;
|
||||
|
||||
struct Golden {
|
||||
float in;
|
||||
float out;
|
||||
};
|
||||
const Golden golden[] = {
|
||||
// All mantissa bits set, all discarded zero (no rounding)
|
||||
{0.46875f, 0.46875f},
|
||||
{0.9375f, 0.9375f},
|
||||
// All mantissa bits set, one below it set (round up to pow2)
|
||||
{0.484375f, 0.5f},
|
||||
{0.96875f, 1.0f},
|
||||
// Lowest mantissa bit set, all discarded zero (no rounding)
|
||||
{0.28125f, 0.28125f},
|
||||
{0.5625f, 0.5625f},
|
||||
// Lowest mantissa bit set, one below it set (round up to even)
|
||||
{0.296875f, 0.3125f},
|
||||
{0.59375f, 0.625f},
|
||||
// All mantissa zero, all discarded set (round up)
|
||||
{0.279296875f, 0.28125f},
|
||||
{0.55859375f, 0.5625f},
|
||||
// All mantissa zero, one below it set (round DOWN to pow2)
|
||||
{0.265625f, 0.25f},
|
||||
{0.53125f, 0.5f},
|
||||
|
||||
// At inflection point: 1.max*2^-8 rounds up to 1.0*2^-7
|
||||
{0.0068359375f, 0.0068359375f}, // 1.11 -> 1.11
|
||||
{0.00732421875f, 0.0078125f}, // 1.111 -> 1.11[1] -> 1.0
|
||||
{0.007568359375f, 0.0078125f}, // 1.1111 -> 1.0
|
||||
|
||||
// Above 1.0: no longer special-cased.
|
||||
{1.0f, 1.0f},
|
||||
{1.0625f, 1.0f}, // 1.000100
|
||||
|
||||
// Smallest normal exponents - we no longer use subnormals.
|
||||
{2.384185791015625E-7f, 2.384185791015625E-7f}, // 1.00p-22
|
||||
{1.49011611938E-07f, 1.49011611938E-07f}, // 1.01p-23
|
||||
{1.19209289551E-07f, 1.49011611938E-07f}, // 1.00p-23 -> 1.01p-23
|
||||
{5.96046447754E-08f, 0.0f}, // 1.00p-24 -> 0
|
||||
{8.94069671631E-08f, 0.0f}, // 1.10p-24 -> 0
|
||||
{1.11758708954E-07f, 1.49011611938E-07f}, // 1.111p-24-> 1.01p-23
|
||||
|
||||
// 1100_010 * 2^-7 rounds down to 110
|
||||
{0.013841f, 0.013671875f},
|
||||
};
|
||||
constexpr size_t kNumGolden = sizeof(golden) / sizeof(Golden);
|
||||
for (uint32_t s : {0, 1}) {
|
||||
for (size_t i = 0; i < kNumGolden; ++i) {
|
||||
const float in = s ? -golden[i].in : golden[i].in;
|
||||
const float out = s ? -golden[i].out : golden[i].out;
|
||||
const hn::Vec<decltype(dbf)> in_bf =
|
||||
hn::OrderedDemote2To(dbf, hn::Set(df, in), hn::Set(df, in));
|
||||
const uint32_t encoded = SFP8FromF32(in);
|
||||
const uint32_t vencoded = hn::GetLane(SfpCodec::EncBytes(
|
||||
d, hn::BitCast(d, in_bf),
|
||||
hn::BitCast(d, hn::ShiftRight<8>(hn::BitCast(d16, in_bf)))));
|
||||
const float decoded = F32FromSFP8(encoded);
|
||||
hn::Vec<D> dec_lo, dec_hi;
|
||||
SfpCodec::DecBytes(d, hn::Set(d, encoded), dec_lo, dec_hi);
|
||||
const hn::Vec<decltype(dbf)> dec =
|
||||
hn::BitCast(dbf, hn::ZipLower(d16, dec_lo, dec_hi));
|
||||
const float vdecoded = hn::GetLane(hn::PromoteLowerTo(df, dec));
|
||||
|
||||
if (decoded != vdecoded || decoded != out || encoded != vencoded) {
|
||||
HWY_ABORT("@%zu in %E dec %E %E golden %E\n", i, in, decoded,
|
||||
vdecoded, golden[i].out);
|
||||
}
|
||||
} // i
|
||||
} // s
|
||||
}
|
||||
};
|
||||
|
||||
void TestAllGolden() {
|
||||
// Full vectors only, other tests cover partial vectors.
|
||||
TestGolden()(uint8_t(), hn::ScalableTag<uint8_t>());
|
||||
}
|
||||
|
||||
// ------------------------------ Foreach bf16 input
|
||||
|
||||
// Generate all values, encode, decode back.
|
||||
struct TestEncDec {
|
||||
template <class T, class DBF>
|
||||
HWY_INLINE void operator()(T /*unused*/, DBF dbf) {
|
||||
const hn::Repartition<uint8_t, DBF> du8;
|
||||
|
||||
// We only use the upper 4 of 7 bf16 mantissa bits, so force the lower three
|
||||
// bits to zero to reduce the number of inputs.
|
||||
constexpr size_t kStep = 8;
|
||||
const size_t max = 0x8000 / 8;
|
||||
|
||||
auto in = hwy::AllocateAligned<T>(max);
|
||||
auto packed = hwy::AllocateAligned<SfpStream>(max);
|
||||
auto dec = hwy::AllocateAligned<T>(max);
|
||||
HWY_ASSERT(in && packed && dec);
|
||||
size_t num = 0;
|
||||
for (size_t i = 0; i < max; ++i) {
|
||||
const uint16_t bits = i * kStep;
|
||||
const float f = hwy::F32FromBF16(hwy::BitCastScalar<T>(bits));
|
||||
// Keep if within range
|
||||
if (hwy::ScalarIsFinite(f) && f <= 1.875f) {
|
||||
in[num] = hwy::BF16FromF32(f);
|
||||
in[num + 1] = hwy::BF16FromF32(-f);
|
||||
num += 2;
|
||||
}
|
||||
}
|
||||
|
||||
double enc_elapsed = hwy::HighestValue<double>();
|
||||
double dec_elapsed = hwy::HighestValue<double>();
|
||||
for (size_t rep = 0; rep < 100; ++rep) {
|
||||
const double t0 = hwy::platform::Now();
|
||||
SfpCodec::Enc(dbf, in.get(), num, packed.get());
|
||||
const double t1 = hwy::platform::Now();
|
||||
SfpCodec::Dec(dbf, packed.get(), num, dec.get());
|
||||
const double t2 = hwy::platform::Now();
|
||||
enc_elapsed = HWY_MIN(enc_elapsed, t1 - t0);
|
||||
dec_elapsed = HWY_MIN(dec_elapsed, t2 - t1);
|
||||
}
|
||||
const double enc_mbs = num * sizeof(T) * 1E-6 / enc_elapsed;
|
||||
const double dec_mbs = num * sizeof(T) * 1E-6 / dec_elapsed;
|
||||
fprintf(stderr, "Vec size %zu Enc %.2f MB/s Dec %.2f MB/s\n", Lanes(du8),
|
||||
enc_mbs, dec_mbs);
|
||||
|
||||
{
|
||||
double sum = 0.0;
|
||||
DistortionStats stats;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
const float out = hwy::F32FromBF16(dec[i]);
|
||||
sum += hwy::ConvertScalarTo<double>(hwy::ScalarAbs(in[i]));
|
||||
stats.Notify(in[i], out);
|
||||
}
|
||||
const double avg = sum / num;
|
||||
fprintf(stderr, "Avg magnitude %.3E, p-norm %.3E snr %.2f @%zu = %.4E\n",
|
||||
avg, stats.PNorm(), stats.GeomeanValueDivL1(), stats.MaxIndex(),
|
||||
stats.MaxL1());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void TestAllEncDec() { hn::ForGEVectors<32, TestEncDec>()(hwy::bfloat16_t()); }
|
||||
|
||||
// ------------------------------ Order
|
||||
|
||||
// Store 8-bit iota, decode, encode, check iota == packed. This ensures
|
||||
// Enc/Dec are preserving the order independent of vector length.
|
||||
struct TestOrder {
|
||||
template <class T, class DBF>
|
||||
HWY_INLINE void operator()(T /*unused*/, DBF dbf) {
|
||||
const hn::Repartition<uint8_t, DBF> du8;
|
||||
|
||||
const size_t num = 10 * hn::Lanes(du8) / 3;
|
||||
|
||||
auto iota = hwy::AllocateAligned<SfpStream>(num);
|
||||
auto packed = hwy::AllocateAligned<SfpStream>(num);
|
||||
auto bf = hwy::AllocateAligned<hwy::bfloat16_t>(num);
|
||||
HWY_ASSERT(iota && packed && bf);
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
// Clear sign bit so we can also check that bf is in ascending order.
|
||||
iota[i].byte = i & 127;
|
||||
}
|
||||
|
||||
SfpCodec::Dec(dbf, iota.get(), num, bf.get());
|
||||
SfpCodec::Enc(dbf, bf.get(), num, packed.get());
|
||||
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
if (iota[i].byte != packed[i].byte) {
|
||||
HWY_ABORT("@%zu: %d %d\n", i, iota[i].byte, packed[i].byte);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void TestAllOrder() { hn::ForGEVectors<32, TestOrder>()(hwy::bfloat16_t()); }
|
||||
|
||||
// ------------------------------ Dot
|
||||
|
||||
struct TestDot {
|
||||
template <typename T, class D>
|
||||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||
const hn::Repartition<float, D> df;
|
||||
const size_t num = 384;
|
||||
auto in = hwy::AllocateAligned<T>(num);
|
||||
auto dec = hwy::AllocateAligned<T>(num);
|
||||
auto vec = hwy::AllocateAligned<T>(num);
|
||||
auto sfp = hwy::AllocateAligned<SfpStream>(num);
|
||||
HWY_ASSERT(in && dec && vec && sfp);
|
||||
|
||||
std::mt19937 rng(123);
|
||||
std::normal_distribution<float> dist{0.001f, 0.3f};
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
in[i] = hwy::ConvertScalarTo<T>(dist(rng));
|
||||
vec[i] = hwy::ConvertScalarTo<T>(dist(rng));
|
||||
}
|
||||
// This changes the correlation between in and vec, which considerably
|
||||
// affects the error of the result.
|
||||
std::shuffle(in.get(), in.get() + num, rng);
|
||||
|
||||
SfpCodec::Enc(d, in.get(), num, sfp.get());
|
||||
|
||||
double actual = 0.0;
|
||||
double elapsed = hwy::HighestValue<double>();
|
||||
for (size_t rep = 0; rep < 200; ++rep) {
|
||||
hn::Vec<decltype(df)> sum0 = hn::Zero(df);
|
||||
hn::Vec<decltype(df)> sum1 = hn::Zero(df);
|
||||
hn::Vec<decltype(df)> sum2 = hn::Zero(df);
|
||||
hn::Vec<decltype(df)> sum3 = hn::Zero(df);
|
||||
const double t0 = hwy::platform::Now();
|
||||
SfpCodec::Dot(df, sfp.get(), num, vec.get(), sum0, sum1, sum2, sum3);
|
||||
const double t1 = hwy::platform::Now();
|
||||
elapsed = HWY_MIN(elapsed, t1 - t0);
|
||||
sum0 = hn::Add(hn::Add(sum0, sum1), hn::Add(sum2, sum3));
|
||||
actual = hn::ReduceSum(df, sum0);
|
||||
}
|
||||
|
||||
SfpCodec::Dec(d, sfp.get(), num, dec.get());
|
||||
fprintf(stderr, "Vec %zu Dot %.2f MB/s\n", Lanes(d) * sizeof(T),
|
||||
num * sizeof(T) * 1E-6 / elapsed);
|
||||
|
||||
double expected = 0.0; // using original input
|
||||
double expected2 = 0.0; // using decoded SFP
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
expected += hwy::ConvertScalarTo<double>(in[i]) *
|
||||
hwy::ConvertScalarTo<double>(vec[i]);
|
||||
expected2 += hwy::ConvertScalarTo<double>(dec[i]) *
|
||||
hwy::ConvertScalarTo<double>(vec[i]);
|
||||
}
|
||||
const double l1 = hwy::ScalarAbs(expected - actual);
|
||||
const double snr = 1.0 + hwy::ScalarAbs(expected) / l1;
|
||||
fprintf(stderr, "expected %.3f e2 %.4f actual %.4f l1 %E snr %.2f\n",
|
||||
expected, expected2, actual, l1, snr);
|
||||
HWY_ASSERT(hwy::ScalarAbs(expected2 - actual) < 1E-4);
|
||||
const double expected_l1 = sizeof(T) == 2 ? 1.52E-2 : 1.15E-2;
|
||||
const double expected_snr = sizeof(T) == 2 ? 80.1f : 104.9f;
|
||||
HWY_ASSERT(expected_l1 <= l1 && l1 < 1.02f * expected_l1);
|
||||
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr);
|
||||
}
|
||||
};
|
||||
|
||||
void TestAllDotF32() {
|
||||
const hn::ForGEVectors<128, TestDot> test;
|
||||
test(float());
|
||||
}
|
||||
void TestAllDotBF16() {
|
||||
const hn::ForGEVectors<128, TestDot> test;
|
||||
test(hwy::bfloat16_t());
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
|
||||
namespace gcpp {
|
||||
HWY_BEFORE_TEST(SfpTest);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllUnique);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDecEnc);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllGolden);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllEncDec);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllOrder);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDotF32);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDotBF16);
|
||||
} // namespace gcpp
|
||||
|
||||
#endif
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/stats.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::min
|
||||
#include <string>
|
||||
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
|
||||
void Stats::Assimilate(const Stats& other) {
|
||||
const int64_t total_n = n_ + other.n_;
|
||||
if (total_n == 0) return; // Nothing to do; prevents div by zero.
|
||||
|
||||
min_ = std::min(min_, other.min_);
|
||||
max_ = std::max(max_, other.max_);
|
||||
|
||||
product_ *= other.product_;
|
||||
|
||||
const double product_n = n_ * other.n_;
|
||||
const double n2 = n_ * n_;
|
||||
const double other_n2 = other.n_ * other.n_;
|
||||
const int64_t total_n2 = total_n * total_n;
|
||||
const double total_n3 = static_cast<double>(total_n2) * total_n;
|
||||
// Precompute reciprocal for speed - used at least twice.
|
||||
const double inv_total_n = 1.0 / total_n;
|
||||
const double inv_total_n2 = 1.0 / total_n2;
|
||||
|
||||
const double delta = other.m1_ - m1_;
|
||||
const double delta2 = delta * delta;
|
||||
const double delta3 = delta * delta2;
|
||||
const double delta4 = delta2 * delta2;
|
||||
|
||||
m1_ = (n_ * m1_ + other.n_ * other.m1_) * inv_total_n;
|
||||
|
||||
const double new_m2 = m2_ + other.m2_ + delta2 * product_n * inv_total_n;
|
||||
|
||||
const double new_m3 =
|
||||
m3_ + other.m3_ + delta3 * product_n * (n_ - other.n_) * inv_total_n2 +
|
||||
3.0 * delta * (n_ * other.m2_ - other.n_ * m2_) * inv_total_n;
|
||||
|
||||
m4_ += other.m4_ +
|
||||
delta4 * product_n * (n2 - product_n + other_n2) / total_n3 +
|
||||
6.0 * delta2 * (n2 * other.m2_ + other_n2 * m2_) * inv_total_n2 +
|
||||
4.0 * delta * (n_ * other.m3_ - other.n_ * m3_) * inv_total_n;
|
||||
|
||||
m2_ = new_m2;
|
||||
m3_ = new_m3;
|
||||
n_ = total_n;
|
||||
}
|
||||
|
||||
std::string Stats::ToString(int exclude) const {
|
||||
if (Count() == 0) return std::string("(none)");
|
||||
|
||||
char buf[300];
|
||||
int pos = 0;
|
||||
int ret; // snprintf - bytes written or negative for error.
|
||||
|
||||
if ((exclude & kNoCount) == 0) {
|
||||
ret = snprintf(buf + pos, sizeof(buf) - pos, "Count=%9zu ",
|
||||
static_cast<size_t>(Count()));
|
||||
HWY_ASSERT(ret > 0);
|
||||
pos += ret;
|
||||
}
|
||||
|
||||
if ((exclude & kNoMeanSD) == 0) {
|
||||
const float sd = StandardDeviation();
|
||||
if (sd > 100) {
|
||||
ret = snprintf(buf + pos, sizeof(buf) - pos, "Mean=%8.2E SD=%7.1E ",
|
||||
Mean(), sd);
|
||||
} else {
|
||||
ret = snprintf(buf + pos, sizeof(buf) - pos, "Mean=%8.6f SD=%7.5f ",
|
||||
Mean(), sd);
|
||||
}
|
||||
HWY_ASSERT(ret > 0);
|
||||
pos += ret;
|
||||
}
|
||||
|
||||
if ((exclude & kNoMinMax) == 0) {
|
||||
ret = snprintf(buf + pos, sizeof(buf) - pos, "Min=%8.5e Max=%8.5e ", Min(),
|
||||
Max());
|
||||
HWY_ASSERT(ret > 0);
|
||||
pos += ret;
|
||||
}
|
||||
|
||||
if ((exclude & kNoSkewKurt) == 0) {
|
||||
ret = snprintf(buf + pos, sizeof(buf) - pos, "Skew=%5.2f Kurt=%7.2f ",
|
||||
Skewness(), Kurtosis());
|
||||
HWY_ASSERT(ret > 0);
|
||||
pos += ret;
|
||||
}
|
||||
|
||||
if ((exclude & kNoGeomean) == 0) {
|
||||
ret = snprintf(buf + pos, sizeof(buf) - pos, "GeoMean=%9.6f ",
|
||||
GeometricMean());
|
||||
HWY_ASSERT(ret > 0);
|
||||
pos += ret;
|
||||
}
|
||||
|
||||
HWY_ASSERT(pos < sizeof(buf));
|
||||
return buf;
|
||||
}
|
||||
|
|
@ -0,0 +1,190 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_STATS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_STATS_H_
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
|
||||
// Thread-compatible.
|
||||
template <size_t N>
|
||||
class Bins {
|
||||
public:
|
||||
Bins() { Reset(); }
|
||||
|
||||
template <typename T>
|
||||
void Notify(T bin) {
|
||||
HWY_ASSERT(T{0} <= bin && bin < static_cast<T>(N));
|
||||
counts_[static_cast<int32_t>(bin)]++;
|
||||
}
|
||||
|
||||
void Assimilate(const Bins<N>& other) {
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
counts_[i] += other.counts_[i];
|
||||
}
|
||||
}
|
||||
|
||||
void Print(const char* caption) {
|
||||
fprintf(stderr, "\n%s [%zu]\n", caption, N);
|
||||
size_t last_nonzero = 0;
|
||||
for (size_t i = N - 1; i < N; --i) {
|
||||
if (counts_[i] != 0) {
|
||||
last_nonzero = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i <= last_nonzero; ++i) {
|
||||
fprintf(stderr, " %zu\n", counts_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
counts_[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
size_t counts_[N];
|
||||
};
|
||||
|
||||
// Descriptive statistics of a variable (4 moments). Thread-compatible.
|
||||
class Stats {
|
||||
public:
|
||||
Stats() { Reset(); }
|
||||
|
||||
void Notify(const float x) {
|
||||
++n_;
|
||||
|
||||
min_ = std::min(min_, x);
|
||||
max_ = std::max(max_, x);
|
||||
|
||||
product_ *= x;
|
||||
|
||||
// Online moments. Reference: https://goo.gl/9ha694
|
||||
const double d = x - m1_;
|
||||
const double d_div_n = d / n_;
|
||||
const double d2n1_div_n = d * (n_ - 1) * d_div_n;
|
||||
const int64_t n_poly = n_ * n_ - 3 * n_ + 3;
|
||||
m1_ += d_div_n;
|
||||
m4_ += d_div_n * (d_div_n * (d2n1_div_n * n_poly + 6.0 * m2_) - 4.0 * m3_);
|
||||
m3_ += d_div_n * (d2n1_div_n * (n_ - 2) - 3.0 * m2_);
|
||||
m2_ += d2n1_div_n;
|
||||
}
|
||||
|
||||
void Assimilate(const Stats& other);
|
||||
|
||||
int64_t Count() const { return n_; }
|
||||
|
||||
float Min() const { return min_; }
|
||||
float Max() const { return max_; }
|
||||
|
||||
double GeometricMean() const {
|
||||
return n_ == 0 ? 0.0 : pow(product_, 1.0 / n_);
|
||||
}
|
||||
|
||||
double Mean() const { return m1_; }
|
||||
// Same as Mu2. Assumes n_ is large.
|
||||
double SampleVariance() const {
|
||||
return n_ == 0 ? 0.0 : m2_ / static_cast<int>(n_);
|
||||
}
|
||||
// Unbiased estimator for population variance even for smaller n_.
|
||||
double Variance() const {
|
||||
if (n_ == 0) return 0.0;
|
||||
if (n_ == 1) return m2_;
|
||||
return m2_ / static_cast<int>(n_ - 1);
|
||||
}
|
||||
double StandardDeviation() const { return std::sqrt(Variance()); }
|
||||
// Near zero for normal distributions; if positive on a unimodal distribution,
|
||||
// the right tail is fatter. Assumes n_ is large.
|
||||
double SampleSkewness() const {
|
||||
if (std::abs(m2_) < 1E-7) return 0.0;
|
||||
return m3_ * std::sqrt(static_cast<double>(n_)) / std::pow(m2_, 1.5);
|
||||
}
|
||||
// Corrected for bias (same as Wikipedia and Minitab but not Excel).
|
||||
double Skewness() const {
|
||||
if (n_ == 0) return 0.0;
|
||||
const double biased = SampleSkewness();
|
||||
const double r = (n_ - 1.0) / n_;
|
||||
return biased * std::pow(r, 1.5);
|
||||
}
|
||||
// Near zero for normal distributions; smaller values indicate fewer/smaller
|
||||
// outliers and larger indicates more/larger outliers. Assumes n_ is large.
|
||||
double SampleKurtosis() const {
|
||||
if (std::abs(m2_) < 1E-7) return 0.0;
|
||||
return m4_ * n_ / (m2_ * m2_);
|
||||
}
|
||||
// Corrected for bias (same as Wikipedia and Minitab but not Excel).
|
||||
double Kurtosis() const {
|
||||
if (n_ == 0) return 0.0;
|
||||
const double biased = SampleKurtosis();
|
||||
const double r = (n_ - 1.0) / n_;
|
||||
return biased * r * r;
|
||||
}
|
||||
|
||||
// Central moments, useful for "method of moments"-based parameter estimation
|
||||
// of a mixture of two Gaussians. Assumes Count() != 0.
|
||||
double Mu1() const { return m1_; }
|
||||
double Mu2() const { return m2_ / static_cast<int>(n_); }
|
||||
double Mu3() const { return m3_ / static_cast<int>(n_); }
|
||||
double Mu4() const { return m4_ / static_cast<int>(n_); }
|
||||
|
||||
// Which statistics to EXCLUDE in ToString
|
||||
enum {
|
||||
kNoCount = 1,
|
||||
kNoMeanSD = 2,
|
||||
kNoMinMax = 4,
|
||||
kNoSkewKurt = 8,
|
||||
kNoGeomean = 16
|
||||
};
|
||||
std::string ToString(int exclude = 0) const;
|
||||
|
||||
void Reset() {
|
||||
n_ = 0;
|
||||
|
||||
min_ = hwy::HighestValue<float>();
|
||||
max_ = hwy::LowestValue<float>();
|
||||
|
||||
product_ = 1.0;
|
||||
|
||||
m1_ = 0.0;
|
||||
m2_ = 0.0;
|
||||
m3_ = 0.0;
|
||||
m4_ = 0.0;
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t n_; // signed for faster conversion + safe subtraction
|
||||
|
||||
float min_;
|
||||
float max_;
|
||||
|
||||
double product_; // for geomean
|
||||
|
||||
// Moments
|
||||
double m1_;
|
||||
double m2_;
|
||||
double m3_;
|
||||
double m4_;
|
||||
};
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_STATS_H_
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Model configurations
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
static constexpr size_t kSeqLen = 7168;
|
||||
|
||||
struct ConfigGemma7B {
|
||||
// NOLINTBEGIN(google3-readability-class-member-naming)
|
||||
static constexpr int seq_len = kSeqLen;
|
||||
static constexpr int vocab_size = 256128;
|
||||
static constexpr int n_layers = 28;
|
||||
static constexpr int dim_model = 3072;
|
||||
static constexpr int dim_ffw_hidden = 16 * 3072 / 2; // = 24576
|
||||
static constexpr int n_heads = 16;
|
||||
static constexpr int n_kv_heads = 16; // standard MHA, no GQA or MQA
|
||||
static constexpr int dim_qkv = 256; // query size == key size == value size
|
||||
static constexpr int top_k = 1;
|
||||
// NOLINTEND(google3-readability-class-member-naming)
|
||||
};
|
||||
|
||||
struct ConfigGemma2B {
|
||||
// NOLINTBEGIN(google3-readability-class-member-naming)
|
||||
static constexpr int seq_len = kSeqLen;
|
||||
static constexpr int vocab_size = 256128;
|
||||
static constexpr int n_layers = 18;
|
||||
static constexpr int dim_model = 2048;
|
||||
static constexpr int dim_ffw_hidden = 16 * 2048 / 2; // = 16384
|
||||
static constexpr int n_heads = 8;
|
||||
static constexpr int n_kv_heads = 8; // TODO(austinvhuang): add MQA support
|
||||
static constexpr int dim_qkv = 256; // query size == key size == value size
|
||||
static constexpr int top_k = 1;
|
||||
// NOLINTEND(google3-readability-class-member-naming)
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
# How to Contribute
|
||||
|
||||
We would love to accept your patches and contributions to this project.
|
||||
|
||||
## Before you begin
|
||||
|
||||
### Sign our Contributor License Agreement
|
||||
|
||||
Contributions to this project must be accompanied by a
|
||||
[Contributor License Agreement](https://cla.developers.google.com/about) (CLA).
|
||||
You (or your employer) retain the copyright to your contribution; this simply
|
||||
gives us permission to use and redistribute your contributions as part of the
|
||||
project.
|
||||
|
||||
If you or your current employer have already signed the Google CLA (even if it
|
||||
was for a different project), you probably don't need to do it again.
|
||||
|
||||
Visit <https://cla.developers.google.com/> to see your current agreements or to
|
||||
sign a new one.
|
||||
|
||||
### Review our Community Guidelines
|
||||
|
||||
This project follows [Google's Open Source Community
|
||||
Guidelines](https://opensource.google/conduct/).
|
||||
|
||||
## Contribution process
|
||||
|
||||
### Code Reviews
|
||||
|
||||
All submissions, including submissions by project members, require review. We
|
||||
use [GitHub pull requests](https://docs.github.com/articles/about-pull-requests)
|
||||
for this purpose.
|
||||
|
|
@ -0,0 +1,811 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Lightweight C++ implementation of the gemma model.
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "gemma.cc" // NOLINT
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
// Must come after foreach_target.h to avoid redefinition errors.
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/compress-inl.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "ops.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h" // Path
|
||||
#include "hwy/contrib/matvec/matvec-inl.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
|
||||
// compile pass, whereas we want this defined in the first.
|
||||
#ifndef GEMMA_ONCE
|
||||
#define GEMMA_ONCE
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <filesystem> // NOLINT
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/compress.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "configs.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "gemma.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
// copybara:import_next_line:sentencepiece
|
||||
#include "src/sentencepiece_processor.h"
|
||||
// #include "third_party/sentencepiece/src/util.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
template <class TConfig>
|
||||
struct Layer {
|
||||
Layer() = default;
|
||||
// NOLINTBEGIN(google3-readability-class-member-naming)
|
||||
static constexpr size_t n_heads = TConfig::n_heads;
|
||||
static constexpr size_t dim_model = TConfig::dim_model;
|
||||
static constexpr size_t dim_qkv = TConfig::dim_qkv;
|
||||
static constexpr size_t dim_ffw_hidden = TConfig::dim_ffw_hidden;
|
||||
static constexpr size_t size_attn_vec_einsum_w =
|
||||
n_heads * dim_qkv * dim_model;
|
||||
// 3x for (query, key, value)
|
||||
static constexpr size_t size_qkv_einsum_w = 3 * n_heads * dim_qkv * dim_model;
|
||||
// 2x for (gelu gating vector, gated vector)
|
||||
static constexpr size_t size_gating_einsum_w = 2 * dim_ffw_hidden * dim_model;
|
||||
static constexpr size_t size_linear_w = dim_model * dim_ffw_hidden;
|
||||
std::array<float, size_attn_vec_einsum_w> attn_vec_einsum_w;
|
||||
std::array<float, size_qkv_einsum_w> qkv_einsum_w;
|
||||
std::array<float, size_gating_einsum_w> gating_einsum_w;
|
||||
std::array<float, size_linear_w> linear_w;
|
||||
std::array<float, dim_model> pre_attention_norm_scale;
|
||||
std::array<float, dim_model> pre_ffw_norm_scale;
|
||||
// NOLINTEND(google3-readability-class-member-naming)
|
||||
};
|
||||
|
||||
template <class TConfig>
|
||||
struct Weights {
|
||||
Weights() = default;
|
||||
|
||||
hwy::AlignedUniquePtr<Layer<TConfig>[]> layers; // n_layers
|
||||
|
||||
std::array<float, TConfig::vocab_size * TConfig::dim_model>
|
||||
embedder_input_embedding;
|
||||
|
||||
std::array<float, TConfig::dim_model> final_norm_scale;
|
||||
};
|
||||
|
||||
// Only called if cached loading fails.
|
||||
template <typename TConfig>
|
||||
hwy::AlignedUniquePtr<Weights<TConfig>> LoadWeights(const Path& checkpoint) {
|
||||
PROFILER_ZONE("Startup.LoadWeights");
|
||||
using TWeights = Weights<TConfig>;
|
||||
hwy::AlignedUniquePtr<TWeights> weights = hwy::MakeUniqueAligned<TWeights>();
|
||||
weights->layers =
|
||||
hwy::MakeUniqueAlignedArray<Layer<TConfig>>(TConfig::n_layers);
|
||||
|
||||
FILE* fptr;
|
||||
fptr = fopen(checkpoint.path.c_str(), "rb");
|
||||
if (fptr == nullptr) {
|
||||
HWY_ABORT("Failed to open model file %s - does it exist?",
|
||||
checkpoint.path.c_str());
|
||||
}
|
||||
bool ok = true;
|
||||
ok &= 1 == fread(&(weights->embedder_input_embedding),
|
||||
sizeof(weights->embedder_input_embedding), 1, fptr);
|
||||
ok &= 1 == fread(&(weights->final_norm_scale),
|
||||
sizeof(weights->final_norm_scale), 1, fptr);
|
||||
for (size_t layer = 0; layer < TConfig::n_layers; ++layer) {
|
||||
Layer<TConfig>* layer_view = &weights->layers[layer];
|
||||
ok &= 1 == fread(&layer_view->attn_vec_einsum_w,
|
||||
sizeof(layer_view->attn_vec_einsum_w), 1, fptr);
|
||||
ok &= 1 == fread(&layer_view->qkv_einsum_w,
|
||||
sizeof(layer_view->qkv_einsum_w), 1, fptr);
|
||||
ok &= 1 == fread(&layer_view->gating_einsum_w,
|
||||
sizeof(layer_view->gating_einsum_w), 1, fptr);
|
||||
ok &= 1 ==
|
||||
fread(&layer_view->linear_w, sizeof(layer_view->linear_w), 1, fptr);
|
||||
ok &= 1 == fread(&layer_view->pre_attention_norm_scale,
|
||||
sizeof(layer_view->pre_attention_norm_scale), 1, fptr);
|
||||
ok &= 1 == fread(&layer_view->pre_ffw_norm_scale,
|
||||
sizeof(layer_view->pre_ffw_norm_scale), 1, fptr);
|
||||
}
|
||||
if (!ok) {
|
||||
HWY_ABORT("Failed to read from %s - might be a directory, or too small?",
|
||||
checkpoint.path.c_str());
|
||||
}
|
||||
HWY_ASSERT(0 == fclose(fptr));
|
||||
return weights;
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
struct CompressedLayer {
|
||||
// No ctor/dtor, allocated via AllocateAligned.
|
||||
|
||||
using TLayer = gcpp::Layer<TConfig>;
|
||||
|
||||
// # NOLINTBEGIN(google3-readability-class-member-naming)
|
||||
static constexpr size_t dim_model = TConfig::dim_model;
|
||||
static constexpr size_t dim_ffw_hidden = TConfig::dim_ffw_hidden;
|
||||
// NOLINTEND(google3-readability-class-member-naming)
|
||||
|
||||
// Compressed Parameters
|
||||
// We don't yet have an RMSNorm that accepts all WeightT.
|
||||
CompressedArray<hwy::bfloat16_t, dim_model> c_pre_attention_norm_scale;
|
||||
CompressedArray<hwy::bfloat16_t, dim_model> c_pre_ffw_norm_scale;
|
||||
CompressedArray<WeightT, TLayer::size_gating_einsum_w> c_gating_einsum_w;
|
||||
CompressedArray<WeightT, dim_model * dim_ffw_hidden> c_linear_w;
|
||||
CompressedArray<WeightT, TLayer::size_qkv_einsum_w> c_qkv_einsum_w;
|
||||
CompressedArray<WeightT, TLayer::size_attn_vec_einsum_w> c_attn_vec_einsum_w;
|
||||
};
|
||||
|
||||
// Array instead of single large allocation for parallel mem init. Split out of
|
||||
// CompressedWeights so that only these pointers are initialized, not the
|
||||
// CompressedArray.
|
||||
template <class TConfig>
|
||||
struct CompressedLayerPointers {
|
||||
explicit CompressedLayerPointers(hwy::ThreadPool& pool) {
|
||||
pool.Run(0, TConfig::n_layers, [this](uint64_t task, size_t /*thread*/) {
|
||||
this->c_layers[task] = hwy::AllocateAligned<CompressedLayer<TConfig>>(1);
|
||||
});
|
||||
}
|
||||
|
||||
using CLayer = CompressedLayer<TConfig>;
|
||||
std::array<hwy::AlignedFreeUniquePtr<CLayer[]>, TConfig::n_layers> c_layers;
|
||||
};
|
||||
|
||||
template <class TConfig>
|
||||
struct CompressedWeights {
|
||||
// No ctor/dtor, allocated via AllocateAligned.
|
||||
|
||||
CompressedArray<EmbedderInputT, TConfig::vocab_size * TConfig::dim_model>
|
||||
c_embedder_input_embedding;
|
||||
|
||||
CompressedArray<hwy::bfloat16_t, TConfig::dim_model> c_final_norm_scale;
|
||||
|
||||
// Must be last so that the other arrays remain aligned.
|
||||
CompressedLayerPointers<TConfig> c_layer_ptrs;
|
||||
|
||||
const CompressedLayer<TConfig>* CLayer(size_t layer) const {
|
||||
return c_layer_ptrs.c_layers[layer].get();
|
||||
}
|
||||
CompressedLayer<TConfig>* CLayer(size_t layer) {
|
||||
return c_layer_ptrs.c_layers[layer].get();
|
||||
}
|
||||
};
|
||||
|
||||
// Aligned.
|
||||
template <class TConfig, size_t BatchSize>
|
||||
struct Activations {
|
||||
// # NOLINTBEGIN(google3-readability-class-member-naming)
|
||||
static constexpr size_t batch_size = BatchSize;
|
||||
using LayerConfig = Layer<TConfig>;
|
||||
static constexpr size_t dim_model = TConfig::dim_model;
|
||||
static constexpr size_t dim_qkv = TConfig::dim_qkv;
|
||||
static constexpr size_t n_heads = TConfig::n_heads;
|
||||
static constexpr size_t n_kv_heads = TConfig::n_kv_heads;
|
||||
static constexpr size_t size_cache_pos =
|
||||
TConfig::n_layers * n_kv_heads * dim_qkv;
|
||||
static constexpr size_t size_cache_layer = n_kv_heads * dim_qkv;
|
||||
// NOLINTEND(google3-readability-class-member-naming)
|
||||
|
||||
std::array<float, batch_size * dim_model> x; // input
|
||||
std::array<float, batch_size * dim_model> pre_att_rms_out;
|
||||
std::array<float, batch_size * n_heads * dim_qkv> q; // query vector
|
||||
std::array<float, batch_size * n_heads * TConfig::seq_len>
|
||||
att; // attention vector
|
||||
std::array<float, batch_size * n_heads * dim_qkv>
|
||||
att_out; // attention output
|
||||
std::array<float, n_heads * batch_size * dim_model>
|
||||
att_post1; // attention output after linear transformation, per head
|
||||
std::array<float, batch_size * dim_model>
|
||||
att_post2; // accumulation of attention outputs over heads
|
||||
std::array<hwy::bfloat16_t, batch_size * dim_model> bf_pre_ffw_rms_out;
|
||||
std::array<float, batch_size * TConfig::dim_ffw_hidden * 2> ffw_hidden;
|
||||
// bf_ version can't be used until GeluMulToBF16 issue in FFW() is resolved.
|
||||
// std::array<hwy::bfloat16_t, batch_size * 2 * TConfig::dim_ffw_hidden>
|
||||
// bf_ffw_hidden;
|
||||
std::array<float, batch_size * dim_model> ffw_out;
|
||||
std::array<float, batch_size * TConfig::vocab_size> logits;
|
||||
};
|
||||
|
||||
// GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we
|
||||
// define an abstract base class.
|
||||
struct GemmaInterface {
|
||||
virtual ~GemmaInterface() = default;
|
||||
|
||||
virtual const sentencepiece::SentencePieceProcessor& Tokenizer() const = 0;
|
||||
|
||||
// TODO: group pool/callbacks into struct
|
||||
virtual void Generate(const InferenceArgs& args,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) = 0;
|
||||
};
|
||||
|
||||
template <class Config>
|
||||
struct GemmaImpl : public GemmaInterface {
|
||||
GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool);
|
||||
|
||||
~GemmaImpl() {
|
||||
using CWeights = CompressedWeights<Config>;
|
||||
CWeights* c_weights = reinterpret_cast<CWeights*>(compressed_weights.get());
|
||||
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
|
||||
}
|
||||
|
||||
const sentencepiece::SentencePieceProcessor& Tokenizer() const {
|
||||
return tokenizer;
|
||||
}
|
||||
|
||||
void Generate(const InferenceArgs& args, const std::vector<int>& prompt,
|
||||
size_t start_pos, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937&, int verbosity);
|
||||
|
||||
sentencepiece::SentencePieceProcessor tokenizer;
|
||||
|
||||
// CompressedWeights<Config>
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights;
|
||||
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
|
||||
hwy::AlignedUniquePtr<Activations<Config, 1>> state;
|
||||
KVCache kv_cache;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // GEMMA_ONCE
|
||||
|
||||
// SIMD code, compiled once per target.
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
template <class TConfig, size_t batch_size>
|
||||
HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||
Activations<TConfig, batch_size>& activations,
|
||||
const CompressedLayer<TConfig>* c_layer,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Gen.Attention");
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
HWY_DASSERT(batch_idx < batch_size);
|
||||
static constexpr size_t dim_qkv = gcpp::Activations<TConfig, 1>::dim_qkv;
|
||||
static constexpr size_t size_cache_pos =
|
||||
gcpp::Activations<TConfig, batch_size>::size_cache_pos;
|
||||
static constexpr size_t size_cache_layer =
|
||||
gcpp::Activations<TConfig, batch_size>::size_cache_layer;
|
||||
static constexpr size_t dim_model =
|
||||
gcpp::Activations<TConfig, batch_size>::dim_model;
|
||||
static constexpr size_t n_heads = TConfig::n_heads;
|
||||
const float kQueryScale = 1.0 / sqrtf(static_cast<float>(dim_qkv));
|
||||
|
||||
pool.Run(0, n_heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||
// linear projections to QKV
|
||||
const size_t head_offset =
|
||||
3 * dim_qkv * dim_model; // 3x for QKV dimensions
|
||||
const size_t q_offset = head * head_offset + 0 * dim_qkv * dim_model;
|
||||
const size_t k_offset = head * head_offset + 1 * dim_qkv * dim_model;
|
||||
const size_t v_offset = head * head_offset + 2 * dim_qkv * dim_model;
|
||||
|
||||
float* HWY_RESTRICT q =
|
||||
activations.q.data() + head * dim_qkv + batch_idx * n_heads * dim_qkv;
|
||||
|
||||
const size_t batch_offset = batch_idx * dim_model;
|
||||
|
||||
MatVecLoop<dim_qkv, dim_model>(
|
||||
c_layer->c_qkv_einsum_w, q_offset,
|
||||
activations.pre_att_rms_out.data() + batch_offset, q);
|
||||
|
||||
const size_t kv_offset =
|
||||
pos * size_cache_pos + layer * size_cache_layer + head * dim_qkv;
|
||||
|
||||
TwoOfsMatVecLoop<dim_qkv, dim_model>(
|
||||
c_layer->c_qkv_einsum_w, k_offset, v_offset,
|
||||
activations.pre_att_rms_out.data() + batch_offset,
|
||||
kv_cache.key_cache.get() + kv_offset,
|
||||
kv_cache.value_cache.get() + kv_offset);
|
||||
|
||||
// Calculate scores
|
||||
float* HWY_RESTRICT head_att = activations.att.data() +
|
||||
head * TConfig::seq_len +
|
||||
batch_idx * n_heads * dim_qkv;
|
||||
|
||||
Rope(q, dim_qkv, pos);
|
||||
Rope(kv_cache.key_cache.get() + kv_offset, dim_qkv, pos);
|
||||
MulByConst(kQueryScale, q, dim_qkv);
|
||||
// Compute Q dot K scores
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const size_t cache_offset =
|
||||
pos2 * size_cache_pos + layer * size_cache_layer + head * dim_qkv;
|
||||
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset;
|
||||
const float score = Dot(q, k2, dim_qkv);
|
||||
head_att[pos2] = score;
|
||||
}
|
||||
Softmax(head_att, pos + 1);
|
||||
|
||||
// Weighted summation
|
||||
float* HWY_RESTRICT att_out = activations.att_out.data() + head * dim_qkv +
|
||||
batch_idx * n_heads * dim_qkv;
|
||||
hwy::ZeroBytes(att_out, dim_qkv * sizeof(*att_out));
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const size_t cache_offset =
|
||||
pos2 * size_cache_pos + layer * size_cache_layer + head * dim_qkv;
|
||||
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset;
|
||||
MulByConstAndAdd(head_att[pos2], v2, att_out, dim_qkv);
|
||||
}
|
||||
// linear projection from dim_qkv back to dim_model, sum projections
|
||||
// across heads
|
||||
float* HWY_RESTRICT head_out =
|
||||
head == 0
|
||||
? activations.att_post2.data() + batch_idx * dim_model
|
||||
: activations.att_post1.data() + head * batch_size * dim_model;
|
||||
MatVecLoop<dim_model, dim_qkv>(c_layer->c_attn_vec_einsum_w,
|
||||
head * dim_model * dim_qkv, att_out,
|
||||
head_out);
|
||||
});
|
||||
|
||||
// accumulate output across all heads into att_post2. head 0 already wrote
|
||||
// directly to att_post2.
|
||||
for (size_t head = 1; head < n_heads; ++head) {
|
||||
AddFrom(activations.att_post1.data() + head * batch_size * dim_model,
|
||||
activations.att_post2.data() + batch_idx * dim_model, dim_model);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TConfig, size_t batch_size>
|
||||
HWY_NOINLINE void FFW(Activations<TConfig, batch_size>& activations,
|
||||
size_t batch_idx, const CompressedLayer<TConfig>* c_layer,
|
||||
hwy::ThreadPool& pool) {
|
||||
HWY_DASSERT(batch_idx < batch_size);
|
||||
static constexpr size_t dim_model = TConfig::dim_model;
|
||||
static constexpr size_t dim_ffw_hidden = TConfig::dim_ffw_hidden;
|
||||
const size_t hidden_offset = batch_idx * dim_ffw_hidden * 2;
|
||||
|
||||
{
|
||||
PROFILER_ZONE("Gen.FFW.GatedGELU");
|
||||
const hwy::bfloat16_t* HWY_RESTRICT vec =
|
||||
activations.bf_pre_ffw_rms_out.data() + batch_idx * dim_model;
|
||||
float* HWY_RESTRICT out = activations.ffw_hidden.data() + hidden_offset;
|
||||
float* HWY_RESTRICT out_mul = out + dim_ffw_hidden;
|
||||
|
||||
// Same matrix, first and second half of rows. Could fuse into one MatVec,
|
||||
// but separating them could help on NUMA e.g. multiple sockets.
|
||||
MatVec<dim_ffw_hidden, dim_model>(c_layer->c_gating_einsum_w,
|
||||
dim_ffw_hidden * dim_model, vec, out_mul,
|
||||
pool);
|
||||
|
||||
// Gate, will go through the nonlinearity.
|
||||
MatVec<dim_ffw_hidden, dim_model>(c_layer->c_gating_einsum_w, 0, vec, out,
|
||||
pool);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
using VF = hn::Vec<DF>;
|
||||
hn::Transform1(DF(), out, dim_ffw_hidden, out_mul,
|
||||
[](DF df, VF v, VF mul)
|
||||
HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); });
|
||||
}
|
||||
|
||||
PROFILER_ZONE("Gen.FFW\\GatedGELU");
|
||||
MatVec<dim_model, dim_ffw_hidden>(
|
||||
c_layer->c_linear_w, 0, activations.ffw_hidden.data() + hidden_offset,
|
||||
activations.ffw_out.data() + batch_idx * dim_model, pool);
|
||||
}
|
||||
|
||||
template <typename TConfig, size_t batch_size>
|
||||
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||
const CompressedWeights<TConfig>& c_weights,
|
||||
Activations<TConfig, batch_size>& activations,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool) {
|
||||
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
|
||||
static constexpr size_t dim_model = TConfig::dim_model;
|
||||
static const float kEmbScaling = sqrtf(static_cast<float>(dim_model));
|
||||
|
||||
pool.Run(
|
||||
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
||||
const int token = tokens[token_idx];
|
||||
Decompress(c_weights.c_embedder_input_embedding, token * dim_model,
|
||||
activations.x.data() + token_idx * dim_model, dim_model);
|
||||
MulByConst(kEmbScaling, activations.x.data() + token_idx * dim_model,
|
||||
dim_model);
|
||||
});
|
||||
|
||||
for (size_t layer = 0; layer < TConfig::n_layers; ++layer) {
|
||||
const CompressedLayer<TConfig>* c_layer = c_weights.CLayer(layer);
|
||||
|
||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
RMSNorm(activations.x.data() + token_idx * dim_model,
|
||||
c_layer->c_pre_attention_norm_scale.data(),
|
||||
activations.pre_att_rms_out.data() + token_idx * dim_model,
|
||||
dim_model);
|
||||
Attention<TConfig, batch_size>(pos, token_idx, layer, activations,
|
||||
c_layer, kv_cache, pool);
|
||||
}
|
||||
|
||||
// TODO: sink the loop into these functions, i.e. make them matmuls.
|
||||
pool.Run(
|
||||
0, num_tokens,
|
||||
[&](const uint64_t token_idx, size_t thread_id) HWY_ATTR {
|
||||
AddFrom(activations.att_post2.data() + token_idx * dim_model,
|
||||
activations.x.data() + token_idx * dim_model, dim_model);
|
||||
RMSNorm(activations.x.data() + token_idx * dim_model,
|
||||
c_layer->c_pre_ffw_norm_scale.data(),
|
||||
activations.bf_pre_ffw_rms_out.data() + token_idx * dim_model,
|
||||
dim_model);
|
||||
FFW<TConfig, batch_size>(activations, token_idx, c_layer, inner_pool);
|
||||
AddFrom(activations.ffw_out.data() + token_idx * dim_model,
|
||||
activations.x.data() + token_idx * dim_model, dim_model);
|
||||
});
|
||||
} // foreach layer
|
||||
|
||||
pool.Run(
|
||||
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
||||
RMSNormInplace(c_weights.c_final_norm_scale.data(),
|
||||
activations.x.data() + token_idx * dim_model, dim_model);
|
||||
});
|
||||
}
|
||||
|
||||
// n = 1 specialization
|
||||
template <class TConfig>
|
||||
void Transformer(int token, size_t pos,
|
||||
const CompressedWeights<TConfig>& c_weights,
|
||||
Activations<TConfig, 1>& activations, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool) {
|
||||
static constexpr size_t n_layers = TConfig::n_layers;
|
||||
static constexpr size_t dim_model = TConfig::dim_model;
|
||||
|
||||
static const float kEmbScaling = sqrtf(static_cast<float>(dim_model));
|
||||
|
||||
Decompress(c_weights.c_embedder_input_embedding, token * dim_model,
|
||||
activations.x.data(), dim_model);
|
||||
|
||||
MulByConst(kEmbScaling, activations.x.data(), dim_model);
|
||||
|
||||
for (size_t layer = 0; layer < n_layers; ++layer) {
|
||||
const CompressedLayer<TConfig>* c_layer = c_weights.CLayer(layer);
|
||||
RMSNorm(activations.x.data(), c_layer->c_pre_attention_norm_scale.data(),
|
||||
activations.pre_att_rms_out.data(), dim_model);
|
||||
Attention<TConfig, 1>(pos, 0, layer, activations, c_layer, kv_cache, pool);
|
||||
AddFrom(activations.att_post2.data(), activations.x.data(), dim_model);
|
||||
RMSNorm(activations.x.data(), c_layer->c_pre_ffw_norm_scale.data(),
|
||||
activations.bf_pre_ffw_rms_out.data(), dim_model);
|
||||
FFW<TConfig, 1>(activations, /* batch_idx = */ 0, c_layer, pool);
|
||||
AddFrom(activations.ffw_out.data(), activations.x.data(), dim_model);
|
||||
}
|
||||
RMSNormInplace(c_weights.c_final_norm_scale.data(), activations.x.data(),
|
||||
dim_model);
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
|
||||
const std::vector<int>& prompt, size_t pos,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) {
|
||||
static constexpr size_t dim_model = TConfig::dim_model;
|
||||
static constexpr size_t vocab_size = TConfig::vocab_size;
|
||||
static constexpr size_t top_k = TConfig::top_k;
|
||||
Activations<TConfig, 1>& activations = *gemma.state.get();
|
||||
Activations<TConfig, kPrefillBatchSize>& prefill_activations =
|
||||
*gemma.prefill.get();
|
||||
const CompressedWeights<TConfig>& c_weights =
|
||||
*reinterpret_cast<CompressedWeights<TConfig>*>(
|
||||
gemma.compressed_weights.get());
|
||||
KVCache& kv_cache = gemma.kv_cache;
|
||||
int token;
|
||||
|
||||
// pos indexes the KV cache. In the first turn of a chat, pos = 0.
|
||||
//
|
||||
// After the first turn, pos gets passed in with > 0 corresponding to the
|
||||
// current token position in the KV cache.
|
||||
//
|
||||
// pos_offset keeps track of the relative position within the turn, starting
|
||||
// at 0 each turn. During prefill, pos_offset corresponds to the index into
|
||||
// the prompt vector.
|
||||
//
|
||||
// In single-turn (non-chat) usage, pos and pos_offset start at 0 and are
|
||||
// always equal.
|
||||
size_t pos_offset = 0; // offset relative to pos
|
||||
double prefill_start = hwy::platform::Now();
|
||||
|
||||
// Prefill stops before prompt.size() - 1 since the last prompt token is the
|
||||
// first input token for generation.
|
||||
while (pos_offset < prompt.size() - 1) {
|
||||
const size_t end_offset =
|
||||
std::min(kPrefillBatchSize, prompt.size() - 1 - pos_offset);
|
||||
HWY_DASSERT(end_offset < prompt.size());
|
||||
const int* batch_tokens = prompt.data() + pos_offset;
|
||||
Prefill<TConfig, kPrefillBatchSize>(batch_tokens, end_offset, pos,
|
||||
c_weights, prefill_activations,
|
||||
kv_cache, pool, inner_pool);
|
||||
for (size_t idx = 0; idx < end_offset; ++idx) {
|
||||
stream_token(batch_tokens[idx], 0.0);
|
||||
}
|
||||
pos += end_offset;
|
||||
pos_offset += end_offset;
|
||||
}
|
||||
|
||||
if (verbosity >= 2) {
|
||||
// in the future this output should not occur in GenerateImpl but instead
|
||||
// should be available as observable state for frontend code to handle I/O.
|
||||
double prefill_end = hwy::platform::Now();
|
||||
const double prefill_tok_sec = pos_offset / (prefill_end - prefill_start);
|
||||
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]\n";
|
||||
}
|
||||
|
||||
double gen_start = hwy::platform::Now();
|
||||
|
||||
HWY_DASSERT(pos_offset == prompt.size() - 1);
|
||||
|
||||
if (verbosity >= 2) {
|
||||
// Provide usage warnings if max_new_tokens is out of range.
|
||||
if (args.max_generated_tokens > args.max_tokens) {
|
||||
std::cout << "Warning: max_new_tokens should be <= max_tokens"
|
||||
<< std::endl;
|
||||
} else if ((prompt.size() + args.max_generated_tokens) > args.max_tokens) {
|
||||
std::cout << "Warning: Prompt size + max_new_tokens exceeds max_tokens."
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
auto pos_gen_start = pos_offset;
|
||||
token = prompt.at(pos_offset);
|
||||
size_t generate_pos = 0;
|
||||
for (; pos < args.max_tokens && generate_pos < args.max_generated_tokens;
|
||||
++pos, ++pos_offset, ++generate_pos) {
|
||||
Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool);
|
||||
float* final_activation = activations.x.data();
|
||||
if (pos_offset >= prompt.size()) {
|
||||
PROFILER_ZONE("Gen.Embedding");
|
||||
// Generation phase
|
||||
MatVec<vocab_size, dim_model>(c_weights.c_embedder_input_embedding, 0,
|
||||
final_activation, activations.logits.data(),
|
||||
pool);
|
||||
// Barrier: must have all logits so we can subtract max.
|
||||
Softmax(activations.logits.data(), vocab_size);
|
||||
token = SampleTopK<top_k>(activations.logits.data(), vocab_size, gen,
|
||||
args.temperature, accept_token);
|
||||
}
|
||||
if (!stream_token(token, activations.logits[token])) {
|
||||
token = EOS_ID;
|
||||
}
|
||||
if (token == EOS_ID) {
|
||||
if (verbosity >= 2) {
|
||||
double gen_end = hwy::platform::Now();
|
||||
const double gen_tok_sec =
|
||||
(pos_offset - pos_gen_start) / (gen_end - gen_start);
|
||||
std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, const InferenceArgs& args,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
GenerateImpl(gemma, args, prompt, start_pos, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity);
|
||||
}
|
||||
|
||||
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, const InferenceArgs& args,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
GenerateImpl(gemma, args, prompt, start_pos, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity);
|
||||
}
|
||||
|
||||
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
|
||||
// if weights = null, which happens during the first call where we attempt to
|
||||
// load from cache.
|
||||
//
|
||||
// This avoids repeating the list of tensors between loading and compressing.
|
||||
template <class TConfig, class Func>
|
||||
void ForEachTensor(const Weights<TConfig>* weights,
|
||||
CompressedWeights<TConfig>& c_weights, Func& func) {
|
||||
func("c_embedding",
|
||||
weights ? weights->embedder_input_embedding.data() : nullptr,
|
||||
c_weights.c_embedder_input_embedding);
|
||||
func("c_final_norm", weights ? weights->final_norm_scale.data() : nullptr,
|
||||
c_weights.c_final_norm_scale);
|
||||
|
||||
char name[16];
|
||||
for (size_t layer_idx = 0; layer_idx < TConfig::n_layers; ++layer_idx) {
|
||||
Layer<TConfig>* layer = weights ? &weights->layers[layer_idx] : nullptr;
|
||||
CompressedLayer<TConfig>* c_layer = c_weights.CLayer(layer_idx);
|
||||
|
||||
snprintf(name, sizeof(name), "pre_ff_ns_%lu", layer_idx);
|
||||
func(name, layer ? layer->pre_ffw_norm_scale.data() : nullptr,
|
||||
c_layer->c_pre_ffw_norm_scale);
|
||||
|
||||
snprintf(name, sizeof(name), "gating_ein_%lu", layer_idx);
|
||||
func(name, layer ? layer->gating_einsum_w.data() : nullptr,
|
||||
c_layer->c_gating_einsum_w);
|
||||
|
||||
snprintf(name, sizeof(name), "linear_w_%lu", layer_idx);
|
||||
func(name, layer ? layer->linear_w.data() : nullptr, c_layer->c_linear_w);
|
||||
snprintf(name, sizeof(name), "qkv_ein_%lu", layer_idx);
|
||||
|
||||
func(name, layer ? layer->qkv_einsum_w.data() : nullptr,
|
||||
c_layer->c_qkv_einsum_w);
|
||||
snprintf(name, sizeof(name), "att_ein_%lu", layer_idx);
|
||||
|
||||
func(name, layer ? layer->attn_vec_einsum_w.data() : nullptr,
|
||||
c_layer->c_attn_vec_einsum_w);
|
||||
|
||||
snprintf(name, sizeof(name), "pre_att_ns_%lu", layer_idx);
|
||||
func(name, layer ? layer->pre_attention_norm_scale.data() : nullptr,
|
||||
c_layer->c_pre_attention_norm_scale);
|
||||
}
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
|
||||
const Path& model, const Path& cache, hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Startup.LoadCache");
|
||||
|
||||
if (!std::filesystem::exists(model.path) &&
|
||||
!std::filesystem::exists(cache.path)) {
|
||||
HWY_ABORT(
|
||||
"Either the model weights (--weights) or cached compressed weights "
|
||||
"(--compressed_weights) must exist.");
|
||||
}
|
||||
|
||||
// Allocate compressed weights.
|
||||
using CWeights = CompressedWeights<TConfig>;
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]> c_weights_u8 =
|
||||
hwy::AllocateAligned<uint8_t>(sizeof(CWeights));
|
||||
CWeights* c_weights = reinterpret_cast<CWeights*>(c_weights_u8.get());
|
||||
new (&c_weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
|
||||
|
||||
// First attempt to load them from cache, without requiring weights.
|
||||
CacheLoader loader(cache.path.c_str());
|
||||
ForEachTensor<TConfig>(nullptr, *c_weights, loader);
|
||||
if (loader.ReadAll(pool)) return c_weights_u8;
|
||||
|
||||
// Get weights, compress, and store in cache.
|
||||
hwy::AlignedUniquePtr<Weights<TConfig>> weights = LoadWeights<TConfig>(model);
|
||||
Compressor compressor(pool);
|
||||
ForEachTensor<TConfig>(weights.get(), *c_weights, compressor);
|
||||
compressor.WriteAll(pool, cache.path.c_str());
|
||||
|
||||
return c_weights_u8;
|
||||
}
|
||||
|
||||
// Type-erased because this function is called via a function pointer.
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeightsT(
|
||||
const LoaderArgs& args, hwy::ThreadPool& pool) {
|
||||
switch (args.ModelType()) {
|
||||
case Model::GEMMA_2B:
|
||||
return GetCompressedWeights<ConfigGemma2B>(args.model, args.cache, pool);
|
||||
case Model::GEMMA_7B:
|
||||
return GetCompressedWeights<ConfigGemma7B>(args.model, args.cache, pool);
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(args.ModelType()));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
|
||||
HWY_EXPORT(GetCompressedWeightsT);
|
||||
HWY_EXPORT(Generate2B);
|
||||
HWY_EXPORT(Generate7B);
|
||||
|
||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) {
|
||||
KVCache kv_cache = {};
|
||||
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||
kv_cache.value_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||
return kv_cache;
|
||||
}
|
||||
|
||||
template <class Config>
|
||||
GemmaImpl<Config>::GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool)
|
||||
: compressed_weights(
|
||||
HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)),
|
||||
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
|
||||
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()),
|
||||
kv_cache(
|
||||
CreateKVCache(Config::n_layers * Config::n_kv_heads * Config::dim_qkv,
|
||||
Config::seq_len)) {
|
||||
PROFILER_ZONE("Startup.tokenizer");
|
||||
|
||||
HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok());
|
||||
}
|
||||
|
||||
template <>
|
||||
void GemmaImpl<ConfigGemma2B>::Generate(const InferenceArgs& args,
|
||||
const std::vector<int>& prompt,
|
||||
size_t start_pos, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
HWY_DYNAMIC_DISPATCH(Generate2B)
|
||||
(*this, args, prompt, start_pos, pool, inner_pool, stream_token, accept_token,
|
||||
gen, verbosity);
|
||||
}
|
||||
template <>
|
||||
void GemmaImpl<ConfigGemma7B>::Generate(const InferenceArgs& args,
|
||||
const std::vector<int>& prompt,
|
||||
size_t start_pos, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
HWY_DYNAMIC_DISPATCH(Generate7B)
|
||||
(*this, args, prompt, start_pos, pool, inner_pool, stream_token, accept_token,
|
||||
gen, verbosity);
|
||||
}
|
||||
|
||||
Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
|
||||
const Model model_type = args.ModelType();
|
||||
model_training = args.ModelTraining();
|
||||
switch (model_type) {
|
||||
case Model::GEMMA_2B:
|
||||
impl_.reset(new GemmaImpl<ConfigGemma2B>(args, pool));
|
||||
break;
|
||||
case Model::GEMMA_7B:
|
||||
impl_.reset(new GemmaImpl<ConfigGemma7B>(args, pool));
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));
|
||||
}
|
||||
}
|
||||
Gemma::~Gemma() = default; // after GemmaInterface is defined
|
||||
|
||||
const sentencepiece::SentencePieceProcessor& Gemma::Tokenizer() const {
|
||||
return impl_->Tokenizer();
|
||||
}
|
||||
|
||||
void GenerateGemma(Gemma& gemma, const InferenceArgs& args,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) {
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
gemma.impl_->Generate(args, prompt, start_pos, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity);
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
@ -0,0 +1,207 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "configs.h" // kSeqLen
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/compress.h" // SfpStream/NuqStream
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h" // ArgsBase
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // hwy::bfloat16_t
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
// copybara:import_next_line:sentencepiece
|
||||
#include "src/sentencepiece_processor.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Allowable types for GEMMA_WEIGHT_T (can be specified at compilation time):
|
||||
// float, hwy::bfloat16_t, SfpStream, NuqStream
|
||||
#ifndef GEMMA_WEIGHT_T
|
||||
#define GEMMA_WEIGHT_T SfpStream
|
||||
#endif // !GEMMA_WEIGHT_T
|
||||
using WeightT = GEMMA_WEIGHT_T;
|
||||
|
||||
using EmbedderInputT = hwy::bfloat16_t;
|
||||
constexpr size_t kPrefillBatchSize = 16;
|
||||
constexpr bool kSystemPrompt = false;
|
||||
|
||||
struct KVCache {
|
||||
hwy::AlignedFreeUniquePtr<float[]>
|
||||
key_cache; // batch_size * seq_len * n_layers * n_kv_heads * dim_qkv
|
||||
hwy::AlignedFreeUniquePtr<float[]>
|
||||
value_cache; // batch_size * seq_len * n_layers * n_kv_heads * dim_qkv
|
||||
};
|
||||
|
||||
// Model variants: see configs.h for details.
|
||||
enum class Model { GEMMA_2B, GEMMA_7B };
|
||||
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
||||
|
||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
|
||||
static std::string ToLower(const std::string& text) {
|
||||
std::string result = text;
|
||||
std::transform(begin(result), end(result), begin(result),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
return result;
|
||||
}
|
||||
|
||||
gcpp::Model ModelType() const {
|
||||
const std::string model_type_lc = ToLower(model_type);
|
||||
if (model_type_lc == "2b-pt" || model_type_lc == "2b-it") {
|
||||
return gcpp::Model::GEMMA_2B;
|
||||
} else {
|
||||
return gcpp::Model::GEMMA_7B;
|
||||
}
|
||||
}
|
||||
|
||||
gcpp::ModelTraining ModelTraining() const {
|
||||
const std::string model_type_lc = ToLower(model_type);
|
||||
if (model_type_lc == "7b-pt" || model_type_lc == "2b-pt") {
|
||||
return gcpp::ModelTraining::GEMMA_PT;
|
||||
} else {
|
||||
return gcpp::ModelTraining::GEMMA_IT;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() const {
|
||||
const std::string model_type_lc = ToLower(model_type);
|
||||
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
|
||||
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
|
||||
return "Model type must be 2b-pt, 7b-pt, 2b-it, or "
|
||||
"7b-it.";
|
||||
}
|
||||
if (tokenizer.path.empty()) {
|
||||
return "Missing --tokenizer flag, a file for the tokenizer is required.";
|
||||
}
|
||||
if (model_type.empty()) {
|
||||
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
|
||||
"2b-it, or 7b-it.";
|
||||
}
|
||||
if (cache.path.empty()) {
|
||||
return "Missing --compressed_weights flag, a file for the compressed "
|
||||
"model.";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Path tokenizer;
|
||||
Path model; // uncompressed weights OR
|
||||
Path cache; // compressed weights
|
||||
std::string model_type;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(tokenizer, "tokenizer", Path(),
|
||||
"Path name of tokenizer model file. (required)");
|
||||
visitor(
|
||||
cache, "compressed_weights", Path(),
|
||||
"Path name of compressed weights file, regenerated from `--weights` "
|
||||
"file if "
|
||||
"the compressed weights file does not exist. (required)");
|
||||
visitor(model_type, "model", std::string(),
|
||||
"Model type - can be 2b-it (2B parameters, instruction-tuned), "
|
||||
"2b-pt (2B parameters, pretrained), 7b-it (7B parameters, "
|
||||
"instruction-tuned), or 7b-pt (7B parameters, pretrained). "
|
||||
"(required)");
|
||||
visitor(model, "weights", Path(),
|
||||
"Path name of model weights (.sbs) file. Only required if "
|
||||
"compressed_weights file is not present and needs to be "
|
||||
"regenerated. Otherwise, not needed");
|
||||
}
|
||||
};
|
||||
|
||||
struct GemmaInterface;
|
||||
|
||||
struct Gemma {
|
||||
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
|
||||
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
||||
|
||||
const sentencepiece::SentencePieceProcessor& Tokenizer() const;
|
||||
|
||||
std::unique_ptr<GemmaInterface> impl_;
|
||||
gcpp::ModelTraining model_training;
|
||||
};
|
||||
|
||||
// StreamFunc is called with (token, probability). For prompt tokens,
|
||||
// probability is 0.0f.
|
||||
using StreamFunc = std::function<bool(int, float)>;
|
||||
using AcceptFunc = std::function<bool(int)>;
|
||||
|
||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
|
||||
size_t max_tokens;
|
||||
size_t max_generated_tokens;
|
||||
|
||||
float temperature;
|
||||
bool deterministic;
|
||||
bool multiturn;
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() const {
|
||||
if (max_tokens > gcpp::kSeqLen) {
|
||||
return "max_tokens is larger than the maximum sequence length (see "
|
||||
"configs.h).";
|
||||
}
|
||||
if (max_generated_tokens > max_tokens) {
|
||||
return "Maximum number of generated tokens is larger than the maximum "
|
||||
"total tokens.";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(max_tokens, "max_tokens", size_t{3072},
|
||||
"Maximum number of tokens in prompt + generation.");
|
||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
||||
"Maximum number of tokens to generate.");
|
||||
|
||||
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
|
||||
visitor(deterministic, "deterministic", false,
|
||||
"Make top-k sampling deterministic", 2);
|
||||
visitor(multiturn, "multiturn", true,
|
||||
"Multiturn mode (if 0, this clears the KV cache after every "
|
||||
"interaction without quitting)",
|
||||
2);
|
||||
}
|
||||
};
|
||||
|
||||
void GenerateGemma(Gemma& gemma, const InferenceArgs& args,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& g,
|
||||
int verbosity);
|
||||
|
||||
constexpr int EOS_ID = 1;
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_
|
||||
|
|
@ -0,0 +1,682 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Include guard for non-SIMD code.
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_OPS_H_
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <random>
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/compress.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_OPS_H_
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||
#ifdef THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE
|
||||
#undef THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE
|
||||
#else
|
||||
#define THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE
|
||||
#endif
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/compress-inl.h"
|
||||
#include "hwy/cache_control.h" // FlushStream
|
||||
#include "hwy/contrib/algo/transform-inl.h"
|
||||
#include "hwy/contrib/dot/dot-inl.h"
|
||||
#include "hwy/contrib/math/math-inl.h"
|
||||
#include "hwy/contrib/matvec/matvec-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
HWY_INLINE constexpr size_t MaxCols() {
|
||||
// Vec + mat rows should fit into 32 KiB L1.
|
||||
return 2048;
|
||||
}
|
||||
|
||||
template <size_t kOuter>
|
||||
HWY_INLINE constexpr size_t RowsPerStrip() {
|
||||
// Aim for 128 work items to reduce pool overhead. Must be at least one
|
||||
// vector; prefer a power of two for faster division.
|
||||
constexpr size_t kRowsPerStrip =
|
||||
HWY_MAX(hn::ScalableTag<float>().MaxLanes(),
|
||||
1ULL << hwy::FloorLog2(kOuter / 128));
|
||||
return kRowsPerStrip;
|
||||
}
|
||||
|
||||
// Simple version without tiling nor threading.
|
||||
template <size_t kOuter, size_t kInner, typename MatT, size_t kCapacity,
|
||||
typename VecT>
|
||||
HWY_INLINE void MatVecLoop(const CompressedArray<MatT, kCapacity>& mat,
|
||||
const size_t mat_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
float* HWY_RESTRICT out) {
|
||||
PROFILER_ZONE("MatVecLoop");
|
||||
const hn::ScalableTag<float> df;
|
||||
|
||||
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) {
|
||||
const size_t row_ofs = mat_ofs + idx_row * kInner;
|
||||
out[idx_row] = Dot(df, mat, row_ofs, vec_aligned, kInner);
|
||||
}
|
||||
}
|
||||
|
||||
// Simple version without tiling nor threading, but two offsets/outputs.
|
||||
template <size_t kOuter, size_t kInner, typename MatT, size_t kCapacity,
|
||||
typename VecT>
|
||||
HWY_INLINE void TwoOfsMatVecLoop(const CompressedArray<MatT, kCapacity>& mat,
|
||||
const size_t mat_ofs0, const size_t mat_ofs1,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
float* HWY_RESTRICT out0,
|
||||
float* HWY_RESTRICT out1) {
|
||||
PROFILER_ZONE("MatVecLoop");
|
||||
const hn::ScalableTag<float> df;
|
||||
|
||||
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) {
|
||||
const size_t row_ofs0 = mat_ofs0 + (idx_row)*kInner;
|
||||
const size_t row_ofs1 = mat_ofs1 + (idx_row)*kInner;
|
||||
out0[idx_row] = Dot(df, mat, row_ofs0, vec_aligned, kInner);
|
||||
out1[idx_row] = Dot(df, mat, row_ofs1, vec_aligned, kInner);
|
||||
}
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product
|
||||
// of row i with `vec_aligned` and add into `out[i]`. The upper-left coordinate
|
||||
// of the tile is r0, c0.
|
||||
template <class DF, typename MatT, size_t kCapacity, typename VecT>
|
||||
HWY_INLINE void AccumulatePartialDotProducts(
|
||||
DF df, const CompressedArray<MatT, kCapacity>& mat, size_t mat_ofs,
|
||||
size_t mat_stride, size_t r0, size_t c0, size_t num_rows, size_t num_cols,
|
||||
const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) {
|
||||
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
|
||||
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
|
||||
out[idx_row] += Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
||||
}
|
||||
}
|
||||
|
||||
// Same as above, but sets out[i] to the first partial dot product, which
|
||||
// avoids having to zero-initialize and accumulate.
|
||||
template <class DF, typename MatT, size_t kCapacity, typename VecT>
|
||||
HWY_INLINE void SetFirstPartialDotProducts(
|
||||
DF df, const CompressedArray<MatT, kCapacity>& mat, size_t mat_ofs,
|
||||
size_t mat_stride, size_t r0, size_t c0, size_t num_rows, size_t num_cols,
|
||||
const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) {
|
||||
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
|
||||
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
|
||||
out[idx_row] = Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
||||
}
|
||||
}
|
||||
|
||||
// Adds together partial dot products for all tiles with the same r0 (a
|
||||
// horizontal strip of the entire matrix); the result is the full dot product
|
||||
// for rows r in [r0, r0 + num_rows), which we store into in out[r - r0].
|
||||
template <class DF, typename MatT, size_t kCapacity, typename VecT>
|
||||
HWY_INLINE void FullDotProductsForStrip(
|
||||
DF df, const CompressedArray<MatT, kCapacity>& mat, size_t mat_ofs,
|
||||
size_t mat_stride, size_t r0, size_t num_rows,
|
||||
const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) {
|
||||
// Tall and skinny: set `out` to the single dot product.
|
||||
if (mat_stride < MaxCols()) {
|
||||
SetFirstPartialDotProducts(df, mat, mat_ofs, mat_stride, r0, 0, num_rows,
|
||||
mat_stride, vec_aligned, out);
|
||||
return;
|
||||
}
|
||||
|
||||
// We have at least MaxCols, so start by setting `out` to that:
|
||||
SetFirstPartialDotProducts(df, mat, mat_ofs, mat_stride, r0, 0, num_rows,
|
||||
MaxCols(), vec_aligned, out);
|
||||
// For further multiples of MaxCols, accumulate. Remainders handled below.
|
||||
size_t c0 = MaxCols();
|
||||
HWY_UNROLL(1)
|
||||
for (; c0 <= mat_stride - MaxCols(); c0 += MaxCols()) {
|
||||
AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows,
|
||||
MaxCols(), vec_aligned, out);
|
||||
}
|
||||
|
||||
if (c0 < mat_stride) { // Final cols
|
||||
AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows,
|
||||
mat_stride - c0, vec_aligned, out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Stores dot products of rows with `vec_aligned` to a buffer, then stores them
|
||||
// to `out`.
|
||||
template <size_t kOuter, size_t kInner, typename MatT, size_t kCapacity,
|
||||
typename VecT>
|
||||
HWY_INLINE void MatVec(const CompressedArray<MatT, kCapacity>& mat,
|
||||
const size_t mat_ofs,
|
||||
const VecT* HWY_RESTRICT const vec_aligned,
|
||||
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("MatVec");
|
||||
|
||||
const hn::ScalableTag<float> df;
|
||||
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
|
||||
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
|
||||
|
||||
// For each entire strip.
|
||||
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
|
||||
PROFILER_ZONE("MatVec.lambda");
|
||||
const size_t r0 = strip * kRowsPerStrip;
|
||||
detail::FullDotProductsForStrip(df, mat, mat_ofs, kInner, r0, kRowsPerStrip,
|
||||
vec_aligned, out + r0);
|
||||
});
|
||||
|
||||
// Remaining rows
|
||||
const size_t r0 = kNumStrips * kRowsPerStrip;
|
||||
if (r0 < kOuter) {
|
||||
PROFILER_ZONE("MatVec remainder");
|
||||
const size_t num_rows = kOuter - r0;
|
||||
detail::FullDotProductsForStrip(df, mat, mat_ofs, kInner, r0, num_rows,
|
||||
vec_aligned, out + r0);
|
||||
}
|
||||
}
|
||||
|
||||
template <class D, HWY_IF_F32_D(D)>
|
||||
static HWY_INLINE hn::Vec<D> Gelu(D d, hn::Vec<D> v) {
|
||||
const hn::Vec<D> kMul = Set(d, 0.044715f);
|
||||
const hn::Vec<D> kSqrt2OverPi = hn::Set(d, 0.797884560804236f);
|
||||
const hn::Vec<D> kHalf = Set(d, 0.5f);
|
||||
|
||||
// tanh approximation matches training.
|
||||
const hn::Vec<D> v3 = hn::Mul(hn::Mul(v, v), v);
|
||||
const hn::Vec<D> arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v));
|
||||
// 0.5 * (1 + tan) = MulAdd(0.5, tan, 0.5).
|
||||
const hn::Vec<D> cdf = hn::MulAdd(kHalf, hn::Tanh(d, arg), kHalf);
|
||||
return Mul(v, cdf);
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x,
|
||||
size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
hn::Transform(D(), x, size, [](D d, hn::Vec<D> v) { return Gelu(d, v); });
|
||||
}
|
||||
|
||||
// out[i] = BF(mul[i] * Gelu(gelu_in[i]))
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void GeluMulToBF16(
|
||||
const float* HWY_RESTRICT gelu_in, const float* HWY_RESTRICT mul,
|
||||
hwy::bfloat16_t* HWY_RESTRICT out, size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
const hn::Repartition<hwy::bfloat16_t, decltype(df)> dbf;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
|
||||
size_t i = 0;
|
||||
if (size >= 2 * NF) {
|
||||
for (; i < size - 2 * NF; i += 2 * NF) {
|
||||
const VF mul0 = LoadU(df, mul + i);
|
||||
const VF mul1 = LoadU(df, mul + i + NF);
|
||||
const VF g0 = Mul(mul0, Gelu(df, LoadU(df, gelu_in + i)));
|
||||
const VF g1 = Mul(mul1, Gelu(df, LoadU(df, gelu_in + i + NF)));
|
||||
const hn::Vec<decltype(dbf)> bf = hn::OrderedDemote2To(dbf, g0, g1);
|
||||
StoreU(bf, dbf, out + i);
|
||||
}
|
||||
}
|
||||
if (i != size) {
|
||||
const size_t remaining = size - i;
|
||||
const VF mul0 = LoadN(df, mul + i, remaining);
|
||||
const VF g0 = Mul(mul0, Gelu(df, LoadN(df, gelu_in + i, remaining)));
|
||||
const hn::Half<decltype(dbf)> dbfh;
|
||||
const hn::Vec<decltype(dbfh)> bfh = hn::DemoteTo(dbfh, g0);
|
||||
StoreN(bfh, dbfh, out + i, remaining);
|
||||
}
|
||||
}
|
||||
|
||||
// Two matrices, same vector
|
||||
// TODO(janwas): apply optimizations from MatVec/replace with above overload
|
||||
template <size_t kOuter, size_t kInner, typename MatT, size_t kCapacity,
|
||||
typename VecT>
|
||||
HWY_NOINLINE void TwoMatVec(const CompressedArray<MatT, kCapacity>& mat0,
|
||||
const CompressedArray<MatT, kCapacity>& mat1,
|
||||
const size_t mat_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
float* HWY_RESTRICT out0, float* HWY_RESTRICT out1,
|
||||
hwy::ThreadPool& pool) {
|
||||
const hn::ScalableTag<float> df;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
||||
// Process multiple rows at a time so that we write multiples of a cache line
|
||||
// to avoid false sharing (>= 64).
|
||||
constexpr size_t kRowsPerStrip = 128 / sizeof(float);
|
||||
const uint32_t num_strips = kOuter / kRowsPerStrip;
|
||||
|
||||
// No remainder handling after ThreadPool.
|
||||
static_assert(kOuter % kRowsPerStrip == 0, "Add remainder handling");
|
||||
|
||||
// Required for Stream loop, otherwise we might have partial vectors.
|
||||
HWY_DASSERT(kRowsPerStrip >= NF);
|
||||
pool.Run(0, num_strips,
|
||||
[&](const uint32_t strip, size_t /*thread*/) HWY_ATTR {
|
||||
// MSVC workaround: duplicate to ensure constexpr.
|
||||
constexpr size_t kRowsPerStrip = 128 / sizeof(float);
|
||||
// Software write-combining to avoid cache pollution from out.
|
||||
// Although `out` may be used later, keeping it out of the cache
|
||||
// now and avoiding RFOs is a consistent 5% overall win.
|
||||
HWY_ALIGN float buf0[kRowsPerStrip];
|
||||
HWY_ALIGN float buf1[kRowsPerStrip];
|
||||
|
||||
// Only handle entire strips here because the Stream is not masked.
|
||||
const size_t begin = strip * kRowsPerStrip;
|
||||
for (size_t idx_row = 0; idx_row < kRowsPerStrip; ++idx_row) {
|
||||
const size_t row_ofs = mat_ofs + (begin + idx_row) * kInner;
|
||||
buf0[idx_row] = Dot(df, mat0, row_ofs, vec_aligned, kInner);
|
||||
buf1[idx_row] = Dot(df, mat1, row_ofs, vec_aligned, kInner);
|
||||
}
|
||||
|
||||
HWY_UNROLL(4)
|
||||
for (size_t i = 0; i != kRowsPerStrip; i += NF) {
|
||||
hn::Stream(hn::Load(df, buf0 + i), df, out0 + begin + i);
|
||||
}
|
||||
HWY_UNROLL(4)
|
||||
for (size_t i = 0; i != kRowsPerStrip; i += NF) {
|
||||
hn::Stream(hn::Load(df, buf1 + i), df, out1 + begin + i);
|
||||
}
|
||||
});
|
||||
hwy::FlushStream();
|
||||
}
|
||||
|
||||
// Baseline Naive MatMul
|
||||
template <size_t kOuter, size_t kInner, size_t kBatchSize, typename MatT,
|
||||
size_t kCapacity, typename VecT>
|
||||
HWY_NOINLINE void MatMul(const CompressedArray<MatT, kCapacity>& mat,
|
||||
const size_t mat_ofs, const VecT* HWY_RESTRICT vec,
|
||||
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
|
||||
for (size_t i = 0; i < kBatchSize; ++i) {
|
||||
MatVec<kOuter, kInner, MatT, kCapacity, VecT>(
|
||||
mat, mat_ofs, vec + i * kInner, out + i * kOuter, pool);
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a,
|
||||
const float* HWY_RESTRICT b,
|
||||
size_t size) {
|
||||
const hn::ScalableTag<float> d;
|
||||
HWY_DASSERT(size >= hn::Lanes(d));
|
||||
HWY_DASSERT(size % hn::Lanes(d) == 0);
|
||||
constexpr int kAssumptions =
|
||||
hn::Dot::kAtLeastOneVector | hn::Dot::kMultipleOfVector;
|
||||
return hn::Dot::Compute<kAssumptions>(d, a, b, size);
|
||||
}
|
||||
|
||||
// = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT.
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2(
|
||||
const float* HWY_RESTRICT a, size_t size) {
|
||||
float total = 0.f;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
total += a[i] * a[i];
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||
const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight,
|
||||
float* HWY_RESTRICT out, size_t size) {
|
||||
constexpr float eps = 1e-6f;
|
||||
float ss = SquaredL2(x, size);
|
||||
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps);
|
||||
for (size_t j = 0; j < size; j++) {
|
||||
// Note 1.0f centering here
|
||||
out[j] = (1.0f + weight[j]) * (ss * x[j]);
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
|
||||
float* HWY_RESTRICT out, size_t size) {
|
||||
constexpr float eps = 1e-6f;
|
||||
float ss = SquaredL2(x, size);
|
||||
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps);
|
||||
for (size_t j = 0; j < size; j++) {
|
||||
// Note 1.0f centering here
|
||||
out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]);
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
||||
const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) {
|
||||
constexpr float eps = 1e-6f;
|
||||
float ss = SquaredL2(inout, size);
|
||||
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps);
|
||||
for (size_t j = 0; j < size; j++) {
|
||||
// Note 1.0f centering here
|
||||
inout[j] = (1.0f + weight[j]) * (ss * inout[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// w=bf16 -> f
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
||||
const hwy::bfloat16_t* HWY_RESTRICT weight, float* HWY_RESTRICT inout,
|
||||
const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
||||
const hn::Repartition<float, decltype(dbf)> df32;
|
||||
using VF = hn::Vec<decltype(df32)>;
|
||||
const size_t N32 = hn::Lanes(df32);
|
||||
|
||||
constexpr float eps = 1e-6f;
|
||||
const float ss = SquaredL2(inout, size);
|
||||
const VF vss = Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
|
||||
|
||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
||||
const hn::Vec<decltype(dbf)> w16 = hn::LoadU(dbf, weight + i);
|
||||
const VF w0 = hn::PromoteLowerTo(df32, w16);
|
||||
const VF w1 = hn::PromoteUpperTo(df32, w16);
|
||||
const VF m0 = hn::Mul(vss, hn::LoadU(df32, inout + i));
|
||||
const VF m1 = hn::Mul(vss, hn::LoadU(df32, inout + i + N32));
|
||||
// (1+weight) * m = m + weight*m = one FMA.
|
||||
hn::StoreU(hn::MulAdd(m0, w0, m0), df32, inout + i);
|
||||
hn::StoreU(hn::MulAdd(m1, w1, m1), df32, inout + i + N32);
|
||||
}
|
||||
}
|
||||
|
||||
// f, f -> bf
|
||||
// TODO(janwas): consider generic function with adapter for loading bf16/f32
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||
const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight,
|
||||
hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
||||
const hn::Repartition<float, decltype(dbf)> df32;
|
||||
using VF = hn::Vec<decltype(df32)>;
|
||||
const size_t N32 = hn::Lanes(df32);
|
||||
|
||||
constexpr float eps = 1e-6f;
|
||||
const float ss = SquaredL2(x, size);
|
||||
const VF vss = Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
|
||||
|
||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
||||
const VF w0 = hn::LoadU(df32, weight + i);
|
||||
const VF w1 = hn::LoadU(df32, weight + i + N32);
|
||||
const VF m0 = hn::Mul(vss, hn::LoadU(df32, x + i));
|
||||
const VF m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32));
|
||||
// (1+weight) * m = m + weight*m = one FMA.
|
||||
const VF out0 = hn::MulAdd(m0, w0, m0);
|
||||
const VF out1 = hn::MulAdd(m1, w1, m1);
|
||||
hn::StoreU(hn::OrderedDemote2To(dbf, out0, out1), dbf, out + i);
|
||||
}
|
||||
}
|
||||
|
||||
// x=f, w=bf16 -> bf16 to enable W16A16 MatVec.
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
|
||||
hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
||||
const hn::Repartition<float, decltype(dbf)> df32;
|
||||
using VF = hn::Vec<decltype(df32)>;
|
||||
const size_t N32 = hn::Lanes(df32);
|
||||
|
||||
constexpr float eps = 1e-6f;
|
||||
const float ss = SquaredL2(x, size);
|
||||
const VF vss = Set(df32, 1.0f / sqrtf(ss / size + eps));
|
||||
|
||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
||||
const hn::Vec<decltype(dbf)> w16 = hn::LoadU(dbf, weight + i);
|
||||
const VF w0 = hn::PromoteLowerTo(df32, w16);
|
||||
const VF w1 = hn::PromoteUpperTo(df32, w16);
|
||||
const VF m0 = hn::Mul(vss, hn::LoadU(df32, x + i));
|
||||
const VF m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32));
|
||||
// (1+weight) * m = m + weight*m = one FMA.
|
||||
const VF out0 = hn::MulAdd(m0, w0, m0);
|
||||
const VF out1 = hn::MulAdd(m1, w1, m1);
|
||||
hn::StoreU(hn::OrderedDemote2To(dbf, out0, out1), dbf, out + i);
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
||||
float* HWY_RESTRICT x, size_t dim_model, size_t pos) {
|
||||
const size_t num_timescales = dim_model / 2;
|
||||
const float log_timescale_increment =
|
||||
logf(10000.0f) /
|
||||
(num_timescales != 0
|
||||
? static_cast<float>(static_cast<int>(num_timescales) - 1)
|
||||
: 1.0f);
|
||||
for (size_t dim = 0; dim < num_timescales; ++dim) {
|
||||
const float inv_timescale =
|
||||
expf(static_cast<int>(dim) * -log_timescale_increment);
|
||||
x[dim] += sinf(pos * inv_timescale);
|
||||
x[num_timescales + dim] += cosf(pos * inv_timescale);
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
|
||||
size_t dim_qkv, size_t pos) {
|
||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
const size_t half_dim_qkv = dim_qkv / 2;
|
||||
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
||||
const float freq_exponents = static_cast<float>(2 * static_cast<int>(dim)) /
|
||||
static_cast<float>(dim_qkv);
|
||||
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
|
||||
const float timescale = powf(10000.0f, freq_exponents);
|
||||
const float theta = pos / timescale;
|
||||
const float cos_val = cosf(theta);
|
||||
const float sin_val = sinf(theta);
|
||||
const float x0 = x[dim];
|
||||
const float x1 = x[dim + half_dim_qkv];
|
||||
x[dim] = x0 * cos_val - x1 * sin_val;
|
||||
x[dim + half_dim_qkv] = x0 * sin_val + x1 * cos_val;
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul,
|
||||
float* HWY_RESTRICT x,
|
||||
size_t dim_qkv,
|
||||
size_t pos) {
|
||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
const size_t half_dim_qkv = dim_qkv / 2;
|
||||
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
||||
const float freq_exponents = static_cast<float>(2 * static_cast<int>(dim)) /
|
||||
static_cast<float>(dim_qkv);
|
||||
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
|
||||
const float timescale = powf(10000.0f, freq_exponents);
|
||||
const float theta = pos / timescale;
|
||||
const float cos_val = cosf(theta);
|
||||
const float sin_val = sinf(theta);
|
||||
const float x0 = x[dim];
|
||||
const float x1 = x[dim + half_dim_qkv];
|
||||
x[dim] = mul * (x0 * cos_val - x1 * sin_val);
|
||||
x[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val);
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
|
||||
const float* HWY_RESTRICT other, float* HWY_RESTRICT x, size_t size) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
x[i] += other[i];
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other,
|
||||
float* HWY_RESTRICT x, size_t size,
|
||||
size_t max_pos) {
|
||||
HWY_DASSERT(max_pos <= size);
|
||||
for (size_t i = 0; i < max_pos; ++i) {
|
||||
x[i] *= other[i];
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_INLINE HWY_MAYBE_UNUSED void MulBy(const float* HWY_RESTRICT other,
|
||||
float* HWY_RESTRICT x,
|
||||
size_t size) {
|
||||
return MulBy(other, x, size, size);
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void MulByConst(float c, float* HWY_RESTRICT x, size_t size,
|
||||
size_t max_pos) {
|
||||
HWY_DASSERT(max_pos <= size);
|
||||
for (size_t i = 0; i < max_pos; ++i) {
|
||||
x[i] *= c;
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(float c,
|
||||
float* HWY_RESTRICT x,
|
||||
size_t size) {
|
||||
MulByConst(c, x, size, size);
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void MulByConstAndAdd(float c, const float* HWY_RESTRICT x,
|
||||
float* HWY_RESTRICT out, size_t size,
|
||||
size_t max_pos) {
|
||||
for (size_t i = 0; i < max_pos; ++i) {
|
||||
out[i] += x[i] * c;
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
||||
float c, const float* HWY_RESTRICT x, float* HWY_RESTRICT out,
|
||||
size_t size) {
|
||||
MulByConstAndAdd(c, x, out, size, size);
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, size_t size,
|
||||
size_t mask_pos) {
|
||||
HWY_DASSERT(size != 0);
|
||||
HWY_DASSERT(mask_pos <= size);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
const D d;
|
||||
using V = hn::Vec<D>;
|
||||
|
||||
// Find max so we can subtract it below.
|
||||
const V vmin = hn::Set(d, hwy::LowestValue<float>());
|
||||
V max = vmin;
|
||||
hn::Foreach(d, x, mask_pos, vmin,
|
||||
[&max](D d, V v) { max = hn::Max(max, v); });
|
||||
max = hn::MaxOfLanes(d, max); // broadcast
|
||||
|
||||
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
||||
V sum = hn::Zero(d);
|
||||
hn::Transform(d, x, mask_pos, [&sum, max](D d, V v) {
|
||||
const V out = hn::Exp(d, hn::Sub(v, max));
|
||||
sum = hn::Add(sum, out);
|
||||
return out;
|
||||
});
|
||||
|
||||
// Normalize to probability distribution
|
||||
const float mul = 1.0f / hn::ReduceSum(d, sum);
|
||||
MulByConst(mul, x, size, mask_pos);
|
||||
}
|
||||
|
||||
static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x,
|
||||
size_t size) {
|
||||
Softmax(x, size, size);
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
||||
size_t size, size_t max_pos) {
|
||||
HWY_DASSERT(max_pos <= size);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
const D d;
|
||||
using V = hn::Vec<D>;
|
||||
|
||||
const V inv_cap = hn::Set(d, 1.0f / cap);
|
||||
const V vcap = hn::Set(d, cap);
|
||||
|
||||
hn::Transform(d, x, size, [vcap, inv_cap](D d, hn::Vec<D> v) {
|
||||
return hn::Mul(vcap, hn::Tanh(d, hn::Mul(inv_cap, v)));
|
||||
});
|
||||
}
|
||||
|
||||
static HWY_INLINE HWY_MAYBE_UNUSED void LogitsSoftCap(const float cap,
|
||||
float* HWY_RESTRICT x,
|
||||
size_t size) {
|
||||
LogitsSoftCap(cap, x, size, size);
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED size_t
|
||||
SampleArgmax(const float* probabilities, size_t vocab_size) {
|
||||
size_t max_index = 0;
|
||||
float max_prob = probabilities[0];
|
||||
for (size_t i = 1; i < vocab_size; ++i) {
|
||||
if (probabilities[i] > max_prob) {
|
||||
max_index = i;
|
||||
max_prob = probabilities[i];
|
||||
}
|
||||
}
|
||||
return max_index;
|
||||
}
|
||||
|
||||
template <size_t k>
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution<int>
|
||||
create_distribution(std::array<float, k>& top_k, float temperature) {
|
||||
// re-normalize distribution
|
||||
for (size_t i = 0; i < k; ++i) {
|
||||
top_k[i] = exp(log(top_k[i]) / temperature);
|
||||
}
|
||||
float denominator = 0.0f;
|
||||
for (size_t i = 0; i < k; ++i) {
|
||||
denominator += top_k[i];
|
||||
}
|
||||
denominator = 1.0f / denominator;
|
||||
MulByConst(denominator, top_k.data(), k);
|
||||
return std::discrete_distribution<int>(std::begin(top_k), std::end(top_k));
|
||||
}
|
||||
|
||||
template <size_t k, typename TAcceptToken>
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
||||
const float* HWY_RESTRICT probabilities, size_t vocab_size,
|
||||
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
|
||||
static_assert(k != 0, "");
|
||||
// TODO(austinvhuang): Optimize this implementation.
|
||||
std::array<float, k> top_k{}; // sorted from highest [0], to lowest [k-1]
|
||||
std::array<int, k> indices{};
|
||||
for (size_t i = 0; i < vocab_size; ++i) {
|
||||
if (probabilities[i] < top_k[k - 1] && accept_token(static_cast<int>(i))) {
|
||||
continue;
|
||||
}
|
||||
for (size_t j = 0; j < k; ++j) {
|
||||
if (probabilities[i] > top_k[j] && accept_token(static_cast<int>(i))) {
|
||||
// shift elements by 1, insert the new value, move on to next value
|
||||
for (size_t idx = k - 1; idx > j; --idx) {
|
||||
top_k[idx] = top_k[idx - 1];
|
||||
indices[idx] = indices[idx - 1];
|
||||
}
|
||||
top_k[j] = probabilities[i];
|
||||
indices[j] = static_cast<int>(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return indices[create_distribution<k>(top_k, temperature)(gen)];
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#endif // NOLINT
|
||||
|
|
@ -0,0 +1,261 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Command line text interface to gemma.
|
||||
|
||||
#include <ctime>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <thread> // NOLINT
|
||||
#include <vector>
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/compress.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "gemma.h" // Gemma
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/app.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h" // HasHelp
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/per_target.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
|
||||
gcpp::AppArgs& app) {
|
||||
fprintf(stderr,
|
||||
"\ngemma.cpp\n---------\n\nTo run gemma.cpp, you need to "
|
||||
"specify 3 required model loading arguments: --tokenizer, "
|
||||
"--compressed_weights, "
|
||||
"and --model.\n\nModel Loading Arguments\n\n");
|
||||
loader.Help();
|
||||
fprintf(stderr, "\nInference Arguments\n\n");
|
||||
inference.Help();
|
||||
fprintf(stderr, "\nApplication Arguments\n\n");
|
||||
app.Help();
|
||||
fprintf(stderr, "\n\n");
|
||||
}
|
||||
|
||||
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||
loader.Print(app.verbosity);
|
||||
inference.Print(app.verbosity);
|
||||
app.Print(app.verbosity);
|
||||
|
||||
if (app.verbosity >= 2) {
|
||||
time_t now = time(nullptr);
|
||||
char* dt = ctime(&now); // NOLINT
|
||||
std::cout << "Date & Time : " << dt
|
||||
<< "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize
|
||||
<< "\n"
|
||||
<< "Hardware concurrency : "
|
||||
<< std::thread::hardware_concurrency() << std::endl
|
||||
<< "Instruction set : "
|
||||
<< hwy::TargetName(hwy::DispatchedTarget()) << " ("
|
||||
<< hwy::VectorBytes() * 8 << " bits)" << "\n"
|
||||
<< "Weight Type : "
|
||||
<< gcpp::TypeName(gcpp::WeightT()) << "\n"
|
||||
<< "EmbedderInput Type : "
|
||||
<< gcpp::TypeName(gcpp::EmbedderInputT()) << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const InferenceArgs& args,
|
||||
int verbosity, const gcpp::AcceptFunc& accept_token) {
|
||||
PROFILER_ZONE("Gen.misc");
|
||||
int abs_pos = 0; // absolute token index over all turns
|
||||
int current_pos = 0; // token index within the current turn
|
||||
int prompt_size{};
|
||||
|
||||
std::mt19937 gen;
|
||||
if (args.deterministic) {
|
||||
gen.seed(42);
|
||||
} else {
|
||||
std::random_device rd;
|
||||
gen.seed(rd());
|
||||
}
|
||||
|
||||
// callback function invoked for each generated token.
|
||||
auto stream_token = [&abs_pos, ¤t_pos, &args, &gen, &prompt_size,
|
||||
tokenizer = &model.Tokenizer(),
|
||||
verbosity](int token, float) {
|
||||
++abs_pos;
|
||||
++current_pos;
|
||||
if (current_pos < prompt_size) {
|
||||
std::cerr << "." << std::flush;
|
||||
} else if (token == gcpp::EOS_ID) {
|
||||
if (!args.multiturn) {
|
||||
abs_pos = 0;
|
||||
if (args.deterministic) {
|
||||
gen.seed(42);
|
||||
}
|
||||
}
|
||||
if (verbosity >= 2) {
|
||||
std::cout << "\n[ End ]" << std::endl;
|
||||
}
|
||||
} else {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text).ok());
|
||||
// +1 since position is incremented above
|
||||
if (current_pos == prompt_size + 1) {
|
||||
// first token of response
|
||||
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
|
||||
if (verbosity >= 1) {
|
||||
std::cout << std::endl << std::endl;
|
||||
}
|
||||
}
|
||||
// TODO(austinvhuang): is explicit space necessary?
|
||||
std::cout << token_text << std::flush;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
while (abs_pos < args.max_tokens) {
|
||||
std::string prompt_string;
|
||||
std::vector<int> prompt;
|
||||
current_pos = 0;
|
||||
{
|
||||
PROFILER_ZONE("Gen.input");
|
||||
if (verbosity >= 1) {
|
||||
std::cout << "> " << std::flush;
|
||||
}
|
||||
std::getline(std::cin, prompt_string);
|
||||
}
|
||||
|
||||
if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") {
|
||||
return;
|
||||
}
|
||||
|
||||
if (model.model_training == ModelTraining::GEMMA_IT) {
|
||||
// For instruction-tuned models: add control tokens.
|
||||
prompt_string = "<start_of_turn>user\n" + prompt_string +
|
||||
"<end_of_turn>\n<start_of_turn>model\n";
|
||||
if (abs_pos > 0) {
|
||||
// Prepend "<end_of_turn>" token if this is a multi-turn dialogue
|
||||
// continuation.
|
||||
prompt_string = "<end_of_turn>\n" + prompt_string;
|
||||
}
|
||||
}
|
||||
|
||||
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok());
|
||||
|
||||
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
||||
// if needed.
|
||||
if (abs_pos == 0) {
|
||||
prompt.insert(prompt.begin(), 2);
|
||||
}
|
||||
|
||||
prompt_size = prompt.size();
|
||||
|
||||
std::cerr << std::endl << "[ Reading prompt ] " << std::flush;
|
||||
|
||||
const double time_start = hwy::platform::Now();
|
||||
GenerateGemma(model, args, prompt, abs_pos, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity);
|
||||
const double time_end = hwy::platform::Now();
|
||||
const double tok_sec = current_pos / (time_end - time_start);
|
||||
if (verbosity >= 2) {
|
||||
std::cout << current_pos << " tokens (" << abs_pos << " total tokens)"
|
||||
<< std::endl
|
||||
<< tok_sec << " tokens / sec" << std::endl;
|
||||
}
|
||||
std::cout << std::endl << std::endl;
|
||||
}
|
||||
std::cout
|
||||
<< "max_tokens (" << args.max_tokens
|
||||
<< ") exceeded. Use a larger value if desired using the --max_tokens "
|
||||
<< "command line flag.\n";
|
||||
}
|
||||
|
||||
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||
PROFILER_ZONE("Run.misc");
|
||||
|
||||
hwy::ThreadPool inner_pool(0);
|
||||
hwy::ThreadPool pool(app.num_threads);
|
||||
// For many-core, pinning threads to cores helps.
|
||||
if (app.num_threads > 10) {
|
||||
PinThreadToCore(app.num_threads - 1); // Main thread
|
||||
|
||||
pool.Run(0, pool.NumThreads(),
|
||||
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
|
||||
}
|
||||
|
||||
gcpp::Gemma model(loader, pool);
|
||||
|
||||
if (const char* error = inference.Validate()) {
|
||||
ShowHelp(loader, inference, app);
|
||||
HWY_ABORT("\nInvalid args: %s", error);
|
||||
}
|
||||
|
||||
if (app.verbosity >= 1) {
|
||||
static const std::string banner_ascii_art =
|
||||
" __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __\n"
|
||||
" / _` |/ _ \\ '_ ` _ \\| '_ ` _ \\ / _` | / __| '_ \\| '_ \\\n"
|
||||
"| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |\n"
|
||||
" \\__, |\\___|_| |_| |_|_| |_| |_|\\__,_(_)___| .__/| .__/\n"
|
||||
" __/ | | | | |\n"
|
||||
" |___/ |_| |_|";
|
||||
|
||||
const std::string instructions =
|
||||
"*Usage*\n"
|
||||
" Enter an instruction and press enter (%Q quits).\n\n"
|
||||
"*Examples*\n"
|
||||
" - Write an email to grandma thanking her for the cookies.\n"
|
||||
" - What are some historical attractions to visit around "
|
||||
"Massachusetts?\n"
|
||||
" - Compute the nth fibonacci number in javascript.\n"
|
||||
" - Write a standup comedy bit about GPU programming.\n";
|
||||
|
||||
std::cout << "\033[2J\033[1;1H" // clear screen
|
||||
<< banner_ascii_art << "\n\n";
|
||||
ShowConfig(loader, inference, app);
|
||||
std::cout << "\n" << instructions << "\n";
|
||||
}
|
||||
|
||||
ReplGemma(model, pool, inner_pool, inference, app.verbosity,
|
||||
/*accept_token=*/[](int) { return true; });
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
{
|
||||
PROFILER_ZONE("Startup.misc");
|
||||
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
gcpp::InferenceArgs inference(argc, argv);
|
||||
gcpp::AppArgs app(argc, argv);
|
||||
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
ShowHelp(loader, inference, app);
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (const char* error = loader.Validate()) {
|
||||
ShowHelp(loader, inference, app);
|
||||
HWY_ABORT("\nInvalid args: %s", error);
|
||||
}
|
||||
|
||||
gcpp::Run(loader, inference, app);
|
||||
}
|
||||
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Shared between various frontends.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
||||
|
||||
#include <sched.h>
|
||||
#include <stddef.h>
|
||||
|
||||
#include <algorithm> // std::clamp
|
||||
#include <thread> // NOLINT>
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h"
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
static inline void PinThreadToCore(size_t cpu_index) {
|
||||
#if HWY_OS_LINUX
|
||||
// Forces the thread to run on the logical processor with the same number.
|
||||
cpu_set_t cset; // bit array
|
||||
CPU_ZERO(&cset); // clear all
|
||||
CPU_SET(cpu_index, &cset); // set bit indicating which processor to run on.
|
||||
HWY_ASSERT(0 == sched_setaffinity(0, sizeof(cset), &cset));
|
||||
#else
|
||||
(void)cpu_index;
|
||||
#endif
|
||||
}
|
||||
|
||||
class AppArgs : public ArgsBase<AppArgs> {
|
||||
static constexpr size_t kDefaultNumThreads = ~size_t{0};
|
||||
|
||||
void ChooseNumThreads() {
|
||||
if (num_threads == kDefaultNumThreads) {
|
||||
// This is a rough heuristic, replace with something better in the future.
|
||||
num_threads = static_cast<size_t>(std::clamp(
|
||||
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
AppArgs(int argc, char* argv[]) {
|
||||
InitAndParse(argc, argv);
|
||||
ChooseNumThreads();
|
||||
}
|
||||
|
||||
Path log; // output
|
||||
int verbosity;
|
||||
size_t num_threads;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(log, "log", Path{"/tmp/log.txt"}, "Logging file", 2);
|
||||
visitor(verbosity, "verbosity", 1,
|
||||
"Show verbose developer information\n 0 = only print generation "
|
||||
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
||||
"developer/debug info).\n Default = 1.",
|
||||
2);
|
||||
visitor(num_threads, "num_threads",
|
||||
kDefaultNumThreads, // see ChooseNumThreads
|
||||
"Number of threads to use. Default value is set based on an "
|
||||
"estimate of "
|
||||
"how many concurrent threads are supported.",
|
||||
2);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
||||
|
|
@ -0,0 +1,223 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Command line arguments.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ARGS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_UTIL_ARGS_H_
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::transform
|
||||
#include <string>
|
||||
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Wrapper for strings representing a path name. Differentiates vs. arbitrary
|
||||
// strings and supports shortening for display purposes.
|
||||
struct Path {
|
||||
Path& operator=(const char* other) {
|
||||
path = other;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::string Shortened() const {
|
||||
constexpr size_t max_len = 48;
|
||||
constexpr size_t cut_point = max_len / 2 - 5;
|
||||
if (path.size() > max_len) {
|
||||
return std::string(begin(path), begin(path) + cut_point) + " ... " +
|
||||
std::string(end(path) - cut_point, end(path));
|
||||
}
|
||||
if (path.empty()) return "[no path specified]";
|
||||
return path;
|
||||
}
|
||||
|
||||
std::string path;
|
||||
};
|
||||
|
||||
// Args is a class that provides a ForEach member function which visits each of
|
||||
// its member variables. ArgsBase provides functions called by Args to
|
||||
// initialize values to their defaults (passed as an argument to the visitor),
|
||||
// print and parse, without having to repeat the args for each usage.
|
||||
template <class Args>
|
||||
class ArgsBase {
|
||||
struct InitVisitor {
|
||||
template <typename T>
|
||||
void operator()(T& t, const char* /*name*/, const T& init,
|
||||
const char* /*help*/, int /*print_verbosity*/ = 0) const {
|
||||
t = init;
|
||||
}
|
||||
};
|
||||
|
||||
struct HelpVisitor {
|
||||
template <typename T>
|
||||
void operator()(T&, const char* name, T /*init*/, const char* help,
|
||||
int /*print_verbosity*/ = 0) const {
|
||||
fprintf(stderr, " --%s : %s\n", name, help);
|
||||
}
|
||||
};
|
||||
|
||||
class PrintVisitor {
|
||||
public:
|
||||
explicit PrintVisitor(int verbosity) : verbosity_(verbosity) {}
|
||||
|
||||
template <typename T>
|
||||
void operator()(const T& t, const char* name, const T& /*init*/,
|
||||
const char* /*help*/, int print_verbosity = 0) const {
|
||||
if (verbosity_ >= print_verbosity) {
|
||||
fprintf(stderr, "%-30s: %s\n", name, std::to_string(t).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
void operator()(const std::string& t, const char* name,
|
||||
const std::string& /*init*/, const char* /*help*/,
|
||||
int print_verbosity = 0) const {
|
||||
if (verbosity_ >= print_verbosity) {
|
||||
fprintf(stderr, "%-30s: %s\n", name, t.c_str());
|
||||
}
|
||||
}
|
||||
void operator()(const Path& t, const char* name, const Path& /*init*/,
|
||||
const char* /*help*/, int print_verbosity = 0) const {
|
||||
if (verbosity_ >= print_verbosity) {
|
||||
fprintf(stderr, "%-30s: %s\n", name, t.Shortened().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int verbosity_;
|
||||
};
|
||||
|
||||
// Supported types: integer, float, std::string, bool, Path. This is O(N^2):
|
||||
// for each arg, we search through argv. If there are more than a dozen args,
|
||||
// consider adding a hash-map to speed this up.
|
||||
class ParseVisitor {
|
||||
public:
|
||||
ParseVisitor(int argc, char* argv[]) : argc_(argc), argv_(argv) {}
|
||||
|
||||
template <typename T>
|
||||
void operator()(T& t, const char* name, const T& /*init*/,
|
||||
const char* /*help*/, int /*print_verbosity*/ = 0) const {
|
||||
const std::string prefixed = std::string("--") + name;
|
||||
for (int i = 1; i < argc_; ++i) {
|
||||
if (std::string(argv_[i]) == prefixed) {
|
||||
if (i + 1 >= argc_) {
|
||||
HWY_ABORT("Missing value for %s\n", name);
|
||||
}
|
||||
if (!SetValue(argv_[i + 1], t)) {
|
||||
HWY_ABORT("Invalid value for %s, got %s\n", name, argv_[i + 1]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Returns false if an invalid value is detected.
|
||||
template <typename T, HWY_IF_NOT_FLOAT(T)>
|
||||
static bool SetValue(const char* string, T& t) {
|
||||
t = std::stoi(string);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T, HWY_IF_FLOAT(T)>
|
||||
static bool SetValue(const char* string, T& t) {
|
||||
t = std::stof(string);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool SetValue(const char* string, std::string& t) {
|
||||
t = string;
|
||||
return true;
|
||||
}
|
||||
static bool SetValue(const char* string, Path& t) {
|
||||
t.path = string;
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool SetValue(const char* string, bool& t) {
|
||||
std::string value(string);
|
||||
// Lower-case. Arg names are expected to be ASCII-only.
|
||||
std::transform(value.begin(), value.end(), value.begin(), [](char c) {
|
||||
return 'A' <= c && c <= 'Z' ? c - ('Z' - 'z') : c;
|
||||
});
|
||||
|
||||
if (value == "true" || value == "on" || value == "1") {
|
||||
t = true;
|
||||
return true;
|
||||
} else if (value == "false" || value == "off" || value == "0") {
|
||||
t = false;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
int argc_;
|
||||
char** argv_;
|
||||
}; // ParseVisitor
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(Visitor& visitor) {
|
||||
static_cast<Args*>(this)->ForEach(visitor);
|
||||
}
|
||||
|
||||
public:
|
||||
// WARNING: cannot call from ctor because the derived ctor has not yet run.
|
||||
void Init() {
|
||||
InitVisitor visitor;
|
||||
ForEach(visitor);
|
||||
}
|
||||
|
||||
void Help() {
|
||||
HelpVisitor visitor;
|
||||
ForEach(visitor);
|
||||
}
|
||||
|
||||
void Print(int verbosity = 0) {
|
||||
PrintVisitor visitor(verbosity);
|
||||
ForEach(visitor);
|
||||
}
|
||||
|
||||
void Parse(int argc, char* argv[]) {
|
||||
ParseVisitor visitor(argc, argv);
|
||||
ForEach(visitor);
|
||||
}
|
||||
|
||||
// For convenience, enables single-line constructor.
|
||||
void InitAndParse(int argc, char* argv[]) {
|
||||
Init();
|
||||
Parse(argc, argv);
|
||||
}
|
||||
};
|
||||
|
||||
static bool HasHelp(int argc, char* argv[]) {
|
||||
// TODO(austinvhuang): handle case insensitivity
|
||||
if (argc == 1) {
|
||||
// no arguments - print help
|
||||
return true;
|
||||
}
|
||||
for (int i = 1; i < argc; ++i) {
|
||||
if (std::string(argv[i]) == "--help") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_ARGS_H_
|
||||
Loading…
Reference in New Issue