commit e29cd566cf3367671e8f59419a04e308796a7c57 Author: Austin Huang Date: Tue Feb 13 06:30:41 2024 +0000 initial commit diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..3858968 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,79 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.11) + +include(FetchContent) + +project(gemma) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f) +FetchContent_MakeAvailable(highway) + +## Note: absl meeds tp be installed by sentencepiece. This will only happen if +## cmake is invoked with -DSPM_ENABLE_SHARED=OFF and -DSPM_ABSL_PROVIDER=module +FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) +FetchContent_MakeAvailable(sentencepiece) + +set(SOURCES + gemma.cc + compression/blob_store.cc + compression/blob_store.h + compression/compress.h + compression/compress-inl.h + compression/nuq.h + compression/nuq-inl.h + compression/sfp.h + compression/sfp-inl.h + util/app.h + util/args.h + ) + +add_compile_options($<$:-O2>) +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release") +endif() + +# Allowable types for WEIGHT_TYPE: +# float - slow, not recommended +# hwy::bfloat16_t - bfloat16 as impemented by https://github.com/google/highway +# SfpStream - 8-bit switched floating point (recommended) +# NuqStream - experimental, work-in-progress +option(WEIGHT_TYPE "Set weight type" "") + +if (WEIGHT_TYPE) + add_definitions(-DGEMMA_WEIGHT_T=${WEIGHT_TYPE}) +endif() + +# Executable Target + +add_executable(gemma run.cc) +target_sources(gemma PRIVATE ${SOURCES}) +set_property(TARGET gemma PROPERTY CXX_STANDARD 17) +target_link_libraries(gemma hwy hwy_contrib sentencepiece) +target_include_directories(gemma PRIVATE ./) +FetchContent_GetProperties(sentencepiece) +target_include_directories(gemma PRIVATE ${sentencepiece_SOURCE_DIR}) + +## Library Target + +add_library(libgemma ${SOURCES}) +set_property(TARGET libgemma PROPERTY CXX_STANDARD 17) +set_target_properties(libgemma PROPERTIES PREFIX "") +target_include_directories(libgemma PUBLIC ./) +target_link_libraries(libgemma hwy hwy_contrib sentencepiece) +target_include_directories(libgemma PRIVATE ${sentencepiece_SOURCE_DIR}) diff --git a/DEVELOPERS.md b/DEVELOPERS.md new file mode 100644 index 0000000..d06b0f8 --- /dev/null +++ b/DEVELOPERS.md @@ -0,0 +1,72 @@ +# Developer Notes + +## Motivation: A Minimalist C++ LLM Runtime for Research and Experimentation + +In the past, neural network inference has been similar to a simple, opaque, +stateless function function with a single input and output. By contrast, +foundation model runtimes are better considered as systems with multiple forms +of state, subsystems, and heterogeneous inputs and outputs. They are often +integrated with a wide variety of other systems that have their own resources +(e.g. RAG and tools) and potentially interact with an external environment. They +have become compute engines to embed proximal tasks and goals within expansively +broad, general-purpose world models. + +With this in mind, we believe that developing an experimental runtime that is +flexible and approachable will allow us to explore the design space of co-design +between high level model concerns and low-level runtime computation. + +## Design Priorities + +Given these motivations, we propose the following priorities for +making decisions regarding the direction and design of the codebase. + +**Maximize Leverage with a Narrow Scope.** We focus on direct implementations of +foundation models like Gemma. This allows us to focus effort on bottlenecks of +specific models. We are willing to trade off generality to keep implementation +code relatively simple and readable at all layers of the stack, achieve good +performance, and maintain the velocity of a small team. + +**Data Oriented Design.** Follow data oriented design principles where possible +to minimize unnecessary performance pessimization. It's best to apply these +optimizations during the initial design, or when refactoring a subcomponent. The +first step is to think in terms of batches or tuples of plain old data (POD) +types: separate arrays, instead of an array of structs. The second is to +de-emphasize control flow (if statements, virtual functions and class +hierarchies). The third step is to know intrinsic properties of data and bake +that into the layout and algorithm. + +**Prioritize Small Batch Latency** Since production serving solutions are +available for large-scale serving powered by accelerators and optimizing for +throughput, this project focuses on the possibilities of local, interactive use +of foundation models. Although throughput remains important, low latency and +small batch sizes are prioritized, other things being equal. + +**Maintain a Portable Baseline** Our starting point is a portable CPU SIMD (via +[highway](https://github.com/google/highway)). We expect to add accelerator and +hybrid CPU/GPU support in the future, but the project should continue to allow +builds using this portable baseline. This ensures that research-oriented and +experimental runtimes and hardware platforms will have a minimum viable option +to run Gemma even if specialized production-ready deployment paths are not +available. + +## Code Organization + +The implementation code is roughly split into 4 layers, from high to low level: + +1. Frontends (`run.cc`) - Either interactive interfaces or automation + orchestration that interacts. Frontend code implements a use case objective + in terms of invocations to model inference and generation (2). Projects that + use gemma.cpp as a library are considered alternative frontends to `run.cc`. + We will add examples of additional frontends in the future. + +2. Models (`gemma.cc`, `gemma.h`, `configs.h`) - Implements the compute graph + of the model including supporting functions such as loading and compressing + weights using transformer operations provided by layer (3). + +3. Operations (`ops.h`) - A minimal set of transformer and supporting + mathematical operations implementations using compute backends (4). This + code should be agnostic to the specifics of the compute graph of the model + implementation (2). + +4. Backend (`highway`) - Low-level hardware interface (SIMD in the case of + highway) supporting the implementations in (3). diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..7a4a3ea --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/LICENSE-BSD3 b/LICENSE-BSD3 new file mode 100644 index 0000000..778ea4f --- /dev/null +++ b/LICENSE-BSD3 @@ -0,0 +1,26 @@ +Copyright (c) The gemma.cpp Project Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..de0315e --- /dev/null +++ b/README.md @@ -0,0 +1,335 @@ +# gemma.cpp + +gemma.cpp is a lightweight, standalone C++ inference engine for the Gemma +foundation models from Google. + +For additional information about Gemma, see +[ai.google.dev/gemma](https://ai.google.dev/gemma). Model weights, including gemma.cpp +specific artifacts, are [available on +kaggle](https://www.kaggle.com/models/google/gemma). + +## Who is this project for? + +Modern LLM inference engines are sophisticated systems, often with bespoke +capabilities extending beyond traditional neural network runtimes. With this +comes opportunities for research and innovation through co-design of high level +algorithms and low-level computation. However, there is a gap between +deployment-oriented C++ inference runtimes, which are not designed for +experimentation, and Python-centric ML research frameworks, which abstract away +low-level computation through compilation. + +gemma.cpp provides a minimalist implementation of Gemma 2B and 7B models, +focusing on simplicity and directness rather than full generality. This is +inspired by vertically-integrated model implementations such as +[ggml](https://github.com/ggerganov/ggml), +[llama.c](https://github.com/karpathy/llama2.c), and +[llama.rs](https://github.com/srush/llama2.rs). + +gemma.cpp targets experimentation and research use cases. It is intended to be +straightforward to embed in other projects with minimal dependencies and also +easily modifiable with a small ~2K LoC core implementation (along with ~4K LoC +of supporting utilities). We use the [Google +Highway](https://github.com/google/highway) Library to take advantage of +portable SIMD for CPU inference. + +For production-oriented edge deployments we recommend standard deployment +pathways using Python frameworks like JAX, Keras, PyTorch, and Transformers +([all model variations here](https://www.kaggle.com/models/google/gemma)). + +Community contributions large and small are welcome. This project follows +[Google's Open Source Community +Guidelines](https://opensource.google.com/conduct/). + +## Quick Start + +### System requirements + +Before starting, you should have installed: + +- [CMake](https://cmake.org/) +- [Clang C++ compiler](https://clang.llvm.org/get_started.html), supporting at + least C++17. +- `tar` for extracting archives from Kaggle. + +### Step 1: Obtain model weights and tokenizer from Kaggle + +Visit [the Gemma model page on +Kaggle](https://www.kaggle.com/models/google/gemma) and select `Model Variations +|> Gemma C++`. On this tab, the `Variation` dropdown includes the options below. +Note bfloat16 weights are higher fidelity, while 8-bit switched floating point +weights enable faster inference. + +2B instruction-tuned (`it`) and pre-trained (`pt`) models: + +| Model name | Description | +| ----------- | ----------- | +| `2b-it` | 2 billion parameter instruction-tuned model, bfloat16 | +| `2b-it-sfp` | 2 billion parameter instruction-tuned model, 8-bit switched floating point | +| `2b-pt` | 2 billion parameter pre-trained model, bfloat16 | +| `2b-pt-sfp` | 2 billion parameter pre-trained model, 8-bit switched floating point | + +7B instruction-tuned (`it`) and pre-trained (`pt`) models: + +| Model name | Description | +| ----------- | ----------- | +| `7b-it` | 7 billion parameter instruction-tuned model, bfloat16 | +| `7b-it-sfp` | 7 billion parameter instruction-tuned model, 8-bit switched floating point | +| `7b-pt` | 7 billion parameter pre-trained model, bfloat16 | +| `7b-pt-sfp` | 7 billion parameter pre-trained model, 8-bit switched floating point | + +> [!NOTE] +> We *recommend starting with `2b-it-sfp`* to get up and running. + +### Step 2: Extract Files + +After filling out the consent form, the download should proceed to retrieve a +tar archive file `archive.tar.gz`. Extract files from `archive.tar.gz` (this can +take a few minutes): + +``` +tar -xf archive.tar.gz +``` + +This should produce a file containing model weights such as `2b-it-sfp.sbs` and +a tokenizer file (`tokenizer.spm`). You may want to move these files to a +convenient directory location (e.g. the `build/` directory in this repo). + +### Step 3: Build + +The build system uses [CMake](https://cmake.org/). To build the gemma inference +runtime, create a build directory and generate the build files using `cmake` +from the top-level project directory: + +```sh +(cd build && cmake ..) +``` + +Then run `make` to build the `./gemma` executable: + +```sh +cd build +make -j [number of parallel threads to use] gemma +``` + +For example, `make -j 8 gemma`. If this is successful, you should now have a +`gemma` executable in the `build/` directory. + +> [!NOTE] +> On Windows Subsystem for Linux (WSL) users should set the number of +> parallel threads to 1. Using a larger number may result in errors. + +### Step 4: Run + +You can now run `gemma` from inside the `build/` directory. + +`gemma` has the following required arguments: + +| Argument | Description | Example value | +| -------- | ----------- | ------------- | +| `--model` | The model type. | `2b-it`, `2b-pt`, `7b-it`, `7b-pt`, ... (see above) | +| `--compressed_weights` | The compressed weights file. | `2b-it-sfp.sbs`, ... (see above) | +| `--tokenizer` | The tokenizer file. | `tokenizer.spm` | + + +`gemma` is invoked as: + +```sh +./gemma \ +--tokenizer [tokenizer file] \ +--compressed_weights [compressed weights file] \ +--model [2b-it or 2b-pt or 7b-it or 7b-pt or ...] +``` + +Example invocation for the following configuration: + +- Compressed weights file `2b-it-sfp.sbs` (2B instruction-tuned model, 8-bit + switched floating point). +- Tokenizer file `tokenizer.spm`. + +```sh +./gemma \ +--tokenizer tokenizer.spm \ +--compressed_weights 2b-it-sfp.sbs \ +--model 2b-it +``` + +## Usage + +`gemma` has different usage modes, controlled by the verbosity flag. + +All usage modes are currently interactive, triggering text generation upon +newline input. + +| Verbosity | Usage mode | Details | +| --------------- | ---------- | --------------------------------------------- | +| `--verbosity 0` | Minimal | Only prints generation output. Suitable as a CLI tool. | +| `--verbosity 1` | Default | Standard user-facing terminal UI. | +| `--verbosity 2` | Detailed | Shows additional developer and debug info. | + +### Interactive Terminal App + +By default, verbosity is set to 1, bringing up a terminal-based interactive +interface when `gemma` is invoked: + +```console +$ ./gemma [...] + __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __ + / _` |/ _ \ '_ ` _ \| '_ ` _ \ / _` | / __| '_ \| '_ \ +| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) | + \__, |\___|_| |_| |_|_| |_| |_|\__,_(_)___| .__/| .__/ + __/ | | | | | + |___/ |_| |_| + +tokenizer : tokenizer.spm +compressed_weights : 2b-it-sfp.sbs +model : 2b-it +weights : [no path specified] +max_tokens : 3072 +max_generated_tokens : 2048 + +*Usage* + Enter an instruction and press enter (%Q quits). + +*Examples* + - Write an email to grandma thanking her for the cookies. + - What are some historical attractions to visit around Massachusetts? + - Compute the nth fibonacci number in javascript. + - Write a standup comedy bit about WebGPU programming. + +> What are some outdoorsy places to visit around Boston? + +[ Reading prompt ] ..................... + + +**Boston Harbor and Islands:** + +* **Boston Harbor Islands National and State Park:** Explore pristine beaches, wildlife, and maritime history. +* **Charles River Esplanade:** Enjoy scenic views of the harbor and city skyline. +* **Boston Harbor Cruise Company:** Take a relaxing harbor cruise and admire the city from a different perspective. +* **Seaport Village:** Visit a charming waterfront area with shops, restaurants, and a seaport museum. + +**Forest and Nature:** + +* **Forest Park:** Hike through a scenic forest with diverse wildlife. +* **Quabbin Reservoir:** Enjoy boating, fishing, and hiking in a scenic setting. +* **Mount Forest:** Explore a mountain with breathtaking views of the city and surrounding landscape. + +... +``` + +### Usage as a Command Line Tool + +For using the `gemma` executable as a command line tool, it may be useful to +create an alias for gemma.cpp with arguments fully specified: + +```sh +alias gemma2b="~/gemma.cpp/build/gemma -- --tokenizer ~/gemma.cpp/build/tokenizer.spm --compressed_weights ~/gemma.cpp/build/2b-it-sfp.sbs --model 2b-it --verbosity 0" +``` + +Replace the above paths with your own paths to the model and tokenizer paths +from the download. + +Here is an example of prompting `gemma` with a truncated input +file (using a `gemma2b` alias like defined above): + +```sh +cat configs.h | tail -35 | tr '\n' ' ' | xargs -0 echo "What does this C++ code do: " | gemma2b +``` + +> [!NOTE] +> CLI usage of gemma.cpp is experimental and should take context length +> limitations into account. + +The output of the above command should look like: + +```console +$ cat configs.h | tail -35 | tr '\n' ' ' | xargs -0 echo "What does this C++ code do: " | gemma2b +[ Reading prompt ] ...................................................................................................................................................................................................................................................................................................................................................................................................................................................................................... +The code defines two C++ structs, `ConfigGemma7B` and `ConfigGemma2B`, which are used for configuring a deep learning model. + +**ConfigGemma7B**: + +* `seq_len`: Stores the length of the sequence to be processed. It's set to 7168. +* `vocab_size`: Stores the size of the vocabulary, which is 256128. +* `n_layers`: Number of layers in the deep learning model. It's set to 28. +* `dim_model`: Dimension of the model's internal representation. It's set to 3072. +* `dim_ffw_hidden`: Dimension of the feedforward and recurrent layers' hidden representations. It's set to 16 * 3072 / 2. + +**ConfigGemma2B**: + +* `seq_len`: Stores the length of the sequence to be processed. It's also set to 7168. +* `vocab_size`: Size of the vocabulary, which is 256128. +* `n_layers`: Number of layers in the deep learning model. It's set to 18. +* `dim_model`: Dimension of the model's internal representation. It's set to 2048. +* `dim_ffw_hidden`: Dimension of the feedforward and recurrent layers' hidden representations. It's set to 16 * 2048 / 2. + +These structs are used to configure a deep learning model with specific parameters for either Gemma7B or Gemma2B architecture. +``` + +### Incorporating gemma.cpp as a Library in your Project + +The easiest way to incorporate gemma.cpp in your own project is to pull in +gemma.cpp and dependencies using `FetchContent`. You can add the following to your +CMakeLists.txt: + +``` +include(FetchContent) + +FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) +FetchContent_MakeAvailable(sentencepiece) + +FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main) +FetchContent_MakeAvailable(gemma) + +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f) +FetchContent_MakeAvailable(highway) +``` + +Note for the gemma.cpp `GIT_TAG`, you may replace `origin/main` for a specific +commit hash if you would like to pin the library version. + +After your executable is defined (substitute your executable name for +`[Executable Name]` below): + +``` +target_link_libraries([Executable Name] libgemma hwy hwy_contrib sentencepiece) +FetchContent_GetProperties(gemma) +FetchContent_GetProperties(sentencepiece) +target_include_directories([Executable Name] PRIVATE ${gemma_SOURCE_DIR}) +target_include_directories([Executable Name] PRIVATE ${sentencepiece_SOURCE_DIR}) +``` + +### Building gemma.cpp as a Library + +gemma.cpp can also be used as a library dependency in your own project. The +shared library artifact can be built by modifying the make invocation to build +the `libgemma` target instead of `gemma`. + +> [!NOTE] +> If you are using gemma.cpp in your own project with the `FetchContent` steps +> in the previous section, building the library is done automatically by `cmake` +> and this section can be skipped. + +First, run `cmake`: + +```sh +(cd build && cmake ..) +``` + +Then, run `make` with the `libgemma` target: + +```sh +cd build +make -j [number of parallel threads to use] libgemma +``` + +If this is successful, you should now have a +`libgemma` library file in the `build/` directory. On linux the filename is `libgemma.a`. + +## Acknowledgements and Contacts + +gemma.cpp was started in fall 2023 by [Austin Huang](austinvhuang@google.com) +and [Jan Wassenberg](janwas@google.com), and subsequently released February 2024 +thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng. + +This is not an officially supported Google product. diff --git a/build/.gitignore b/build/.gitignore new file mode 100644 index 0000000..3822a0b --- /dev/null +++ b/build/.gitignore @@ -0,0 +1,3 @@ +* +!.gitignore +!.hgignore \ No newline at end of file diff --git a/compression/analyze.h b/compression/analyze.h new file mode 100644 index 0000000..d719aee --- /dev/null +++ b/compression/analyze.h @@ -0,0 +1,244 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Normal include guard to placate lint. +#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_ +#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_ + +#include +#include +#include +#include // memcpy + +#include // std::signbit +#include // std::abs +#include + +// copybara:import_next_line:gemma_cpp +#include "compression/distortion.h" +// copybara:import_next_line:gemma_cpp +#include "compression/nuq.h" +// copybara:import_next_line:gemma_cpp +#include "compression/stats.h" +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/timer.h" + +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_ + +// Actual per-target include guard. +#if defined(THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE) == defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE +#undef THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE +#else +#define THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE +#endif + +// copybara:import_next_line:gemma_cpp +#include "compression/nuq-inl.h" +// copybara:import_next_line:gemma_cpp +#include "compression/sfp-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +class PerThread { + public: + void NotifyGroup(const float* group) { + Stats s_group; + for (size_t i = 0; i < kGroupSize; ++i) { + // Skip zero so we can see the lowest actual magnitude + if (group[i] == 0.0f || group[i] == -0.0f) continue; + s_all_.Notify(group[i]); + s_group.Notify(group[i]); + + num_tiny_ += std::abs(group[i]) < 1e-3f; + + // b_magn100_.Notify(group[i] * 40.0f + 20.0f); + const uint32_t binary32 = + hwy::BitCastScalar(std::abs(group[i])); + + // const int32_t exp = (binary32 >> 23) - 127; + b_exp256_.Notify(binary32 >> 23); + const uint32_t m4 = (binary32 & 0x7FFFFF) >> (23 - 4); + b_m4_.Notify(m4); + } + s_group_ranges_.Notify(s_group.Max() - s_group.Min()); + s_group_mins_.Notify(s_group.Min()); + s_group_maxs_.Notify(s_group.Max()); + + float desc[kGroupSize]; + memcpy(desc, group, kGroupSize * sizeof(group[0])); + hn::VQSortStatic(desc, kGroupSize, hwy::SortDescending()); + + // Find largest |max/min| (dynamic range) + float max_ratio = 0.0f; + for (size_t i = 0; i < kGroupSize; ++i) { + if (desc[i] != 0.0f && desc[i] != -0.0f) { + max_ratio = std::max(max_ratio, std::abs(desc[0] / desc[i])); + } + } + s_group_max_vs_min_.Notify(max_ratio); + + // Relative errors + float diffs[kGroupSize]; + for (size_t i = 0; i < kGroupSize - 1; ++i) { + // was in descending order. Avoid div by 0. Ignore sign changes. + diffs[i] = std::abs(desc[i]) < 1e-5 + ? 0 + : std::abs((desc[i] - desc[i + 1]) / desc[i]); + } + hn::VQSortStatic(diffs, kGroupSize, hwy::SortDescending()); + s_cut15_.Notify(diffs[15]); + } + + void Assimilate(const PerThread& other) { + num_tiny_ += other.num_tiny_; + s_all_.Assimilate(other.s_all_); + s_group_ranges_.Assimilate(other.s_group_ranges_); + s_group_mins_.Assimilate(other.s_group_mins_); + s_group_maxs_.Assimilate(other.s_group_maxs_); + s_group_max_vs_min_.Assimilate(other.s_group_max_vs_min_); + s_erange_.Assimilate(other.s_erange_); + s_km_1_.Assimilate(other.s_km_1_); + s_km_2_.Assimilate(other.s_km_2_); + s_cut15_.Assimilate(other.s_cut15_); + b_magn100_.Assimilate(other.b_magn100_); + b_exp256_.Assimilate(other.b_exp256_); + b_m4_.Assimilate(other.b_m4_); + } + + void PrintAll() { + const int skip = Stats::kNoGeomean; + fprintf(stderr, "num tiny %zu\n", num_tiny_); + fprintf(stderr, "weights %s\n", s_all_.ToString(skip).c_str()); + fprintf(stderr, " ranges %s\n", s_group_ranges_.ToString(skip).c_str()); + fprintf(stderr, " mins %s\n", s_group_mins_.ToString(skip).c_str()); + fprintf(stderr, " maxs %s\n", s_group_maxs_.ToString(skip).c_str()); + fprintf(stderr, " Mvm %s\n", s_group_max_vs_min_.ToString(skip).c_str()); + fprintf(stderr, " cut15 %s\n", s_cut15_.ToString(skip).c_str()); + fprintf(stderr, " erange %s\n", s_erange_.ToString(skip).c_str()); + fprintf(stderr, " km1 %s\n", s_km_1_.ToString(skip).c_str()); + fprintf(stderr, " km2 %s\n", s_km_2_.ToString(skip).c_str()); + + // b_magn100_.Print("magn100"); + // b_exp256_.Print("exp"); + // b_m4_.Print("mantissa bits4"); + + fprintf(stderr, "\n"); + } + + private: + size_t num_tiny_ = 0; + Stats s_all_; + Stats s_group_ranges_; + Stats s_group_mins_; + Stats s_group_maxs_; + Stats s_group_max_vs_min_; + Stats s_erange_; + Stats s_km_1_; + Stats s_km_2_; + Stats s_cut15_; + Bins<100> b_magn100_; + Bins<256> b_exp256_; + Bins<16> b_m4_; + uint8_t padding_[64]; // prevent false sharing +}; + +class PerLayer { + public: + void NotifyGroup(const float* group) { + for (size_t i = 0; i < kGroupSize; ++i) { + s_layer_.Notify(group[i]); + } + } + + void UpdateOutliers(const float* layer, size_t weights_per_layer) { + const float layer_mean = s_layer_.Mean(); + const float layer_sd = s_layer_.StandardDeviation(); + for (size_t i = 0; i < weights_per_layer; ++i) { + num_outliers_ += + std::abs(std::abs(layer[i]) - layer_mean) >= 3.0f * layer_sd; + } + } + + const Stats& GetStats() const { return s_layer_; } + size_t Outliers() const { return num_outliers_; } + + private: + Stats s_layer_; + size_t num_outliers_ = 0; + uint8_t padding[64]; // prevent false sharing +}; + +static HWY_NOINLINE void Analyze(const char* caption, float* mat, size_t layers, + size_t weights_per_layer, + hwy::ThreadPool& pool) { + std::vector tls; + std::vector per_layer(layers); + const auto init = [&](size_t num_threads) { + tls.resize(num_threads); + return true; + }; + + pool.Run(0, static_cast(layers), init, + [&](uint32_t idx_layer, size_t idx_thread) { + PerThread& self = tls[idx_thread]; + const float* layer = &mat[idx_layer * weights_per_layer]; + // For each whole group in the layer + for (size_t group_start = 0; + group_start + kGroupSize <= weights_per_layer; + group_start += kGroupSize) { + const float* group = layer + group_start; + per_layer[idx_layer].NotifyGroup(group); + self.NotifyGroup(group); + } + + per_layer[idx_layer].UpdateOutliers(layer, weights_per_layer); + }); + + const int skip = Stats::kNoGeomean; + fprintf(stderr, "\n------------%s\n", caption); + + for (size_t i = 1; i < pool.NumThreads(); ++i) { + tls[0].Assimilate(tls[i]); + } + tls[0].PrintAll(); + + Stats s_layer_ranges; + Stats s_layer_outliers; + for (size_t i = 0; i < layers; ++i) { + fprintf(stderr, " %02zu %s\n", i, + per_layer[i].GetStats().ToString(skip).c_str()); + const float range = + per_layer[i].GetStats().Max() - per_layer[i].GetStats().Min(); + s_layer_ranges.Notify(range); + s_layer_outliers.Notify((100.0 * per_layer[i].Outliers()) / + weights_per_layer); + } + fprintf(stderr, "layer outliers%% %s\n", + s_layer_outliers.ToString(skip).c_str()); + fprintf(stderr, "layer ranges %s\n", s_layer_ranges.ToString(skip).c_str()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_ diff --git a/compression/blob_store.cc b/compression/blob_store.cc new file mode 100644 index 0000000..8d6c1d0 --- /dev/null +++ b/compression/blob_store.cc @@ -0,0 +1,348 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// copybara:import_next_line:gemma_cpp +#include "compression/blob_store.h" + +#include // open +#include +#include // SEEK_END - unistd isn't enough for IDE. +#include // O_RDONLY +#include // read, close + +#include +#include + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/detect_compiler_arch.h" + +namespace gcpp { + +hwy::uint128_t MakeKey(const char* string) { + size_t length = 0; + for (size_t i = 0; string[i] != '\0'; ++i) { + ++length; + } + if (length > 16) { + HWY_ABORT("Key %s is too long, please truncate to 16 chars.", string); + } + + hwy::uint128_t ret; + hwy::ZeroBytes(&ret); + hwy::CopyBytes(string, &ret, length); + return ret; +} + +static void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data, + std::vector& requests) { + // Split into chunks for load-balancing even if blob sizes vary. + constexpr size_t kChunkSize = 4 * 1024 * 1024; + + // Split into whole chunks and possibly one remainder. + uint64_t pos = 0; + if (size >= kChunkSize) { + for (; pos <= size - kChunkSize; pos += kChunkSize) { + requests.emplace_back(offset + pos, kChunkSize, data + pos, 0); + } + } + if (pos != size) { + requests.emplace_back(offset + pos, size - pos, data + pos, 0); + } +} + +struct IO { + // Returns size in bytes or 0. + static uint64_t FileSize(const char* filename) { + int fd = open(filename, O_RDONLY); + if (fd >= 0) { + const off_t size = lseek(fd, 0, SEEK_END); + HWY_ASSERT(close(fd) != -1); + if (size != static_cast(-1)) { + return static_cast(size); + } + } + + return 0; + } + + static bool Read(int fd, uint64_t offset, uint64_t size, void* to) { + uint8_t* bytes = reinterpret_cast(to); + uint64_t pos = 0; + for (;;) { + // pread seems to be faster than lseek + read when parallelized. + const auto bytes_read = pread(fd, bytes + pos, size - pos, offset + pos); + if (bytes_read <= 0) break; + pos += bytes_read; + HWY_ASSERT(pos <= size); + if (pos == size) break; + } + return pos == size; // success if managed to read desired size + } + + static bool Write(const void* from, uint64_t size, uint64_t offset, int fd) { + const uint8_t* bytes = reinterpret_cast(from); + uint64_t pos = 0; + for (;;) { + const auto bytes_written = + pwrite(fd, bytes + pos, size - pos, offset + pos); + if (bytes_written <= 0) break; + pos += bytes_written; + HWY_ASSERT(pos <= size); + if (pos == size) break; + } + return pos == size; // success if managed to write desired size + } +}; // IO + +static_assert(HWY_IS_LITTLE_ENDIAN, "Assumes little endian"); + +// On-disk representation (little-endian). +// +// Deliberately omits a version number because this file format is unchanging. +// Additional data may be added only inside new blobs. Changes to the blob +// contents or type should be handled by renaming keys. +#pragma pack(push, 1) +class BlobStore { + static constexpr uint32_t kMagic = 0x0A534253; // SBS\n + + // Blob offsets on disk and memory addresses are a multiple of this, because + // we pad the header and each blob's size. This matches CUDA alignment and the + // maximum SVE vector size, and exceeds typical x86 cache line sizes (64 or + // 128), which can help performance. + static constexpr size_t kAlign = 256; + + public: + // NOT including padding, so that we can also use ZeroFillPadding after + // copying the header. + static constexpr size_t HeaderSize(size_t num_blobs) { + // 16-byte fixed fields plus per-blob: 16-byte key, 16-byte offset/size. + return 16 + 32 * num_blobs; + } + + // Returns how many bytes to allocate for the header without the subsequent + // blobs. Requires num_blobs_ to already be set, typically by reading + // sizeof(BlobStore) bytes from disk. + size_t PaddedHeaderSize() const { + return hwy::RoundUpTo(HeaderSize(num_blobs_), kAlign); + } + + // Returns aligned offset and zero-fills between that and `offset`. + uint64_t ZeroFillPadding(uint64_t offset) { + uint8_t* const bytes = reinterpret_cast(this); + const uint64_t padded = hwy::RoundUpTo(offset, kAlign); + hwy::ZeroBytes(bytes + offset, padded - offset); + return padded; + } + + BlobError CheckValidity(const uint64_t file_size) { + if (magic_ != kMagic) return __LINE__; + if (num_blobs_ == 0) return __LINE__; + if (file_size_ != file_size) return __LINE__; + + // Ensure blobs are back to back, and zero-pad. + uint64_t offset = ZeroFillPadding(HeaderSize(num_blobs_)); + for (size_t i = 0; i < num_blobs_; ++i) { + const hwy::uint128_t val = keys_[num_blobs_ + i]; + if (val.lo != offset) return __LINE__; + offset = ZeroFillPadding(offset + val.hi); + } + + if (offset != file_size_) return __LINE__; + + return 0; // all OK + } + + static BlobStorePtr Allocate(uint64_t total_size) { + uint8_t* bytes = + static_cast(hwy::AllocateAlignedBytes(total_size)); + if (!bytes) return BlobStorePtr(); + return BlobStorePtr(new (bytes) BlobStore(), hwy::AlignedFreer()); + } + + static std::vector PrepareWriteRequests( + const hwy::uint128_t keys[], const hwy::Span blobs[], + size_t num_blobs) { + // Sanity check and ensure the cast below is safe. + HWY_ASSERT(num_blobs < (1ULL << 20)); + + // Allocate var-length header. + const size_t header_size = HeaderSize(num_blobs); + const size_t padded_header_size = hwy::RoundUpTo(header_size, kAlign); + BlobStorePtr bs = Allocate(padded_header_size); + const uint64_t padded_header_end = bs->ZeroFillPadding(header_size); + HWY_ASSERT(padded_header_end == padded_header_size); + + // All-zero buffer used to write padding to the file without copying the + // input blobs. + static uint8_t zeros[kAlign] = {0}; + + // Total file size will be the header plus all padded blobs. + uint64_t payload = 0; + for (size_t i = 0; i < num_blobs; ++i) { + payload += hwy::RoundUpTo(blobs[i].size(), kAlign); + } + const size_t total_size = padded_header_size + payload; + + // Fill header. + bs->magic_ = kMagic; + bs->num_blobs_ = static_cast(num_blobs); + bs->file_size_ = total_size; + hwy::CopyBytes(keys, bs->keys_, num_blobs * sizeof(keys[0])); + + // First IO request is for the header (not yet filled!). + std::vector requests; + requests.reserve(1 + 2 * num_blobs); + requests.emplace_back(/*offset=*/0, padded_header_size, + reinterpret_cast(bs.get()), 0); + + // Fill second half of keys_ with offset/size and prepare IO requests. + uint64_t offset = padded_header_end; + for (size_t i = 0; i < num_blobs; ++i) { + bs->keys_[num_blobs + i].lo = offset; + bs->keys_[num_blobs + i].hi = blobs[i].size(); + + EnqueueChunkRequests(offset, blobs[i].size(), blobs[i].data(), requests); + offset += blobs[i].size(); + const size_t padded_size = hwy::RoundUpTo(blobs[i].size(), kAlign); + if (padded_size != blobs[i].size()) { + const size_t padding = padded_size - blobs[i].size(); + HWY_ASSERT(padding <= kAlign); + requests.emplace_back(offset, padding, zeros, 0); + offset += padding; + } + } + + HWY_ASSERT(offset == total_size); + return requests; + } + + bool FindKey(const hwy::uint128_t key, uint64_t& offset, size_t& size) const { + for (size_t i = 0; i < num_blobs_; ++i) { + if (keys_[i] == key) { + const hwy::uint128_t val = keys_[num_blobs_ + i]; + offset = val.lo; + size = val.hi; + return true; + } + } + return false; + } + + private: + uint32_t magic_; + uint32_t num_blobs_; // never 0 + uint64_t file_size_; // must match actual size of file + hwy::uint128_t keys_[1]; // length: 2 * num_blobs + // Padding, then the blob identified by keys[0], then padding etc. +}; +#pragma pack(pop) + +BlobError BlobReader::Open(const char* filename) { + fd_ = open(filename, O_RDONLY); + if (fd_ < 0) return __LINE__; + +#if _POSIX_C_SOURCE >= 200112L + // Doubles the readahead window, which seems slightly faster when cached. + (void)posix_fadvise(fd_, 0, 0, POSIX_FADV_SEQUENTIAL); +#endif + + // Read first part of header to get actual size. + BlobStore bs; + if (!IO::Read(fd_, 0, sizeof(bs), &bs)) return __LINE__; + const size_t padded_size = bs.PaddedHeaderSize(); + HWY_ASSERT(padded_size >= sizeof(bs)); + + // Allocate full header. + blob_store_ = BlobStore::Allocate(padded_size); + if (!blob_store_) return __LINE__; + + // Copy what we already read (more efficient than seek + re-read). + hwy::CopySameSize(&bs, blob_store_.get()); + // Read the rest of the header, but not the full file. + uint8_t* bytes = reinterpret_cast(blob_store_.get()); + if (!IO::Read(fd_, sizeof(bs), padded_size - sizeof(bs), + bytes + sizeof(bs))) { + return __LINE__; + } + + return blob_store_->CheckValidity(IO::FileSize(filename)); +} + +BlobReader::~BlobReader() { + if (fd_ >= 0) { + HWY_ASSERT(close(fd_) != -1); + } +} + +BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) { + uint64_t offset; + size_t actual_size; + if (!blob_store_->FindKey(key, offset, actual_size)) return __LINE__; + if (actual_size != size) return __LINE__; + + EnqueueChunkRequests(offset, actual_size, reinterpret_cast(data), + requests_); + return 0; +} + +// Parallel synchronous I/O. Alternatives considered: +// - readv is limited to 0x7FFFF000 bytes on Linux (even 64-bit). Note that +// pread calls preadv with a single iovec. +// - O_DIRECT seems undesirable because we do want to use the OS cache +// between consecutive runs. +// - memory-mapped I/O is less predictable and adds noise to measurements. +BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) { + const int fd = fd_; + const auto& requests = requests_; + std::atomic_flag err = ATOMIC_FLAG_INIT; + // >5x speedup from parallel reads when cached. + pool.Run(0, requests.size(), + [fd, &requests, &err](uint64_t i, size_t /*thread*/) { + if (!IO::Read(fd, requests[i].offset, requests[i].size, + requests[i].data)) { + err.test_and_set(); + } + }); + if (err.test_and_set()) return __LINE__; + return 0; +} + +BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, + const char* filename) const { + HWY_ASSERT(keys_.size() == blobs_.size()); + + // Concatenate blobs in memory. + std::vector requests = BlobStore::PrepareWriteRequests( + keys_.data(), blobs_.data(), keys_.size()); + + // Create/replace existing file. + const int fd = open(filename, O_CREAT | O_RDWR | O_TRUNC, 0644); + if (fd < 0) return __LINE__; + + std::atomic_flag err = ATOMIC_FLAG_INIT; + pool.Run(0, requests.size(), + [fd, &requests, &err](uint64_t i, size_t /*thread*/) { + if (!IO::Write(requests[i].data, requests[i].size, + requests[i].offset, fd)) { + err.test_and_set(); + } + }); + if (err.test_and_set()) return __LINE__; + return 0; +} + +} // namespace gcpp diff --git a/compression/blob_store.h b/compression/blob_store.h new file mode 100644 index 0000000..6ced37f --- /dev/null +++ b/compression/blob_store.h @@ -0,0 +1,90 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_BLOB_STORE_H_ +#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_BLOB_STORE_H_ + +#include +#include + +#include + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" // hwy::uint128_t +#include "hwy/contrib/thread_pool/thread_pool.h" + +namespace gcpp { + +// Convenient way to construct a key from a string (<= 16 chars). +hwy::uint128_t MakeKey(const char* string); + +// Ordered list of opaque blobs (~hundreds), identified by unique opaque +// 128-bit keys. +class BlobStore; + +// Incomplete type, so dtor will not be called. +using BlobStorePtr = hwy::AlignedFreeUniquePtr; + +// 0 if successful, otherwise the line number of the failing check. +using BlobError = int; + +struct BlobIO { + BlobIO(uint64_t offset, size_t size, void* data, uint64_t padding) + : offset(offset), size(size), data(data), padding(padding) {} + + uint64_t offset; + size_t size; + void* data; + uint64_t padding; +}; + +class BlobReader { + public: + BlobReader() { requests_.reserve(500); } + ~BlobReader(); + + // Opens `filename` and reads its header. + BlobError Open(const char* filename); + + // Enqueues read requests if `key` is found and its size matches `size`. + BlobError Enqueue(hwy::uint128_t key, void* data, size_t size); + + // Reads all enqueued requests. + BlobError ReadAll(hwy::ThreadPool& pool); + + private: + BlobStorePtr blob_store_; // holds header, not the entire file + std::vector requests_; + int fd_ = 0; +}; + +class BlobWriter { + public: + void Add(hwy::uint128_t key, void* data, size_t size) { + keys_.push_back(key); + blobs_.emplace_back(static_cast(data), size); + } + + // Stores all blobs to disk in the given order with padding for alignment. + BlobError WriteAll(hwy::ThreadPool& pool, const char* filename) const; + + private: + std::vector keys_; + std::vector> blobs_; +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_BLOB_STORE_H_ diff --git a/compression/compress-inl.h b/compression/compress-inl.h new file mode 100644 index 0000000..588f5c6 --- /dev/null +++ b/compression/compress-inl.h @@ -0,0 +1,467 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Include guard for headers. +#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_INL_H_ +#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_INL_H_ + +#include +#include +#include + +#include + +// copybara:import_next_line:gemma_cpp +#include "compression/blob_store.h" +// copybara:import_next_line:gemma_cpp +#include "compression/compress.h" +// copybara:import_next_line:gemma_cpp +#include "compression/distortion.h" +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/timer.h" + +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_INL_H_ + +// Include guard for (potentially) SIMD code. +#if defined(THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE) == defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE +#undef THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE +#else +#define THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE +#endif + +// copybara:import_next_line:gemma_cpp +#include "compression/nuq-inl.h" +// copybara:import_next_line:gemma_cpp +#include "compression/sfp-inl.h" +#include "hwy/contrib/dot/dot-inl.h" +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +// Enables generic code independent of compression type. +template // primary, must specialize +struct CompressTraits {}; + +template <> +struct CompressTraits { + using MatT = float; + + template + static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in, + size_t num, CompressPerThread& tls, + size_t /*out_capacity*/, + MatT* HWY_RESTRICT out, size_t out_ofs) { + using VF = hn::Vec; + const size_t N = hn::Lanes(df); + HWY_DASSERT(num >= 2 * N && num % (2 * N) == 0); + + for (size_t i = 0; i < num; i += 2 * N) { + const VF in0 = hn::LoadU(df, in + i); + const VF in1 = hn::LoadU(df, in + i + N); + hn::StoreU(in0, df, out + out_ofs + i); + hn::StoreU(in1, df, out + out_ofs + i + N); + } + } + + template + static HWY_INLINE void Decompress(DF df, size_t /*in_capacity*/, + const MatT* HWY_RESTRICT in, size_t in_ofs, + float* HWY_RESTRICT out, size_t num) { + using VF = hn::Vec; + const size_t N = hn::Lanes(df); + HWY_DASSERT(num >= 2 * N && num % (2 * N) == 0); + + for (size_t i = 0; i < num; i += 2 * N) { + const VF in0 = hn::LoadU(df, in + in_ofs + i); + const VF in1 = hn::LoadU(df, in + in_ofs + i + N); + hn::StoreU(in0, df, out + i); + hn::StoreU(in1, df, out + i + N); + } + } + + // VecT can be float or hwy::bfloat16_t. + template + static HWY_INLINE float Dot(DF df, size_t /*in_capacity*/, + const MatT* HWY_RESTRICT in, size_t in_ofs, + const VecT* HWY_RESTRICT vec_aligned, + size_t num) { + HWY_DASSERT(num >= hn::Lanes(df) && (num % hn::Lanes(df)) == 0); + HWY_DASSERT(hn::IsAligned(df, vec_aligned)); + constexpr int kAssumptions = + hn::Dot::kAtLeastOneVector | hn::Dot::kMultipleOfVector; + // vec_aligned must be the second argument because hn::Dot supports f32*bf16 + // and f32*f32. + return hn::Dot::Compute(df, in + in_ofs, vec_aligned, num); + } +}; + +template <> +struct CompressTraits { + using MatT = hwy::bfloat16_t; + + template + static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in, + size_t num, CompressPerThread& tls, + size_t /*out_capacity*/, + MatT* HWY_RESTRICT out, size_t out_ofs) { + const hn::RebindToUnsigned du; + const hn::Repartition dbf; + using VF = hn::Vec; + const size_t N = hn::Lanes(df); + + hn::Vec or_sum = hn::Zero(du); + + size_t i = 0; + if (num >= 2 * N) { + for (; i <= num - 2 * N; i += 2 * N) { + const VF in0 = hn::LoadU(df, in + i); + const VF in1 = hn::LoadU(df, in + i + N); + + // Sticky bits so we can warn if any lower bits were set. + or_sum = hn::Or3(or_sum, hn::BitCast(du, in0), hn::BitCast(du, in1)); + hn::StoreU(hn::OrderedDemote2To(dbf, in0, in1), dbf, out + out_ofs + i); + + if (COMPRESS_STATS) { + DistortionStats stats; + for (size_t j = 0; j < 2 * N; ++j) { + stats.Notify(in[i + j], hwy::F32FromBF16(out[out_ofs + i + j])); + } + tls.stats.Notify(stats); + } + } + } + + size_t remaining = num - i; + if (remaining != 0) { + const VF in0 = hn::LoadN(df, in + i, remaining); + const size_t remaining1 = remaining - HWY_MIN(remaining, N / 2); + const VF in1 = hn::LoadN(df, in + i + N, remaining1); + + // Sticky bits so we can warn if any lower bits were set. + or_sum = hn::Or3(or_sum, hn::BitCast(du, in0), hn::BitCast(du, in1)); + hn::StoreU(hn::OrderedDemote2To(dbf, in0, in1), dbf, out + out_ofs + i); + + if (COMPRESS_STATS) { + DistortionStats stats; + for (size_t j = 0; j < remaining; ++j) { + stats.Notify(in[i + j], hwy::F32FromBF16(out[out_ofs + i + j])); + } + tls.stats.Notify(stats); + } + } + + // If the lower 16 bits are not zero, we should implement rounding. + or_sum = hn::And(or_sum, hn::Set(du, 0xFFFF)); + if (!hn::AllTrue(du, hn::Eq(or_sum, hn::Zero(du)))) { + // fprintf(stderr, "Warning: Lossy truncation."); + } + } + + template + static HWY_INLINE void Decompress(DF df, size_t /*in_capacity*/, + const MatT* HWY_RESTRICT in, size_t in_ofs, + float* HWY_RESTRICT out, size_t num) { + const hn::Repartition dbf; + using VBF = hn::Vec; + using VF = hn::Vec; + const size_t N16 = hn::Lanes(dbf); + + size_t i = 0; + if (num >= N16) { + for (i = 0; i <= num - N16; i += N16) { + const VBF in16 = hn::LoadU(dbf, in + in_ofs + i); + const VF in0 = hn::PromoteLowerTo(df, in16); + const VF in1 = hn::PromoteUpperTo(df, in16); + hn::StoreU(in0, df, out + i); + hn::StoreU(in1, df, out + i + N16 / 2); + } + } + + size_t remaining = num - i; + if (remaining != 0) { + const VBF in16 = hn::LoadN(dbf, in + in_ofs + i, remaining); + const VF in0 = hn::PromoteLowerTo(df, in16); + const VF in1 = hn::PromoteUpperTo(df, in16); + hn::StoreN(in0, df, out + i, remaining); + // Avoid wraparound, potentially store nothing. + const size_t remaining1 = remaining - HWY_MIN(remaining, N16 / 2); + hn::StoreN(in1, df, out + i + N16 / 2, remaining1); + } + } + + // VecT can be float or hwy::bfloat16_t. + template + static HWY_INLINE float Dot(DF df, size_t /*in_capacity*/, + const MatT* HWY_RESTRICT in, size_t in_ofs, + const VecT* HWY_RESTRICT vec_aligned, + size_t num) { + HWY_DASSERT(num >= hn::Lanes(df) && (num % hn::Lanes(df)) == 0); + HWY_DASSERT(hn::IsAligned(df, vec_aligned)); + + const hn::Repartition d_vec; + + constexpr int kAssumptions = + hn::Dot::kAtLeastOneVector | hn::Dot::kMultipleOfVector; + // vec_aligned must be first argument because hn::Dot supports f32*bf16 and + // bf16*bf16. + return hn::Dot::Compute(d_vec, vec_aligned, in + in_ofs, num); + } +}; + +template <> +struct CompressTraits { + using MatT = SfpStream; + + template + static HWY_INLINE void Compress(DF df, const float* in, size_t num, + CompressPerThread& tls, + size_t /*out_capacity*/, MatT* out, + size_t out_ofs) { + SfpCodec::Enc(df, in, num, out + out_ofs); + + if (COMPRESS_STATS) { + const hn::Repartition dbf; + auto distorted = hwy::AllocateAligned(num); + SfpCodec::Dec(dbf, out + out_ofs, num, distorted.get()); + DistortionStats stats; + for (size_t i = 0; i < num; ++i) { + stats.Notify(in[i], hwy::F32FromBF16(distorted[i])); + } + tls.stats.Notify(stats); + } + } + + template + static HWY_INLINE void Decompress(D d, size_t /*in_capacity*/, const MatT* in, + size_t in_ofs, OutT* out, size_t num) { + SfpCodec::Dec(d, in + in_ofs, num, out); + } + + template + static HWY_INLINE float Dot(DF df, size_t /*in_capacity*/, const MatT* in, + size_t in_ofs, const VecT* vec_aligned, + size_t num) { + using VF = hn::Vec; + VF sum0 = hn::Zero(df); + VF sum1 = hn::Zero(df); + VF sum2 = hn::Zero(df); + VF sum3 = hn::Zero(df); + + SfpCodec::Dot(df, in + in_ofs, num, vec_aligned, sum0, sum1, sum2, sum3); + + // Reduction tree: sum of all accumulators, then their lanes + sum0 = hn::Add(sum0, sum1); + sum2 = hn::Add(sum2, sum3); + sum0 = hn::Add(sum0, sum2); + return hn::ReduceSum(df, sum0); + } +}; + +template <> +struct CompressTraits { + using MatT = NuqStream; + + template + static HWY_INLINE void Compress(DF df, const float* in, size_t num, + CompressPerThread& tls, size_t out_capacity, + MatT* out, size_t out_ofs) { + NuqCodec::Enc(df, in, num, tls.buf, out_capacity, out, out_ofs); + + if (COMPRESS_STATS) { + for (size_t i = 0; i < num; ++i) { + tls.stats.NotifyIn(in[i] * 100 + 500); + } + + const hn::Repartition dbf; + auto distorted = hwy::AllocateAligned(num); + NuqCodec::Dec(dbf, out_capacity, out, out_ofs, distorted.get(), num); + DistortionStats stats; + for (size_t i = 0; i < num; ++i) { + stats.Notify(in[i], hwy::F32FromBF16(distorted[i])); + } + tls.stats.Notify(stats); + } + } + + template + static HWY_INLINE void Decompress(D d, size_t in_capacity, const MatT* in, + size_t in_ofs, OutT* out, size_t num) { + NuqCodec::Dec(d, in_capacity, in, in_ofs, out, num); + } + + template + static HWY_INLINE float Dot(DF df, size_t in_capacity, const MatT* in, + size_t in_ofs, + const VecT* HWY_RESTRICT vec_aligned, + size_t num) { + using VF = hn::Vec; + VF sum0 = hn::Zero(df); + VF sum1 = hn::Zero(df); + VF sum2 = hn::Zero(df); + VF sum3 = hn::Zero(df); + + NuqCodec::Dot(df, in_capacity, in, in_ofs, vec_aligned, num, sum0, sum1, + sum2, sum3); + + // Reduction tree: sum of all accumulators, then their lanes + sum0 = hn::Add(hn::Add(sum0, sum1), hn::Add(sum2, sum3)); + return hn::ReduceSum(df, sum0); + } +}; + +// Compresses `num` inputs to `out` starting at `out_ofs`. This can be used for +// compressing sub-regions of an array. +template +HWY_NOINLINE void Compress(const float* in, size_t num, + CompressWorkingSet& work, size_t out_capacity, + MatT* out, size_t out_ofs, hwy::ThreadPool& pool) { + HWY_DASSERT(out_ofs + num <= out_capacity); + work.tls.resize(pool.NumThreads()); + if (COMPRESS_STATS) { + for (auto& tls : work.tls) { + tls.stats.Reset(); + } + } + + const double t0 = hwy::platform::Now(); + + using Traits = CompressTraits; + constexpr size_t kBatch = 8192; + const size_t num_batches = hwy::DivCeil(num, kBatch); + pool.Run(0, num_batches, + [&](const uint32_t idx_batch, size_t thread) HWY_ATTR { + const hn::ScalableTag df; + + const size_t in_ofs = idx_batch * kBatch; + const size_t my_num = + idx_batch == num_batches - 1 ? (num - in_ofs) : kBatch; + Traits::Compress(df, in + in_ofs, my_num, work.tls[thread], + out_capacity, out, out_ofs + in_ofs); + }); + + const double t1 = hwy::platform::Now(); + const double mb = num * sizeof(in[0]) * 1E-6; + const double mbps = mb / (t1 - t0); + fprintf(stderr, "Compress %.1f MB/s\n", mbps); + + if (COMPRESS_STATS) { + for (size_t i = 1; i < work.tls.size(); ++i) { + work.tls[0].stats.Assimilate(work.tls[i].stats); + } + work.tls[0].stats.PrintAll(); + } +} + +// Compresses an entire std::array into `out`, which is assumed to have exactly +// that much capacity. +template +HWY_INLINE void Compress(const std::array& in, + CompressWorkingSet& work, + CompressedArray& compressed, + hwy::ThreadPool& pool) { + Compress(in.data(), kCapacity, work, kCapacity, compressed.data(), 0, pool); +} + +// Decompresses `num` values from `compressed` starting at `compressed_ofs`. +template +HWY_NOINLINE void Decompress(const CompressedArray& compressed, + size_t compressed_ofs, OutT* out, size_t num) { + HWY_DASSERT(compressed_ofs + num <= compressed.NumElements()); + const hn::ScalableTag d; + using Traits = CompressTraits; + Traits::Decompress(d, kCapacity, compressed.data(), compressed_ofs, out, num); +} + +// As above, but with threading and benchmarking. +template +HWY_INLINE void Decompress(const CompressedArray& compressed, + size_t compressed_ofs, OutT* out, size_t num, + hwy::ThreadPool& pool) { + HWY_DASSERT(compressed_ofs + num <= compressed.NumElements()); + const double t0 = hwy::platform::Now(); + + using Traits = CompressTraits; + constexpr size_t kBatch = 8192; + const size_t num_batches = hwy::DivCeil(num, kBatch); + pool.Run( + 0, num_batches, [&](const uint32_t idx_batch, size_t thread) HWY_ATTR { + const hn::ScalableTag d; + + const size_t ofs = idx_batch * kBatch; + const size_t num = idx_batch == num_batches - 1 ? (num - ofs) : kBatch; + Traits::Decompress(d, compressed.NumElements(), compressed.data(), + compressed_ofs + ofs, out + ofs, num); + }); + + const double t1 = hwy::platform::Now(); + const double mb = num * sizeof(MatT) * 1E-6; + const double mbps = mb / (t1 - t0); + fprintf(stderr, "Decompress %.1f MB/s\n", mbps); +} + +// Returns dot product with `vec_aligned` of length `num`. +template +HWY_INLINE float Dot(DF df, const CompressedArray& compressed, + size_t compressed_ofs, const VecT* vec_aligned, + size_t num) { + HWY_DASSERT(compressed_ofs + num <= compressed.NumElements()); + HWY_DASSERT(hn::IsAligned(df, vec_aligned)); + using Traits = CompressTraits; + return Traits::Dot(df, kCapacity, compressed.data(), compressed_ofs, + vec_aligned, num); +} + +// Callback used by ForeachTensor. +class Compressor { + public: + explicit Compressor(hwy::ThreadPool& pool) : pool_(pool) {} + + // Called for each tensor; compresses it and stores to the cache. + template + void operator()(const char* name, const float* weights, + CompressedArray& compressed) { + fprintf(stderr, "Regenerating %s (%zuM), please wait\n", name, + kCapacity / (1000 * 1000)); + Compress(weights, kCapacity, work_, kCapacity, compressed.data(), 0, pool_); + writer_.Add(CacheKey(name), compressed.data(), + compressed.CompressedSize()); + } + + void WriteAll(hwy::ThreadPool& pool, const char* blob_filename) { + const BlobError err = writer_.WriteAll(pool, blob_filename); + if (err != 0) { + fprintf(stderr, "Failed to write blobs to %s (error %d)\n", blob_filename, + err); + } + } + + private: + CompressWorkingSet work_; + hwy::ThreadPool& pool_; + BlobWriter writer_; +}; + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#endif // NOLINT diff --git a/compression/compress.h b/compression/compress.h new file mode 100644 index 0000000..e09d7e5 --- /dev/null +++ b/compression/compress.h @@ -0,0 +1,215 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Target-independent definitions. +#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_ +#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_ + +#define COMPRESS_STATS 0 + +#include +#include + +#include +#include +#include + +// IWYU pragma: begin_exports +// copybara:import_next_line:gemma_cpp +#include "compression/blob_store.h" +// copybara:import_next_line:gemma_cpp +#include "compression/nuq.h" +// copybara:import_next_line:gemma_cpp +#include "compression/sfp.h" +// IWYU pragma: end_exports +// copybara:import_next_line:gemma_cpp +#include "compression/distortion.h" +#include "hwy/base.h" // hwy::bfloat16_t +#include "hwy/contrib/thread_pool/thread_pool.h" +#if COMPRESS_STATS +// copybara:import_next_line:gemma_cpp +#include "compression/stats.h" +#endif + +namespace gcpp { + +static inline const char* TypeName(float) { return "f32"; } +static inline const char* TypeName(hwy::bfloat16_t) { return "b16"; } + +namespace detail { +// How many MatT are required to store `capacity` weights. For all but +// NuqStream, this is the same as `capacity`. For use by CompressedArray. +template +constexpr size_t CompressedArrayLen(size_t capacity) { + return capacity; +} +template <> +constexpr size_t CompressedArrayLen(size_t capacity) { + return NuqStream::PackedEnd(capacity); +} +} // namespace detail + +// Compressed representation of floating-point elements. The array length may +// differ from the number of elements. Associated operations such as Dot are +// implemented in SIMD code and are thus non-member functions. +template +class CompressedArray { + static constexpr size_t NumCompressed() { + return detail::CompressedArrayLen(kCapacity); + } + + public: + MatT* data() { return data_.data(); } + const MatT* data() const { return data_.data(); } + + constexpr size_t NumElements() const { return kCapacity; } + + constexpr size_t CompressedSize() const { + return NumCompressed() * sizeof(MatT); + } + + private: + std::array data_; +}; + +#if COMPRESS_STATS +class CompressStats { + public: + void Notify(const DistortionStats& stats) { + const float pnorm = stats.PNorm(); + const float snr = stats.GeomeanValueDivL1(); + num_exact_ += stats.NumExact(); + s_pnorm_.Notify(pnorm); + // No loss - skip to avoid dragging down the average. + if (snr != 0.0f) { + s_snr_.Notify(snr); + } + } + + void NotifyIn(int sfp) { hist_weights_.Notify(sfp); } + + void Assimilate(const CompressStats& other) { + s_pnorm_.Assimilate(other.s_pnorm_); + s_snr_.Assimilate(other.s_snr_); + num_exact_ += other.num_exact_; + hist_weights_.Assimilate(other.hist_weights_); + } + + void PrintAll() { + const int skip = Stats::kNoGeomean; + fprintf(stderr, " pnorm %s\n", s_pnorm_.ToString(skip).c_str()); + fprintf(stderr, " SNR %s\n", s_snr_.ToString(skip).c_str()); + fprintf(stderr, " #exact %.3E\n", static_cast(num_exact_)); + // hist_weights_.Print("indices"); + } + + void Reset() { + s_pnorm_.Reset(); + s_snr_.Reset(); + num_exact_ = 0; + hist_weights_.Reset(); + } + + private: + Stats s_pnorm_; + Stats s_snr_; + size_t num_exact_ = 0; + Bins<1000> hist_weights_; + char padding_[64]; // prevent false sharing +}; +#else +struct CompressStats { + void Notify(const DistortionStats&) {} + void NotifyIn(int) {} + void Assimilate(const CompressStats&) {} + void PrintAll() {} + void Reset() {} +}; +#endif // COMPRESS_STATS + +struct CompressPerThread { + CompressStats stats; + ClusterBuf buf; +}; + +struct CompressWorkingSet { + std::vector tls; +}; + +// Returns key for the given tensor name. Also encodes the type, so that +// changing the representation automatically invalidates prior cached files +// (the new blob name will not be found). +template +hwy::uint128_t CacheKey(const char* name) { + // Already used/retired: s, S, n, 1 + const char prefix = hwy::IsSame() ? 'F' + : hwy::IsSame() ? 'B' + : hwy::IsSame() ? '$' + : hwy::IsSame() ? '2' + : '?'; + + return MakeKey((std::string(1, prefix) + name).c_str()); +} + +class CacheLoader { + public: + explicit CacheLoader(const char* blob_filename) { + err_ = reader_.Open(blob_filename); + if (err_ != 0) { + fprintf(stderr, + "Cached compressed weights does not exist yet (code %d), " + "compressing weights and creating file: %s.\n", + err_, blob_filename); + } + } + + // Called for each tensor, enqueues read requests. + template + void operator()(const char* name, const float* null, + CompressedArray& compressed) { + HWY_DASSERT(null == nullptr); + + // Skip if reader_ is invalid or any load failed: we will regenerate + // everything because it's rare to update only a few tensors. + if (err_ != 0) return; + + err_ = reader_.Enqueue(CacheKey(name), compressed.data(), + compressed.CompressedSize()); + if (err_ != 0) { + fprintf(stderr, "Failed to read cache %s (error %d)\n", name, err_); + } + } + + // Returns whether all tensors are successfully loaded from cache. + bool ReadAll(hwy::ThreadPool& pool) { + // reader_ invalid or any Enqueue failed + if (err_ != 0) return false; + + err_ = reader_.ReadAll(pool); + if (err_ != 0) { + fprintf(stderr, "Failed to read all tensors (error %d)\n", err_); + return false; + } + + return true; + } + + private: + BlobReader reader_; + BlobError err_ = 0; +}; + +} // namespace gcpp +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_ diff --git a/compression/distortion.h b/compression/distortion.h new file mode 100644 index 0000000..8c0742a --- /dev/null +++ b/compression/distortion.h @@ -0,0 +1,99 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_ +#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_ +#include // pow +#include + +#include "hwy/base.h" // ScalarAbs + +namespace gcpp { + +class DistortionStats { + public: + void Notify(float original, float distorted) { + const double l1 = hwy::ScalarAbs(original - distorted); + + if (l1 > max_l1_) { + max_l1_ = l1; + max_idx_ = n_; + } + + const double pow3 = l1 * l1 * l1; + sum_pow3_ += pow3; + sum_pow6_ += pow3 * pow3; + n_ += 1; + + // Avoid division by zero, which happens when there is no error. NumExact() + // reports the number of times this happens. + if (l1 != 0.0) { + const double rel = 1.0 + hwy::ScalarAbs(original) / l1; + // Logarithm is required to prevent overflow. A hierarchical geomean + // could also work, but that is more complex and not necessarily better. + sum_log_rel_ += log(rel); + num_rel_ += 1; + } + } + + void Assimilate(const DistortionStats& other) { + if (other.max_l1_ > max_l1_) { + max_l1_ = other.max_l1_; + max_idx_ = other.max_idx_; + } + + sum_pow3_ += other.sum_pow3_; + sum_pow6_ += other.sum_pow6_; + n_ += other.n_; + + sum_log_rel_ += other.sum_log_rel_; + num_rel_ += other.num_rel_; + } + + size_t NumExact() const { return n_ - num_rel_; } + + double GeomeanValueDivL1() const { + if (num_rel_ == 0) return 0.0; + return exp(sum_log_rel_ / num_rel_); + } + + double PNorm() const { + // p-norms are a compromise between max-norm (penalizes the largest error + // without dilution, but does not notice any other errors) and L1 (all + // errors contribute, but large errors are diluted by smaller ones). + const double norm3 = pow(sum_pow3_ / n_, 1.0 / 3); + const double norm6 = pow(sum_pow6_ / n_, 1.0 / 6); + return 0.5 * (norm3 + norm6); + } + + size_t MaxIndex() const { return max_idx_; } + double MaxL1() const { return max_l1_; } + + private: + size_t n_ = 0; + size_t max_idx_ = 0; // index that had l1 = max_l1_. + double max_l1_ = -1.0; + + double sum_pow3_ = 0.0; + double sum_pow6_ = 0.0; + + double sum_log_rel_ = 0.0; + size_t num_rel_ = 0; + double padding_; // prevents false sharing +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_ diff --git a/compression/nuq-inl.h b/compression/nuq-inl.h new file mode 100644 index 0000000..767014a --- /dev/null +++ b/compression/nuq-inl.h @@ -0,0 +1,730 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Normal include guard. +#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_ +#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_ + +#include +#include + +// copybara:import_next_line:gemma_cpp +#include "compression/nuq.h" +// copybara:import_next_line:gemma_cpp +#include "compression/sfp.h" +#include "hwy/base.h" + +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_ + +// Actual per-target include guard. +#if defined(THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE +#undef THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE +#else +#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE +#endif + +// copybara:import_next_line:gemma_cpp +#include "compression/sfp-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +// For internal use by NuqCodec. +class NuqClustering { + // To go from sorted order back to the original order in O(1), we store the + // original index in the lower bits of the float32 mantissa, which means they + // are sorted alongside the value. + struct FloatPayload { + // Resets payload to zero; useful for displaying the actual value. + static HWY_INLINE float Clear(float f) { + const uint32_t binary32 = hwy::BitCastScalar(f); + return hwy::BitCastScalar(binary32 & + ~static_cast(kGroupSize - 1)); + } + + // Sets payload to `bits`. + static HWY_INLINE float Set(float f, size_t bits) { + HWY_DASSERT(bits < kGroupSize); + const uint32_t binary32 = hwy::BitCastScalar(Clear(f)); + return hwy::BitCastScalar(static_cast(binary32 | bits)); + } + + // Obtains the payload (index) previously set by `Set`. + static HWY_INLINE size_t Get(float f) { + return hwy::BitCastScalar(f) & + static_cast(kGroupSize - 1); + } + }; + + // Cumulative sums for O(1) mean and interval sums. + class ClusterCost { + public: + explicit ClusterCost(const float* sorted) { + cumsum_[0] = cumsum2_[0] = 0.0; + for (size_t i = 0; i < kGroupSize; ++i) { + const float x = FloatPayload::Clear(sorted[i]); + cumsum_[1 + i] = x + cumsum_[i]; + cumsum2_[1 + i] = x * x + cumsum2_[i]; + } + + inv_len_[0] = 0.0f; // unused + for (size_t i = 0; i <= kGroupSize; ++i) { + inv_len_[i] = 1.0f / i; + } + } + + float SumOfSorted(size_t first, size_t last) const { + return cumsum_[last + 1] - cumsum_[first]; + } + + // Returns cost of clustering first..last with their mean, for a vector of + // last. O(1) thanks to cumulative sums, which works for Lp-norms with p > + // 1; we choose p=2 for simplicity (fewer terms). + template + hn::Vec operator()(DF df, size_t first, size_t last) const { + // Callers are responsible for ignoring lanes where last < first. + HWY_DASSERT(first < kGroupSize); + HWY_DASSERT(last < kGroupSize); + const size_t len = last - first + 1; + const hn::Vec vlen = + hn::Iota(df, static_cast(static_cast(len))); + + const hn::Vec u_lo = hn::Set(df, cumsum_[first]); + const hn::Vec u_lo2 = hn::Set(df, cumsum2_[first]); + const hn::Vec hi = hn::LoadU(df, cumsum_ + last + 1); + const hn::Vec hi2 = hn::LoadU(df, cumsum2_ + last + 1); + const hn::Vec sum = hn::Sub(hi, u_lo); + const hn::Vec sum2 = hn::Sub(hi2, u_lo2); + + // Compute mean: table lookup is faster than division. + const hn::Vec mu = hn::Mul(sum, hn::LoadU(df, inv_len_ + len)); + + // (x - mu)^2 = sum2 - 2mu*sum + mu^2 + const hn::Vec mu2 = hn::Mul(mu, mu); + const hn::Vec two_mu = hn::Add(mu, mu); + return hn::NegMulAdd(two_mu, sum, hn::MulAdd(vlen, mu2, sum2)); + } + + private: + // Float has enough precision for our relatively small kGroupSize (128). + float cumsum_[kGroupSize + 1]; + float cumsum2_[kGroupSize + 1]; + float inv_len_[kGroupSize + 1]; + }; + + // Cost of clustering 0..last, where the rightmost cluster is j..last. This is + // called in a loop over j, and we return the vector of costs for a batch of + // last = [last, last + N). + template + static HWY_INLINE hn::Vec ClusterDynProg( + DF df, const AlignedMatrix& D, const ClusterCost& cc, + const size_t num_clusters, const size_t last, const size_t j) { + HWY_DASSERT(last < kGroupSize); + HWY_DASSERT(0 != j && j < kGroupSize); + + const hn::RebindToSigned di; + using VF = hn::Vec; + using VI = hn::Vec; + using MI = hn::Mask; + + const VI vlast = hn::Iota(di, static_cast(last)); + + // We have a non-empty rightmost cluster if j <= last <==> j-1 < last. + const MI valid = hn::Lt(hn::Set(di, static_cast(j) - 1), vlast); + // If not valid, return an arbitrary high cost, which will not be the min. + const VF max = hn::Set(df, 1E38f); + // Cost of clustering 0..j-1 with one fewer cluster than now. + const VF vd = hn::Set(df, D(num_clusters - 1, j - 1)); + // Eq2: add to that the cost of another cluster from j..last. + return hn::MaskedAddOr(max, RebindMask(df, valid), vd, cc(df, j, last)); + } + + public: + // Clusters `kGroupSize` values in `x`, which need not be sorted already nor + // aligned, by choosing and filling `centers` (size `kClusters`, ascending + // order, not necessarily equal to one of the `x`). Fills `indices` with the + // index of the cluster to which each `x` belongs (16-bit for bit-packing). + // `buf` is per-thread. + // + // Returns the number of unused clusters, i.e., the starting index within + // `centers`; prior centers are zero-initialized. + // + // O(kClusters * kGroupSize * kGroupSize), but the constant factors are so low + // that this is about 10 times as fast as the O(kClusters * kGroupSize) SMAWK + // as implemented in FAISS, for our kGroupSize <= 128. + template + static HWY_NOINLINE size_t ClusterExactL2(DF df, const float* x, + ClusterBuf& buf, + float* HWY_RESTRICT centers, + uint16_t* HWY_RESTRICT indices) { + const hn::RebindToSigned di; + using VF = hn::Vec; + using VI = hn::Vec; + const VI k1 = hn::Set(di, 1); + const size_t N = hn::Lanes(df); + + HWY_ALIGN float sorted_and_i[kGroupSize]; + for (size_t i = 0; i < kGroupSize; ++i) { + sorted_and_i[i] = FloatPayload::Set(x[i], i); + } + hn::VQSortStatic(sorted_and_i, kGroupSize, hwy::SortAscending()); + ClusterCost cc(sorted_and_i); + + // Reference: https://arxiv.org/abs/1701.07204 + // D[k-1][m] is the lowest cost of clustering x1..m into k clusters. + AlignedMatrix& D = buf.d; + // T[k][m] is the starting index within sorted_and_i[] of the k-th cluster. + AlignedMatrix& T = buf.t; + + // Initialize the first rows for a single cluster. + for (size_t last = 0; last < kGroupSize; last += N) { + hn::Store(cc(df, 0, last), df, &D(0, last)); // Cost of 0..last + hn::Store(Zero(di), di, &T(0, last)); // Cluster index = 0 + } + + for (size_t num_clusters = 1; num_clusters < kClusters; ++num_clusters) { + // For each batch starting at `last`, one per lane: + for (size_t last = 0; last < kGroupSize; last += N) { + VF min = cc(df, 0, last); + VI arg = hn::Zero(di); + // For each j (start of rightmost cluster): + VI vj = k1; + for (size_t j = 1; j < last + N; ++j, vj = Add(vj, k1)) { + const VF c = ClusterDynProg(df, D, cc, num_clusters, last, j); + + // Retain the min cost and the j index that caused it. + const auto less = hn::Lt(c, min); + min = hn::IfThenElse(less, c, min); + arg = hn::IfThenElse(RebindMask(di, less), vj, arg); + } + hn::Store(min, df, &D(num_clusters, last)); + hn::Store(arg, di, &T(num_clusters, last)); + } + } + + // Backtrack to find centers. Clusters are [T(k, last), last]. + size_t last = kGroupSize - 1; + size_t unused_clusters = 0; + for (size_t k = kClusters - 1; k < kClusters; --k) { + const size_t start = static_cast(T(k, last)); + // Center = mean, O(1) thanks to cumulative sums. + const float sum = cc.SumOfSorted(start, last); + const int size = static_cast(last) - static_cast(start) + 1; + HWY_DASSERT(0 < size && size <= kGroupSize); + centers[k] = sum / size; + + // We know the range inside sorted_and_i[]; translate to original indices, + // which are stored inside each of the sorted_and_i mantissas. + for (size_t i = start; i <= last; ++i) { + const size_t idx_x = FloatPayload::Get(sorted_and_i[i]); + HWY_DASSERT(idx_x < kGroupSize); + indices[idx_x] = static_cast(k); + } + + // Not using all clusters. Avoid out of bounds accesses by stopping early. + if (start == 0) { + unused_clusters = k; + for (size_t cluster = 0; cluster < unused_clusters; ++cluster) { + centers[cluster] = 0.0f; + } + break; + } + + last = start - 1; + HWY_DASSERT(last < kGroupSize); + } + + if (HWY_IS_DEBUG_BUILD) { + // Centers are in ascending order. + for (size_t i = unused_clusters + 1; i < kClusters; ++i) { + HWY_DASSERT(centers[i] >= centers[i - 1]); + } + } + return unused_clusters; + } +}; // NuqClustering + +// Bit-packing 4-bit values is trivial if we have 2 or 4 independent vectors: +// simply shift+OR them together into a full vector of 8 or 16-bit lanes. +// However, the order then depends on the vector length, which is unacceptable +// because we may store the encoding to disk and decode on another CPU. +// +// The dependency on vector length could be removed by introducing fixed-size +// packets and loading the next vector from a fixed offset known to be at +// least the vector length. However, this may require packets that are larger +// than the seek granularity of the application (e.g. matrix rows). +// +// We instead choose a continuous stream layout, which seems to entail the +// nibbles being stored and decoded in-order. This involves nontrivial shuffle +// operations which benefit from special-casing for target and vector length. +class NibbleCodec { + public: + // Packs four u16 vectors' lanes to nibbles within one vector, in order, and + // stores that vector to `out`. + template > + static HWY_INLINE void OrderedPackU16(D16 d16, V16 in0, V16 in1, V16 in2, + V16 in3, uint8_t* HWY_RESTRICT out) { + const hn::Repartition d8; + const hn::Repartition d32; + const hn::Repartition d64; + using V8 = hn::Vec; + + // Pairwise compaction of a single vector so nibbles are packed in-order. + // v16 lanes hold a 4-bit value; OR together adjacent pairs into the lower + // byte of *even* u16. + const auto combine_u16_pair_to_8 = [d16, d32](V16 v16) HWY_ATTR { + return hn::Xor( + v16, hn::BitCast(d16, hn::ShiftRight<12>(hn::BitCast(d32, v16)))); + }; + + const V16 u8_0 = combine_u16_pair_to_8(in0); + const V16 u8_1 = combine_u16_pair_to_8(in1); + const V16 u8_2 = combine_u16_pair_to_8(in2); + const V16 u8_3 = combine_u16_pair_to_8(in3); + V8 packed; + if (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) { + // 8-bit ConcatEven is efficient. Let digits denote eight u8 lanes + // of u8_1/0: ?d?3 ?c?2 / ?b?1 ?a?0. 8-bit ConcatEven = d3c2 b1a0, and + // again with the second x2_1 gives 7654 3210. + const V8 x2_0 = hn::ConcatEven(d8, BitCast(d8, u8_1), BitCast(d8, u8_0)); + const V8 x2_1 = hn::ConcatEven(d8, BitCast(d8, u8_3), BitCast(d8, u8_2)); + packed = hn::ConcatEven(d8, x2_1, x2_0); + } else { + // To avoid expensive 8-bit ConcatEven, compact pairs of u32 into the + // lower 16 bits in each u64, with other bits undefined. + const auto combine_u32_pair_to_16 = [d16, d64](V16 v16) HWY_ATTR { + return hn::Xor( + v16, hn::BitCast(d16, hn::ShiftRight<24>(hn::BitCast(d64, v16)))); + }; + const V16 u16_0 = combine_u32_pair_to_16(u8_0); + const V16 u16_1 = combine_u32_pair_to_16(u8_1); + const V16 u16_2 = combine_u32_pair_to_16(u8_2); + const V16 u16_3 = combine_u32_pair_to_16(u8_3); + // In-order compaction of four vectors into one, keeping only the low + // u16 of every u64. This is the same as above but with 16-bit Concat. + const V16 x2_0 = hn::ConcatEven(d16, u16_1, u16_0); + const V16 x2_1 = hn::ConcatEven(d16, u16_3, u16_2); + packed = hn::BitCast(d8, hn::ConcatEven(d16, x2_1, x2_0)); + } + hn::StoreU(packed, d8, out); + } + + // Unpacks `Lanes(d16)` nibbles to u16 lanes. The first comes from the low + // nibble of packed[0], then its high nibble, then the next low nibble, etc. + template > + static HWY_INLINE V16 OrderedUnpackU16(D16 d16, const uint8_t* packed) { + const hn::Repartition d8; + using V8 = hn::Vec; + const hn::CappedTag d_load; + + // We replicate each byte 4x, so that its two nibbles propagate to both + // u16 lanes that they will initialize. The only performance-portable op to + // replicate bytes is TableLookupBytes, which shuffles 128-bit blocks + // independently. Thus each block receives 4 packed bytes, replicates them + // 4x, shifts/masks, and casts to 8 u16 lanes. + // + // Loading 16 bytes via LoadDup128 only works on AVX3; for smaller vectors, + // it may trigger asan errors from overrunning the end. We thus special-case + // vector lengths, handling any non-constexpr, and constexpr <= 512 bit. + V8 rep4; + if (HWY_HAVE_SCALABLE) { + // Non constexpr length: 4 per whole block equals size/4. + const size_t num_bytes = HWY_MAX(1, hn::Lanes(d8) / 4); + const V8 bytes = hn::LoadN(d8, packed, num_bytes); + // Replicate bytes 4x: lowest 4 = 0, next 4 = 1 etc. + const V8 idx = hn::And(hn::Iota(d8, 0), hn::Set(d8, 0xFCu)); + rep4 = hn::TableLookupBytes(bytes, idx); + } else if (hn::MaxLanes(d16) <= 8) { // <= 128-bit + const V8 bytes = hn::ResizeBitCast(d8, hn::LoadU(d_load, packed)); + alignas(16) static constexpr uint8_t kRep4[16] = { + HWY_REP4(0), HWY_REP4(1), HWY_REP4(2), HWY_REP4(3)}; + rep4 = hn::TableLookupBytes(bytes, hn::Load(d8, kRep4)); + } else if (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) { + // Plain load, can do 256..512-bit permute across blocks. + const V8 bytes = hn::ResizeBitCast(d8, hn::LoadU(d_load, packed)); + alignas(64) static constexpr uint8_t kRep4[64] = { + HWY_REP4(0), HWY_REP4(1), HWY_REP4(2), HWY_REP4(3), + HWY_REP4(4), HWY_REP4(5), HWY_REP4(6), HWY_REP4(7), + HWY_REP4(8), HWY_REP4(9), HWY_REP4(10), HWY_REP4(11), + HWY_REP4(12), HWY_REP4(13), HWY_REP4(14), HWY_REP4(15)}; + rep4 = hn::TableLookupLanes(bytes, hn::SetTableIndices(d8, kRep4)); + } else if (hn::MaxLanes(d16) == 16) { // 256-bit + const V8 bytes = hn::ResizeBitCast(d8, hn::LoadU(d_load, packed)); + // First copy to upper block for TableLookupBytes. This is slightly + // faster than 64-bit BroadcastLane. + const V8 bcast = hn::ConcatLowerLower(d8, bytes, bytes); + alignas(32) static constexpr uint8_t kRep4[32] = { + HWY_REP4(0), HWY_REP4(1), HWY_REP4(2), HWY_REP4(3), + HWY_REP4(4), HWY_REP4(5), HWY_REP4(6), HWY_REP4(7)}; + rep4 = hn::TableLookupBytes(bcast, hn::Load(d8, kRep4)); + } else if (hn::MaxLanes(d16) == 32) { // 512-bit + const V8 bytes = hn::LoadDup128(d8, packed); + alignas(64) static constexpr uint8_t kRep4[64] = { + HWY_REP4(0), HWY_REP4(1), HWY_REP4(2), HWY_REP4(3), + HWY_REP4(4), HWY_REP4(5), HWY_REP4(6), HWY_REP4(7), + HWY_REP4(8), HWY_REP4(9), HWY_REP4(10), HWY_REP4(11), + HWY_REP4(12), HWY_REP4(13), HWY_REP4(14), HWY_REP4(15)}; + rep4 = hn::TableLookupBytes(bytes, hn::Load(d8, kRep4)); + } else { + HWY_DASSERT(false); + } + + const V16 mask4 = hn::Set(d16, 0xF); + const V16 u16 = BitCast(d16, rep4); + // In-order unpack. Right-shift odd u16 by 4. Example with two packed + // bytes, one digit representing a nibble: + // 32 32 32 32 | 10 10 10 10 u16 + // z3 23 32 32 | z1 01 10 10 OddEven+ShiftRight + // zz z3 zz z2 | zz z1 zz z0 And (unpacked result) + return hn::And(mask4, hn::OddEven(hn::ShiftRight<4>(u16), u16)); + } +}; + +// Encode/decode functions. +class NuqCodec { + // 256-bit vectors can hold 16 bf16, otherwise we require 2x128-bit. + template + static constexpr size_t NumTables(DU du) { + return (!HWY_HAVE_SCALABLE && du.MaxBytes() >= 32) ? 1 : 2; + } + + // Unpacks `centers` from SFP into bf16 and loads them into one or two vectors + // for use by [Two]TableLookups. Returns as u16 because TableLookupLanes might + // not be available for bf16. + template + static HWY_INLINE hn::Vec LoadTable(DU du, const uint8_t* centers, + hn::Vec* HWY_RESTRICT tbl1) { + // Cap to the table size (kClusters) for decoding SFP - sufficient, and may + // be faster than a large vector. + const hn::CappedTag d_table; + // We ResizeCast tables to DU: if DU is bigger, table lookups will only + // access lanes < kClusters. If DU is smaller (128-bit), we have 2 tables. + HWY_DASSERT(hn::Lanes(du) >= hn::Lanes(d_table) || NumTables(du) == 2); + + HWY_ALIGN hwy::bfloat16_t table[kClusters]; + SfpCodec::Dec(d_table, reinterpret_cast(centers), + kClusters, table); + + // If we assume >= 128-bit vectors, we can use [Two]TableLookupLanes + // instead of TableLookupBytes, which requires extra interleaving of lo/hi. + HWY_DASSERT(hn::Lanes(du) >= 8); + + if (NumTables(du) == 2) { + // Reduce cap for second half to avoid loading past the end of the table. + const hn::CappedTag d_table2; + *tbl1 = hn::ResizeBitCast(du, hn::LoadU(d_table2, table + kClusters / 2)); + } + return hn::ResizeBitCast(du, hn::Load(d_table, table)); + } + + // Unpacks per-weight indices and sets c0/c1 to the corresponding centers. + template + static HWY_INLINE void TableLookups(DU du, hn::Vec tbl0, hn::Vec tbl1, + const uint8_t* packed, hn::Vec& c0, + hn::Vec& c1) { + using V16 = hn::Vec; + const size_t N16 = hn::Lanes(du); + + const V16 idx0 = NibbleCodec::OrderedUnpackU16(du, packed); + const V16 idx1 = NibbleCodec::OrderedUnpackU16(du, packed + N16 / 2); + + const auto indices0 = hn::IndicesFromVec(du, idx0); + const auto indices1 = hn::IndicesFromVec(du, idx1); + + if (NumTables(du) == 1) { + (void)tbl1; + c0 = hn::TableLookupLanes(tbl0, indices0); + c1 = hn::TableLookupLanes(tbl0, indices1); + } else { + c0 = hn::TwoTablesLookupLanes(du, tbl0, tbl1, indices0); + c1 = hn::TwoTablesLookupLanes(du, tbl0, tbl1, indices1); + } + } + + public: + // Encodes `num` floats starting from `in`. `out` points to compressed + // storage for `out_capacity` values and `out_ofs` indicates the destination + // offset within it, in units of float values, for parallel encoding by + // multiple threads. `num`, `out_capacity`, and `out_ofs` must all be + // multiples of `kGroupSize`. Returns the total number of unused clusters, + // which is expected to be zero. + template + static HWY_INLINE size_t Enc(DF df, const float* const in, const size_t num, + ClusterBuf& buf, const size_t out_capacity, + NuqStream* const out, const size_t out_ofs) { + const hn::Repartition d8; + const hn::Repartition d16; + using V8 = hn::Vec; + using V16 = hn::Vec; + + const size_t N16 = hn::Lanes(d16); + HWY_ASSERT(kGroupSize >= 4 * N16); + + HWY_ASSERT(out_ofs + num <= out_capacity); + buf.Resize(num); + HWY_ASSERT(num % kGroupSize == 0); + HWY_ASSERT(out_capacity % kGroupSize == 0); + HWY_ASSERT(out_ofs % kGroupSize == 0); + const size_t num_groups = num / kGroupSize; + const size_t ofs_groups = out_ofs / kGroupSize; + + size_t unused_clusters = 0; + for (size_t g = 0; g < num_groups; ++g) { + const float* HWY_RESTRICT g_in = in + g * kGroupSize; + float* HWY_RESTRICT g_centers = buf.centers.get() + g * kClusters; + uint16_t* HWY_RESTRICT g_idx = buf.idx.get() + g * kGroupSize; + unused_clusters += + NuqClustering::ClusterExactL2(df, g_in, buf, g_centers, g_idx); + } + + uint8_t* centers = &out->byte + ofs_groups * kClusters; + SfpCodec::Enc(df, buf.centers.get(), num_groups * kClusters, + reinterpret_cast(centers)); + uint8_t* packed_start = &out->byte + NuqStream::PackedStart(out_capacity) + + ofs_groups * kGroupSize / 2; + + HWY_UNROLL(1) + for (size_t g = 0; g < num_groups; ++g) { + const uint16_t* HWY_RESTRICT g_idx = buf.idx.get() + g * kGroupSize; + uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2; + + HWY_UNROLL(1) + for (size_t i = 0; i < kGroupSize; i += 4 * N16) { + const V16 idx0 = hn::LoadU(d16, g_idx + i + N16 * 0); + const V16 idx1 = hn::LoadU(d16, g_idx + i + N16 * 1); + const V16 idx2 = hn::LoadU(d16, g_idx + i + N16 * 2); + const V16 idx3 = hn::LoadU(d16, g_idx + i + N16 * 3); + NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3, + g_packed + i / 2); + } + } + + return unused_clusters; + } + + // Decodes `num` values from the stream `in`, starting at the offset `in_ofs` + // (in units of values), to bf16 in `out`. `in_capacity`, `in_ofs` and `num` + // must all be multiples of `kGroupSize`. + template + static HWY_INLINE void Dec(DF dbf, const size_t in_capacity, + const NuqStream* const in, const size_t in_ofs, + hwy::bfloat16_t* const out, const size_t num) { + const hn::RebindToUnsigned d16; + using V16 = hn::Vec; + + const size_t N16 = hn::Lanes(d16); + HWY_DASSERT(kGroupSize >= 4 * N16); + + HWY_DASSERT(in_ofs + num <= in_capacity); + HWY_DASSERT(in_capacity % kGroupSize == 0); + HWY_DASSERT(in_ofs % kGroupSize == 0); + HWY_DASSERT(num % kGroupSize == 0); + const size_t num_groups = num / kGroupSize; + const size_t ofs_groups = in_ofs / kGroupSize; + const uint8_t* tables = &in->byte + ofs_groups * kClusters; + const uint8_t* packed_start = &in->byte + + NuqStream::PackedStart(in_capacity) + + ofs_groups * kGroupSize / 2; + + HWY_UNROLL(1) + for (size_t g = 0; g < num_groups; ++g) { + const uint8_t* g_centers = tables + g * kClusters; + const uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2; + hwy::bfloat16_t* HWY_RESTRICT g_out = out + g * kGroupSize; + + V16 tbl1 = Zero(d16); + const V16 tbl0 = LoadTable(d16, g_centers, &tbl1); + + HWY_UNROLL(1) + for (size_t i = 0; i < kGroupSize; i += 2 * N16) { + V16 c0, c1; + TableLookups(d16, tbl0, tbl1, g_packed + i / 2, c0, c1); + hn::StoreU(BitCast(dbf, c0), dbf, g_out + i + N16 * 0); + hn::StoreU(BitCast(dbf, c1), dbf, g_out + i + N16 * 1); + } + } + } + + // Decodes `num` values from the stream `in`, starting at the offset + // `in_ofs` (in units of values), to f32 in `out`. `in_capacity`, + // `in_ofs` and `num` must all be multiples of `kGroupSize`. + template + static HWY_INLINE void Dec(DF df, const size_t in_capacity, + const NuqStream* const in, const size_t in_ofs, + float* const out, const size_t num) { + const hn::Repartition dbf; + const hn::RebindToUnsigned d16; + using V16 = hn::Vec; + using VF = hn::Vec; + + const size_t NF = hn::Lanes(df); + HWY_DASSERT(kGroupSize >= 4 * NF); + + HWY_DASSERT(in_ofs + num <= in_capacity); + HWY_DASSERT(in_capacity % kGroupSize == 0); + HWY_DASSERT(in_ofs % kGroupSize == 0); + HWY_DASSERT(num % kGroupSize == 0); + const size_t ofs_groups = in_ofs / kGroupSize; + const size_t num_groups = num / kGroupSize; + const uint8_t* tables = &in->byte + ofs_groups * kClusters; + const uint8_t* packed_start = &in->byte + + NuqStream::PackedStart(in_capacity) + + ofs_groups * kGroupSize / 2; + + HWY_UNROLL(1) + for (size_t g = 0; g < num_groups; ++g) { + const uint8_t* g_centers = tables + g * kClusters; + const uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2; + float* HWY_RESTRICT g_out = out + g * kGroupSize; + + V16 tbl1 = Zero(d16); + const V16 tbl0 = LoadTable(d16, g_centers, &tbl1); + + HWY_UNROLL(1) + for (size_t i = 0; i < kGroupSize; i += 4 * NF) { + V16 c0, c1; + TableLookups(d16, tbl0, tbl1, g_packed + i / 2, c0, c1); + const VF f0 = hn::PromoteLowerTo(df, BitCast(dbf, c0)); + const VF f1 = hn::PromoteUpperTo(df, BitCast(dbf, c0)); + const VF f2 = hn::PromoteLowerTo(df, BitCast(dbf, c1)); + const VF f3 = hn::PromoteUpperTo(df, BitCast(dbf, c1)); + hn::StoreU(f0, df, g_out + i + NF * 0); + hn::StoreU(f1, df, g_out + i + NF * 1); + hn::StoreU(f2, df, g_out + i + NF * 2); + hn::StoreU(f3, df, g_out + i + NF * 3); + } + } + } + + // Accumulates into `sum0..3` dot products of decoded values with `num` bf16 + // from `vec_aligned`. DF is f32 because sum0..3 are also f32. `in_capacity`, + // `in_ofs` and `num` must all be multiples of `kGroupSize`. + template + static HWY_INLINE void Dot(DF df, const size_t in_capacity, + const NuqStream* const in, const size_t in_ofs, + const hwy::bfloat16_t* const vec_aligned, + const size_t num, hn::Vec& sum0, + hn::Vec& sum1, hn::Vec& sum2, + hn::Vec& sum3) { + const hn::Repartition dbf; + const hn::RebindToUnsigned d16; + using VBF = hn::Vec; + using V16 = hn::Vec; + const size_t N16 = hn::Lanes(d16); + HWY_DASSERT(kGroupSize >= 4 * N16); + + HWY_DASSERT(in_ofs + num <= in_capacity); + HWY_DASSERT(in_capacity % kGroupSize == 0); + HWY_DASSERT(in_ofs % kGroupSize == 0); + HWY_DASSERT(num % kGroupSize == 0); + const size_t ofs_groups = in_ofs / kGroupSize; + const size_t num_groups = num / kGroupSize; + const uint8_t* tables = &in->byte + ofs_groups * kClusters; + const uint8_t* packed_start = &in->byte + + NuqStream::PackedStart(in_capacity) + + ofs_groups * kGroupSize / 2; + + HWY_UNROLL(1) + for (size_t g = 0; g < num_groups; ++g) { + const uint8_t* g_centers = tables + g * kClusters; + const uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2; + const hwy::bfloat16_t* HWY_RESTRICT g_in = vec_aligned + g * kGroupSize; + + V16 tbl1 = Zero(d16); + const V16 tbl0 = LoadTable(d16, g_centers, &tbl1); + + HWY_UNROLL(1) + for (size_t i = 0; i < kGroupSize; i += 2 * N16) { + V16 c0, c1; + TableLookups(d16, tbl0, tbl1, g_packed + i / 2, c0, c1); + const VBF in0 = hn::Load(dbf, g_in + i + N16 * 0); + const VBF in1 = hn::Load(dbf, g_in + i + N16 * 1); + sum0 = hn::ReorderWidenMulAccumulate(df, in0, BitCast(dbf, c0), sum0, + sum1); + sum2 = hn::ReorderWidenMulAccumulate(df, in1, BitCast(dbf, c1), sum2, + sum3); + } + } + } + + // Accumulates into `sum0..3` dot products of decoded values with `num` f32 + // from `vec_aligned`. `in_capacity`, `in_ofs` and `num` must all be + // multiples of `kGroupSize`. + template + static HWY_INLINE void Dot(DF df, const size_t in_capacity, + const NuqStream* const in, const size_t in_ofs, + const float* const vec_aligned, const size_t num, + hn::Vec& sum0, hn::Vec& sum1, + hn::Vec& sum2, hn::Vec& sum3) { + const hn::Repartition dbf; + const hn::RebindToUnsigned d16; + using VF = hn::Vec; + using V16 = hn::Vec; + const size_t NF = hn::Lanes(df); + HWY_DASSERT(kGroupSize >= 4 * NF); + + HWY_DASSERT(in_ofs + num <= in_capacity); + HWY_DASSERT(in_capacity % kGroupSize == 0); + HWY_DASSERT(in_ofs % kGroupSize == 0); + HWY_DASSERT(num % kGroupSize == 0); + const size_t ofs_groups = in_ofs / kGroupSize; + const size_t num_groups = num / kGroupSize; + const uint8_t* tables = &in->byte + ofs_groups * kClusters; + const uint8_t* packed_start = &in->byte + + NuqStream::PackedStart(in_capacity) + + ofs_groups * kGroupSize / 2; + + HWY_UNROLL(1) + for (size_t g = 0; g < num_groups; ++g) { + const uint8_t* g_centers = tables + g * kClusters; + const uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2; + const float* HWY_RESTRICT g_in = vec_aligned + g * kGroupSize; + + V16 tbl1 = Zero(d16); + const V16 tbl0 = LoadTable(d16, g_centers, &tbl1); + + HWY_UNROLL(1) + for (size_t i = 0; i < kGroupSize; i += 4 * NF) { + V16 c0, c1; + TableLookups(d16, tbl0, tbl1, g_packed + i / 2, c0, c1); + const VF in0 = hn::LoadU(df, g_in + i + NF * 0); + const VF in1 = hn::LoadU(df, g_in + i + NF * 1); + const VF in2 = hn::LoadU(df, g_in + i + NF * 2); + const VF in3 = hn::LoadU(df, g_in + i + NF * 3); + const VF f0 = hn::PromoteLowerTo(df, BitCast(dbf, c0)); + const VF f1 = hn::PromoteUpperTo(df, BitCast(dbf, c0)); + const VF f2 = hn::PromoteLowerTo(df, BitCast(dbf, c1)); + const VF f3 = hn::PromoteUpperTo(df, BitCast(dbf, c1)); + sum0 = hn::MulAdd(in0, f0, sum0); + sum1 = hn::MulAdd(in1, f1, sum1); + sum2 = hn::MulAdd(in2, f2, sum2); + sum3 = hn::MulAdd(in3, f3, sum3); + } + } + } +}; // NuqCodec + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_ diff --git a/compression/nuq.h b/compression/nuq.h new file mode 100644 index 0000000..08d162f --- /dev/null +++ b/compression/nuq.h @@ -0,0 +1,116 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_H_ +#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_H_ + +// Non-uniform quantization: a compressed representation of f32 inputs that +// supports seeking at a granularity of kGroupSize, decoding to bf16/f32, and a +// fused decode/dot product with bf16/f32 vectors. + +#include +#include + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" // HWY_INLINE + +namespace gcpp { + +// 4-bit indices are a sweet spot in terms of quality per size. +static constexpr size_t kClusters = 16; + +// Number of weights that share a table. Larger = slower encode, higher error, +// smaller size (table amortized over more weights). This is the minimum +// granularity for seeking/decoding in the stream, and must be at least four +// times the number of bf16 elements per vector. +static constexpr size_t kGroupSize = 256; + +// Points to the *start* of a NUQ stream. Aligning the allocation (see +// aligned_allocator.h) may be speed up decoding but is not required. +// +// See go/streaming-weight-decode for background and design. Layout: first one +// table of kClusters entries per group, in ascending order of group index, +// then two packed indices per byte. +// +// Indices are stored in-order to enable vector-length agnostic decode, because +// streams may be persisted to disk and used by other CPUs. +// +// To enable parallel encoding and decoding, Enc/Dec have `offset` parameters +// which refer to the stream, NOT the raw from/to pointers, which point directly +// to the source/destination. Offsets are in units of values, NOT compressed +// bytes within the stream. +#pragma pack(push, 1) +struct NuqStream { + // Returns offset of packed indices from the start of the stream. This matches + // the (padded) total table size because table entries are bytes. `capacity` + // is already a multiple of `kGroupSize`. + static constexpr size_t PackedStart(size_t capacity) { + // Round up to avoid cache-line splits when loading indices. No effect on + // size as long as capacity / kGroupSize is a multiple of 4. + return hwy::RoundUpTo((capacity / kGroupSize) * kClusters, 64); + } + + // Returns number of NuqStream to allocate for the stream, which matches its + // size in bytes. `capacity` is already a multiple of `kGroupSize`. + static constexpr size_t PackedEnd(size_t capacity) { + return PackedStart(capacity) + capacity / 2; // two 4-bit indices per byte. + } + + uint8_t byte; +}; +#pragma pack(pop) + +static inline const char* TypeName(NuqStream) { return "NUQ"; } + +// Storage for dynamic programming. There are two matrices; we use separate +// allocations to avoid type punning. +template +class AlignedMatrix { + public: + AlignedMatrix() : mem_(hwy::AllocateAligned(kClusters * kGroupSize)) {} + + HWY_INLINE const T& operator()(size_t row, size_t col) const { + return mem_[row * kGroupSize + col]; + } + + HWY_INLINE T& operator()(size_t row, size_t col) { + return mem_[row * kGroupSize + col]; + } + + private: + hwy::AlignedFreeUniquePtr mem_; +}; + +// Reuse memory across calls to Enc to avoid per-call allocations. +struct ClusterBuf { + void Resize(size_t new_num) { + if (new_num < num) return; + + num = new_num; + const size_t num_groups = hwy::DivCeil(num, kGroupSize); + centers = hwy::AllocateAligned(num_groups * kClusters); + idx = hwy::AllocateAligned(num); + } + + AlignedMatrix d; + AlignedMatrix t; + + size_t num = 0; + hwy::AlignedFreeUniquePtr centers; + hwy::AlignedFreeUniquePtr idx; +}; + +} // namespace gcpp +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_H_ diff --git a/compression/nuq_test.cc b/compression/nuq_test.cc new file mode 100644 index 0000000..75bdc18 --- /dev/null +++ b/compression/nuq_test.cc @@ -0,0 +1,428 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include // std::shuffle +#include + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "third_party/gemma_cpp/compression/nuq_test.cc" // NOLINT +#include "hwy/foreach_target.h" // IWYU pragma: keep +// Other headers that include Highway must come after foreach_target.h +// copybara:import_next_line:gemma_cpp +#include "compression/distortion.h" +// copybara:import_next_line:gemma_cpp +#include "compression/nuq-inl.h" +// copybara:import_next_line:gemma_cpp +#include "compression/nuq.h" +#include "hwy/highway.h" +#include "hwy/tests/hwy_gtest.h" +#include "hwy/tests/test_util-inl.h" +#include "hwy/timer.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +// All-equal inputs: only one cluster +struct TestFlat { + template + HWY_INLINE void operator()(T /*unused*/, DF df) { + // Run this simple test only once to save time/debug output. + if (!(HWY_ONCE && hn::Lanes(df) == hn::Lanes(hn::ScalableTag()))) { + return; + } + + auto in = hwy::AllocateAligned(kGroupSize); + HWY_ASSERT(in); + for (size_t i = 0; i < kGroupSize; ++i) { + in[i] = 0.5f; + } + ClusterBuf buf; + float centers[kClusters]; + uint16_t indices[kGroupSize]; + const size_t unused_clusters = + NuqClustering::ClusterExactL2(df, in.get(), buf, centers, indices); + HWY_ASSERT(unused_clusters == kClusters - 1); + + for (size_t i = 0; i < unused_clusters; ++i) { + HWY_ASSERT(centers[i] == 0.0f); + } + HWY_ASSERT(centers[unused_clusters] == 0.5f); + for (size_t i = 0; i < kGroupSize; ++i) { + HWY_ASSERT(indices[i] == unused_clusters); + } + } +}; + +void TestAllFlat() { hn::ForGEVectors<64, TestFlat>()(float()); } + +// Generate shuffled plateaus, one per cluster +struct TestPlateaus { + template + HWY_INLINE void operator()(T /*unused*/, DF df) { + // Run this simple test only once to save time/debug output. + if (!(HWY_ONCE && hn::Lanes(df) == hn::Lanes(hn::ScalableTag()))) { + return; + } + + auto in = hwy::AllocateAligned(kGroupSize); + HWY_ASSERT(in); + + for (size_t i = 0; i < kGroupSize; ++i) { + const size_t idx_cluster = i / (kGroupSize / kClusters); + HWY_ASSERT(idx_cluster < kClusters); + in[i] = (1.0f * idx_cluster / kClusters) - 0.5f; + HWY_ASSERT(-0.5f <= in[i] && in[i] < 0.5f); + } + + std::random_device rd; + std::mt19937 rng(rd()); + std::shuffle(in.get(), in.get() + kGroupSize, rng); + + ClusterBuf buf; + float centers[kClusters]; + uint16_t indices[kGroupSize]; + const size_t unused_clusters = + NuqClustering::ClusterExactL2(df, in.get(), buf, centers, indices); + HWY_ASSERT(unused_clusters == 0); + + DistortionStats stats; + for (size_t i = 0; i < kGroupSize; ++i) { + HWY_ASSERT(indices[i] < kClusters); + stats.Notify(in[i], centers[indices[i]]); + } + const float pnorm = stats.PNorm(); + const float snr = stats.GeomeanValueDivL1(); + fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr, + stats.MaxIndex(), stats.MaxL1()); + HWY_ASSERT(pnorm == 0.0f); + HWY_ASSERT(snr == 0.0f); + } +}; + +void TestAllPlateaus() { hn::ForGEVectors<64, TestPlateaus>()(float()); } + +struct TestRamp { + template + HWY_INLINE void operator()(T /*unused*/, DF df) { + // Run this simple test only once to save time/debug output. + if (!(HWY_ONCE && hn::Lanes(df) == hn::Lanes(hn::ScalableTag()))) { + return; + } + + auto in = hwy::AllocateAligned(kGroupSize); + HWY_ASSERT(in); + + for (size_t i = 0; i < kGroupSize; ++i) { + in[i] = (1.0f * i / kGroupSize) - 0.45f; // slightly asymmetric + HWY_ASSERT(-0.45f <= in[i] && in[i] < 0.55f); + } + + std::random_device rd; + std::mt19937 rng(rd()); + std::shuffle(in.get(), in.get() + kGroupSize, rng); + + ClusterBuf buf; + float centers[kClusters]; + uint16_t indices[kGroupSize]; + const size_t unused_clusters = + NuqClustering::ClusterExactL2(df, in.get(), buf, centers, indices); + HWY_ASSERT(unused_clusters == 0); + + DistortionStats stats; + for (size_t i = 0; i < kGroupSize; ++i) { + HWY_ASSERT(indices[i] < kClusters); + stats.Notify(in[i], centers[indices[i]]); + } + const float pnorm = stats.PNorm(); + const float snr = stats.GeomeanValueDivL1(); + fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr, + stats.MaxIndex(), stats.MaxL1()); + static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected"); + + const float expected_pnorm = kGroupSize == 128 ? 2.08E-2f : 2.1E-2f; + const float expected_snr = kGroupSize == 128 ? 16.9f : 17.6f; + HWY_ASSERT(expected_pnorm <= pnorm && pnorm < 1.02f * expected_pnorm); + HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr); + } +}; + +void TestAllRamp() { hn::ForGEVectors<64, TestRamp>()(float()); } + +struct TestNormal { + template + HWY_INLINE void operator()(T /*unused*/, DF df) { + auto in = hwy::AllocateAligned(kGroupSize); + HWY_ASSERT(in); + + std::mt19937 rng(123); + std::normal_distribution dist{0.001f, 0.3f}; + for (size_t i = 0; i < kGroupSize; ++i) { + in[i] = dist(rng); + } + std::shuffle(in.get(), in.get() + kGroupSize, rng); + + ClusterBuf buf; + float centers[kClusters]; + uint16_t indices[kGroupSize]; + double elapsed = hwy::HighestValue(); + for (size_t rep = 0; rep < 100; ++rep) { + const double t0 = hwy::platform::Now(); + const size_t unused_clusters = + NuqClustering::ClusterExactL2(df, in.get(), buf, centers, indices); + HWY_ASSERT(unused_clusters == 0); + const double t1 = hwy::platform::Now(); + elapsed = HWY_MIN(elapsed, t1 - t0); + } + fprintf(stderr, "Vec %zu Enc %.2f MB/s\n", Lanes(df) * 4, + kGroupSize * sizeof(float) * 1E-6 / elapsed); + + DistortionStats stats; + for (size_t i = 0; i < kGroupSize; ++i) { + HWY_ASSERT(indices[i] < kClusters); + stats.Notify(in[i], centers[indices[i]]); + } + const float pnorm = stats.PNorm(); + const float snr = stats.GeomeanValueDivL1(); + fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr, + stats.MaxIndex(), stats.MaxL1()); + static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected"); + const float expected_pnorm = kGroupSize == 128 ? 3E-2f : 3.4E-2f; + const float expected_snr = kGroupSize == 128 ? 17.4f : 13.1f; + HWY_ASSERT(expected_pnorm <= pnorm && pnorm < 1.02f * expected_pnorm); + HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr); + } +}; + +void TestAllNormal() { hn::ForGEVectors<64, TestNormal>()(float()); } + +// Can encode and decode sub-regions. +struct TestOffset { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::Repartition df; + const size_t total = 10 * kGroupSize; + const size_t kMidLen = 2 * kGroupSize; // length of middle piece + + auto in = hwy::AllocateAligned(total); // Enc() requires f32 + auto dec1 = hwy::AllocateAligned(total); + auto dec2 = hwy::AllocateAligned(kMidLen); + auto nuq = hwy::AllocateAligned(NuqStream::PackedEnd(total)); + HWY_ASSERT(in && dec1 && dec2 && nuq); + + std::mt19937 rng(123); + std::normal_distribution dist{0.001f, 0.3f}; + for (size_t i = 0; i < total; ++i) { + in[i] = dist(rng); + } + + // Encode + decode everything + ClusterBuf buf; + (void)NuqCodec::Enc(df, in.get(), total, buf, total, nuq.get(), 0); + NuqCodec::Dec(d, total, nuq.get(), 0, dec1.get(), total); + + // Overwrite middle with first inputs + const size_t offset = 5 * kGroupSize; + (void)NuqCodec::Enc(df, in.get(), kMidLen, buf, total, nuq.get(), offset); + + // Decoded middle now matches previously decoded first + NuqCodec::Dec(d, total, nuq.get(), offset, dec2.get(), kMidLen); + for (size_t i = 0; i < kMidLen; ++i) { + HWY_ASSERT(dec1[i] == dec2[i]); + } + } +}; + +void TestAllOffsetF32() { + const hn::ForGEVectors<128, TestOffset> test; + test(float()); +} + +void TestAllOffsetBF16() { + const hn::ForGEVectors<128, TestOffset> test; + test(hwy::bfloat16_t()); +} + +struct TestStream { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::Repartition df; + const size_t num = 4 * kGroupSize; + auto in = hwy::AllocateAligned(num); // Enc() requires f32 + auto out = hwy::AllocateAligned(num); + auto nuq = hwy::AllocateAligned(NuqStream::PackedEnd(num)); + HWY_ASSERT(in && out && nuq); + + std::mt19937 rng(123); + std::normal_distribution dist{0.001f, 0.3f}; + for (size_t i = 0; i < num; ++i) { + in[i] = dist(rng); + } + + ClusterBuf buf; + double elapsed = hwy::HighestValue(); + for (size_t rep = 0; rep < 100; ++rep) { + const double t0 = hwy::platform::Now(); + const size_t unused_clusters = + NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0); + HWY_ASSERT(unused_clusters == 0); + const double t1 = hwy::platform::Now(); + elapsed = HWY_MIN(elapsed, t1 - t0); + } + fprintf(stderr, "Vec %zu Enc %.2f MB/s\n", Lanes(d) * sizeof(T), + num * sizeof(float) * 1E-6 / elapsed); + + elapsed = hwy::HighestValue(); + for (size_t rep = 0; rep < 100; ++rep) { + const double t0 = hwy::platform::Now(); + NuqCodec::Dec(d, num, nuq.get(), 0, out.get(), num); + const double t1 = hwy::platform::Now(); + elapsed = HWY_MIN(elapsed, t1 - t0); + } + fprintf(stderr, "Vec %zu Dec %.2f MB/s\n", Lanes(d) * sizeof(T), + num * sizeof(T) * 1E-6 / elapsed); + + DistortionStats stats; + for (size_t i = 0; i < num; ++i) { + stats.Notify(in[i], hwy::ConvertScalarTo(out[i])); + } + const float pnorm = stats.PNorm(); + const float snr = stats.GeomeanValueDivL1(); + fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr, + stats.MaxIndex(), stats.MaxL1()); + static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected"); + const float expected_pnorm = kGroupSize == 128 ? 3.44E-2f : 3.88E-2f; + const float expected_snr = kGroupSize == 128 ? 15.0f : 13.3f; + HWY_ASSERT(expected_pnorm <= pnorm && pnorm < 1.02f * expected_pnorm); + HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr); + } +}; + +void TestAllStreamF32() { + const hn::ForGEVectors<128, TestStream> test; + test(float()); +} + +void TestAllStreamBF16() { + const hn::ForGEVectors<128, TestStream> test; + test(hwy::bfloat16_t()); +} + +struct TestDot { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::Repartition df; + const size_t num = 4 * kGroupSize; + auto in = hwy::AllocateAligned(num); + auto dec = hwy::AllocateAligned(num); + auto vec = hwy::AllocateAligned(num); + auto nuq = hwy::AllocateAligned(NuqStream::PackedEnd(num)); + HWY_ASSERT(in && dec && vec && nuq); + + std::mt19937 rng(123); + std::normal_distribution dist{0.001f, 0.3f}; + for (size_t i = 0; i < num; ++i) { + in[i] = dist(rng); + vec[i] = hwy::ConvertScalarTo(dist(rng)); + } + // This changes the correlation between in and vec, which considerably + // affects the error of the result. + std::shuffle(in.get(), in.get() + num, rng); + + ClusterBuf buf; + const size_t unused_clusters = + NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0); + HWY_ASSERT(unused_clusters == 0); + + double actual = 0.0; + double elapsed = hwy::HighestValue(); + for (size_t rep = 0; rep < 20; ++rep) { + hn::Vec sum0 = hn::Zero(df); + hn::Vec sum1 = hn::Zero(df); + hn::Vec sum2 = hn::Zero(df); + hn::Vec sum3 = hn::Zero(df); + const double t0 = hwy::platform::Now(); + NuqCodec::Dot(df, num, nuq.get(), 0, vec.get(), num, sum0, sum1, sum2, + sum3); + const double t1 = hwy::platform::Now(); + elapsed = HWY_MIN(elapsed, t1 - t0); + sum0 = hn::Add(hn::Add(sum0, sum1), hn::Add(sum2, sum3)); + actual = hn::ReduceSum(df, sum0); + } + + NuqCodec::Dec(df, num, nuq.get(), 0, dec.get(), num); + fprintf(stderr, "Vec %zu Dec %.2f MB/s\n", Lanes(d) * sizeof(T), + num * sizeof(in[0]) * 1E-6 / elapsed); + + double expected = 0.0; // using original input + double expected2 = 0.0; // using decoded NUQ + for (size_t i = 0; i < num; ++i) { + expected += in[i] * hwy::ConvertScalarTo(vec[i]); + expected2 += dec[i] * hwy::ConvertScalarTo(vec[i]); + } + const double l1 = hwy::ScalarAbs(expected - actual); + const double snr = 1.0 + hwy::ScalarAbs(expected) / l1; + fprintf(stderr, "expected %.3f e2 %.4f actual %.4f l1 %E snr %.2f\n", + expected, expected2, actual, l1, snr); + HWY_ASSERT(hwy::ScalarAbs(expected2 - actual) < 1E-4); + static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected"); + const double expected_l1 = kGroupSize == 128 ? 7.3E-2 : 4.34E-2; + const double expected_snr = kGroupSize == 128 ? 9.7f + : sizeof(T) == 2 ? 14.5f + : 14.9f; + HWY_ASSERT(expected_l1 <= l1 && l1 < 1.02f * expected_l1); + HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr); + } +}; + +void TestAllDotF32() { + const hn::ForGEVectors<128, TestDot> test; + test(float()); +} +void TestAllDotBF16() { + const hn::ForGEVectors<128, TestDot> test; + test(hwy::bfloat16_t()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace gcpp { +HWY_BEFORE_TEST(NuqTest); +HWY_EXPORT_AND_TEST_P(NuqTest, TestAllFlat); +HWY_EXPORT_AND_TEST_P(NuqTest, TestAllPlateaus); +HWY_EXPORT_AND_TEST_P(NuqTest, TestAllRamp); +HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNormal); +HWY_EXPORT_AND_TEST_P(NuqTest, TestAllOffsetF32); +HWY_EXPORT_AND_TEST_P(NuqTest, TestAllOffsetBF16); +HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamF32); +HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamBF16); +HWY_EXPORT_AND_TEST_P(NuqTest, TestAllDotF32); +HWY_EXPORT_AND_TEST_P(NuqTest, TestAllDotBF16); +} // namespace gcpp + +#endif diff --git a/compression/sfp-inl.h b/compression/sfp-inl.h new file mode 100644 index 0000000..62b8955 --- /dev/null +++ b/compression/sfp-inl.h @@ -0,0 +1,515 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Normal include guard to placate lint. +#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_ +#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_ + +#include +#include + +// copybara:import_next_line:gemma_cpp +#include "compression/sfp.h" +#include "hwy/base.h" + +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_ + +// Actual per-target include guard. +#if defined(THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE) == defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE +#undef THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE +#else +#define THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +// For unsigned numbers with MSB zero, signed comparison is faster on x86. +template +HWY_INLINE hn::Mask SignedGt(DU du, hn::Vec a, hn::Vec b) { + const hn::RebindToSigned di; + return hn::RebindMask(du, hn::Gt(BitCast(di, a), hn::BitCast(di, b))); +} +template +HWY_INLINE hn::Mask SignedLt(DU du, hn::Vec a, hn::Vec b) { + return SignedGt(du, b, a); +} + +// Encode/decode functions. +class SfpCodec { + public: + // Returns 8-bit packed representation of `lo` and `hi` bytes of bf16. 31 ops. + // Implementation detail, public because called by test. + template + static HWY_INLINE hn::Vec EncBytes(D d, const hn::Vec lo, + const hn::Vec hi) { + const hn::Vec k1 = hn::Set(d, 1u); + const hn::Vec k80 = hn::Set(d, 0x80u); + + // Copy sign for later insertion. + const hn::Vec sign_in_msb = hi; + // Biased exponent = lower 7 bits of hi and MSB of lo. Modified below. + hn::Vec biased_e = hn::Or(hn::Add(hi, hi), hn::ShiftRight<7>(lo)); + HWY_ASSERT(hn::AllTrue(d, hn::Lt(biased_e, k80))); // <= 2^0 + + // Clear MSB to isolate the mantissa and enable signed comparisons, then + // shift right by *one* (plus 1 to undo the prior add/left-shift) to leave + // headroom for overflow during rounding. + const hn::Vec m6 = hn::ShiftRight<2>(hn::Add(lo, lo)); + + // The place to round depends on whether the exponent is large (>= -7) - if + // so, we retain three mantissa bits, otherwise two. However, rounding can + // also cause the exponent to increase. We first choose a threshold that + // rounds up to 1.0*2^-7 for both two and three bit mantissas: + // >= 1.1111 * 2^-8 (0.007568359375). This entails the exponent being + // greater, or equal and the mantissa > (1111000 >> 1) - 1 = 0x3B. + const hn::Vec kMinLargeE = hn::Set(d, 127 - 8); + const hn::Mask is_large_before_round = hn::Or( + SignedGt(d, biased_e, kMinLargeE), + hn::And(hn::Eq(biased_e, kMinLargeE), SignedGt(d, m6, Set(d, 0x3B)))); + + // To retain the most-significant 3 or 2 mantissa bits, we will right-shift + // by is_large_before_round ? 3 : 4. Variable Shr is expensive for 8-bit + // elements, so (<< 1) if is_large_before_round, then always (>> 4). + const hn::Vec m_shl4 = + hn::MaskedAddOr(m6, is_large_before_round, m6, m6); + + // Before shifting (truncation), round to nearest even to reduce bias. If + // the lowest remaining mantissa bit is odd, increase the offset. Example + // with the lowest remaining bit (left) and next lower two bits; the + // latter, plus two more, will be truncated. + // 0[00] + 1 = 0[01] + // 0[01] + 1 = 0[10] + // 0[10] + 1 = 0[11] (round down toward even) + // 0[11] + 1 = 1[00] (round up) + // 1[00] + 10 = 1[10] + // 1[01] + 10 = 1[11] + // 1[10] + 10 = C0[00] (round up toward even with C=1 carry out) + // 1[11] + 10 = C0[01] (round up toward even with C=1 carry out) + const hn::Vec odd_bit = hn::And(hn::ShiftRight<4>(m_shl4), k1); + const hn::Vec rounded = hn::Add(m_shl4, hn::Add(odd_bit, Set(d, 7))); + // Update the exponent if rounding overflowed. + const hn::Vec carry_bit = + hn::IfThenElse(is_large_before_round, k80, hn::Set(d, 0x40u)); + const hn::Vec carry_clear = hn::AndNot(carry_bit, rounded); + HWY_DASSERT(hn::AllTrue(d, hn::Lt(carry_clear, carry_bit))); + const hn::Mask is_overflow = hn::Ne(carry_clear, rounded); + biased_e = hn::MaskedAddOr(biased_e, is_overflow, biased_e, k1); + HWY_DASSERT(hn::AllTrue(d, hn::Lt(biased_e, Set(d, 128)))); + + // Detect if zero or the min exponent. + const hn::Vec kMinNormal = hn::Set(d, 127 - 23); + const hn::Mask is_zero = SignedLt(d, biased_e, kMinNormal); + const hn::Mask is_min = hn::Eq(biased_e, kMinNormal); + + // 1.1110xxx * 2^-8 was considered small above, and thus rounded up to 2^-7, + // which the decoder will consider large, and expect 3 mantissa bits. If we + // set the threshold above to 1.111, then it does NOT round up. Thus we + // check exponent >= -7 *after* rounding. + const hn::Mask is_large = SignedGt(d, biased_e, hn::Set(d, 127 - 8)); + + // To extract and pack the mantissa, only is_large matters. Either it + // matches is_large_before_round, or the rounding resulted in mantissa=0, so + // we either extract two or three bits by shifting out the lower 5..6 bits. + // is_large_before is_large rounded want + // 0 0 0Cmm???? mm + // 0 1 0100???? 000 + // 1 0 impossible - + // 1 1 Cmmm???0 mmm + hn::Vec m = hn::ShiftRight<4>(carry_clear); + HWY_DASSERT(hn::AllTrue( + d, SignedLt(d, m, + hn::IfThenElse(is_large, hn::Set(d, 8), hn::Set(d, 4))))); + + // 1.0 * 2^-23 has the same encoding as zero, so round it up to 1.01. + m = hn::MaskedMaxOr(m, is_min, m, k1); + + const hn::Vec e_bias = hn::IfThenElse( + is_large, + hn::Set(d, hwy::BitCastScalar(static_cast(15 - 127))), + hn::Set(d, hwy::BitCastScalar(static_cast(23 - 127)))); + const hn::Vec e = hn::Add(biased_e, e_bias); + HWY_DASSERT( + hn::AllTrue(d, hn::Lt(hn::IfThenZeroElse(is_zero, e), hn::Set(d, 16)))); + + // Shift exponent left 2 or 3 bits to make space for `m`. + const hn::Vec em = + hn::Or(m, hn::ShiftLeft<2>(hn::MaskedAddOr(e, is_large, e, e))); + HWY_DASSERT(hn::AllTrue(d, hn::Lt(hn::IfThenZeroElse(is_zero, em), k80))); + const hn::Vec encoded = hn::BitwiseIfThenElse(k80, sign_in_msb, em); + // Doing this last ensures -0 is replaced with 0. + return hn::IfThenZeroElse(is_zero, encoded); + } + + // Decodes u8 `encoded` into `lo` and `hi` bytes of bf16. 12 ops. + // Implementation detail, public because called by test. + template + static HWY_INLINE void DecBytes(D d, hn::Vec encoded, hn::Vec& lo, + hn::Vec& hi) { + const hn::Vec k0 = hn::Zero(d); + const hn::Vec k80 = hn::Set(d, 0x80u); + + HWY_DASSERT(hn::AllTrue(d, hn::Ne(encoded, k80))); // -0 is reserved + // Copy sign for later insertion via BitwiseIfThenElse. + const hn::Vec sign_in_msb = encoded; + encoded = hn::AndNot(k80, encoded); + + // Special-case zero, negated so we can use MaskedAddOr. Signed comparison + // is fine because we have cleared the sign bit. + const hn::Mask is_nonzero = SignedGt(d, encoded, k0); + // If MSB is clear, we have two mantissa bits, otherwise three. + const hn::Mask is_small_e = SignedLt(d, encoded, hn::Set(d, 64)); + // If is_small_e, add/left-shift 0xxxx.mm to 0xxxx.mm0; else keep 1xxx.mmm. + const hn::Vec e4m3 = + hn::MaskedAddOr(encoded, is_small_e, encoded, encoded); + HWY_DASSERT(hn::AllTrue(d, hn::Lt(e4m3, k80))); + const hn::Vec e = hn::ShiftRight<3>(e4m3); // 4-bit exponent only + HWY_DASSERT(hn::AllTrue(d, hn::Lt(e, Set(d, 16u)))); + // The encoded exponent for 2^0 is 15, so subtract 15. Add 127 for the + // binary32/bf16 bias. Subtract another 8 if is_small_e because its lowest + // encoded value (0) should be less than the lowest 'large' exponent 2^-7. + const hn::Vec e_bias = hn::IfThenElse( + is_small_e, hn::Set(d, 127u - 15u - 8u), hn::Set(d, 127u - 15u)); + // Special-case zero or add e_bias. If encoded=0, e and e4m3 are zero, but + // we must zero e_bias to get the desired all-zero bf16. + const hn::Vec biased_e = hn::MaskedAddOr(k0, is_nonzero, e_bias, e); + // The decoded binary32 exponent should be at most 2^0. + HWY_DASSERT(hn::AllTrue(d, hn::Lt(biased_e, k80))); + + // Shift the MSB of e4m3's mantissa into the MSB of the bf16 mantissa. + const hn::Vec m7 = hn::ShiftLeft<4>(e4m3); + // Lower byte of bf16 = exponent LSB || mantissa. + lo = hn::BitwiseIfThenElse(k80, hn::ShiftLeft<7>(biased_e), m7); + // Upper byte of bf16 = sign || lower 7 bits of exponent. + hi = hn::BitwiseIfThenElse(k80, sign_in_msb, hn::ShiftRight<1>(biased_e)); + } + + // Encodes `num` bf16 values from `in_bf` to `out_packed`. + template + static HWY_INLINE void Enc(DBF dbf, const hwy::bfloat16_t* HWY_RESTRICT in_bf, + size_t num, SfpStream* HWY_RESTRICT out_packed) { + const hn::Repartition d8; + using V8 = hn::Vec; + const size_t N16 = hn::Lanes(dbf); + + size_t i = 0; + if (num >= 2 * N16) { + HWY_UNROLL(1) + for (; i <= num - 2 * N16; i += 2 * N16) { + const V8 packed = Enc2B(dbf, in_bf + i); + hn::StoreU(packed, d8, &out_packed->byte + i); + } + } + + const size_t remaining = num - i; + HWY_DASSERT(remaining < 2 * N16); + if (remaining != 0) { + HWY_ALIGN hwy::bfloat16_t padded[2 * hn::MaxLanes(dbf)]; + hwy::ZeroBytes(padded, sizeof(padded)); + hwy::CopyBytes(in_bf + i, padded, remaining * sizeof(padded[0])); + const V8 packed = Enc2B(dbf, padded); + hn::StoreN(packed, d8, &out_packed->byte + i, remaining); + } + } + + // Encodes `num` f32 values from `in_f` to `packed`. + template + static HWY_INLINE void Enc(DF df, const float* HWY_RESTRICT in_f, size_t num, + SfpStream* HWY_RESTRICT out_packed) { + const hn::Repartition d8; + using V8 = hn::Vec; + const size_t NF = hn::Lanes(df); + + size_t i = 0; + if (num >= 4 * NF) { + HWY_UNROLL(1) + for (; i <= num - 4 * NF; i += 4 * NF) { + const V8 packed = Enc4F(df, in_f + i); + hn::StoreU(packed, d8, &out_packed->byte + i); + } + } + + const size_t remaining = num - i; + HWY_DASSERT(remaining < 4 * NF); + if (remaining != 0) { + HWY_ALIGN float padded[4 * hn::MaxLanes(df)]; + hwy::ZeroBytes(padded, sizeof(padded)); + hwy::CopyBytes(in_f + i, padded, remaining * sizeof(padded[0])); + const V8 packed = Enc4F(df, padded); + hn::StoreN(packed, d8, &out_packed->byte + i, remaining); + } + } + + // Decodes `num` values from `in_packed` to `out_bf`. + template + static HWY_INLINE void Dec(DBF dbf, const SfpStream* HWY_RESTRICT in_packed, + size_t num, hwy::bfloat16_t* HWY_RESTRICT out_bf) { + const hn::Repartition d8; + using V8 = hn::Vec; + using VBF = hn::Vec; + const size_t N16 = hn::Lanes(dbf); + + size_t i = 0; + if (num >= 2 * N16) { + HWY_UNROLL(1) + for (; i <= num - 2 * N16; i += 2 * N16) { + const V8 packed = hn::LoadU(d8, &in_packed->byte + i); + VBF bf0, bf1; + Dec2B(dbf, packed, bf0, bf1); + hn::StoreU(bf0, dbf, out_bf + i); + hn::StoreU(bf1, dbf, out_bf + i + N16); + } + } + + const size_t remaining = num - i; + HWY_DASSERT(remaining < 2 * N16); + if (remaining != 0) { + const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining); + HWY_ALIGN hwy::bfloat16_t padded[2 * hn::MaxLanes(dbf)]; + VBF bf0, bf1; + Dec2B(dbf, packed, bf0, bf1); + hn::StoreU(bf0, dbf, padded); + hn::StoreU(bf1, dbf, padded + N16); + hwy::CopyBytes(padded, out_bf + i, remaining * sizeof(padded[0])); + } + } + + // Decodes `num` values from `in_packed` to `out_f`. + template + static HWY_INLINE void Dec(DF df, const SfpStream* HWY_RESTRICT in_packed, + size_t num, float* HWY_RESTRICT out_f) { + const hn::Repartition d8; + using V8 = hn::Vec; + using VF = hn::Vec; + const size_t NF = hn::Lanes(df); + + size_t i = 0; + if (num >= 4 * NF) { + HWY_UNROLL(1) + for (; i <= num - 4 * NF; i += 4 * NF) { + const V8 packed = hn::LoadU(d8, &in_packed->byte + i); + VF f0, f1, f2, f3; + Dec4F(df, packed, f0, f1, f2, f3); + hn::StoreU(f0, df, out_f + i + NF * 0); + hn::StoreU(f1, df, out_f + i + NF * 1); + hn::StoreU(f2, df, out_f + i + NF * 2); + hn::StoreU(f3, df, out_f + i + NF * 3); + } + } + + const size_t remaining = num - i; + HWY_DASSERT(remaining < 4 * NF); + if (remaining != 0) { + const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining); + HWY_ALIGN float padded[4 * hn::MaxLanes(df)]; + VF f0, f1, f2, f3; + Dec4F(df, packed, f0, f1, f2, f3); + hn::StoreU(f0, df, padded + NF * 0); + hn::StoreU(f1, df, padded + NF * 1); + hn::StoreU(f2, df, padded + NF * 2); + hn::StoreU(f3, df, padded + NF * 3); + hwy::CopyBytes(padded, out_f + i, remaining * sizeof(padded[0])); + } + } + + // Fused decode and dot product with bf16 into four output accumulators. + template + static HWY_INLINE void Dot(DF df, const SfpStream* HWY_RESTRICT in_packed, + size_t num, + const hwy::bfloat16_t* HWY_RESTRICT vec_aligned, + hn::Vec& sum0, hn::Vec& sum1, + hn::Vec& sum2, hn::Vec& sum3) { + const hn::Repartition d8; + const hn::Repartition dbf; + using V8 = hn::Vec; + using VBF = hn::Vec; + const size_t N16 = hn::Lanes(dbf); + + size_t i = 0; + if (num >= 2 * N16) { + HWY_UNROLL(1) + for (; i <= num - 2 * N16; i += 2 * N16) { + const V8 packed = hn::LoadU(d8, &in_packed->byte + i); + const VBF v0 = hn::LoadU(dbf, vec_aligned + i); + const VBF v1 = hn::LoadU(dbf, vec_aligned + i + N16); + VBF bf0, bf1; + Dec2B(dbf, packed, bf0, bf1); + sum0 = hn::ReorderWidenMulAccumulate(df, bf0, v0, sum0, sum1); + sum2 = hn::ReorderWidenMulAccumulate(df, bf1, v1, sum2, sum3); + } + } + + const size_t remaining = num - i; + if (remaining != 0) { + const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining); + HWY_ALIGN hwy::bfloat16_t padded[2 * hn::MaxLanes(dbf)]; + hwy::ZeroBytes(padded, sizeof(padded)); + hwy::CopyBytes(vec_aligned + i, padded, remaining * sizeof(padded[0])); + const VBF v0 = hn::LoadU(dbf, padded); + const VBF v1 = hn::LoadU(dbf, padded + N16); + VBF bf0, bf1; + Dec2B(dbf, packed, bf0, bf1); + sum0 = hn::ReorderWidenMulAccumulate(df, bf0, v0, sum0, sum1); + sum2 = hn::ReorderWidenMulAccumulate(df, bf1, v1, sum2, sum3); + } + } + + // Fused decode and dot product with f32 into four output accumulators. + template + static HWY_INLINE void Dot(DF df, const SfpStream* HWY_RESTRICT in_packed, + size_t num, const float* HWY_RESTRICT vec_aligned, + hn::Vec& sum0, hn::Vec& sum1, + hn::Vec& sum2, hn::Vec& sum3) { + const hn::Repartition d8; + using V8 = hn::Vec; + using VF = hn::Vec; + const size_t NF = hn::Lanes(df); + + size_t i = 0; + if (num >= 4 * NF) { + HWY_UNROLL(1) + for (; i <= num - 4 * NF; i += 4 * NF) { + const V8 packed = hn::LoadU(d8, &in_packed->byte + i); + const VF v0 = hn::LoadU(df, vec_aligned + i + NF * 0); + const VF v1 = hn::LoadU(df, vec_aligned + i + NF * 1); + const VF v2 = hn::LoadU(df, vec_aligned + i + NF * 2); + const VF v3 = hn::LoadU(df, vec_aligned + i + NF * 3); + VF f0, f1, f2, f3; + Dec4F(df, packed, f0, f1, f2, f3); + sum0 = hn::MulAdd(f0, v0, sum0); + sum1 = hn::MulAdd(f1, v1, sum1); + sum2 = hn::MulAdd(f2, v2, sum2); + sum3 = hn::MulAdd(f3, v3, sum3); + } + } + + const size_t remaining = num - i; + if (remaining != 0) { + const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining); + HWY_ALIGN float padded[4 * hn::MaxLanes(df)]; + hwy::ZeroBytes(padded, sizeof(padded)); + hwy::CopyBytes(vec_aligned + i, padded, remaining * sizeof(padded[0])); + const VF v0 = hn::LoadU(df, padded + NF * 0); + const VF v1 = hn::LoadU(df, padded + NF * 1); + const VF v2 = hn::LoadU(df, padded + NF * 2); + const VF v3 = hn::LoadU(df, padded + NF * 3); + VF f0, f1, f2, f3; + Dec4F(df, packed, f0, f1, f2, f3); + sum0 = hn::MulAdd(f0, v0, sum0); + sum1 = hn::MulAdd(f1, v1, sum1); + sum2 = hn::MulAdd(f2, v2, sum2); + sum3 = hn::MulAdd(f3, v3, sum3); + } + } + + private: + // Wrappers to avoid code duplication across float/bf16 input types and + // the main loop/remainder. + + // Returns vector of packed bytes for callers to StoreU or StoreN. + template >> + static HWY_INLINE V8 Enc2U(D16 d16, const hn::Vec w0, + const hn::Vec w1) { + const hn::Repartition d8; + + // Although more expensive on AVX3, in-order packing enables streaming + // decompression without fixed-size packets. + const V8 lo = hn::ConcatEven(d8, hn::BitCast(d8, w1), hn::BitCast(d8, w0)); + const V8 hi = hn::ConcatOdd(d8, hn::BitCast(d8, w1), hn::BitCast(d8, w0)); + return EncBytes(d8, lo, hi); + } + + template >> + static HWY_INLINE V8 Enc2B(DBF dbf, const hwy::bfloat16_t* HWY_RESTRICT in) { + const hn::Repartition d16; + const size_t N16 = hn::Lanes(d16); + using V16 = hn::Vec; + + const V16 w0 = hn::BitCast(d16, hn::LoadU(dbf, in)); + const V16 w1 = hn::BitCast(d16, hn::LoadU(dbf, in + N16)); + return Enc2U(d16, w0, w1); + } + + template >> + static HWY_INLINE V8 Enc4F(DF df, const float* HWY_RESTRICT in) { + const hn::Repartition d16; + const hn::Repartition dbf; + using VF = hn::Vec; + using V16 = hn::Vec; + const size_t NF = hn::Lanes(df); + + const VF f0 = hn::LoadU(df, in + NF * 0); + const VF f1 = hn::LoadU(df, in + NF * 1); + const VF f2 = hn::LoadU(df, in + NF * 2); + const VF f3 = hn::LoadU(df, in + NF * 3); + // Chop off the lower 16 bits; EncBytes still rounds properly. + const V16 w0 = hn::BitCast(d16, hn::OrderedDemote2To(dbf, f0, f1)); + const V16 w1 = hn::BitCast(d16, hn::OrderedDemote2To(dbf, f2, f3)); + return Enc2U(d16, w0, w1); + } + + template >> + static HWY_INLINE void Dec2U(D16 d16, V8 packed, hn::Vec& w0, + hn::Vec& w1) { + const hn::Repartition d8; + V8 lo, hi; + DecBytes(d8, packed, lo, hi); + w0 = hn::BitCast(d16, hn::InterleaveWholeLower(d8, lo, hi)); + w1 = hn::BitCast(d16, hn::InterleaveWholeUpper(d8, lo, hi)); + } + + template >> + static HWY_INLINE void Dec2B(DBF dbf, V8 packed, hn::Vec& bf0, + hn::Vec& bf1) { + const hn::Repartition d16; + using V16 = hn::Vec; + V16 w0, w1; + Dec2U(d16, packed, w0, w1); + bf0 = hn::BitCast(dbf, w0); + bf1 = hn::BitCast(dbf, w1); + } + + template >> + static HWY_INLINE void Dec4F(DF df, V8 packed, hn::Vec& f0, + hn::Vec& f1, hn::Vec& f2, + hn::Vec& f3) { + const hn::Repartition dbf; + using VBF = hn::Vec; + VBF bf0, bf1; + Dec2B(dbf, packed, bf0, bf1); + f0 = hn::PromoteLowerTo(df, bf0); + f1 = hn::PromoteUpperTo(df, bf0); + f2 = hn::PromoteLowerTo(df, bf1); + f3 = hn::PromoteUpperTo(df, bf1); + } +}; // SfpCodec + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_ diff --git a/compression/sfp.h b/compression/sfp.h new file mode 100644 index 0000000..332ca43 --- /dev/null +++ b/compression/sfp.h @@ -0,0 +1,51 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_H_ +#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_H_ + +// Switching Floating Point: a hybrid 8-bit float representation of bf16/f32 +// inputs that combines the advantages of e4m3 and e5m2 into a single format. +// It supports seeking at a granularity of 1, decoding to bf16/f32, and a +// fused decode/dot product with bf16/f32 vectors. + +#include + +namespace gcpp { + +// Points to the *start* of an SFP stream. Values are stored in-order to enable +// vector-length agnostic seeking, because streams may be written to disk for +// loading on other CPUs. +// +// Characteristics: +// - 24-bit dynamic range, with max exponent 2^0. +// - 3 bit mantissa for values >= 2^-7, otherwise 2. +// +// This is faster to decode than a straightforward implementation of eXmY, in +// part because SFP does not require subnormals. Unlike OCP MX, it also does not +// require side information (shared exponents). +// +// Although the representation could probably be shrunk to 6-7 bits, more +// savings can be had by non-uniform clustering - see nuq.h. +#pragma pack(push, 1) +struct SfpStream { + uint8_t byte; +}; +#pragma pack(pop) + +static inline const char* TypeName(SfpStream) { return "SFP"; } + +} // namespace gcpp +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_H_ diff --git a/compression/sfp_test.cc b/compression/sfp_test.cc new file mode 100644 index 0000000..ee35743 --- /dev/null +++ b/compression/sfp_test.cc @@ -0,0 +1,440 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// copybara:import_next_line:gemma_cpp +#include "compression/sfp.h" + +#include +#include +#include + +#include +#include +#include + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "third_party/gemma_cpp/compression/sfp_test.cc" // NOLINT +#include "hwy/foreach_target.h" // IWYU pragma: keep +// Any highway.h must come after foreach_target.h +// copybara:import_next_line:gemma_cpp +#include "compression/distortion.h" +// copybara:import_next_line:gemma_cpp +#include "compression/sfp-inl.h" +#include "hwy/highway.h" +#include "hwy/tests/hwy_gtest.h" +#include "hwy/tests/test_util-inl.h" +#include "hwy/timer.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +// Decode +float F32FromSFP8(uint32_t sfp) { + HWY_ASSERT(sfp < 256); + HWY_ASSERT(sfp != 0x80); // -0 is reserved + + const uint32_t sign32 = (sfp & 0x80) << 24; + sfp &= 0x7F; + const bool large_e = sfp >= 64; + const size_t m_bits = large_e ? 3 : 2; + uint32_t m = sfp & ((1u << m_bits) - 1u); + size_t e = sfp >> m_bits; + if (sfp == 0) return 0.0f; + const uint32_t e_bias = large_e ? 15 : 23; + const uint32_t exp32 = static_cast(127 + e - e_bias) << 23; + const uint32_t mnt32 = m << (23 - m_bits); + const uint32_t binary32 = sign32 | exp32 | mnt32; + float result; + hwy::CopySameSize(&binary32, &result); + return result; +} + +void TestAllUnique() { + std::set unique; + for (uint32_t sfp = 0; sfp < 256; ++sfp) { + if (sfp == 0x80) continue; // -0 is reserved + unique.insert(F32FromSFP8(sfp)); + } + HWY_ASSERT_EQ(size_t{255}, unique.size()); + if (false) { + for (float f : unique) { + fprintf(stderr, "%e\n", f); + } + } +} + +// ------------------------------ Foreach compressed representation + +// Encode +HWY_INLINE uint32_t SFP8FromF32(float f) { + HWY_ASSERT(-1.875f <= f && f <= 1.875f); + + constexpr uint32_t kMaskM = hwy::MantissaMask(); + uint32_t binary32; + hwy::CopySameSize(&f, &binary32); + const uint32_t s = (binary32 & hwy::SignMask()) >> 24; + binary32 &= ~hwy::SignMask(); + f = hwy::ScalarAbs(f); + + // >= 1.1111 * 2^-8 rounds up to 1.0*2^-7. + bool large_e = (f >= 0.007568359375f); + + const uint32_t org_binary32 = binary32; + const uint32_t m32 = binary32 & kMaskM; + binary32 = (binary32 & ~kMaskM) | m32; + size_t m_bits = large_e ? 3 : 2; + const uint32_t is_odd = (m32 >> (23 - m_bits)) & 1; + const uint32_t round = is_odd + (1u << (23 - m_bits - 1)) - 1; + const uint32_t rounded = binary32 + round; + + // >= 1.111 also rounds up, but only if it was considered !large_e before. + if (f >= 0.00732421875f) { + large_e = true; + m_bits = 3; + } + + uint32_t m = (kMaskM & rounded) >> (23 - m_bits); + int32_t e = (rounded >> 23) - 127; + + if (e <= -23) { + // 2^-23 is the smallest normal exponent. Zero has e = -127. Do not set the + // SFP sign bit because the encoding for -0 is reserved. + if (e < -23) return 0; + // e = 2^-23: round up mantissa because m=0 encodes 0.0f. + if (m == 0) m = 1; + } + + if (false) { + fprintf(stderr, "in %x round %x rounded %x e %d m %x large_e %d\n", + org_binary32, round, rounded, e, m, large_e); + } + uint32_t e_sfp = e + (large_e ? 15 : 23); + HWY_ASSERT(e_sfp < 16); + + const uint32_t encoded = (e_sfp << m_bits) | m | s; + HWY_ASSERT(encoded < 256); + return encoded; +} + +// For every possible encoding: ensure re-encoding the decoded value matches it. +struct TestDecEnc { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::RepartitionToWide d16; + const hn::Rebind dbf; + const hn::Repartition df; + for (uint32_t encoded = 0; encoded < 256; ++encoded) { + if (encoded == 0x80) continue; // -0 is reserved + const float decoded = F32FromSFP8(encoded); + const uint32_t encoded2 = SFP8FromF32(decoded); + + hn::Vec dec_lo, dec_hi; + SfpCodec::DecBytes(d, hn::Set(d, encoded), dec_lo, dec_hi); + const hn::Vec dec = + hn::BitCast(dbf, hn::ZipLower(d16, dec_lo, dec_hi)); + const float vdecoded = hn::GetLane(hn::PromoteLowerTo(df, dec)); + const uint32_t vencoded2 = + hn::GetLane(SfpCodec::EncBytes(d, dec_lo, dec_hi)); + + if (decoded != vdecoded || encoded2 != vencoded2 || encoded != encoded2) { + HWY_ABORT("enc %u -> dec %E=%x=%E -> enc %u %u\n", encoded, decoded, + hwy::BitCastScalar(decoded), vdecoded, encoded2, + vencoded2); + } + } + } +}; + +void TestAllDecEnc() { hn::ForGEVectors<32, TestDecEnc>()(uint8_t()); } + +// ------------------------------ Golden (known values) + +// Generate values, encode, decode back to that value. +struct TestGolden { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::Repartition df; + const hn::Repartition dbf; + const hn::RebindToUnsigned d16; + + struct Golden { + float in; + float out; + }; + const Golden golden[] = { + // All mantissa bits set, all discarded zero (no rounding) + {0.46875f, 0.46875f}, + {0.9375f, 0.9375f}, + // All mantissa bits set, one below it set (round up to pow2) + {0.484375f, 0.5f}, + {0.96875f, 1.0f}, + // Lowest mantissa bit set, all discarded zero (no rounding) + {0.28125f, 0.28125f}, + {0.5625f, 0.5625f}, + // Lowest mantissa bit set, one below it set (round up to even) + {0.296875f, 0.3125f}, + {0.59375f, 0.625f}, + // All mantissa zero, all discarded set (round up) + {0.279296875f, 0.28125f}, + {0.55859375f, 0.5625f}, + // All mantissa zero, one below it set (round DOWN to pow2) + {0.265625f, 0.25f}, + {0.53125f, 0.5f}, + + // At inflection point: 1.max*2^-8 rounds up to 1.0*2^-7 + {0.0068359375f, 0.0068359375f}, // 1.11 -> 1.11 + {0.00732421875f, 0.0078125f}, // 1.111 -> 1.11[1] -> 1.0 + {0.007568359375f, 0.0078125f}, // 1.1111 -> 1.0 + + // Above 1.0: no longer special-cased. + {1.0f, 1.0f}, + {1.0625f, 1.0f}, // 1.000100 + + // Smallest normal exponents - we no longer use subnormals. + {2.384185791015625E-7f, 2.384185791015625E-7f}, // 1.00p-22 + {1.49011611938E-07f, 1.49011611938E-07f}, // 1.01p-23 + {1.19209289551E-07f, 1.49011611938E-07f}, // 1.00p-23 -> 1.01p-23 + {5.96046447754E-08f, 0.0f}, // 1.00p-24 -> 0 + {8.94069671631E-08f, 0.0f}, // 1.10p-24 -> 0 + {1.11758708954E-07f, 1.49011611938E-07f}, // 1.111p-24-> 1.01p-23 + + // 1100_010 * 2^-7 rounds down to 110 + {0.013841f, 0.013671875f}, + }; + constexpr size_t kNumGolden = sizeof(golden) / sizeof(Golden); + for (uint32_t s : {0, 1}) { + for (size_t i = 0; i < kNumGolden; ++i) { + const float in = s ? -golden[i].in : golden[i].in; + const float out = s ? -golden[i].out : golden[i].out; + const hn::Vec in_bf = + hn::OrderedDemote2To(dbf, hn::Set(df, in), hn::Set(df, in)); + const uint32_t encoded = SFP8FromF32(in); + const uint32_t vencoded = hn::GetLane(SfpCodec::EncBytes( + d, hn::BitCast(d, in_bf), + hn::BitCast(d, hn::ShiftRight<8>(hn::BitCast(d16, in_bf))))); + const float decoded = F32FromSFP8(encoded); + hn::Vec dec_lo, dec_hi; + SfpCodec::DecBytes(d, hn::Set(d, encoded), dec_lo, dec_hi); + const hn::Vec dec = + hn::BitCast(dbf, hn::ZipLower(d16, dec_lo, dec_hi)); + const float vdecoded = hn::GetLane(hn::PromoteLowerTo(df, dec)); + + if (decoded != vdecoded || decoded != out || encoded != vencoded) { + HWY_ABORT("@%zu in %E dec %E %E golden %E\n", i, in, decoded, + vdecoded, golden[i].out); + } + } // i + } // s + } +}; + +void TestAllGolden() { + // Full vectors only, other tests cover partial vectors. + TestGolden()(uint8_t(), hn::ScalableTag()); +} + +// ------------------------------ Foreach bf16 input + +// Generate all values, encode, decode back. +struct TestEncDec { + template + HWY_INLINE void operator()(T /*unused*/, DBF dbf) { + const hn::Repartition du8; + + // We only use the upper 4 of 7 bf16 mantissa bits, so force the lower three + // bits to zero to reduce the number of inputs. + constexpr size_t kStep = 8; + const size_t max = 0x8000 / 8; + + auto in = hwy::AllocateAligned(max); + auto packed = hwy::AllocateAligned(max); + auto dec = hwy::AllocateAligned(max); + HWY_ASSERT(in && packed && dec); + size_t num = 0; + for (size_t i = 0; i < max; ++i) { + const uint16_t bits = i * kStep; + const float f = hwy::F32FromBF16(hwy::BitCastScalar(bits)); + // Keep if within range + if (hwy::ScalarIsFinite(f) && f <= 1.875f) { + in[num] = hwy::BF16FromF32(f); + in[num + 1] = hwy::BF16FromF32(-f); + num += 2; + } + } + + double enc_elapsed = hwy::HighestValue(); + double dec_elapsed = hwy::HighestValue(); + for (size_t rep = 0; rep < 100; ++rep) { + const double t0 = hwy::platform::Now(); + SfpCodec::Enc(dbf, in.get(), num, packed.get()); + const double t1 = hwy::platform::Now(); + SfpCodec::Dec(dbf, packed.get(), num, dec.get()); + const double t2 = hwy::platform::Now(); + enc_elapsed = HWY_MIN(enc_elapsed, t1 - t0); + dec_elapsed = HWY_MIN(dec_elapsed, t2 - t1); + } + const double enc_mbs = num * sizeof(T) * 1E-6 / enc_elapsed; + const double dec_mbs = num * sizeof(T) * 1E-6 / dec_elapsed; + fprintf(stderr, "Vec size %zu Enc %.2f MB/s Dec %.2f MB/s\n", Lanes(du8), + enc_mbs, dec_mbs); + + { + double sum = 0.0; + DistortionStats stats; + for (size_t i = 0; i < num; ++i) { + const float out = hwy::F32FromBF16(dec[i]); + sum += hwy::ConvertScalarTo(hwy::ScalarAbs(in[i])); + stats.Notify(in[i], out); + } + const double avg = sum / num; + fprintf(stderr, "Avg magnitude %.3E, p-norm %.3E snr %.2f @%zu = %.4E\n", + avg, stats.PNorm(), stats.GeomeanValueDivL1(), stats.MaxIndex(), + stats.MaxL1()); + } + } +}; + +void TestAllEncDec() { hn::ForGEVectors<32, TestEncDec>()(hwy::bfloat16_t()); } + +// ------------------------------ Order + +// Store 8-bit iota, decode, encode, check iota == packed. This ensures +// Enc/Dec are preserving the order independent of vector length. +struct TestOrder { + template + HWY_INLINE void operator()(T /*unused*/, DBF dbf) { + const hn::Repartition du8; + + const size_t num = 10 * hn::Lanes(du8) / 3; + + auto iota = hwy::AllocateAligned(num); + auto packed = hwy::AllocateAligned(num); + auto bf = hwy::AllocateAligned(num); + HWY_ASSERT(iota && packed && bf); + for (size_t i = 0; i < num; ++i) { + // Clear sign bit so we can also check that bf is in ascending order. + iota[i].byte = i & 127; + } + + SfpCodec::Dec(dbf, iota.get(), num, bf.get()); + SfpCodec::Enc(dbf, bf.get(), num, packed.get()); + + for (size_t i = 0; i < num; ++i) { + if (iota[i].byte != packed[i].byte) { + HWY_ABORT("@%zu: %d %d\n", i, iota[i].byte, packed[i].byte); + } + } + } +}; + +void TestAllOrder() { hn::ForGEVectors<32, TestOrder>()(hwy::bfloat16_t()); } + +// ------------------------------ Dot + +struct TestDot { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::Repartition df; + const size_t num = 384; + auto in = hwy::AllocateAligned(num); + auto dec = hwy::AllocateAligned(num); + auto vec = hwy::AllocateAligned(num); + auto sfp = hwy::AllocateAligned(num); + HWY_ASSERT(in && dec && vec && sfp); + + std::mt19937 rng(123); + std::normal_distribution dist{0.001f, 0.3f}; + for (size_t i = 0; i < num; ++i) { + in[i] = hwy::ConvertScalarTo(dist(rng)); + vec[i] = hwy::ConvertScalarTo(dist(rng)); + } + // This changes the correlation between in and vec, which considerably + // affects the error of the result. + std::shuffle(in.get(), in.get() + num, rng); + + SfpCodec::Enc(d, in.get(), num, sfp.get()); + + double actual = 0.0; + double elapsed = hwy::HighestValue(); + for (size_t rep = 0; rep < 200; ++rep) { + hn::Vec sum0 = hn::Zero(df); + hn::Vec sum1 = hn::Zero(df); + hn::Vec sum2 = hn::Zero(df); + hn::Vec sum3 = hn::Zero(df); + const double t0 = hwy::platform::Now(); + SfpCodec::Dot(df, sfp.get(), num, vec.get(), sum0, sum1, sum2, sum3); + const double t1 = hwy::platform::Now(); + elapsed = HWY_MIN(elapsed, t1 - t0); + sum0 = hn::Add(hn::Add(sum0, sum1), hn::Add(sum2, sum3)); + actual = hn::ReduceSum(df, sum0); + } + + SfpCodec::Dec(d, sfp.get(), num, dec.get()); + fprintf(stderr, "Vec %zu Dot %.2f MB/s\n", Lanes(d) * sizeof(T), + num * sizeof(T) * 1E-6 / elapsed); + + double expected = 0.0; // using original input + double expected2 = 0.0; // using decoded SFP + for (size_t i = 0; i < num; ++i) { + expected += hwy::ConvertScalarTo(in[i]) * + hwy::ConvertScalarTo(vec[i]); + expected2 += hwy::ConvertScalarTo(dec[i]) * + hwy::ConvertScalarTo(vec[i]); + } + const double l1 = hwy::ScalarAbs(expected - actual); + const double snr = 1.0 + hwy::ScalarAbs(expected) / l1; + fprintf(stderr, "expected %.3f e2 %.4f actual %.4f l1 %E snr %.2f\n", + expected, expected2, actual, l1, snr); + HWY_ASSERT(hwy::ScalarAbs(expected2 - actual) < 1E-4); + const double expected_l1 = sizeof(T) == 2 ? 1.52E-2 : 1.15E-2; + const double expected_snr = sizeof(T) == 2 ? 80.1f : 104.9f; + HWY_ASSERT(expected_l1 <= l1 && l1 < 1.02f * expected_l1); + HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr); + } +}; + +void TestAllDotF32() { + const hn::ForGEVectors<128, TestDot> test; + test(float()); +} +void TestAllDotBF16() { + const hn::ForGEVectors<128, TestDot> test; + test(hwy::bfloat16_t()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace gcpp { +HWY_BEFORE_TEST(SfpTest); +HWY_EXPORT_AND_TEST_P(SfpTest, TestAllUnique); +HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDecEnc); +HWY_EXPORT_AND_TEST_P(SfpTest, TestAllGolden); +HWY_EXPORT_AND_TEST_P(SfpTest, TestAllEncDec); +HWY_EXPORT_AND_TEST_P(SfpTest, TestAllOrder); +HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDotF32); +HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDotBF16); +} // namespace gcpp + +#endif diff --git a/compression/stats.cc b/compression/stats.cc new file mode 100644 index 0000000..bfc7cbf --- /dev/null +++ b/compression/stats.cc @@ -0,0 +1,117 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// copybara:import_next_line:gemma_cpp +#include "compression/stats.h" + +#include + +#include // std::min +#include + +#include "hwy/base.h" // HWY_ASSERT + +void Stats::Assimilate(const Stats& other) { + const int64_t total_n = n_ + other.n_; + if (total_n == 0) return; // Nothing to do; prevents div by zero. + + min_ = std::min(min_, other.min_); + max_ = std::max(max_, other.max_); + + product_ *= other.product_; + + const double product_n = n_ * other.n_; + const double n2 = n_ * n_; + const double other_n2 = other.n_ * other.n_; + const int64_t total_n2 = total_n * total_n; + const double total_n3 = static_cast(total_n2) * total_n; + // Precompute reciprocal for speed - used at least twice. + const double inv_total_n = 1.0 / total_n; + const double inv_total_n2 = 1.0 / total_n2; + + const double delta = other.m1_ - m1_; + const double delta2 = delta * delta; + const double delta3 = delta * delta2; + const double delta4 = delta2 * delta2; + + m1_ = (n_ * m1_ + other.n_ * other.m1_) * inv_total_n; + + const double new_m2 = m2_ + other.m2_ + delta2 * product_n * inv_total_n; + + const double new_m3 = + m3_ + other.m3_ + delta3 * product_n * (n_ - other.n_) * inv_total_n2 + + 3.0 * delta * (n_ * other.m2_ - other.n_ * m2_) * inv_total_n; + + m4_ += other.m4_ + + delta4 * product_n * (n2 - product_n + other_n2) / total_n3 + + 6.0 * delta2 * (n2 * other.m2_ + other_n2 * m2_) * inv_total_n2 + + 4.0 * delta * (n_ * other.m3_ - other.n_ * m3_) * inv_total_n; + + m2_ = new_m2; + m3_ = new_m3; + n_ = total_n; +} + +std::string Stats::ToString(int exclude) const { + if (Count() == 0) return std::string("(none)"); + + char buf[300]; + int pos = 0; + int ret; // snprintf - bytes written or negative for error. + + if ((exclude & kNoCount) == 0) { + ret = snprintf(buf + pos, sizeof(buf) - pos, "Count=%9zu ", + static_cast(Count())); + HWY_ASSERT(ret > 0); + pos += ret; + } + + if ((exclude & kNoMeanSD) == 0) { + const float sd = StandardDeviation(); + if (sd > 100) { + ret = snprintf(buf + pos, sizeof(buf) - pos, "Mean=%8.2E SD=%7.1E ", + Mean(), sd); + } else { + ret = snprintf(buf + pos, sizeof(buf) - pos, "Mean=%8.6f SD=%7.5f ", + Mean(), sd); + } + HWY_ASSERT(ret > 0); + pos += ret; + } + + if ((exclude & kNoMinMax) == 0) { + ret = snprintf(buf + pos, sizeof(buf) - pos, "Min=%8.5e Max=%8.5e ", Min(), + Max()); + HWY_ASSERT(ret > 0); + pos += ret; + } + + if ((exclude & kNoSkewKurt) == 0) { + ret = snprintf(buf + pos, sizeof(buf) - pos, "Skew=%5.2f Kurt=%7.2f ", + Skewness(), Kurtosis()); + HWY_ASSERT(ret > 0); + pos += ret; + } + + if ((exclude & kNoGeomean) == 0) { + ret = snprintf(buf + pos, sizeof(buf) - pos, "GeoMean=%9.6f ", + GeometricMean()); + HWY_ASSERT(ret > 0); + pos += ret; + } + + HWY_ASSERT(pos < sizeof(buf)); + return buf; +} diff --git a/compression/stats.h b/compression/stats.h new file mode 100644 index 0000000..1f0d262 --- /dev/null +++ b/compression/stats.h @@ -0,0 +1,190 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_STATS_H_ +#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_STATS_H_ + +#include +#include + +#include +#include +#include + +#include "hwy/base.h" // HWY_ASSERT + +// Thread-compatible. +template +class Bins { + public: + Bins() { Reset(); } + + template + void Notify(T bin) { + HWY_ASSERT(T{0} <= bin && bin < static_cast(N)); + counts_[static_cast(bin)]++; + } + + void Assimilate(const Bins& other) { + for (size_t i = 0; i < N; ++i) { + counts_[i] += other.counts_[i]; + } + } + + void Print(const char* caption) { + fprintf(stderr, "\n%s [%zu]\n", caption, N); + size_t last_nonzero = 0; + for (size_t i = N - 1; i < N; --i) { + if (counts_[i] != 0) { + last_nonzero = i; + break; + } + } + for (size_t i = 0; i <= last_nonzero; ++i) { + fprintf(stderr, " %zu\n", counts_[i]); + } + } + + void Reset() { + for (size_t i = 0; i < N; ++i) { + counts_[i] = 0; + } + } + + private: + size_t counts_[N]; +}; + +// Descriptive statistics of a variable (4 moments). Thread-compatible. +class Stats { + public: + Stats() { Reset(); } + + void Notify(const float x) { + ++n_; + + min_ = std::min(min_, x); + max_ = std::max(max_, x); + + product_ *= x; + + // Online moments. Reference: https://goo.gl/9ha694 + const double d = x - m1_; + const double d_div_n = d / n_; + const double d2n1_div_n = d * (n_ - 1) * d_div_n; + const int64_t n_poly = n_ * n_ - 3 * n_ + 3; + m1_ += d_div_n; + m4_ += d_div_n * (d_div_n * (d2n1_div_n * n_poly + 6.0 * m2_) - 4.0 * m3_); + m3_ += d_div_n * (d2n1_div_n * (n_ - 2) - 3.0 * m2_); + m2_ += d2n1_div_n; + } + + void Assimilate(const Stats& other); + + int64_t Count() const { return n_; } + + float Min() const { return min_; } + float Max() const { return max_; } + + double GeometricMean() const { + return n_ == 0 ? 0.0 : pow(product_, 1.0 / n_); + } + + double Mean() const { return m1_; } + // Same as Mu2. Assumes n_ is large. + double SampleVariance() const { + return n_ == 0 ? 0.0 : m2_ / static_cast(n_); + } + // Unbiased estimator for population variance even for smaller n_. + double Variance() const { + if (n_ == 0) return 0.0; + if (n_ == 1) return m2_; + return m2_ / static_cast(n_ - 1); + } + double StandardDeviation() const { return std::sqrt(Variance()); } + // Near zero for normal distributions; if positive on a unimodal distribution, + // the right tail is fatter. Assumes n_ is large. + double SampleSkewness() const { + if (std::abs(m2_) < 1E-7) return 0.0; + return m3_ * std::sqrt(static_cast(n_)) / std::pow(m2_, 1.5); + } + // Corrected for bias (same as Wikipedia and Minitab but not Excel). + double Skewness() const { + if (n_ == 0) return 0.0; + const double biased = SampleSkewness(); + const double r = (n_ - 1.0) / n_; + return biased * std::pow(r, 1.5); + } + // Near zero for normal distributions; smaller values indicate fewer/smaller + // outliers and larger indicates more/larger outliers. Assumes n_ is large. + double SampleKurtosis() const { + if (std::abs(m2_) < 1E-7) return 0.0; + return m4_ * n_ / (m2_ * m2_); + } + // Corrected for bias (same as Wikipedia and Minitab but not Excel). + double Kurtosis() const { + if (n_ == 0) return 0.0; + const double biased = SampleKurtosis(); + const double r = (n_ - 1.0) / n_; + return biased * r * r; + } + + // Central moments, useful for "method of moments"-based parameter estimation + // of a mixture of two Gaussians. Assumes Count() != 0. + double Mu1() const { return m1_; } + double Mu2() const { return m2_ / static_cast(n_); } + double Mu3() const { return m3_ / static_cast(n_); } + double Mu4() const { return m4_ / static_cast(n_); } + + // Which statistics to EXCLUDE in ToString + enum { + kNoCount = 1, + kNoMeanSD = 2, + kNoMinMax = 4, + kNoSkewKurt = 8, + kNoGeomean = 16 + }; + std::string ToString(int exclude = 0) const; + + void Reset() { + n_ = 0; + + min_ = hwy::HighestValue(); + max_ = hwy::LowestValue(); + + product_ = 1.0; + + m1_ = 0.0; + m2_ = 0.0; + m3_ = 0.0; + m4_ = 0.0; + } + + private: + int64_t n_; // signed for faster conversion + safe subtraction + + float min_; + float max_; + + double product_; // for geomean + + // Moments + double m1_; + double m2_; + double m3_; + double m4_; +}; + +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_STATS_H_ diff --git a/configs.h b/configs.h new file mode 100644 index 0000000..278d5ea --- /dev/null +++ b/configs.h @@ -0,0 +1,57 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Model configurations + +#ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ +#define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ + +#include + +namespace gcpp { + +static constexpr size_t kSeqLen = 7168; + +struct ConfigGemma7B { + // NOLINTBEGIN(google3-readability-class-member-naming) + static constexpr int seq_len = kSeqLen; + static constexpr int vocab_size = 256128; + static constexpr int n_layers = 28; + static constexpr int dim_model = 3072; + static constexpr int dim_ffw_hidden = 16 * 3072 / 2; // = 24576 + static constexpr int n_heads = 16; + static constexpr int n_kv_heads = 16; // standard MHA, no GQA or MQA + static constexpr int dim_qkv = 256; // query size == key size == value size + static constexpr int top_k = 1; + // NOLINTEND(google3-readability-class-member-naming) +}; + +struct ConfigGemma2B { + // NOLINTBEGIN(google3-readability-class-member-naming) + static constexpr int seq_len = kSeqLen; + static constexpr int vocab_size = 256128; + static constexpr int n_layers = 18; + static constexpr int dim_model = 2048; + static constexpr int dim_ffw_hidden = 16 * 2048 / 2; // = 16384 + static constexpr int n_heads = 8; + static constexpr int n_kv_heads = 8; // TODO(austinvhuang): add MQA support + static constexpr int dim_qkv = 256; // query size == key size == value size + static constexpr int top_k = 1; + // NOLINTEND(google3-readability-class-member-naming) +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md new file mode 100644 index 0000000..ea73169 --- /dev/null +++ b/docs/CONTRIBUTING.md @@ -0,0 +1,32 @@ +# How to Contribute + +We would love to accept your patches and contributions to this project. + +## Before you begin + +### Sign our Contributor License Agreement + +Contributions to this project must be accompanied by a +[Contributor License Agreement](https://cla.developers.google.com/about) (CLA). +You (or your employer) retain the copyright to your contribution; this simply +gives us permission to use and redistribute your contributions as part of the +project. + +If you or your current employer have already signed the Google CLA (even if it +was for a different project), you probably don't need to do it again. + +Visit to see your current agreements or to +sign a new one. + +### Review our Community Guidelines + +This project follows [Google's Open Source Community +Guidelines](https://opensource.google/conduct/). + +## Contribution process + +### Code Reviews + +All submissions, including submissions by project members, require review. We +use [GitHub pull requests](https://docs.github.com/articles/about-pull-requests) +for this purpose. \ No newline at end of file diff --git a/gemma.cc b/gemma.cc new file mode 100644 index 0000000..5cea7d5 --- /dev/null +++ b/gemma.cc @@ -0,0 +1,811 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Lightweight C++ implementation of the gemma model. + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "gemma.cc" // NOLINT +#include "hwy/foreach_target.h" // IWYU pragma: keep +// Must come after foreach_target.h to avoid redefinition errors. +// copybara:import_next_line:gemma_cpp +#include "compression/compress-inl.h" +// copybara:import_next_line:gemma_cpp +#include "ops.h" +// copybara:import_next_line:gemma_cpp +#include "util/args.h" // Path +#include "hwy/contrib/matvec/matvec-inl.h" +#include "hwy/highway.h" +#include "hwy/profiler.h" +#include "hwy/timer.h" + +// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last +// compile pass, whereas we want this defined in the first. +#ifndef GEMMA_ONCE +#define GEMMA_ONCE + +#include +#include + +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include +#include + +// copybara:import_next_line:gemma_cpp +#include "compression/compress.h" +// copybara:import_next_line:gemma_cpp +#include "configs.h" +// copybara:import_next_line:gemma_cpp +#include "gemma.h" +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +// copybara:import_next_line:sentencepiece +#include "src/sentencepiece_processor.h" +// #include "third_party/sentencepiece/src/util.h" + +namespace gcpp { + +template +struct Layer { + Layer() = default; + // NOLINTBEGIN(google3-readability-class-member-naming) + static constexpr size_t n_heads = TConfig::n_heads; + static constexpr size_t dim_model = TConfig::dim_model; + static constexpr size_t dim_qkv = TConfig::dim_qkv; + static constexpr size_t dim_ffw_hidden = TConfig::dim_ffw_hidden; + static constexpr size_t size_attn_vec_einsum_w = + n_heads * dim_qkv * dim_model; + // 3x for (query, key, value) + static constexpr size_t size_qkv_einsum_w = 3 * n_heads * dim_qkv * dim_model; + // 2x for (gelu gating vector, gated vector) + static constexpr size_t size_gating_einsum_w = 2 * dim_ffw_hidden * dim_model; + static constexpr size_t size_linear_w = dim_model * dim_ffw_hidden; + std::array attn_vec_einsum_w; + std::array qkv_einsum_w; + std::array gating_einsum_w; + std::array linear_w; + std::array pre_attention_norm_scale; + std::array pre_ffw_norm_scale; + // NOLINTEND(google3-readability-class-member-naming) +}; + +template +struct Weights { + Weights() = default; + + hwy::AlignedUniquePtr[]> layers; // n_layers + + std::array + embedder_input_embedding; + + std::array final_norm_scale; +}; + +// Only called if cached loading fails. +template +hwy::AlignedUniquePtr> LoadWeights(const Path& checkpoint) { + PROFILER_ZONE("Startup.LoadWeights"); + using TWeights = Weights; + hwy::AlignedUniquePtr weights = hwy::MakeUniqueAligned(); + weights->layers = + hwy::MakeUniqueAlignedArray>(TConfig::n_layers); + + FILE* fptr; + fptr = fopen(checkpoint.path.c_str(), "rb"); + if (fptr == nullptr) { + HWY_ABORT("Failed to open model file %s - does it exist?", + checkpoint.path.c_str()); + } + bool ok = true; + ok &= 1 == fread(&(weights->embedder_input_embedding), + sizeof(weights->embedder_input_embedding), 1, fptr); + ok &= 1 == fread(&(weights->final_norm_scale), + sizeof(weights->final_norm_scale), 1, fptr); + for (size_t layer = 0; layer < TConfig::n_layers; ++layer) { + Layer* layer_view = &weights->layers[layer]; + ok &= 1 == fread(&layer_view->attn_vec_einsum_w, + sizeof(layer_view->attn_vec_einsum_w), 1, fptr); + ok &= 1 == fread(&layer_view->qkv_einsum_w, + sizeof(layer_view->qkv_einsum_w), 1, fptr); + ok &= 1 == fread(&layer_view->gating_einsum_w, + sizeof(layer_view->gating_einsum_w), 1, fptr); + ok &= 1 == + fread(&layer_view->linear_w, sizeof(layer_view->linear_w), 1, fptr); + ok &= 1 == fread(&layer_view->pre_attention_norm_scale, + sizeof(layer_view->pre_attention_norm_scale), 1, fptr); + ok &= 1 == fread(&layer_view->pre_ffw_norm_scale, + sizeof(layer_view->pre_ffw_norm_scale), 1, fptr); + } + if (!ok) { + HWY_ABORT("Failed to read from %s - might be a directory, or too small?", + checkpoint.path.c_str()); + } + HWY_ASSERT(0 == fclose(fptr)); + return weights; +} + +template +struct CompressedLayer { + // No ctor/dtor, allocated via AllocateAligned. + + using TLayer = gcpp::Layer; + + // # NOLINTBEGIN(google3-readability-class-member-naming) + static constexpr size_t dim_model = TConfig::dim_model; + static constexpr size_t dim_ffw_hidden = TConfig::dim_ffw_hidden; + // NOLINTEND(google3-readability-class-member-naming) + + // Compressed Parameters + // We don't yet have an RMSNorm that accepts all WeightT. + CompressedArray c_pre_attention_norm_scale; + CompressedArray c_pre_ffw_norm_scale; + CompressedArray c_gating_einsum_w; + CompressedArray c_linear_w; + CompressedArray c_qkv_einsum_w; + CompressedArray c_attn_vec_einsum_w; +}; + +// Array instead of single large allocation for parallel mem init. Split out of +// CompressedWeights so that only these pointers are initialized, not the +// CompressedArray. +template +struct CompressedLayerPointers { + explicit CompressedLayerPointers(hwy::ThreadPool& pool) { + pool.Run(0, TConfig::n_layers, [this](uint64_t task, size_t /*thread*/) { + this->c_layers[task] = hwy::AllocateAligned>(1); + }); + } + + using CLayer = CompressedLayer; + std::array, TConfig::n_layers> c_layers; +}; + +template +struct CompressedWeights { + // No ctor/dtor, allocated via AllocateAligned. + + CompressedArray + c_embedder_input_embedding; + + CompressedArray c_final_norm_scale; + + // Must be last so that the other arrays remain aligned. + CompressedLayerPointers c_layer_ptrs; + + const CompressedLayer* CLayer(size_t layer) const { + return c_layer_ptrs.c_layers[layer].get(); + } + CompressedLayer* CLayer(size_t layer) { + return c_layer_ptrs.c_layers[layer].get(); + } +}; + +// Aligned. +template +struct Activations { + // # NOLINTBEGIN(google3-readability-class-member-naming) + static constexpr size_t batch_size = BatchSize; + using LayerConfig = Layer; + static constexpr size_t dim_model = TConfig::dim_model; + static constexpr size_t dim_qkv = TConfig::dim_qkv; + static constexpr size_t n_heads = TConfig::n_heads; + static constexpr size_t n_kv_heads = TConfig::n_kv_heads; + static constexpr size_t size_cache_pos = + TConfig::n_layers * n_kv_heads * dim_qkv; + static constexpr size_t size_cache_layer = n_kv_heads * dim_qkv; + // NOLINTEND(google3-readability-class-member-naming) + + std::array x; // input + std::array pre_att_rms_out; + std::array q; // query vector + std::array + att; // attention vector + std::array + att_out; // attention output + std::array + att_post1; // attention output after linear transformation, per head + std::array + att_post2; // accumulation of attention outputs over heads + std::array bf_pre_ffw_rms_out; + std::array ffw_hidden; + // bf_ version can't be used until GeluMulToBF16 issue in FFW() is resolved. + // std::array + // bf_ffw_hidden; + std::array ffw_out; + std::array logits; +}; + +// GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we +// define an abstract base class. +struct GemmaInterface { + virtual ~GemmaInterface() = default; + + virtual const sentencepiece::SentencePieceProcessor& Tokenizer() const = 0; + + // TODO: group pool/callbacks into struct + virtual void Generate(const InferenceArgs& args, + const std::vector& prompt, size_t start_pos, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, + int verbosity) = 0; +}; + +template +struct GemmaImpl : public GemmaInterface { + GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool); + + ~GemmaImpl() { + using CWeights = CompressedWeights; + CWeights* c_weights = reinterpret_cast(compressed_weights.get()); + c_weights->c_layer_ptrs.~CompressedLayerPointers(); + } + + const sentencepiece::SentencePieceProcessor& Tokenizer() const { + return tokenizer; + } + + void Generate(const InferenceArgs& args, const std::vector& prompt, + size_t start_pos, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937&, int verbosity); + + sentencepiece::SentencePieceProcessor tokenizer; + + // CompressedWeights + hwy::AlignedFreeUniquePtr compressed_weights; + hwy::AlignedUniquePtr> prefill; + hwy::AlignedUniquePtr> state; + KVCache kv_cache; +}; + +} // namespace gcpp +#endif // GEMMA_ONCE + +// SIMD code, compiled once per target. +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +template +HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, + Activations& activations, + const CompressedLayer* c_layer, + KVCache& kv_cache, hwy::ThreadPool& pool) { + PROFILER_ZONE("Gen.Attention"); + const size_t pos = batch_start + batch_idx; + HWY_DASSERT(batch_idx < batch_size); + static constexpr size_t dim_qkv = gcpp::Activations::dim_qkv; + static constexpr size_t size_cache_pos = + gcpp::Activations::size_cache_pos; + static constexpr size_t size_cache_layer = + gcpp::Activations::size_cache_layer; + static constexpr size_t dim_model = + gcpp::Activations::dim_model; + static constexpr size_t n_heads = TConfig::n_heads; + const float kQueryScale = 1.0 / sqrtf(static_cast(dim_qkv)); + + pool.Run(0, n_heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + // linear projections to QKV + const size_t head_offset = + 3 * dim_qkv * dim_model; // 3x for QKV dimensions + const size_t q_offset = head * head_offset + 0 * dim_qkv * dim_model; + const size_t k_offset = head * head_offset + 1 * dim_qkv * dim_model; + const size_t v_offset = head * head_offset + 2 * dim_qkv * dim_model; + + float* HWY_RESTRICT q = + activations.q.data() + head * dim_qkv + batch_idx * n_heads * dim_qkv; + + const size_t batch_offset = batch_idx * dim_model; + + MatVecLoop( + c_layer->c_qkv_einsum_w, q_offset, + activations.pre_att_rms_out.data() + batch_offset, q); + + const size_t kv_offset = + pos * size_cache_pos + layer * size_cache_layer + head * dim_qkv; + + TwoOfsMatVecLoop( + c_layer->c_qkv_einsum_w, k_offset, v_offset, + activations.pre_att_rms_out.data() + batch_offset, + kv_cache.key_cache.get() + kv_offset, + kv_cache.value_cache.get() + kv_offset); + + // Calculate scores + float* HWY_RESTRICT head_att = activations.att.data() + + head * TConfig::seq_len + + batch_idx * n_heads * dim_qkv; + + Rope(q, dim_qkv, pos); + Rope(kv_cache.key_cache.get() + kv_offset, dim_qkv, pos); + MulByConst(kQueryScale, q, dim_qkv); + // Compute Q dot K scores + for (size_t pos2 = 0; pos2 <= pos; ++pos2) { + const size_t cache_offset = + pos2 * size_cache_pos + layer * size_cache_layer + head * dim_qkv; + const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset; + const float score = Dot(q, k2, dim_qkv); + head_att[pos2] = score; + } + Softmax(head_att, pos + 1); + + // Weighted summation + float* HWY_RESTRICT att_out = activations.att_out.data() + head * dim_qkv + + batch_idx * n_heads * dim_qkv; + hwy::ZeroBytes(att_out, dim_qkv * sizeof(*att_out)); + for (size_t pos2 = 0; pos2 <= pos; ++pos2) { + const size_t cache_offset = + pos2 * size_cache_pos + layer * size_cache_layer + head * dim_qkv; + float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset; + MulByConstAndAdd(head_att[pos2], v2, att_out, dim_qkv); + } + // linear projection from dim_qkv back to dim_model, sum projections + // across heads + float* HWY_RESTRICT head_out = + head == 0 + ? activations.att_post2.data() + batch_idx * dim_model + : activations.att_post1.data() + head * batch_size * dim_model; + MatVecLoop(c_layer->c_attn_vec_einsum_w, + head * dim_model * dim_qkv, att_out, + head_out); + }); + + // accumulate output across all heads into att_post2. head 0 already wrote + // directly to att_post2. + for (size_t head = 1; head < n_heads; ++head) { + AddFrom(activations.att_post1.data() + head * batch_size * dim_model, + activations.att_post2.data() + batch_idx * dim_model, dim_model); + } +} + +template +HWY_NOINLINE void FFW(Activations& activations, + size_t batch_idx, const CompressedLayer* c_layer, + hwy::ThreadPool& pool) { + HWY_DASSERT(batch_idx < batch_size); + static constexpr size_t dim_model = TConfig::dim_model; + static constexpr size_t dim_ffw_hidden = TConfig::dim_ffw_hidden; + const size_t hidden_offset = batch_idx * dim_ffw_hidden * 2; + + { + PROFILER_ZONE("Gen.FFW.GatedGELU"); + const hwy::bfloat16_t* HWY_RESTRICT vec = + activations.bf_pre_ffw_rms_out.data() + batch_idx * dim_model; + float* HWY_RESTRICT out = activations.ffw_hidden.data() + hidden_offset; + float* HWY_RESTRICT out_mul = out + dim_ffw_hidden; + + // Same matrix, first and second half of rows. Could fuse into one MatVec, + // but separating them could help on NUMA e.g. multiple sockets. + MatVec(c_layer->c_gating_einsum_w, + dim_ffw_hidden * dim_model, vec, out_mul, + pool); + + // Gate, will go through the nonlinearity. + MatVec(c_layer->c_gating_einsum_w, 0, vec, out, + pool); + + namespace hn = hwy::HWY_NAMESPACE; + using DF = hn::ScalableTag; + using VF = hn::Vec; + hn::Transform1(DF(), out, dim_ffw_hidden, out_mul, + [](DF df, VF v, VF mul) + HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); }); + } + + PROFILER_ZONE("Gen.FFW\\GatedGELU"); + MatVec( + c_layer->c_linear_w, 0, activations.ffw_hidden.data() + hidden_offset, + activations.ffw_out.data() + batch_idx * dim_model, pool); +} + +template +HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, + const CompressedWeights& c_weights, + Activations& activations, + KVCache& kv_cache, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool) { + PROFILER_ZONE("Gen.Prefill\\Att\\FFW"); + static constexpr size_t dim_model = TConfig::dim_model; + static const float kEmbScaling = sqrtf(static_cast(dim_model)); + + pool.Run( + 0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { + const int token = tokens[token_idx]; + Decompress(c_weights.c_embedder_input_embedding, token * dim_model, + activations.x.data() + token_idx * dim_model, dim_model); + MulByConst(kEmbScaling, activations.x.data() + token_idx * dim_model, + dim_model); + }); + + for (size_t layer = 0; layer < TConfig::n_layers; ++layer) { + const CompressedLayer* c_layer = c_weights.CLayer(layer); + + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + RMSNorm(activations.x.data() + token_idx * dim_model, + c_layer->c_pre_attention_norm_scale.data(), + activations.pre_att_rms_out.data() + token_idx * dim_model, + dim_model); + Attention(pos, token_idx, layer, activations, + c_layer, kv_cache, pool); + } + + // TODO: sink the loop into these functions, i.e. make them matmuls. + pool.Run( + 0, num_tokens, + [&](const uint64_t token_idx, size_t thread_id) HWY_ATTR { + AddFrom(activations.att_post2.data() + token_idx * dim_model, + activations.x.data() + token_idx * dim_model, dim_model); + RMSNorm(activations.x.data() + token_idx * dim_model, + c_layer->c_pre_ffw_norm_scale.data(), + activations.bf_pre_ffw_rms_out.data() + token_idx * dim_model, + dim_model); + FFW(activations, token_idx, c_layer, inner_pool); + AddFrom(activations.ffw_out.data() + token_idx * dim_model, + activations.x.data() + token_idx * dim_model, dim_model); + }); + } // foreach layer + + pool.Run( + 0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { + RMSNormInplace(c_weights.c_final_norm_scale.data(), + activations.x.data() + token_idx * dim_model, dim_model); + }); +} + +// n = 1 specialization +template +void Transformer(int token, size_t pos, + const CompressedWeights& c_weights, + Activations& activations, KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool) { + static constexpr size_t n_layers = TConfig::n_layers; + static constexpr size_t dim_model = TConfig::dim_model; + + static const float kEmbScaling = sqrtf(static_cast(dim_model)); + + Decompress(c_weights.c_embedder_input_embedding, token * dim_model, + activations.x.data(), dim_model); + + MulByConst(kEmbScaling, activations.x.data(), dim_model); + + for (size_t layer = 0; layer < n_layers; ++layer) { + const CompressedLayer* c_layer = c_weights.CLayer(layer); + RMSNorm(activations.x.data(), c_layer->c_pre_attention_norm_scale.data(), + activations.pre_att_rms_out.data(), dim_model); + Attention(pos, 0, layer, activations, c_layer, kv_cache, pool); + AddFrom(activations.att_post2.data(), activations.x.data(), dim_model); + RMSNorm(activations.x.data(), c_layer->c_pre_ffw_norm_scale.data(), + activations.bf_pre_ffw_rms_out.data(), dim_model); + FFW(activations, /* batch_idx = */ 0, c_layer, pool); + AddFrom(activations.ffw_out.data(), activations.x.data(), dim_model); + } + RMSNormInplace(c_weights.c_final_norm_scale.data(), activations.x.data(), + dim_model); +} + +template +void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, + const std::vector& prompt, size_t pos, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, + int verbosity) { + static constexpr size_t dim_model = TConfig::dim_model; + static constexpr size_t vocab_size = TConfig::vocab_size; + static constexpr size_t top_k = TConfig::top_k; + Activations& activations = *gemma.state.get(); + Activations& prefill_activations = + *gemma.prefill.get(); + const CompressedWeights& c_weights = + *reinterpret_cast*>( + gemma.compressed_weights.get()); + KVCache& kv_cache = gemma.kv_cache; + int token; + + // pos indexes the KV cache. In the first turn of a chat, pos = 0. + // + // After the first turn, pos gets passed in with > 0 corresponding to the + // current token position in the KV cache. + // + // pos_offset keeps track of the relative position within the turn, starting + // at 0 each turn. During prefill, pos_offset corresponds to the index into + // the prompt vector. + // + // In single-turn (non-chat) usage, pos and pos_offset start at 0 and are + // always equal. + size_t pos_offset = 0; // offset relative to pos + double prefill_start = hwy::platform::Now(); + + // Prefill stops before prompt.size() - 1 since the last prompt token is the + // first input token for generation. + while (pos_offset < prompt.size() - 1) { + const size_t end_offset = + std::min(kPrefillBatchSize, prompt.size() - 1 - pos_offset); + HWY_DASSERT(end_offset < prompt.size()); + const int* batch_tokens = prompt.data() + pos_offset; + Prefill(batch_tokens, end_offset, pos, + c_weights, prefill_activations, + kv_cache, pool, inner_pool); + for (size_t idx = 0; idx < end_offset; ++idx) { + stream_token(batch_tokens[idx], 0.0); + } + pos += end_offset; + pos_offset += end_offset; + } + + if (verbosity >= 2) { + // in the future this output should not occur in GenerateImpl but instead + // should be available as observable state for frontend code to handle I/O. + double prefill_end = hwy::platform::Now(); + const double prefill_tok_sec = pos_offset / (prefill_end - prefill_start); + std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]\n"; + } + + double gen_start = hwy::platform::Now(); + + HWY_DASSERT(pos_offset == prompt.size() - 1); + + if (verbosity >= 2) { + // Provide usage warnings if max_new_tokens is out of range. + if (args.max_generated_tokens > args.max_tokens) { + std::cout << "Warning: max_new_tokens should be <= max_tokens" + << std::endl; + } else if ((prompt.size() + args.max_generated_tokens) > args.max_tokens) { + std::cout << "Warning: Prompt size + max_new_tokens exceeds max_tokens." + << std::endl; + } + } + + auto pos_gen_start = pos_offset; + token = prompt.at(pos_offset); + size_t generate_pos = 0; + for (; pos < args.max_tokens && generate_pos < args.max_generated_tokens; + ++pos, ++pos_offset, ++generate_pos) { + Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool); + float* final_activation = activations.x.data(); + if (pos_offset >= prompt.size()) { + PROFILER_ZONE("Gen.Embedding"); + // Generation phase + MatVec(c_weights.c_embedder_input_embedding, 0, + final_activation, activations.logits.data(), + pool); + // Barrier: must have all logits so we can subtract max. + Softmax(activations.logits.data(), vocab_size); + token = SampleTopK(activations.logits.data(), vocab_size, gen, + args.temperature, accept_token); + } + if (!stream_token(token, activations.logits[token])) { + token = EOS_ID; + } + if (token == EOS_ID) { + if (verbosity >= 2) { + double gen_end = hwy::platform::Now(); + const double gen_tok_sec = + (pos_offset - pos_gen_start) / (gen_end - gen_start); + std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n"; + } + break; + } + } +} + +void Generate2B(GemmaImpl& gemma, const InferenceArgs& args, + const std::vector& prompt, size_t start_pos, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const StreamFunc& stream_token, const AcceptFunc& accept_token, + std::mt19937& gen, int verbosity) { + GenerateImpl(gemma, args, prompt, start_pos, pool, inner_pool, stream_token, + accept_token, gen, verbosity); +} + +void Generate7B(GemmaImpl& gemma, const InferenceArgs& args, + const std::vector& prompt, size_t start_pos, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const StreamFunc& stream_token, const AcceptFunc& accept_token, + std::mt19937& gen, int verbosity) { + GenerateImpl(gemma, args, prompt, start_pos, pool, inner_pool, stream_token, + accept_token, gen, verbosity); +} + +// Calls func(name, float*, CompressedArray&) for each tensor. float* is null +// if weights = null, which happens during the first call where we attempt to +// load from cache. +// +// This avoids repeating the list of tensors between loading and compressing. +template +void ForEachTensor(const Weights* weights, + CompressedWeights& c_weights, Func& func) { + func("c_embedding", + weights ? weights->embedder_input_embedding.data() : nullptr, + c_weights.c_embedder_input_embedding); + func("c_final_norm", weights ? weights->final_norm_scale.data() : nullptr, + c_weights.c_final_norm_scale); + + char name[16]; + for (size_t layer_idx = 0; layer_idx < TConfig::n_layers; ++layer_idx) { + Layer* layer = weights ? &weights->layers[layer_idx] : nullptr; + CompressedLayer* c_layer = c_weights.CLayer(layer_idx); + + snprintf(name, sizeof(name), "pre_ff_ns_%lu", layer_idx); + func(name, layer ? layer->pre_ffw_norm_scale.data() : nullptr, + c_layer->c_pre_ffw_norm_scale); + + snprintf(name, sizeof(name), "gating_ein_%lu", layer_idx); + func(name, layer ? layer->gating_einsum_w.data() : nullptr, + c_layer->c_gating_einsum_w); + + snprintf(name, sizeof(name), "linear_w_%lu", layer_idx); + func(name, layer ? layer->linear_w.data() : nullptr, c_layer->c_linear_w); + snprintf(name, sizeof(name), "qkv_ein_%lu", layer_idx); + + func(name, layer ? layer->qkv_einsum_w.data() : nullptr, + c_layer->c_qkv_einsum_w); + snprintf(name, sizeof(name), "att_ein_%lu", layer_idx); + + func(name, layer ? layer->attn_vec_einsum_w.data() : nullptr, + c_layer->c_attn_vec_einsum_w); + + snprintf(name, sizeof(name), "pre_att_ns_%lu", layer_idx); + func(name, layer ? layer->pre_attention_norm_scale.data() : nullptr, + c_layer->c_pre_attention_norm_scale); + } +} + +template +hwy::AlignedFreeUniquePtr GetCompressedWeights( + const Path& model, const Path& cache, hwy::ThreadPool& pool) { + PROFILER_ZONE("Startup.LoadCache"); + + if (!std::filesystem::exists(model.path) && + !std::filesystem::exists(cache.path)) { + HWY_ABORT( + "Either the model weights (--weights) or cached compressed weights " + "(--compressed_weights) must exist."); + } + + // Allocate compressed weights. + using CWeights = CompressedWeights; + hwy::AlignedFreeUniquePtr c_weights_u8 = + hwy::AllocateAligned(sizeof(CWeights)); + CWeights* c_weights = reinterpret_cast(c_weights_u8.get()); + new (&c_weights->c_layer_ptrs) CompressedLayerPointers(pool); + + // First attempt to load them from cache, without requiring weights. + CacheLoader loader(cache.path.c_str()); + ForEachTensor(nullptr, *c_weights, loader); + if (loader.ReadAll(pool)) return c_weights_u8; + + // Get weights, compress, and store in cache. + hwy::AlignedUniquePtr> weights = LoadWeights(model); + Compressor compressor(pool); + ForEachTensor(weights.get(), *c_weights, compressor); + compressor.WriteAll(pool, cache.path.c_str()); + + return c_weights_u8; +} + +// Type-erased because this function is called via a function pointer. +hwy::AlignedFreeUniquePtr GetCompressedWeightsT( + const LoaderArgs& args, hwy::ThreadPool& pool) { + switch (args.ModelType()) { + case Model::GEMMA_2B: + return GetCompressedWeights(args.model, args.cache, pool); + case Model::GEMMA_7B: + return GetCompressedWeights(args.model, args.cache, pool); + default: + HWY_ABORT("Model type %d unknown.", static_cast(args.ModelType())); + } +} + +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace gcpp { + +HWY_EXPORT(GetCompressedWeightsT); +HWY_EXPORT(Generate2B); +HWY_EXPORT(Generate7B); + +KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) { + KVCache kv_cache = {}; + kv_cache.key_cache = hwy::AllocateAligned(seq_len * size_cache_pos); + kv_cache.value_cache = hwy::AllocateAligned(seq_len * size_cache_pos); + return kv_cache; +} + +template +GemmaImpl::GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool) + : compressed_weights( + HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)), + prefill(hwy::MakeUniqueAligned>()), + state(hwy::MakeUniqueAligned>()), + kv_cache( + CreateKVCache(Config::n_layers * Config::n_kv_heads * Config::dim_qkv, + Config::seq_len)) { + PROFILER_ZONE("Startup.tokenizer"); + + HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok()); +} + +template <> +void GemmaImpl::Generate(const InferenceArgs& args, + const std::vector& prompt, + size_t start_pos, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, + const StreamFunc& stream_token, + const AcceptFunc& accept_token, + std::mt19937& gen, int verbosity) { + HWY_DYNAMIC_DISPATCH(Generate2B) + (*this, args, prompt, start_pos, pool, inner_pool, stream_token, accept_token, + gen, verbosity); +} +template <> +void GemmaImpl::Generate(const InferenceArgs& args, + const std::vector& prompt, + size_t start_pos, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, + const StreamFunc& stream_token, + const AcceptFunc& accept_token, + std::mt19937& gen, int verbosity) { + HWY_DYNAMIC_DISPATCH(Generate7B) + (*this, args, prompt, start_pos, pool, inner_pool, stream_token, accept_token, + gen, verbosity); +} + +Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { + const Model model_type = args.ModelType(); + model_training = args.ModelTraining(); + switch (model_type) { + case Model::GEMMA_2B: + impl_.reset(new GemmaImpl(args, pool)); + break; + case Model::GEMMA_7B: + impl_.reset(new GemmaImpl(args, pool)); + break; + default: + HWY_ABORT("Model type %d unknown.", static_cast(model_type)); + } +} +Gemma::~Gemma() = default; // after GemmaInterface is defined + +const sentencepiece::SentencePieceProcessor& Gemma::Tokenizer() const { + return impl_->Tokenizer(); +} + +void GenerateGemma(Gemma& gemma, const InferenceArgs& args, + const std::vector& prompt, size_t start_pos, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, + int verbosity) { + pool.SetWaitMode(hwy::PoolWaitMode::kSpin); + gemma.impl_->Generate(args, prompt, start_pos, pool, inner_pool, stream_token, + accept_token, gen, verbosity); + pool.SetWaitMode(hwy::PoolWaitMode::kBlock); +} + +} // namespace gcpp +#endif // HWY_ONCE diff --git a/gemma.h b/gemma.h new file mode 100644 index 0000000..9647a6c --- /dev/null +++ b/gemma.h @@ -0,0 +1,207 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_H_ + +#include +#include +#include +#include +#include +#include +#include + +// copybara:import_next_line:gemma_cpp +#include "configs.h" // kSeqLen +// copybara:import_next_line:gemma_cpp +#include "compression/compress.h" // SfpStream/NuqStream +// copybara:import_next_line:gemma_cpp +#include "util/args.h" // ArgsBase +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" // hwy::bfloat16_t +#include "hwy/contrib/thread_pool/thread_pool.h" +// copybara:import_next_line:sentencepiece +#include "src/sentencepiece_processor.h" + +namespace gcpp { + +// Allowable types for GEMMA_WEIGHT_T (can be specified at compilation time): +// float, hwy::bfloat16_t, SfpStream, NuqStream +#ifndef GEMMA_WEIGHT_T +#define GEMMA_WEIGHT_T SfpStream +#endif // !GEMMA_WEIGHT_T +using WeightT = GEMMA_WEIGHT_T; + +using EmbedderInputT = hwy::bfloat16_t; +constexpr size_t kPrefillBatchSize = 16; +constexpr bool kSystemPrompt = false; + +struct KVCache { + hwy::AlignedFreeUniquePtr + key_cache; // batch_size * seq_len * n_layers * n_kv_heads * dim_qkv + hwy::AlignedFreeUniquePtr + value_cache; // batch_size * seq_len * n_layers * n_kv_heads * dim_qkv +}; + +// Model variants: see configs.h for details. +enum class Model { GEMMA_2B, GEMMA_7B }; +enum class ModelTraining { GEMMA_IT, GEMMA_PT }; + +struct LoaderArgs : public ArgsBase { + LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + + static std::string ToLower(const std::string& text) { + std::string result = text; + std::transform(begin(result), end(result), begin(result), + [](unsigned char c) { return std::tolower(c); }); + return result; + } + + gcpp::Model ModelType() const { + const std::string model_type_lc = ToLower(model_type); + if (model_type_lc == "2b-pt" || model_type_lc == "2b-it") { + return gcpp::Model::GEMMA_2B; + } else { + return gcpp::Model::GEMMA_7B; + } + } + + gcpp::ModelTraining ModelTraining() const { + const std::string model_type_lc = ToLower(model_type); + if (model_type_lc == "7b-pt" || model_type_lc == "2b-pt") { + return gcpp::ModelTraining::GEMMA_PT; + } else { + return gcpp::ModelTraining::GEMMA_IT; + } + } + + // Returns error string or nullptr if OK. + const char* Validate() const { + const std::string model_type_lc = ToLower(model_type); + if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" && + model_type_lc != "2b-it" && model_type_lc != "7b-it") { + return "Model type must be 2b-pt, 7b-pt, 2b-it, or " + "7b-it."; + } + if (tokenizer.path.empty()) { + return "Missing --tokenizer flag, a file for the tokenizer is required."; + } + if (model_type.empty()) { + return "Missing --model flag, need to specify either 2b-pt, 7b-pt, " + "2b-it, or 7b-it."; + } + if (cache.path.empty()) { + return "Missing --compressed_weights flag, a file for the compressed " + "model."; + } + return nullptr; + } + + Path tokenizer; + Path model; // uncompressed weights OR + Path cache; // compressed weights + std::string model_type; + + template + void ForEach(const Visitor& visitor) { + visitor(tokenizer, "tokenizer", Path(), + "Path name of tokenizer model file. (required)"); + visitor( + cache, "compressed_weights", Path(), + "Path name of compressed weights file, regenerated from `--weights` " + "file if " + "the compressed weights file does not exist. (required)"); + visitor(model_type, "model", std::string(), + "Model type - can be 2b-it (2B parameters, instruction-tuned), " + "2b-pt (2B parameters, pretrained), 7b-it (7B parameters, " + "instruction-tuned), or 7b-pt (7B parameters, pretrained). " + "(required)"); + visitor(model, "weights", Path(), + "Path name of model weights (.sbs) file. Only required if " + "compressed_weights file is not present and needs to be " + "regenerated. Otherwise, not needed"); + } +}; + +struct GemmaInterface; + +struct Gemma { + Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); + ~Gemma(); // must be defined after GemmaInterface's dtor is defined. + + const sentencepiece::SentencePieceProcessor& Tokenizer() const; + + std::unique_ptr impl_; + gcpp::ModelTraining model_training; +}; + +// StreamFunc is called with (token, probability). For prompt tokens, +// probability is 0.0f. +using StreamFunc = std::function; +using AcceptFunc = std::function; + +struct InferenceArgs : public ArgsBase { + InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + + size_t max_tokens; + size_t max_generated_tokens; + + float temperature; + bool deterministic; + bool multiturn; + + // Returns error string or nullptr if OK. + const char* Validate() const { + if (max_tokens > gcpp::kSeqLen) { + return "max_tokens is larger than the maximum sequence length (see " + "configs.h)."; + } + if (max_generated_tokens > max_tokens) { + return "Maximum number of generated tokens is larger than the maximum " + "total tokens."; + } + return nullptr; + } + + template + void ForEach(const Visitor& visitor) { + visitor(max_tokens, "max_tokens", size_t{3072}, + "Maximum number of tokens in prompt + generation."); + visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, + "Maximum number of tokens to generate."); + + visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2); + visitor(deterministic, "deterministic", false, + "Make top-k sampling deterministic", 2); + visitor(multiturn, "multiturn", true, + "Multiturn mode (if 0, this clears the KV cache after every " + "interaction without quitting)", + 2); + } +}; + +void GenerateGemma(Gemma& gemma, const InferenceArgs& args, + const std::vector& prompt, size_t start_pos, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& g, + int verbosity); + +constexpr int EOS_ID = 1; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_ diff --git a/ops.h b/ops.h new file mode 100644 index 0000000..ac91cc5 --- /dev/null +++ b/ops.h @@ -0,0 +1,682 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Include guard for non-SIMD code. +#ifndef THIRD_PARTY_GEMMA_CPP_OPS_H_ +#define THIRD_PARTY_GEMMA_CPP_OPS_H_ +#include +#include + +#include +#include +#include + +// copybara:import_next_line:gemma_cpp +#include "compression/compress.h" +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/profiler.h" + +#endif // THIRD_PARTY_GEMMA_CPP_OPS_H_ + +// Include guard for (potentially) SIMD code. +#if defined(THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE) == defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE +#undef THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE +#else +#define THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE +#endif + +// copybara:import_next_line:gemma_cpp +#include "compression/compress-inl.h" +#include "hwy/cache_control.h" // FlushStream +#include "hwy/contrib/algo/transform-inl.h" +#include "hwy/contrib/dot/dot-inl.h" +#include "hwy/contrib/math/math-inl.h" +#include "hwy/contrib/matvec/matvec-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +HWY_INLINE constexpr size_t MaxCols() { + // Vec + mat rows should fit into 32 KiB L1. + return 2048; +} + +template +HWY_INLINE constexpr size_t RowsPerStrip() { + // Aim for 128 work items to reduce pool overhead. Must be at least one + // vector; prefer a power of two for faster division. + constexpr size_t kRowsPerStrip = + HWY_MAX(hn::ScalableTag().MaxLanes(), + 1ULL << hwy::FloorLog2(kOuter / 128)); + return kRowsPerStrip; +} + +// Simple version without tiling nor threading. +template +HWY_INLINE void MatVecLoop(const CompressedArray& mat, + const size_t mat_ofs, + const VecT* HWY_RESTRICT vec_aligned, + float* HWY_RESTRICT out) { + PROFILER_ZONE("MatVecLoop"); + const hn::ScalableTag df; + + for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) { + const size_t row_ofs = mat_ofs + idx_row * kInner; + out[idx_row] = Dot(df, mat, row_ofs, vec_aligned, kInner); + } +} + +// Simple version without tiling nor threading, but two offsets/outputs. +template +HWY_INLINE void TwoOfsMatVecLoop(const CompressedArray& mat, + const size_t mat_ofs0, const size_t mat_ofs1, + const VecT* HWY_RESTRICT vec_aligned, + float* HWY_RESTRICT out0, + float* HWY_RESTRICT out1) { + PROFILER_ZONE("MatVecLoop"); + const hn::ScalableTag df; + + for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) { + const size_t row_ofs0 = mat_ofs0 + (idx_row)*kInner; + const size_t row_ofs1 = mat_ofs1 + (idx_row)*kInner; + out0[idx_row] = Dot(df, mat, row_ofs0, vec_aligned, kInner); + out1[idx_row] = Dot(df, mat, row_ofs1, vec_aligned, kInner); + } +} + +namespace detail { + +// For each i = [0, num_rows), compute partial (length `num_cols`) dot product +// of row i with `vec_aligned` and add into `out[i]`. The upper-left coordinate +// of the tile is r0, c0. +template +HWY_INLINE void AccumulatePartialDotProducts( + DF df, const CompressedArray& mat, size_t mat_ofs, + size_t mat_stride, size_t r0, size_t c0, size_t num_rows, size_t num_cols, + const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) { + for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) { + const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride; + out[idx_row] += Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); + } +} + +// Same as above, but sets out[i] to the first partial dot product, which +// avoids having to zero-initialize and accumulate. +template +HWY_INLINE void SetFirstPartialDotProducts( + DF df, const CompressedArray& mat, size_t mat_ofs, + size_t mat_stride, size_t r0, size_t c0, size_t num_rows, size_t num_cols, + const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) { + for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) { + const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride; + out[idx_row] = Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); + } +} + +// Adds together partial dot products for all tiles with the same r0 (a +// horizontal strip of the entire matrix); the result is the full dot product +// for rows r in [r0, r0 + num_rows), which we store into in out[r - r0]. +template +HWY_INLINE void FullDotProductsForStrip( + DF df, const CompressedArray& mat, size_t mat_ofs, + size_t mat_stride, size_t r0, size_t num_rows, + const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) { + // Tall and skinny: set `out` to the single dot product. + if (mat_stride < MaxCols()) { + SetFirstPartialDotProducts(df, mat, mat_ofs, mat_stride, r0, 0, num_rows, + mat_stride, vec_aligned, out); + return; + } + + // We have at least MaxCols, so start by setting `out` to that: + SetFirstPartialDotProducts(df, mat, mat_ofs, mat_stride, r0, 0, num_rows, + MaxCols(), vec_aligned, out); + // For further multiples of MaxCols, accumulate. Remainders handled below. + size_t c0 = MaxCols(); + HWY_UNROLL(1) + for (; c0 <= mat_stride - MaxCols(); c0 += MaxCols()) { + AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows, + MaxCols(), vec_aligned, out); + } + + if (c0 < mat_stride) { // Final cols + AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows, + mat_stride - c0, vec_aligned, out); + } +} + +} // namespace detail + +// Stores dot products of rows with `vec_aligned` to a buffer, then stores them +// to `out`. +template +HWY_INLINE void MatVec(const CompressedArray& mat, + const size_t mat_ofs, + const VecT* HWY_RESTRICT const vec_aligned, + float* HWY_RESTRICT out, hwy::ThreadPool& pool) { + PROFILER_ZONE("MatVec"); + + const hn::ScalableTag df; + constexpr size_t kRowsPerStrip = RowsPerStrip(); + constexpr size_t kNumStrips = kOuter / kRowsPerStrip; + + // For each entire strip. + pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { + PROFILER_ZONE("MatVec.lambda"); + const size_t r0 = strip * kRowsPerStrip; + detail::FullDotProductsForStrip(df, mat, mat_ofs, kInner, r0, kRowsPerStrip, + vec_aligned, out + r0); + }); + + // Remaining rows + const size_t r0 = kNumStrips * kRowsPerStrip; + if (r0 < kOuter) { + PROFILER_ZONE("MatVec remainder"); + const size_t num_rows = kOuter - r0; + detail::FullDotProductsForStrip(df, mat, mat_ofs, kInner, r0, num_rows, + vec_aligned, out + r0); + } +} + +template +static HWY_INLINE hn::Vec Gelu(D d, hn::Vec v) { + const hn::Vec kMul = Set(d, 0.044715f); + const hn::Vec kSqrt2OverPi = hn::Set(d, 0.797884560804236f); + const hn::Vec kHalf = Set(d, 0.5f); + + // tanh approximation matches training. + const hn::Vec v3 = hn::Mul(hn::Mul(v, v), v); + const hn::Vec arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v)); + // 0.5 * (1 + tan) = MulAdd(0.5, tan, 0.5). + const hn::Vec cdf = hn::MulAdd(kHalf, hn::Tanh(d, arg), kHalf); + return Mul(v, cdf); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x, + size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + hn::Transform(D(), x, size, [](D d, hn::Vec v) { return Gelu(d, v); }); +} + +// out[i] = BF(mul[i] * Gelu(gelu_in[i])) +static HWY_NOINLINE HWY_MAYBE_UNUSED void GeluMulToBF16( + const float* HWY_RESTRICT gelu_in, const float* HWY_RESTRICT mul, + hwy::bfloat16_t* HWY_RESTRICT out, size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + const hn::ScalableTag df; + const hn::Repartition dbf; + const size_t NF = hn::Lanes(df); + using VF = hn::Vec; + + size_t i = 0; + if (size >= 2 * NF) { + for (; i < size - 2 * NF; i += 2 * NF) { + const VF mul0 = LoadU(df, mul + i); + const VF mul1 = LoadU(df, mul + i + NF); + const VF g0 = Mul(mul0, Gelu(df, LoadU(df, gelu_in + i))); + const VF g1 = Mul(mul1, Gelu(df, LoadU(df, gelu_in + i + NF))); + const hn::Vec bf = hn::OrderedDemote2To(dbf, g0, g1); + StoreU(bf, dbf, out + i); + } + } + if (i != size) { + const size_t remaining = size - i; + const VF mul0 = LoadN(df, mul + i, remaining); + const VF g0 = Mul(mul0, Gelu(df, LoadN(df, gelu_in + i, remaining))); + const hn::Half dbfh; + const hn::Vec bfh = hn::DemoteTo(dbfh, g0); + StoreN(bfh, dbfh, out + i, remaining); + } +} + +// Two matrices, same vector +// TODO(janwas): apply optimizations from MatVec/replace with above overload +template +HWY_NOINLINE void TwoMatVec(const CompressedArray& mat0, + const CompressedArray& mat1, + const size_t mat_ofs, + const VecT* HWY_RESTRICT vec_aligned, + float* HWY_RESTRICT out0, float* HWY_RESTRICT out1, + hwy::ThreadPool& pool) { + const hn::ScalableTag df; + const size_t NF = hn::Lanes(df); + + // Process multiple rows at a time so that we write multiples of a cache line + // to avoid false sharing (>= 64). + constexpr size_t kRowsPerStrip = 128 / sizeof(float); + const uint32_t num_strips = kOuter / kRowsPerStrip; + + // No remainder handling after ThreadPool. + static_assert(kOuter % kRowsPerStrip == 0, "Add remainder handling"); + + // Required for Stream loop, otherwise we might have partial vectors. + HWY_DASSERT(kRowsPerStrip >= NF); + pool.Run(0, num_strips, + [&](const uint32_t strip, size_t /*thread*/) HWY_ATTR { + // MSVC workaround: duplicate to ensure constexpr. + constexpr size_t kRowsPerStrip = 128 / sizeof(float); + // Software write-combining to avoid cache pollution from out. + // Although `out` may be used later, keeping it out of the cache + // now and avoiding RFOs is a consistent 5% overall win. + HWY_ALIGN float buf0[kRowsPerStrip]; + HWY_ALIGN float buf1[kRowsPerStrip]; + + // Only handle entire strips here because the Stream is not masked. + const size_t begin = strip * kRowsPerStrip; + for (size_t idx_row = 0; idx_row < kRowsPerStrip; ++idx_row) { + const size_t row_ofs = mat_ofs + (begin + idx_row) * kInner; + buf0[idx_row] = Dot(df, mat0, row_ofs, vec_aligned, kInner); + buf1[idx_row] = Dot(df, mat1, row_ofs, vec_aligned, kInner); + } + + HWY_UNROLL(4) + for (size_t i = 0; i != kRowsPerStrip; i += NF) { + hn::Stream(hn::Load(df, buf0 + i), df, out0 + begin + i); + } + HWY_UNROLL(4) + for (size_t i = 0; i != kRowsPerStrip; i += NF) { + hn::Stream(hn::Load(df, buf1 + i), df, out1 + begin + i); + } + }); + hwy::FlushStream(); +} + +// Baseline Naive MatMul +template +HWY_NOINLINE void MatMul(const CompressedArray& mat, + const size_t mat_ofs, const VecT* HWY_RESTRICT vec, + float* HWY_RESTRICT out, hwy::ThreadPool& pool) { + for (size_t i = 0; i < kBatchSize; ++i) { + MatVec( + mat, mat_ofs, vec + i * kInner, out + i * kOuter, pool); + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a, + const float* HWY_RESTRICT b, + size_t size) { + const hn::ScalableTag d; + HWY_DASSERT(size >= hn::Lanes(d)); + HWY_DASSERT(size % hn::Lanes(d) == 0); + constexpr int kAssumptions = + hn::Dot::kAtLeastOneVector | hn::Dot::kMultipleOfVector; + return hn::Dot::Compute(d, a, b, size); +} + +// = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT. +static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2( + const float* HWY_RESTRICT a, size_t size) { + float total = 0.f; + for (size_t i = 0; i < size; ++i) { + total += a[i] * a[i]; + } + return total; +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( + const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight, + float* HWY_RESTRICT out, size_t size) { + constexpr float eps = 1e-6f; + float ss = SquaredL2(x, size); + ss = 1.0f / sqrtf(ss / static_cast(size) + eps); + for (size_t j = 0; j < size; j++) { + // Note 1.0f centering here + out[j] = (1.0f + weight[j]) * (ss * x[j]); + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( + const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight, + float* HWY_RESTRICT out, size_t size) { + constexpr float eps = 1e-6f; + float ss = SquaredL2(x, size); + ss = 1.0f / sqrtf(ss / static_cast(size) + eps); + for (size_t j = 0; j < size; j++) { + // Note 1.0f centering here + out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]); + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( + const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) { + constexpr float eps = 1e-6f; + float ss = SquaredL2(inout, size); + ss = 1.0f / sqrtf(ss / static_cast(size) + eps); + for (size_t j = 0; j < size; j++) { + // Note 1.0f centering here + inout[j] = (1.0f + weight[j]) * (ss * inout[j]); + } +} + +// w=bf16 -> f +static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( + const hwy::bfloat16_t* HWY_RESTRICT weight, float* HWY_RESTRICT inout, + const size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + const hn::ScalableTag dbf; + const hn::Repartition df32; + using VF = hn::Vec; + const size_t N32 = hn::Lanes(df32); + + constexpr float eps = 1e-6f; + const float ss = SquaredL2(inout, size); + const VF vss = Set(df32, 1.0f / sqrtf(ss / static_cast(size) + eps)); + + HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); + for (size_t i = 0; i < size; i += 2 * N32) { + const hn::Vec w16 = hn::LoadU(dbf, weight + i); + const VF w0 = hn::PromoteLowerTo(df32, w16); + const VF w1 = hn::PromoteUpperTo(df32, w16); + const VF m0 = hn::Mul(vss, hn::LoadU(df32, inout + i)); + const VF m1 = hn::Mul(vss, hn::LoadU(df32, inout + i + N32)); + // (1+weight) * m = m + weight*m = one FMA. + hn::StoreU(hn::MulAdd(m0, w0, m0), df32, inout + i); + hn::StoreU(hn::MulAdd(m1, w1, m1), df32, inout + i + N32); + } +} + +// f, f -> bf +// TODO(janwas): consider generic function with adapter for loading bf16/f32 +static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( + const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight, + hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + const hn::ScalableTag dbf; + const hn::Repartition df32; + using VF = hn::Vec; + const size_t N32 = hn::Lanes(df32); + + constexpr float eps = 1e-6f; + const float ss = SquaredL2(x, size); + const VF vss = Set(df32, 1.0f / sqrtf(ss / static_cast(size) + eps)); + + HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); + for (size_t i = 0; i < size; i += 2 * N32) { + const VF w0 = hn::LoadU(df32, weight + i); + const VF w1 = hn::LoadU(df32, weight + i + N32); + const VF m0 = hn::Mul(vss, hn::LoadU(df32, x + i)); + const VF m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32)); + // (1+weight) * m = m + weight*m = one FMA. + const VF out0 = hn::MulAdd(m0, w0, m0); + const VF out1 = hn::MulAdd(m1, w1, m1); + hn::StoreU(hn::OrderedDemote2To(dbf, out0, out1), dbf, out + i); + } +} + +// x=f, w=bf16 -> bf16 to enable W16A16 MatVec. +static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( + const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight, + hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + const hn::ScalableTag dbf; + const hn::Repartition df32; + using VF = hn::Vec; + const size_t N32 = hn::Lanes(df32); + + constexpr float eps = 1e-6f; + const float ss = SquaredL2(x, size); + const VF vss = Set(df32, 1.0f / sqrtf(ss / size + eps)); + + HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); + for (size_t i = 0; i < size; i += 2 * N32) { + const hn::Vec w16 = hn::LoadU(dbf, weight + i); + const VF w0 = hn::PromoteLowerTo(df32, w16); + const VF w1 = hn::PromoteUpperTo(df32, w16); + const VF m0 = hn::Mul(vss, hn::LoadU(df32, x + i)); + const VF m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32)); + // (1+weight) * m = m + weight*m = one FMA. + const VF out0 = hn::MulAdd(m0, w0, m0); + const VF out1 = hn::MulAdd(m1, w1, m1); + hn::StoreU(hn::OrderedDemote2To(dbf, out0, out1), dbf, out + i); + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings( + float* HWY_RESTRICT x, size_t dim_model, size_t pos) { + const size_t num_timescales = dim_model / 2; + const float log_timescale_increment = + logf(10000.0f) / + (num_timescales != 0 + ? static_cast(static_cast(num_timescales) - 1) + : 1.0f); + for (size_t dim = 0; dim < num_timescales; ++dim) { + const float inv_timescale = + expf(static_cast(dim) * -log_timescale_increment); + x[dim] += sinf(pos * inv_timescale); + x[num_timescales + dim] += cosf(pos * inv_timescale); + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x, + size_t dim_qkv, size_t pos) { + HWY_DASSERT(dim_qkv % 2 == 0); + const size_t half_dim_qkv = dim_qkv / 2; + for (size_t dim = 0; dim < half_dim_qkv; ++dim) { + const float freq_exponents = static_cast(2 * static_cast(dim)) / + static_cast(dim_qkv); + // Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably. + const float timescale = powf(10000.0f, freq_exponents); + const float theta = pos / timescale; + const float cos_val = cosf(theta); + const float sin_val = sinf(theta); + const float x0 = x[dim]; + const float x1 = x[dim + half_dim_qkv]; + x[dim] = x0 * cos_val - x1 * sin_val; + x[dim + half_dim_qkv] = x0 * sin_val + x1 * cos_val; + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul, + float* HWY_RESTRICT x, + size_t dim_qkv, + size_t pos) { + HWY_DASSERT(dim_qkv % 2 == 0); + const size_t half_dim_qkv = dim_qkv / 2; + for (size_t dim = 0; dim < half_dim_qkv; ++dim) { + const float freq_exponents = static_cast(2 * static_cast(dim)) / + static_cast(dim_qkv); + // Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably. + const float timescale = powf(10000.0f, freq_exponents); + const float theta = pos / timescale; + const float cos_val = cosf(theta); + const float sin_val = sinf(theta); + const float x0 = x[dim]; + const float x1 = x[dim + half_dim_qkv]; + x[dim] = mul * (x0 * cos_val - x1 * sin_val); + x[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val); + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom( + const float* HWY_RESTRICT other, float* HWY_RESTRICT x, size_t size) { + for (size_t i = 0; i < size; ++i) { + x[i] += other[i]; + } +} + +static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other, + float* HWY_RESTRICT x, size_t size, + size_t max_pos) { + HWY_DASSERT(max_pos <= size); + for (size_t i = 0; i < max_pos; ++i) { + x[i] *= other[i]; + } +} + +static HWY_INLINE HWY_MAYBE_UNUSED void MulBy(const float* HWY_RESTRICT other, + float* HWY_RESTRICT x, + size_t size) { + return MulBy(other, x, size, size); +} + +static HWY_NOINLINE void MulByConst(float c, float* HWY_RESTRICT x, size_t size, + size_t max_pos) { + HWY_DASSERT(max_pos <= size); + for (size_t i = 0; i < max_pos; ++i) { + x[i] *= c; + } +} + +static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(float c, + float* HWY_RESTRICT x, + size_t size) { + MulByConst(c, x, size, size); +} + +static HWY_NOINLINE void MulByConstAndAdd(float c, const float* HWY_RESTRICT x, + float* HWY_RESTRICT out, size_t size, + size_t max_pos) { + for (size_t i = 0; i < max_pos; ++i) { + out[i] += x[i] * c; + } +} + +static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( + float c, const float* HWY_RESTRICT x, float* HWY_RESTRICT out, + size_t size) { + MulByConstAndAdd(c, x, out, size, size); +} + +static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, size_t size, + size_t mask_pos) { + HWY_DASSERT(size != 0); + HWY_DASSERT(mask_pos <= size); + + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + const D d; + using V = hn::Vec; + + // Find max so we can subtract it below. + const V vmin = hn::Set(d, hwy::LowestValue()); + V max = vmin; + hn::Foreach(d, x, mask_pos, vmin, + [&max](D d, V v) { max = hn::Max(max, v); }); + max = hn::MaxOfLanes(d, max); // broadcast + + // Subtract max (avoid precision loss for large exponents) and exponentiate. + V sum = hn::Zero(d); + hn::Transform(d, x, mask_pos, [&sum, max](D d, V v) { + const V out = hn::Exp(d, hn::Sub(v, max)); + sum = hn::Add(sum, out); + return out; + }); + + // Normalize to probability distribution + const float mul = 1.0f / hn::ReduceSum(d, sum); + MulByConst(mul, x, size, mask_pos); +} + +static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x, + size_t size) { + Softmax(x, size, size); +} + +static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, + size_t size, size_t max_pos) { + HWY_DASSERT(max_pos <= size); + + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + const D d; + using V = hn::Vec; + + const V inv_cap = hn::Set(d, 1.0f / cap); + const V vcap = hn::Set(d, cap); + + hn::Transform(d, x, size, [vcap, inv_cap](D d, hn::Vec v) { + return hn::Mul(vcap, hn::Tanh(d, hn::Mul(inv_cap, v))); + }); +} + +static HWY_INLINE HWY_MAYBE_UNUSED void LogitsSoftCap(const float cap, + float* HWY_RESTRICT x, + size_t size) { + LogitsSoftCap(cap, x, size, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED size_t +SampleArgmax(const float* probabilities, size_t vocab_size) { + size_t max_index = 0; + float max_prob = probabilities[0]; + for (size_t i = 1; i < vocab_size; ++i) { + if (probabilities[i] > max_prob) { + max_index = i; + max_prob = probabilities[i]; + } + } + return max_index; +} + +template +static HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution +create_distribution(std::array& top_k, float temperature) { + // re-normalize distribution + for (size_t i = 0; i < k; ++i) { + top_k[i] = exp(log(top_k[i]) / temperature); + } + float denominator = 0.0f; + for (size_t i = 0; i < k; ++i) { + denominator += top_k[i]; + } + denominator = 1.0f / denominator; + MulByConst(denominator, top_k.data(), k); + return std::discrete_distribution(std::begin(top_k), std::end(top_k)); +} + +template +static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( + const float* HWY_RESTRICT probabilities, size_t vocab_size, + std::mt19937& gen, float temperature, TAcceptToken& accept_token) { + static_assert(k != 0, ""); + // TODO(austinvhuang): Optimize this implementation. + std::array top_k{}; // sorted from highest [0], to lowest [k-1] + std::array indices{}; + for (size_t i = 0; i < vocab_size; ++i) { + if (probabilities[i] < top_k[k - 1] && accept_token(static_cast(i))) { + continue; + } + for (size_t j = 0; j < k; ++j) { + if (probabilities[i] > top_k[j] && accept_token(static_cast(i))) { + // shift elements by 1, insert the new value, move on to next value + for (size_t idx = k - 1; idx > j; --idx) { + top_k[idx] = top_k[idx - 1]; + indices[idx] = indices[idx - 1]; + } + top_k[j] = probabilities[i]; + indices[j] = static_cast(i); + break; + } + } + } + return indices[create_distribution(top_k, temperature)(gen)]; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#endif // NOLINT diff --git a/run.cc b/run.cc new file mode 100644 index 0000000..87d8445 --- /dev/null +++ b/run.cc @@ -0,0 +1,261 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Command line text interface to gemma. + +#include +#include +#include +#include +#include // NOLINT +#include + +// copybara:import_next_line:gemma_cpp +#include "compression/compress.h" +// copybara:import_next_line:gemma_cpp +#include "gemma.h" // Gemma +// copybara:import_next_line:gemma_cpp +#include "util/app.h" +// copybara:import_next_line:gemma_cpp +#include "util/args.h" // HasHelp +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/highway.h" +#include "hwy/per_target.h" +#include "hwy/profiler.h" +#include "hwy/timer.h" + +namespace gcpp { + +void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, + gcpp::AppArgs& app) { + fprintf(stderr, + "\ngemma.cpp\n---------\n\nTo run gemma.cpp, you need to " + "specify 3 required model loading arguments: --tokenizer, " + "--compressed_weights, " + "and --model.\n\nModel Loading Arguments\n\n"); + loader.Help(); + fprintf(stderr, "\nInference Arguments\n\n"); + inference.Help(); + fprintf(stderr, "\nApplication Arguments\n\n"); + app.Help(); + fprintf(stderr, "\n\n"); +} + +void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { + loader.Print(app.verbosity); + inference.Print(app.verbosity); + app.Print(app.verbosity); + + if (app.verbosity >= 2) { + time_t now = time(nullptr); + char* dt = ctime(&now); // NOLINT + std::cout << "Date & Time : " << dt + << "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize + << "\n" + << "Hardware concurrency : " + << std::thread::hardware_concurrency() << std::endl + << "Instruction set : " + << hwy::TargetName(hwy::DispatchedTarget()) << " (" + << hwy::VectorBytes() * 8 << " bits)" << "\n" + << "Weight Type : " + << gcpp::TypeName(gcpp::WeightT()) << "\n" + << "EmbedderInput Type : " + << gcpp::TypeName(gcpp::EmbedderInputT()) << "\n"; + } +} + +void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const InferenceArgs& args, + int verbosity, const gcpp::AcceptFunc& accept_token) { + PROFILER_ZONE("Gen.misc"); + int abs_pos = 0; // absolute token index over all turns + int current_pos = 0; // token index within the current turn + int prompt_size{}; + + std::mt19937 gen; + if (args.deterministic) { + gen.seed(42); + } else { + std::random_device rd; + gen.seed(rd()); + } + + // callback function invoked for each generated token. + auto stream_token = [&abs_pos, ¤t_pos, &args, &gen, &prompt_size, + tokenizer = &model.Tokenizer(), + verbosity](int token, float) { + ++abs_pos; + ++current_pos; + if (current_pos < prompt_size) { + std::cerr << "." << std::flush; + } else if (token == gcpp::EOS_ID) { + if (!args.multiturn) { + abs_pos = 0; + if (args.deterministic) { + gen.seed(42); + } + } + if (verbosity >= 2) { + std::cout << "\n[ End ]" << std::endl; + } + } else { + std::string token_text; + HWY_ASSERT(tokenizer->Decode(std::vector{token}, &token_text).ok()); + // +1 since position is incremented above + if (current_pos == prompt_size + 1) { + // first token of response + token_text.erase(0, token_text.find_first_not_of(" \t\n")); + if (verbosity >= 1) { + std::cout << std::endl << std::endl; + } + } + // TODO(austinvhuang): is explicit space necessary? + std::cout << token_text << std::flush; + } + return true; + }; + + while (abs_pos < args.max_tokens) { + std::string prompt_string; + std::vector prompt; + current_pos = 0; + { + PROFILER_ZONE("Gen.input"); + if (verbosity >= 1) { + std::cout << "> " << std::flush; + } + std::getline(std::cin, prompt_string); + } + + if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") { + return; + } + + if (model.model_training == ModelTraining::GEMMA_IT) { + // For instruction-tuned models: add control tokens. + prompt_string = "user\n" + prompt_string + + "\nmodel\n"; + if (abs_pos > 0) { + // Prepend "" token if this is a multi-turn dialogue + // continuation. + prompt_string = "\n" + prompt_string; + } + } + + HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok()); + + // For both pre-trained and instruction-tuned models: prepend "" token + // if needed. + if (abs_pos == 0) { + prompt.insert(prompt.begin(), 2); + } + + prompt_size = prompt.size(); + + std::cerr << std::endl << "[ Reading prompt ] " << std::flush; + + const double time_start = hwy::platform::Now(); + GenerateGemma(model, args, prompt, abs_pos, pool, inner_pool, stream_token, + accept_token, gen, verbosity); + const double time_end = hwy::platform::Now(); + const double tok_sec = current_pos / (time_end - time_start); + if (verbosity >= 2) { + std::cout << current_pos << " tokens (" << abs_pos << " total tokens)" + << std::endl + << tok_sec << " tokens / sec" << std::endl; + } + std::cout << std::endl << std::endl; + } + std::cout + << "max_tokens (" << args.max_tokens + << ") exceeded. Use a larger value if desired using the --max_tokens " + << "command line flag.\n"; +} + +void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { + PROFILER_ZONE("Run.misc"); + + hwy::ThreadPool inner_pool(0); + hwy::ThreadPool pool(app.num_threads); + // For many-core, pinning threads to cores helps. + if (app.num_threads > 10) { + PinThreadToCore(app.num_threads - 1); // Main thread + + pool.Run(0, pool.NumThreads(), + [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); + } + + gcpp::Gemma model(loader, pool); + + if (const char* error = inference.Validate()) { + ShowHelp(loader, inference, app); + HWY_ABORT("\nInvalid args: %s", error); + } + + if (app.verbosity >= 1) { + static const std::string banner_ascii_art = + " __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __\n" + " / _` |/ _ \\ '_ ` _ \\| '_ ` _ \\ / _` | / __| '_ \\| '_ \\\n" + "| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |\n" + " \\__, |\\___|_| |_| |_|_| |_| |_|\\__,_(_)___| .__/| .__/\n" + " __/ | | | | |\n" + " |___/ |_| |_|"; + + const std::string instructions = + "*Usage*\n" + " Enter an instruction and press enter (%Q quits).\n\n" + "*Examples*\n" + " - Write an email to grandma thanking her for the cookies.\n" + " - What are some historical attractions to visit around " + "Massachusetts?\n" + " - Compute the nth fibonacci number in javascript.\n" + " - Write a standup comedy bit about GPU programming.\n"; + + std::cout << "\033[2J\033[1;1H" // clear screen + << banner_ascii_art << "\n\n"; + ShowConfig(loader, inference, app); + std::cout << "\n" << instructions << "\n"; + } + + ReplGemma(model, pool, inner_pool, inference, app.verbosity, + /*accept_token=*/[](int) { return true; }); +} + +} // namespace gcpp + +int main(int argc, char** argv) { + { + PROFILER_ZONE("Startup.misc"); + + gcpp::LoaderArgs loader(argc, argv); + gcpp::InferenceArgs inference(argc, argv); + gcpp::AppArgs app(argc, argv); + + if (gcpp::HasHelp(argc, argv)) { + ShowHelp(loader, inference, app); + return 0; + } + + if (const char* error = loader.Validate()) { + ShowHelp(loader, inference, app); + HWY_ABORT("\nInvalid args: %s", error); + } + + gcpp::Run(loader, inference, app); + } + PROFILER_PRINT_RESULTS(); // Must call outside the zone above. + return 0; +} diff --git a/util/app.h b/util/app.h new file mode 100644 index 0000000..966fa41 --- /dev/null +++ b/util/app.h @@ -0,0 +1,85 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Shared between various frontends. + +#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ +#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ + +#include +#include + +#include // std::clamp +#include // NOLINT> + +// copybara:import_next_line:gemma_cpp +#include "util/args.h" +#include "hwy/base.h" // HWY_ASSERT + +namespace gcpp { + +static inline void PinThreadToCore(size_t cpu_index) { +#if HWY_OS_LINUX + // Forces the thread to run on the logical processor with the same number. + cpu_set_t cset; // bit array + CPU_ZERO(&cset); // clear all + CPU_SET(cpu_index, &cset); // set bit indicating which processor to run on. + HWY_ASSERT(0 == sched_setaffinity(0, sizeof(cset), &cset)); +#else + (void)cpu_index; +#endif +} + +class AppArgs : public ArgsBase { + static constexpr size_t kDefaultNumThreads = ~size_t{0}; + + void ChooseNumThreads() { + if (num_threads == kDefaultNumThreads) { + // This is a rough heuristic, replace with something better in the future. + num_threads = static_cast(std::clamp( + static_cast(std::thread::hardware_concurrency()) - 2, 1, 18)); + } + } + + public: + AppArgs(int argc, char* argv[]) { + InitAndParse(argc, argv); + ChooseNumThreads(); + } + + Path log; // output + int verbosity; + size_t num_threads; + + template + void ForEach(const Visitor& visitor) { + visitor(log, "log", Path{"/tmp/log.txt"}, "Logging file", 2); + visitor(verbosity, "verbosity", 1, + "Show verbose developer information\n 0 = only print generation " + "output\n 1 = standard user-facing terminal ui\n 2 = show " + "developer/debug info).\n Default = 1.", + 2); + visitor(num_threads, "num_threads", + kDefaultNumThreads, // see ChooseNumThreads + "Number of threads to use. Default value is set based on an " + "estimate of " + "how many concurrent threads are supported.", + 2); + } +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ diff --git a/util/args.h b/util/args.h new file mode 100644 index 0000000..ce03ef2 --- /dev/null +++ b/util/args.h @@ -0,0 +1,223 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Command line arguments. + +#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ARGS_H_ +#define THIRD_PARTY_GEMMA_CPP_UTIL_ARGS_H_ + +#include + +#include // std::transform +#include + +#include "hwy/base.h" // HWY_ABORT + +namespace gcpp { + +// Wrapper for strings representing a path name. Differentiates vs. arbitrary +// strings and supports shortening for display purposes. +struct Path { + Path& operator=(const char* other) { + path = other; + return *this; + } + + std::string Shortened() const { + constexpr size_t max_len = 48; + constexpr size_t cut_point = max_len / 2 - 5; + if (path.size() > max_len) { + return std::string(begin(path), begin(path) + cut_point) + " ... " + + std::string(end(path) - cut_point, end(path)); + } + if (path.empty()) return "[no path specified]"; + return path; + } + + std::string path; +}; + +// Args is a class that provides a ForEach member function which visits each of +// its member variables. ArgsBase provides functions called by Args to +// initialize values to their defaults (passed as an argument to the visitor), +// print and parse, without having to repeat the args for each usage. +template +class ArgsBase { + struct InitVisitor { + template + void operator()(T& t, const char* /*name*/, const T& init, + const char* /*help*/, int /*print_verbosity*/ = 0) const { + t = init; + } + }; + + struct HelpVisitor { + template + void operator()(T&, const char* name, T /*init*/, const char* help, + int /*print_verbosity*/ = 0) const { + fprintf(stderr, " --%s : %s\n", name, help); + } + }; + + class PrintVisitor { + public: + explicit PrintVisitor(int verbosity) : verbosity_(verbosity) {} + + template + void operator()(const T& t, const char* name, const T& /*init*/, + const char* /*help*/, int print_verbosity = 0) const { + if (verbosity_ >= print_verbosity) { + fprintf(stderr, "%-30s: %s\n", name, std::to_string(t).c_str()); + } + } + + void operator()(const std::string& t, const char* name, + const std::string& /*init*/, const char* /*help*/, + int print_verbosity = 0) const { + if (verbosity_ >= print_verbosity) { + fprintf(stderr, "%-30s: %s\n", name, t.c_str()); + } + } + void operator()(const Path& t, const char* name, const Path& /*init*/, + const char* /*help*/, int print_verbosity = 0) const { + if (verbosity_ >= print_verbosity) { + fprintf(stderr, "%-30s: %s\n", name, t.Shortened().c_str()); + } + } + + private: + int verbosity_; + }; + + // Supported types: integer, float, std::string, bool, Path. This is O(N^2): + // for each arg, we search through argv. If there are more than a dozen args, + // consider adding a hash-map to speed this up. + class ParseVisitor { + public: + ParseVisitor(int argc, char* argv[]) : argc_(argc), argv_(argv) {} + + template + void operator()(T& t, const char* name, const T& /*init*/, + const char* /*help*/, int /*print_verbosity*/ = 0) const { + const std::string prefixed = std::string("--") + name; + for (int i = 1; i < argc_; ++i) { + if (std::string(argv_[i]) == prefixed) { + if (i + 1 >= argc_) { + HWY_ABORT("Missing value for %s\n", name); + } + if (!SetValue(argv_[i + 1], t)) { + HWY_ABORT("Invalid value for %s, got %s\n", name, argv_[i + 1]); + } + return; + } + } + } + + private: + // Returns false if an invalid value is detected. + template + static bool SetValue(const char* string, T& t) { + t = std::stoi(string); + return true; + } + + template + static bool SetValue(const char* string, T& t) { + t = std::stof(string); + return true; + } + + static bool SetValue(const char* string, std::string& t) { + t = string; + return true; + } + static bool SetValue(const char* string, Path& t) { + t.path = string; + return true; + } + + static bool SetValue(const char* string, bool& t) { + std::string value(string); + // Lower-case. Arg names are expected to be ASCII-only. + std::transform(value.begin(), value.end(), value.begin(), [](char c) { + return 'A' <= c && c <= 'Z' ? c - ('Z' - 'z') : c; + }); + + if (value == "true" || value == "on" || value == "1") { + t = true; + return true; + } else if (value == "false" || value == "off" || value == "0") { + t = false; + return true; + } else { + return false; + } + } + + int argc_; + char** argv_; + }; // ParseVisitor + + template + void ForEach(Visitor& visitor) { + static_cast(this)->ForEach(visitor); + } + + public: + // WARNING: cannot call from ctor because the derived ctor has not yet run. + void Init() { + InitVisitor visitor; + ForEach(visitor); + } + + void Help() { + HelpVisitor visitor; + ForEach(visitor); + } + + void Print(int verbosity = 0) { + PrintVisitor visitor(verbosity); + ForEach(visitor); + } + + void Parse(int argc, char* argv[]) { + ParseVisitor visitor(argc, argv); + ForEach(visitor); + } + + // For convenience, enables single-line constructor. + void InitAndParse(int argc, char* argv[]) { + Init(); + Parse(argc, argv); + } +}; + +static bool HasHelp(int argc, char* argv[]) { + // TODO(austinvhuang): handle case insensitivity + if (argc == 1) { + // no arguments - print help + return true; + } + for (int i = 1; i < argc; ++i) { + if (std::string(argv[i]) == "--help") { + return true; + } + } + return false; +} + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_UTIL_ARGS_H_