Merge branch 'dev' into gemma.cpp-windows-build-fix

This commit is contained in:
Hitesh K V 2025-10-16 20:18:09 +05:30 committed by GitHub
commit c55120fc6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
98 changed files with 8324 additions and 6281 deletions

View File

@ -46,6 +46,7 @@ jobs:
-D CMAKE_BUILD_TYPE=${{ matrix.build_type }}
-D CMAKE_C_COMPILER_LAUNCHER=ccache
-D CMAKE_CXX_COMPILER_LAUNCHER=ccache
-DCMAKE_POLICY_VERSION_MINIMUM=3.5
- name: Build
run: cmake --build ${{ github.workspace }}/build --preset ${{ matrix.preset }} --config ${{ matrix.build_type }} -j 4

250
API_SERVER_README.md Normal file
View File

@ -0,0 +1,250 @@
# Gemma.cpp API Server
This is an HTTP API server for gemma.cpp that implements the Google API protocol, allowing you to interact with Gemma models through REST API endpoints compatible with the Google API format.
## Features
- **API-compatible**: Implements Google API endpoints
- **Unified client/server**: Single codebase supports both local and public API modes
- **Text generation**: Support for `generateContent` endpoint
- **Streaming support**: Server-Sent Events (SSE) for `streamGenerateContent`
- **Model management**: Support for `/v1beta/models` endpoint
- **Session management**: Maintains conversation context with KV cache
- **JSON responses**: All responses in Google API format
- **Error handling**: Proper HTTP status codes and error messages
## Building
The API server is built alongside the main gemma.cpp project:
```bash
# Configure the build
cmake -B build -DCMAKE_BUILD_TYPE=Release
# Build the API server and client
cmake --build build --target gemma_api_server gemma_api_client -j 8
```
The binaries will be created at:
- `build/gemma_api_server` - Local API server
- `build/gemma_api_client` - Unified client for both local and public APIs
## Usage
### Starting the Local API Server
```bash
./build/gemma_api_server \
--tokenizer path/to/tokenizer.spm \
--weights path/to/model.sbs \
--port 8080
```
**Required arguments:**
- `--tokenizer`: Path to the tokenizer file (`.spm`)
- `--weights`: Path to the model weights file (`.sbs`)
**Optional arguments:**
- `--port`: Port to listen on (default: 8080)
- `--model`: Model name for API endpoints (default: gemma3-4b)
### Using the Unified Client
#### With Local Server
```bash
# Interactive chat with local server
./build/gemma_api_client --interactive 1 --host localhost --port 8080
# Single prompt with local server
./build/gemma_api_client --prompt "Hello, how are you?"
```
#### With Public Google API
```bash
# Set API key and use public API
export GOOGLE_API_KEY="your-api-key-here"
./build/gemma_api_client --interactive 1
# Or pass API key directly
./build/gemma_api_client --api_key "your-api-key" --interactive 1
```
## API Endpoints
The server implements Google API endpoints:
### 1. Generate Content - `POST /v1beta/models/gemma3-4b:generateContent`
Generate a response for given content (non-streaming).
**Request:**
```json
{
"contents": [
{
"parts": [
{"text": "Why is the sky blue?"}
]
}
],
"generationConfig": {
"temperature": 0.9,
"topK": 1,
"maxOutputTokens": 1024
}
}
```
**Response:**
```json
{
"candidates": [
{
"content": {
"parts": [
{"text": "The sky appears blue because..."}
],
"role": "model"
},
"finishReason": "STOP",
"index": 0
}
],
"promptFeedback": {
"safetyRatings": []
},
"usageMetadata": {
"promptTokenCount": 5,
"candidatesTokenCount": 25,
"totalTokenCount": 30
}
}
```
### 2. Stream Generate Content - `POST /v1beta/models/gemma3-4b:streamGenerateContent`
Generate a response with Server-Sent Events (SSE) streaming.
**Request:** Same as above
**Response:** Stream of SSE events:
```
data: {"candidates":[{"content":{"parts":[{"text":"The"}],"role":"model"},"index":0}],"promptFeedback":{"safetyRatings":[]}}
data: {"candidates":[{"content":{"parts":[{"text":" sky"}],"role":"model"},"index":0}],"promptFeedback":{"safetyRatings":[]}}
data: [DONE]
```
### 3. List Models - `GET /v1beta/models`
List available models.
**Response:**
```json
{
"models": [
{
"name": "models/gemma3-4b",
"displayName": "Gemma3 4B",
"description": "Gemma3 4B model running locally"
}
]
}
```
## Example Usage
### Using curl with Local Server
```bash
# Generate content (non-streaming)
curl -X POST http://localhost:8080/v1beta/models/gemma3-4b:generateContent \
-H "Content-Type: application/json" \
-d '{
"contents": [{"parts": [{"text": "Hello, how are you?"}]}],
"generationConfig": {"temperature": 0.9, "topK": 1, "maxOutputTokens": 1024}
}'
# Stream generate content (SSE)
curl -X POST http://localhost:8080/v1beta/models/gemma3-4b:streamGenerateContent \
-H "Content-Type: application/json" \
-d '{
"contents": [{"parts": [{"text": "Tell me a story"}]}],
"generationConfig": {"temperature": 0.9, "topK": 1, "maxOutputTokens": 1024}
}'
# List models
curl http://localhost:8080/v1beta/models
```
### Multi-turn Conversation with curl
```bash
# First message
curl -X POST http://localhost:8080/v1beta/models/gemma3-4b:generateContent \
-H "Content-Type: application/json" \
-d '{
"contents": [
{"parts": [{"text": "Hi, my name is Alice"}]}
]
}'
# Follow-up message with conversation history
curl -X POST http://localhost:8080/v1beta/models/gemma3-4b:generateContent \
-H "Content-Type: application/json" \
-d '{
"contents": [
{"parts": [{"text": "Hi, my name is Alice"}]},
{"parts": [{"text": "Hello Alice! Nice to meet you."}]},
{"parts": [{"text": "What is my name?"}]}
]
}'
```
### Using Python
```python
import requests
# Generate content
response = requests.post('http://localhost:8080/v1beta/models/gemma3-4b:generateContent',
json={
'contents': [{'parts': [{'text': 'Explain quantum computing in simple terms'}]}],
'generationConfig': {
'temperature': 0.9,
'topK': 1,
'maxOutputTokens': 1024
}
}
)
result = response.json()
if 'candidates' in result and result['candidates']:
text = result['candidates'][0]['content']['parts'][0]['text']
print(text)
```
## Configuration Options
The Google API supports various generation configuration options:
- **temperature**: Controls randomness (0.0 to 2.0, default: 1.0)
- **topK**: Top-K sampling parameter (default: 1)
- **maxOutputTokens**: Maximum number of tokens to generate (default: 8192)
## Key Features
- **Unified Implementation**: Same codebase handles both local server and public API
- **Session Management**: Maintains conversation context using KV cache
- **Streaming Support**: Real-time token generation via Server-Sent Events
- **Error Handling**: Comprehensive error responses and HTTP status codes
- **Memory Efficient**: Optimized token processing and caching
## Compatibility
This implementation is compatible with:
- Google API format and endpoints
- Standard HTTP clients (curl, browsers, Python requests, etc.)
- Server-Sent Events (SSE) for streaming responses
- JSON request/response format

View File

@ -29,9 +29,24 @@ exports_files([
cc_library(
name = "basics",
srcs = ["util/basics.cc"],
hdrs = ["util/basics.h"],
deps = [
"@highway//:hwy",
"@highway//:timer",
"@highway//hwy/contrib/sort:vqsort",
],
)
cc_test(
name = "basics_test",
srcs = ["util/basics_test.cc"],
deps = [
":basics",
"@googletest//:gtest_main", # buildcleaner: keep
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:timer",
],
)
@ -96,9 +111,41 @@ cc_library(
":basics",
":threading",
":topology",
":zones",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:profiler",
"@highway//:thread_pool",
],
)
cc_library(
name = "zones",
srcs = ["util/zones.cc"],
hdrs = ["util/zones.h"],
deps = [
"@highway//:profiler",
],
)
cc_test(
name = "flash_attention_test",
srcs = ["gemma/flash_attention_test.cc"],
deps = [
":configs",
":gemma_args",
":gemma_lib",
":kv_cache",
":mat",
":matmul",
":ops",
":threading_context",
":weights",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"//compression:types",
"@highway//:hwy",
"@highway//:hwy_test_util",
],
)
@ -223,10 +270,10 @@ cc_library(
":configs",
":gemma_args",
":mat",
":matmul",
":model_store",
":tensor_info",
":threading_context",
":zones",
"//compression:compress",
"//io:blob_store",
"@highway//:hwy",
@ -256,15 +303,36 @@ test_suite(
)
cc_library(
name = "matmul",
name = "matmul_env",
srcs = ["ops/matmul.cc"],
hdrs = ["ops/matmul.h"],
deps = [
":allocator",
":basics",
":configs",
":mat",
":threading",
":threading_context",
"@highway//:bit_set",
"@highway//:hwy",
"@highway//:nanobenchmark",
"@highway//:profiler",
],
)
cc_library(
name = "matmul",
# allow depending only on this target, without also matmul_env.
hdrs = ["ops/matmul.h"],
textual_hdrs = ["ops/matmul-inl.h"],
deps = [
":allocator",
":basics",
":mat",
":matmul_env",
":threading",
":threading_context",
":zones",
"//compression:compress",
"@highway//:bit_set",
"@highway//:hwy",
@ -281,6 +349,7 @@ cc_library(
"ops/matmul_static_f32.cc",
"ops/matmul_static_nuq.cc",
"ops/matmul_static_sfp.cc",
"ops/matmul_static_i8.cc",
],
hdrs = [
"ops/matmul_static.h",
@ -294,7 +363,9 @@ cc_library(
":basics",
":mat",
":matmul",
":matmul_env",
":threading_context",
":zones",
"//compression:compress",
"//compression:types",
"@highway//:hwy",
@ -310,21 +381,21 @@ cc_library(
"ops/dot-inl.h",
"ops/sum-inl.h",
"ops/fp_arith-inl.h",
"ops/matvec-inl.h",
"ops/ops-inl.h",
],
deps = [
":allocator",
":basics",
":mat",
":matmul",
":matmul_env", # MMOptions
":matmul_static",
":threading_context",
":zones",
"//compression:compress",
"@highway//:algo",
"@highway//:bit_set",
"@highway//:hwy",
"@highway//:math",
"@highway//:matvec",
"@highway//:profiler",
"@highway//:thread_pool",
"@highway//hwy/contrib/sort:vqsort",
@ -375,7 +446,9 @@ cc_test(
":ops",
":test_util",
":threading_context",
":zones",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:test_util",
"//compression:types",
"@highway//:hwy",
"@highway//:hwy_test_util",
@ -384,27 +457,6 @@ cc_test(
],
)
cc_test(
name = "gemma_matvec_test",
size = "small",
timeout = "long",
srcs = ["ops/gemma_matvec_test.cc"],
linkstatic = True,
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["ops_tests"],
deps = [
":mat",
":ops",
":threading_context",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:thread_pool",
],
)
cc_test(
name = "matmul_test",
size = "small",
@ -417,7 +469,7 @@ cc_test(
deps = [
":basics",
":mat",
":matmul",
":matmul_env",
":matmul_static",
":ops",
":threading_context",
@ -445,7 +497,8 @@ cc_test(
],
deps = [
":basics",
":matmul",
":matmul_env",
":matmul_static",
":threading_context",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
@ -478,7 +531,6 @@ cc_library(
":args",
":basics",
":mat",
":matmul",
"//io",
"@highway//:hwy",
"@highway//:profiler",
@ -489,15 +541,15 @@ cc_library(
name = "gemma_lib",
srcs = [
"gemma/attention.cc",
"gemma/flash_attention.cc",
"gemma/gemma.cc",
"gemma/griffin.cc",
"gemma/vit.cc",
],
hdrs = [
"gemma/activations.h",
"gemma/attention.h",
"gemma/flash_attention.h",
"gemma/gemma.h",
"gemma/griffin.h",
"gemma/vit.h",
],
exec_properties = {
@ -508,27 +560,30 @@ cc_library(
"gemma/gemma-inl.h",
],
deps = [
":allocator",
":basics",
":configs",
":gemma_args",
":kv_cache",
":mat",
":matmul",
":matmul_env",
":model_store",
":ops",
":threading",
":threading_context",
":weights",
":zones",
"//compression:compress",
"//compression:types",
"//io:blob_store",
"//io",
"//io:blob_store",
"//paligemma:image",
"@highway//:bit_set",
"@highway//:hwy",
"@highway//:nanobenchmark", # timer
"@highway//:profiler",
"@highway//:thread_pool",
"@highway//hwy/contrib/sort:vqsort",
],
)
@ -553,7 +608,7 @@ cc_library(
":cross_entropy",
":gemma_args",
":gemma_lib",
":matmul",
":matmul_env",
":ops",
":threading_context",
":tokenizer",
@ -584,7 +639,7 @@ cc_library(
":gemma_args",
":gemma_lib",
":kv_cache",
":matmul",
":matmul_env",
":threading",
":threading_context",
":tokenizer",
@ -645,7 +700,7 @@ cc_binary(
":benchmark_helper",
":gemma_args",
":gemma_lib",
":matmul",
":matmul_env",
":tokenizer",
"//compression:types",
"//paligemma:image",

View File

@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 1d16731233de45a365b43867f27d0a5f73925300 EXCLUDE_FROM_ALL)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(highway)
## Note: absl needs to be installed by sentencepiece. This will only happen if
@ -33,6 +33,28 @@ FetchContent_MakeAvailable(sentencepiece)
FetchContent_Declare(json GIT_REPOSITORY https://github.com/nlohmann/json.git GIT_TAG 9cca280a4d0ccf0c08f47a99aa71d1b0e52f8d03 EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(json)
# Find OpenSSL for HTTPS support
find_package(OpenSSL)
if(OPENSSL_FOUND)
message(STATUS "OpenSSL found, enabling HTTPS support")
set(HTTPLIB_USE_OPENSSL_IF_AVAILABLE ON)
else()
message(STATUS "OpenSSL not found, HTTPS support disabled")
set(HTTPLIB_USE_OPENSSL_IF_AVAILABLE OFF)
endif()
# HTTP library for API server
FetchContent_Declare(httplib GIT_REPOSITORY https://github.com/yhirose/cpp-httplib.git GIT_TAG v0.18.1 EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(httplib)
# Create interface target for httplib (header-only library)
add_library(httplib_interface INTERFACE)
target_include_directories(httplib_interface INTERFACE ${httplib_SOURCE_DIR})
if(OPENSSL_FOUND)
target_link_libraries(httplib_interface INTERFACE OpenSSL::SSL OpenSSL::Crypto)
target_compile_definitions(httplib_interface INTERFACE CPPHTTPLIB_OPENSSL_SUPPORT)
endif()
set(BENCHMARK_ENABLE_TESTING OFF)
set(BENCHMARK_ENABLE_GTEST_TESTS OFF)
@ -46,6 +68,7 @@ set(SOURCES
compression/compress.h
compression/nuq-inl.h
compression/sfp-inl.h
compression/int-inl.h
compression/types.h
compression/test_util-inl.h
evals/benchmark_helper.cc
@ -57,12 +80,12 @@ set(SOURCES
gemma/attention.h
gemma/configs.cc
gemma/configs.h
gemma/flash_attention.cc
gemma/flash_attention.h
gemma/gemma_args.h
gemma/gemma-inl.h
gemma/gemma.cc
gemma/gemma.h
gemma/griffin.cc
gemma/griffin.h
gemma/kv_cache.cc
gemma/kv_cache.h
gemma/model_store.cc
@ -87,10 +110,10 @@ set(SOURCES
ops/matmul_static_f32.cc
ops/matmul_static_nuq.cc
ops/matmul_static_sfp.cc
ops/matmul_static_i8.cc
ops/matmul-inl.h
ops/matmul.cc
ops/matmul.h
ops/matvec-inl.h
ops/ops-inl.h
ops/ops.h
ops/sum-inl.h
@ -98,6 +121,7 @@ set(SOURCES
paligemma/image.h
util/allocator.cc
util/allocator.h
util/basics.cc
util/basics.h
util/mat.cc
util/mat.h
@ -108,6 +132,8 @@ set(SOURCES
util/threading.h
util/topology.cc
util/topology.h
util/zones.cc
util/zones.h
)
# Add C API sources only when building DLL
@ -195,16 +221,17 @@ set(GEMMA_TEST_FILES
compression/nuq_test.cc
compression/sfp_test.cc
evals/gemma_test.cc
gemma/flash_attention_test.cc
gemma/tensor_info_test.cc
io/blob_store_test.cc
io/fields_test.cc
ops/bench_matmul.cc
ops/dot_test.cc
ops/gemma_matvec_test.cc
ops/matmul_test.cc
ops/ops_test.cc
paligemma/image_test.cc
paligemma/paligemma_test.cc
util/basics_test.cc
util/threading_test.cc
)
@ -232,3 +259,12 @@ endif() # GEMMA_ENABLE_TESTS
add_executable(migrate_weights io/migrate_weights.cc)
target_link_libraries(migrate_weights libgemma hwy hwy_contrib)
# API server with SSE support
add_executable(gemma_api_server gemma/api_server.cc)
target_link_libraries(gemma_api_server libgemma hwy hwy_contrib nlohmann_json::nlohmann_json httplib_interface)
# API client for testing
add_executable(gemma_api_client gemma/api_client.cc)
target_link_libraries(gemma_api_client libgemma hwy hwy_contrib nlohmann_json::nlohmann_json httplib_interface)

View File

@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5")
# Require a more recent version.
git_override(
module_name = "highway",
commit = "1d16731233de45a365b43867f27d0a5f73925300",
commit = "9781a1698ee0756ef1eaaf96930113ed7cb6d3ee",
remote = "https://github.com/google/highway",
)

View File

@ -53,7 +53,7 @@ Guidelines](https://opensource.google.com/conduct/).
- LLM
- CPU-only inference for: Gemma 2-3, Griffin(SSM), PaliGemma 2.
- CPU-only inference for: Gemma 2-3, PaliGemma 2.
- Sampling with TopK and temperature.
- Backward pass (VJP) and Adam optimizer for Gemma research.
@ -222,23 +222,6 @@ Example invocation for the following configuration:
--tokenizer tokenizer.spm --weights gemma2-2b-it-sfp.sbs
```
### RecurrentGemma
This repository includes a version of Gemma based on Griffin
([paper](https://arxiv.org/abs/2402.19427),
[code](https://github.com/google-deepmind/recurrentgemma)). Its architecture
includes both recurrent layers and local attention, thus it is more efficient
for longer sequences and has a smaller memory footprint than standard Gemma. We
here provide a C++ implementation of this model based on the paper.
To use the recurrent version of Gemma included in this repository, build the
gemma binary as noted above in Step 3. Download the compressed weights and
tokenizer from the RecurrentGemma
[Kaggle](https://www.kaggle.com/models/google/recurrentgemma/gemmaCpp) as in
Step 1, and run the binary as follows:
`./gemma --tokenizer tokenizer.spm --model gr2b-it --weights 2b-it-sfp.sbs`
### PaliGemma Vision-Language Model
This repository includes a version of the PaliGemma 2 VLM
@ -469,7 +452,7 @@ FetchContent_MakeAvailable(sentencepiece)
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main)
FetchContent_MakeAvailable(gemma)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee)
FetchContent_MakeAvailable(highway)
```
@ -535,7 +518,7 @@ gemma.cpp was started in fall 2023 by
Griffin support was implemented in April 2024 thanks to contributions by Andrey
Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode
Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas
Fischbacher and Zoltan Szabadka.
Fischbacher and Zoltan Szabadka. It was removed in 2025-09.
Gemma-2 support was implemented in June/July 2024 with the help of several
people.

View File

@ -80,6 +80,37 @@ cc_library(
],
)
cc_library(
name = "int",
textual_hdrs = ["int-inl.h"],
deps = [
":types",
"//:basics",
"@highway//:hwy",
],
)
cc_test(
name = "int_test",
size = "small",
timeout = "long",
srcs = ["int_test.cc"],
features = ["fully_static_link"],
linkstatic = True,
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
deps = [
":distortion",
":int",
"@googletest//:gtest_main", # buildcleaner: keep
"//:test_util",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
],
)
cc_library(
name = "test_util",
textual_hdrs = [
@ -144,6 +175,7 @@ cc_library(
textual_hdrs = ["compress-inl.h"],
deps = [
":distortion",
":int",
":nuq",
":sfp",
"//:basics",
@ -182,6 +214,7 @@ cc_library(
name = "analyze",
textual_hdrs = ["analyze.h"],
deps = [
":int",
":nuq",
":sfp",
":types",

View File

@ -47,6 +47,7 @@
#include "hwy/highway.h"
// After highway.h
#include "compression/int-inl.h"
#include "compression/nuq-inl.h"
#include "compression/sfp-inl.h"
@ -55,12 +56,6 @@ namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
#ifdef HWY_IS_TEST
static constexpr bool kIsTest = true;
#else
static constexpr bool kIsTest = false;
#endif
// Enables generic code independent of compression type.
template <typename T> // primary, must specialize
struct CompressTraits {};
@ -422,6 +417,34 @@ struct CompressTraits<SfpStream> {
}
};
// Integer quantization.
template <>
struct CompressTraits<I8Stream> {
using Packed = I8Stream;
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw,
size_t num, CompressPerThread& tls,
const PackedSpan<Packed>& packed,
const size_t packed_ofs) {
IntCodec::Enc(df, raw, num, packed, packed_ofs);
}
template <class D> // Caller checks this is f32 or bf16
static HWY_INLINE void Load2(D d, const PackedSpan<const Packed>& packed,
const size_t packed_ofs, hn::Vec<D>& raw0,
hn::Vec<D>& raw1) {
IntCodec::Dec2(d, packed, packed_ofs, raw0, raw1);
}
template <class D, typename Raw>
static HWY_INLINE void DecompressAndZeroPad(
D d, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
Raw* raw, const size_t num) {
IntCodec::DecompressAndZeroPad(d, packed, packed_ofs, raw, num);
}
};
// Nonuniform quantization, 4.5 bits per element, two separate streams.
template <>
struct CompressTraits<NuqStream> {
@ -485,9 +508,6 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
}
}
const bool want_bench = COMPRESS_STATS || !kIsTest;
const double t0 = want_bench ? hwy::platform::Now() : 0.0;
using Traits = CompressTraits<hwy::RemoveConst<Packed>>;
constexpr size_t kBatch = 8192;
const size_t num_batches = hwy::DivCeil(num, kBatch);
@ -502,13 +522,6 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
packed, packed_ofs + my_pos);
});
if (want_bench) { // Avoids log spam in tests
const double t1 = hwy::platform::Now();
const double mb = static_cast<double>(num) * sizeof(raw[0]) * 1E-6;
const double mbps = mb / (t1 - t0);
fprintf(stderr, "Compress %.1f MB/s\n", mbps);
}
if constexpr (COMPRESS_STATS) {
for (size_t i = 1; i < work.tls.size(); ++i) {
work.tls[0].stats.Assimilate(work.tls[i].stats);
@ -709,6 +722,243 @@ HWY_INLINE float DecompressAndCall(D, const PackedSpan<const VT> v,
comp3);
}
// Similar to `hn::Transform*`, but for compressed `T`. Used by ops-inl.h.
// `DF` is the decompressed type, typically `float`.
template <class DF, typename T, class Func>
HWY_INLINE void DecompressAndCompressInplace(DF df, T* HWY_RESTRICT inout,
size_t num, Func&& func) {
const auto packed_inout = MakeSpan(inout, num);
using VF = hn::Vec<decltype(df)>;
HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df);
size_t i = 0;
if (num >= 2 * NF) {
for (; i <= num - 2 * NF; i += 2 * NF) {
VF v0, v1;
Decompress2(df, packed_inout, i, v0, v1);
const VF out0 = func(df, v0);
const VF out1 = func(df, v1);
Compress2(df, out0, out1, packed_inout, i);
}
}
const size_t remaining = num - i;
HWY_DASSERT(remaining < 2 * NF);
if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN float buf_inout[2 * hn::MaxLanes(df)];
// Ensure the second vector is zeroed even if remaining <= NF.
hn::Store(hn::Zero(df), df, buf_inout + NF);
DecompressAndZeroPad(df, packed_inout, i, buf_inout, remaining);
const VF v0 = hn::Load(df, buf_inout);
const VF v1 = hn::Load(df, buf_inout + NF);
const VF out0 = func(df, v0);
const VF out1 = func(df, v1);
Compress2(df, out0, out1, MakeSpan(buf_inout, 2 * NF), 0);
// Clang generates incorrect code for CopyBytes if num = 2.
for (size_t j = 0; j < remaining; ++j) {
inout[i + j] = hwy::ConvertScalarTo<T>(buf_inout[j]);
}
}
}
// One extra argument. `DF` is the decompressed type, typically `float`.
template <class DF, typename T, typename T1, class Func>
HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout,
size_t num,
const T1* HWY_RESTRICT p1,
const size_t p1_ofs,
Func&& func) {
const auto packed_inout = MakeSpan(inout, num);
const auto packed1 = MakeSpan(p1, p1_ofs + num);
using VF = hn::Vec<decltype(df)>;
HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df);
size_t i = 0;
if (num >= 2 * NF) {
for (; i <= num - 2 * NF; i += 2 * NF) {
VF v0, v1;
Decompress2(df, packed_inout, i, v0, v1);
VF v10, v11;
Decompress2(df, packed1, p1_ofs + i, v10, v11);
const VF out0 = func(df, v0, v10);
const VF out1 = func(df, v1, v11);
Compress2(df, out0, out1, packed_inout, i);
}
}
const size_t remaining = num - i;
HWY_DASSERT(remaining < 2 * NF);
if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN float buf_inout[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf1[2 * hn::MaxLanes(df)];
// Ensure the second vector is zeroed even if remaining <= NF.
hn::Store(hn::Zero(df), df, buf_inout + NF);
hn::Store(hn::Zero(df), df, buf1 + NF);
DecompressAndZeroPad(df, packed_inout, i, buf_inout, remaining);
DecompressAndZeroPad(df, packed1, p1_ofs + i, buf1, remaining);
const VF v0 = hn::Load(df, buf_inout);
const VF v1 = hn::Load(df, buf_inout + NF);
const VF v10 = hn::Load(df, buf1);
const VF v11 = hn::Load(df, buf1 + NF);
const VF out0 = func(df, v0, v10);
const VF out1 = func(df, v1, v11);
Compress2(df, out0, out1, MakeSpan(buf_inout, 2 * NF), 0);
// Clang generates incorrect code for CopyBytes if num = 2.
for (size_t j = 0; j < remaining; ++j) {
inout[i + j] = hwy::ConvertScalarTo<T>(buf_inout[j]);
}
}
}
// Single input, separate output. `DF` is the decompressed type, typically
// `float`.
template <class DF, typename T, typename T1, class Func>
HWY_INLINE void Decompress1AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
const T1* HWY_RESTRICT p1,
Func&& func) {
const auto packed_out = MakeSpan(out, num);
const auto packed1 = MakeSpan(p1, num);
using VF = hn::Vec<decltype(df)>;
HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df);
size_t i = 0;
if (num >= 2 * NF) {
for (; i <= num - 2 * NF; i += 2 * NF) {
VF v10, v11;
Decompress2(df, packed1, i, v10, v11);
const VF out0 = func(df, v10);
const VF out1 = func(df, v11);
Compress2(df, out0, out1, packed_out, i);
}
}
const size_t remaining = num - i;
HWY_DASSERT(remaining < 2 * NF);
if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN float buf1[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf_out[2 * hn::MaxLanes(df)];
// Ensure the second vector is zeroed even if remaining <= NF.
hn::Store(hn::Zero(df), df, buf1 + NF);
DecompressAndZeroPad(df, packed1, i, buf1, remaining);
const VF v10 = hn::Load(df, buf1);
const VF v11 = hn::Load(df, buf1 + NF);
const VF out0 = func(df, v10);
const VF out1 = func(df, v11);
Compress2(df, out0, out1, MakeSpan(buf_out, 2 * NF), 0);
// Clang generates incorrect code for CopyBytes if num = 2.
for (size_t j = 0; j < remaining; ++j) {
out[i + j] = hwy::ConvertScalarTo<T>(buf_out[j]);
}
}
}
// Two inputs. `DF` is the decompressed type, typically `float`.
template <class DF, typename T, typename T1, typename T2, class Func>
HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
const T1* HWY_RESTRICT p1,
const T2* HWY_RESTRICT p2,
const size_t p2_ofs, Func&& func) {
const auto packed_out = MakeSpan(out, num);
const auto packed1 = MakeSpan(p1, num);
const auto packed2 = MakeSpan(p2, p2_ofs + num);
using VF = hn::Vec<decltype(df)>;
HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df);
size_t i = 0;
if (num >= 2 * NF) {
for (; i <= num - 2 * NF; i += 2 * NF) {
VF v10, v11, v20, v21;
Decompress2(df, packed1, i, v10, v11);
Decompress2(df, packed2, p2_ofs + i, v20, v21);
const VF out0 = func(df, v10, v20);
const VF out1 = func(df, v11, v21);
Compress2(df, out0, out1, packed_out, i);
}
}
const size_t remaining = num - i;
HWY_DASSERT(remaining < 2 * NF);
if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN float buf1[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf2[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf_out[2 * hn::MaxLanes(df)];
// Ensure the second vector is zeroed even if remaining <= NF.
hn::Store(hn::Zero(df), df, buf1 + NF);
hn::Store(hn::Zero(df), df, buf2 + NF);
DecompressAndZeroPad(df, packed1, i, buf1, remaining);
DecompressAndZeroPad(df, packed2, p2_ofs + i, buf2, remaining);
const VF v10 = hn::Load(df, buf1);
const VF v11 = hn::Load(df, buf1 + NF);
const VF v20 = hn::Load(df, buf2);
const VF v21 = hn::Load(df, buf2 + NF);
const VF out0 = func(df, v10, v20);
const VF out1 = func(df, v11, v21);
Compress2(df, out0, out1, MakeSpan(buf_out, 2 * NF), 0);
// Clang generates incorrect code for CopyBytes if num = 2.
for (size_t j = 0; j < remaining; ++j) {
out[i + j] = hwy::ConvertScalarTo<T>(buf_out[j]);
}
}
}
// Three inputs. `DF` is the decompressed type, typically `float`.
template <class DF, typename T, typename T1, typename T2, typename T3,
class Func>
HWY_INLINE void Decompress3AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
const T1* HWY_RESTRICT p1,
const T2* HWY_RESTRICT p2,
const T3* HWY_RESTRICT p3,
Func&& func) {
const auto packed_out = MakeSpan(out, num);
const auto packed1 = MakeSpan(p1, num);
const auto packed2 = MakeSpan(p2, num);
const auto packed3 = MakeSpan(p3, num);
using VF = hn::Vec<decltype(df)>;
HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df);
size_t i = 0;
if (num >= 2 * NF) {
for (; i <= num - 2 * NF; i += 2 * NF) {
VF v10, v11, v20, v21, v30, v31;
Decompress2(df, packed1, i, v10, v11);
Decompress2(df, packed2, i, v20, v21);
Decompress2(df, packed3, i, v30, v31);
const VF out0 = func(df, v10, v20, v30);
const VF out1 = func(df, v11, v21, v31);
Compress2(df, out0, out1, packed_out, i);
}
}
const size_t remaining = num - i;
HWY_DASSERT(remaining < 2 * NF);
if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN float buf1[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf2[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf3[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf_out[2 * hn::MaxLanes(df)];
// Ensure the second vector is zeroed even if remaining <= NF.
hn::Store(hn::Zero(df), df, buf1 + NF);
hn::Store(hn::Zero(df), df, buf2 + NF);
hn::Store(hn::Zero(df), df, buf3 + NF);
DecompressAndZeroPad(df, packed1, i, buf1, remaining);
DecompressAndZeroPad(df, packed2, i, buf2, remaining);
DecompressAndZeroPad(df, packed3, i, buf3, remaining);
const VF v10 = hn::Load(df, buf1);
const VF v11 = hn::Load(df, buf1 + NF);
const VF v20 = hn::Load(df, buf2);
const VF v21 = hn::Load(df, buf2 + NF);
const VF v30 = hn::Load(df, buf3);
const VF v31 = hn::Load(df, buf3 + NF);
const VF out0 = func(df, v10, v20, v30);
const VF out1 = func(df, v11, v21, v31);
Compress2(df, out0, out1, MakeSpan(buf_out, 2 * NF), 0);
// Clang generates incorrect code for CopyBytes if num = 2.
for (size_t j = 0; j < remaining; ++j) {
out[i + j] = hwy::ConvertScalarTo<T>(buf_out[j]);
}
}
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp

View File

@ -18,11 +18,10 @@
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
#include "compression/compress.h"
#include <stddef.h>
#include <stdio.h>
#include "compression/compress.h"
#include "compression/distortion.h"
#include "util/test_util.h"
#include "hwy/aligned_allocator.h"
@ -45,7 +44,7 @@ namespace hn = hwy::HWY_NAMESPACE;
// Calls Compress and Decompress2 and verifies the distortion/error.
template <typename Packed>
struct TestDecompress2T {
struct TestDecompress2 {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const size_t N = hn::Lanes(d);
@ -120,12 +119,12 @@ struct TestDecompress2T {
}
};
void TestAllDecompress2() { ForeachPackedAndRawType<TestDecompress2T>(); }
void TestAllDecompress2() { ForeachPackedAndRawType<TestDecompress2>(); }
// Calls Compress and DecompressAndZeroPad for all short lengths and verifies
// the distortion/error.
template <typename Packed>
struct TestShortLengthsT {
struct TestShortLengths {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const size_t N = hn::Lanes(d);
@ -196,7 +195,82 @@ struct TestShortLengthsT {
}
};
void TestAllShortLengths() { ForeachPackedAndRawType<TestShortLengthsT>(); }
void TestAllShortLengths() { ForeachPackedAndRawType<TestShortLengths>(); }
// Verifies the arguments and remainder handling of `DecompressAndCompress*`.
class TestDecompressAndCompress {
public:
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
ForeachActivationType3<Test>(d);
}
private:
struct Test {
template <typename T1, typename T2, typename T3, /*Deduced:*/ class D>
void operator()(T1, T2, T3, D d) {
hwy::RandomState rng;
using DF = hn::Repartition<float, D>;
using VF = hn::Vec<DF>;
const DF df;
for (size_t num = 1; num < 7 * hn::Lanes(d); ++num) {
auto p = hwy::AllocateAligned<T1>(num);
auto p1 = hwy::AllocateAligned<T2>(num);
auto p2 = hwy::AllocateAligned<T3>(num);
auto out = hwy::AllocateAligned<T1>(num);
auto expected1 = hwy::AllocateAligned<T1>(num);
auto expected2 = hwy::AllocateAligned<T1>(num);
auto expected3 = hwy::AllocateAligned<T1>(num);
HWY_ASSERT(p && p1 && p2 && out && expected1 && expected2 && expected3);
// Two bits each, totalling 6 bits which fit in the BF16 mantissa.
for (size_t i = 0; i < num; ++i) {
const size_t mod = i & 3;
p[i] = hwy::ConvertScalarTo<T1>(mod);
p1[i] = hwy::ConvertScalarTo<T2>(mod << 2);
p2[i] = hwy::ConvertScalarTo<T3>(mod << 4);
// For `Decompress1AndCompressInplace` to not overwrite `p`.
out[i] = p[i];
expected1[i] = hwy::ConvertScalarTo<T1>(mod);
expected2[i] = hwy::ConvertScalarTo<T1>((mod << 2) | mod);
expected3[i] =
hwy::ConvertScalarTo<T1>((mod << 4) | (mod << 2) | mod);
}
DecompressAndCompressInplace(df, p.get(), num,
[](DF, VF v) HWY_ATTR -> VF { return v; });
HWY_ASSERT_ARRAY_EQ(expected1.get(), p.get(), num);
// Uses `out` so as not to overwrite `p`.
Decompress1AndCompressInplace(
df, out.get(), num, p1.get(), /*p1_ofs=*/0,
[](DF, VF v, VF v1) HWY_ATTR -> VF { return hn::Add(v, v1); });
HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num);
Decompress1AndCompressTo(df, out.get(), num, p.get(),
[](DF, VF v) HWY_ATTR -> VF { return v; });
HWY_ASSERT_ARRAY_EQ(expected1.get(), out.get(), num);
Decompress2AndCompressTo(
df, out.get(), num, p.get(), p1.get(), /*p2_ofs=*/0,
[](DF, VF v, VF v1) HWY_ATTR -> VF { return hn::Add(v, v1); });
HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num);
Decompress3AndCompressTo(
df, out.get(), num, p.get(), p1.get(), p2.get(),
[](DF, VF v, VF v1, VF v2)
HWY_ATTR -> VF { return hn::Add(hn::Add(v, v1), v2); });
HWY_ASSERT_ARRAY_EQ(expected3.get(), out.get(), num);
}
}
};
};
void TestAllDecompressAndCompress() {
// The Highway Test interface (`ForGE128Vectors`) only supports a single type.
// We hard-code one here, and use `ForeachActivationType` internally.
hn::ForGE128Vectors<TestDecompressAndCompress>()(float());
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
@ -208,6 +282,7 @@ namespace gcpp {
HWY_BEFORE_TEST(CompressTest);
HWY_EXPORT_AND_TEST_P(CompressTest, TestAllDecompress2);
HWY_EXPORT_AND_TEST_P(CompressTest, TestAllShortLengths);
HWY_EXPORT_AND_TEST_P(CompressTest, TestAllDecompressAndCompress);
HWY_AFTER_TEST();
} // namespace gcpp
#endif // HWY_ONCE

474
compression/int-inl.h Normal file
View File

@ -0,0 +1,474 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Normal include guard.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_H_
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <cstdint>
#include <cstdio>
#include "compression/types.h"
#include "util/basics.h"
#include "hwy/base.h"
#include "hwy/print-inl.h"
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_H_
// Actual per-target include guard.
#if defined(THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_TOGGLE) == \
defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_TOGGLE
#endif
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
// Encode/decode functions.
class IntCodec {
using ScaleT = hwy::bfloat16_t;
static constexpr size_t kGroupSize = I8Stream::kGroupSize;
// Offset (in bytes) of a group's start for packed_ofs (in elements) within a
// set of groups.
static constexpr size_t GroupByteOffset(size_t packed_ofs) {
const size_t kBytesPerGroup = (2 * sizeof(ScaleT)) + kGroupSize;
return (packed_ofs / kGroupSize) * kBytesPerGroup;
}
public:
template <class DBF, HWY_IF_BF16_D(DBF)>
static HWY_INLINE void DequantizeGroup(
DBF dbf, const PackedSpan<const I8Stream>& packed, size_t packed_ofs,
hwy::bfloat16_t* HWY_RESTRICT raw, size_t num) {
using T = ScaleT;
const hn::ScalableTag<float> df;
const hn::Rebind<int32_t, decltype(df)> di32;
const hn::Rebind<int16_t, decltype(di32)> di16;
const hn::Rebind<int8_t, decltype(di16)> di8;
const hn::Twice<hn::Rebind<hwy::bfloat16_t, decltype(df)>> dbf16;
const size_t N = hn::Lanes(di8);
const size_t N16 = hn::Lanes(dbf16);
using VI8 = hn::Vec<decltype(di8)>;
using VF = hn::Vec<decltype(df)>;
T inv_scale, zeropoint;
hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs), &inv_scale,
sizeof(T));
hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs) + sizeof(T),
&zeropoint, sizeof(T));
float inv_scale_f = hwy::ConvertScalarTo<float>(inv_scale);
float zeropoint_f = hwy::ConvertScalarTo<float>(zeropoint);
VF inv_scale_vec = hn::Set(df, inv_scale_f);
VF zeroscale_vec = hn::Set(df, -zeropoint_f * (inv_scale_f));
// Then iterate over remainder of packed, extracting num / N vectors and
// inserting into raw.
const size_t g_num = HWY_MIN(num, kGroupSize);
const size_t current_offset = GroupByteOffset(packed_ofs) +
(2 * sizeof(T)) + (packed_ofs % kGroupSize);
size_t i = 0;
for (i = 0; i + 4 * N <= g_num; i += 4 * N) {
const VI8 val0 =
hn::LoadU(di8, &packed.ptr->i + current_offset + i + 0 * N);
const VI8 val1 =
hn::LoadU(di8, &packed.ptr->i + current_offset + i + 1 * N);
const VI8 val2 =
hn::LoadU(di8, &packed.ptr->i + current_offset + i + 2 * N);
const VI8 val3 =
hn::LoadU(di8, &packed.ptr->i + current_offset + i + 3 * N);
const VF val0_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0)));
const VF val1_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val1)));
const VF val2_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val2)));
const VF val3_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val3)));
VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec);
VF dequantized_val1 = hn::MulAdd(inv_scale_vec, val1_f, zeroscale_vec);
VF dequantized_val2 = hn::MulAdd(inv_scale_vec, val2_f, zeroscale_vec);
VF dequantized_val3 = hn::MulAdd(inv_scale_vec, val3_f, zeroscale_vec);
hn::StoreU(
hn::OrderedDemote2To(dbf16, dequantized_val0, dequantized_val1),
dbf16, raw + i + 0 * N16);
hn::StoreU(
hn::OrderedDemote2To(dbf16, dequantized_val2, dequantized_val3),
dbf16, raw + i + 1 * N16);
}
for (; i + N <= g_num; i += N) {
const VI8 val0 = hn::LoadU(di8, &packed.ptr->i + current_offset + i);
const VF val0_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0)));
VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec);
const hn::Rebind<hwy::bfloat16_t, decltype(df)> dbf_half;
hn::StoreU(hn::DemoteTo(dbf_half, dequantized_val0), dbf_half, raw + i);
}
if (i < g_num) {
const size_t remaining = g_num - i;
const VI8 val0 =
hn::LoadN(di8, &packed.ptr->i + current_offset + i, remaining);
const VF val0_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0)));
VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec);
const hn::Rebind<hwy::bfloat16_t, decltype(df)> dbf_half;
hn::StoreN(hn::DemoteTo(dbf_half, dequantized_val0), dbf_half, raw + i,
remaining);
}
}
// Dequantizes `num` floats from `packed` into `raw`. `packed` points to
// compressed storage and `packed_ofs` indicates the destination offset
// within it, in number of elements. Scaling values are interleaved with int
// values to allow for easier unpacking.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void DequantizeGroup(
DF df, const PackedSpan<const I8Stream>& packed, size_t packed_ofs,
float* HWY_RESTRICT raw, size_t num) {
using T = ScaleT;
const hn::Rebind<int32_t, decltype(df)> di32;
const hn::Rebind<int16_t, decltype(di32)> di16;
const hn::Rebind<int8_t, decltype(di16)> di8;
const hn::Rebind<int8_t, decltype(df)> df8;
const size_t N = hn::Lanes(di8);
const size_t N32 = hn::Lanes(df);
using VI8 = hn::Vec<decltype(di8)>;
using VF = hn::Vec<decltype(df)>;
// HWY_ASSERT(num % 2 * N == 0);
// Load scale and zero point from the beginning - ensure correct pointer
// offset.
T inv_scale, zeropoint;
hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs), &inv_scale,
sizeof(T));
hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs) + sizeof(T),
&zeropoint, sizeof(T));
float inv_scale_f = hwy::ConvertScalarTo<float>(inv_scale);
float zeropoint_f = hwy::ConvertScalarTo<float>(zeropoint);
VF inv_scale_vec = hn::Set(df, inv_scale_f);
VF zeroscale_vec = hn::Set(df, -zeropoint_f * (inv_scale_f));
// Then iterate over remainder of packed, extracting num / N vectors and
// inserting into raw.
const size_t g_num = HWY_MIN(num, kGroupSize);
const size_t current_offset = GroupByteOffset(packed_ofs) +
(2 * sizeof(T)) + (packed_ofs % kGroupSize);
size_t i = 0;
for (; i + 2 * N <= g_num; i += 2 * N) {
const VI8 val0 =
hn::LoadU(di8, &packed.ptr->i + current_offset + i + 0 * N);
const VI8 val1 =
hn::LoadU(di8, &packed.ptr->i + current_offset + i + 1 * N);
const VF val0_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0)));
const VF val1_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val1)));
VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec);
VF dequantized_val1 = hn::MulAdd(inv_scale_vec, val1_f, zeroscale_vec);
hn::StoreU(dequantized_val0, df, raw + i + 0 * N32);
hn::StoreU(dequantized_val1, df, raw + i + 1 * N32);
}
for (; i + N <= g_num; i += N) {
const VI8 val0 = hn::LoadU(di8, &packed.ptr->i + current_offset + i);
const VF val0_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0)));
VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec);
hn::StoreU(dequantized_val0, df, raw + i);
}
if (i < g_num) {
const size_t remaining = g_num - i;
const VI8 val0 =
hn::LoadN(di8, &packed.ptr->i + current_offset + i, remaining);
const VF val0_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0)));
VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec);
hn::StoreN(dequantized_val0, df, raw + i, remaining);
}
}
// Quantizes `num` floats from `raw` into `packed`. `packed` points to
// compressed storage and `packed_ofs` indicates the destination offset
// within it, in number of elements. Scaling values are interleaved with
// int values to allow for easier unpacking.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void QuantizeGroup(DF df, const float* HWY_RESTRICT raw,
size_t num,
const PackedSpan<I8Stream>& packed,
size_t packed_ofs) {
using T = ScaleT;
const hn::Repartition<int32_t, DF> di32;
const hn::Half<hn::Repartition<int16_t, decltype(di32)>> di16;
const hn::Half<hn::Repartition<int8_t, decltype(di16)>> di8;
const size_t N = hn::Lanes(df);
using VI8 = hn::Vec<decltype(di8)>;
using VF = hn::Vec<decltype(df)>;
HWY_DASSERT(packed_ofs % kGroupSize == 0);
HWY_DASSERT(num % 2 * N == 0);
// Calculate min/max using SIMD
float min_val = hwy::HighestValue<float>();
float max_val = hwy::LowestValue<float>();
VF vmin = hn::Set(df, hwy::HighestValue<float>());
VF vmax = hn::Set(df, hwy::LowestValue<float>());
size_t j = 0;
for (; j + N <= num; j += N) {
const VF xi = hn::LoadU(df, raw + j);
vmin = hn::Min(vmin, xi);
vmax = hn::Max(vmax, xi);
}
min_val = hn::ReduceMin(df, vmin);
max_val = hn::ReduceMax(df, vmax);
for (; j < num; ++j) {
min_val = HWY_MIN(min_val, raw[j]);
max_val = HWY_MAX(max_val, raw[j]);
}
// Calculate range, scale and zeropoint
float x_range = max_val - min_val;
x_range = x_range == 0.0f ? 1.0f : x_range;
const float scale_f = 255.0f / x_range;
const float zeropoint_f = static_cast<float>(
static_cast<int32_t>(-scale_f * min_val - 128.0f)); // Correct casting
const T scale = hwy::ConvertScalarTo<T>(scale_f);
// inv_scale is used for all dequantization.
const T inv_scale = hwy::ConvertScalarTo<T>(1.0f / scale_f);
const T zeropoint = hwy::ConvertScalarTo<T>(zeropoint_f);
memcpy(&packed.ptr->i + GroupByteOffset(packed_ofs), &inv_scale, sizeof(T));
memcpy(&packed.ptr->i + GroupByteOffset(packed_ofs) + sizeof(T), &zeropoint,
sizeof(T));
const size_t g_num = HWY_MIN(num, kGroupSize);
VF mul = hn::Set(df, hwy::ConvertScalarTo<float>(scale));
VF add = hn::Set(df, hwy::ConvertScalarTo<float>(zeropoint));
const size_t current_offset = GroupByteOffset(packed_ofs) +
(2 * sizeof(T)) + (packed_ofs % kGroupSize);
size_t i = 0;
for (; i + 2 * N <= g_num; i += 2 * N) {
const VI8 val0 = hn::DemoteTo(
di8,
hn::DemoteTo(di16, NearestInt(hn::MulAdd(
mul, hn::LoadU(df, raw + i + 0 * N), add))));
const VI8 val1 = hn::DemoteTo(
di8,
hn::DemoteTo(di16, NearestInt(hn::MulAdd(
mul, hn::LoadU(df, raw + i + 1 * N), add))));
hn::StoreU(val0, di8, &packed.ptr->i + current_offset + i + 0 * N);
hn::StoreU(val1, di8, &packed.ptr->i + current_offset + i + 1 * N);
}
size_t remaining = g_num - i;
HWY_DASSERT(remaining < 2 * N);
if (HWY_UNLIKELY(remaining == 0)) return;
if (remaining > N) {
const VI8 val0 = hn::DemoteTo(
di8, hn::DemoteTo(di16, NearestInt(hn::MulAdd(
mul, hn::LoadU(df, raw + i), add))));
hn::StoreU(val0, di8, &packed.ptr->i + current_offset + i);
const size_t remaining1 = remaining - N;
const VI8 val1 = hn::DemoteTo(
di8,
hn::DemoteTo(di16,
NearestInt(hn::MulAdd(
mul, hn::LoadN(df, raw + i + N, remaining1), add))));
hn::StoreN(val1, di8, &packed.ptr->i + current_offset + i + N,
remaining1);
} else { // remaining <= N
const VI8 val0 = hn::DemoteTo(
di8, hn::DemoteTo(di16,
NearestInt(hn::MulAdd(
mul, hn::LoadN(df, raw + i, remaining), add))));
hn::StoreN(val0, di8, &packed.ptr->i + current_offset + i, remaining);
}
}
// Encodes `num` floats from `raw` into `packed`. `packed` points to
// compressed storage and `packed_ofs` indicates the destination offset
// within it, in number of elements. Scaling values are interleaved with
// int
// values to allow for easier unpacking.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Enc(DF df, const float* HWY_RESTRICT raw,
const size_t num,
const PackedSpan<I8Stream>& packed,
size_t packed_ofs) {
HWY_ASSERT(packed_ofs % kGroupSize == 0);
const size_t num_groups = hwy::DivCeil(num, kGroupSize);
size_t current_offset = packed_ofs;
for (size_t g = 0; g < num_groups; ++g) {
const size_t g_num = HWY_MIN(num - g * kGroupSize, kGroupSize);
const float* HWY_RESTRICT g_in = raw + g * kGroupSize;
QuantizeGroup(df, g_in, g_num, packed, current_offset);
current_offset += g_num;
}
}
// Decompresses to two bf16 vectors. `packed_ofs` must be a multiple of two
// vectors so that we only have to load one group's table.
template <class DBF, HWY_IF_BF16_D(DBF)>
static HWY_INLINE void Dec2(DBF dbf, const PackedSpan<const I8Stream>& packed,
const size_t packed_ofs, hn::Vec<DBF>& raw0,
hn::Vec<DBF>& raw1) {
const hn::Repartition<float, decltype(dbf)> df;
using VF = hn::Vec<decltype(df)>;
const size_t NF = hn::Lanes(df);
HWY_ASSERT(packed_ofs % 2 * NF == 0);
VF raw0_f, raw1_f, raw2_f, raw3_f;
Dec2(df, packed, packed_ofs + 0 * 2 * NF, raw0_f, raw1_f);
Dec2(df, packed, packed_ofs + 1 * 2 * NF, raw2_f, raw3_f);
raw0 = hn::OrderedDemote2To(dbf, raw0_f, raw1_f);
raw1 = hn::OrderedDemote2To(dbf, raw2_f, raw3_f);
}
// Decompresses to two f32 vectors. `packed_ofs` must be a multiple of two
// vectors so that we only have to load one group's table.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Dec2(DF df, const PackedSpan<const I8Stream>& packed,
const size_t packed_ofs, hn::Vec<DF>& raw0,
hn::Vec<DF>& raw1) {
using T = ScaleT;
const hn::Rebind<int32_t, decltype(df)> di32;
const hn::Rebind<int16_t, decltype(di32)> di16;
const hn::Rebind<int8_t, decltype(di16)> di8;
const hn::Rebind<int8_t, decltype(df)> df8;
const size_t N = hn::Lanes(di8);
using VI8 = hn::Vec<decltype(di8)>;
using VF = hn::Vec<decltype(df)>;
T inv_scale, zeropoint;
hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs), &inv_scale,
sizeof(T));
hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs) + sizeof(T),
&zeropoint, sizeof(T));
float inv_scale_f = hwy::ConvertScalarTo<float>(inv_scale);
float zeropoint_f = hwy::ConvertScalarTo<float>(zeropoint);
VF inv_scale_vec = hn::Set(df, inv_scale_f);
VF zeroscale_vec = hn::Set(df, -zeropoint_f * (inv_scale_f));
const size_t current_offset = GroupByteOffset(packed_ofs) +
(2 * sizeof(T)) + (packed_ofs % kGroupSize);
const VI8 val0 = hn::LoadU(di8, &packed.ptr->i + current_offset + 0 * N);
const VI8 val1 = hn::LoadU(di8, &packed.ptr->i + current_offset + 1 * N);
const VF val0_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0)));
const VF val1_f =
hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val1)));
raw0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec);
raw1 = hn::MulAdd(inv_scale_vec, val1_f, zeroscale_vec);
}
template <class D, typename Raw = hn::TFromD<D>>
static HWY_INLINE void DecompressAndZeroPad(
D d, const PackedSpan<const I8Stream>& packed, size_t packed_ofs,
Raw* HWY_RESTRICT raw, size_t num) {
if (num == 0) return;
const size_t N = hn::Lanes(d);
const size_t padded_num = hwy::RoundUpTo(num, N);
if (padded_num > num) {
hwy::ZeroBytes(raw + num, (padded_num - num) * sizeof(Raw));
}
size_t current_packed_ofs = packed_ofs;
Raw* HWY_RESTRICT current_raw = raw;
size_t num_to_decompress = num;
if (size_t within_group = current_packed_ofs % kGroupSize;
within_group != 0) {
const size_t remaining_in_group = kGroupSize - within_group;
const size_t num_in_first_group =
HWY_MIN(num_to_decompress, remaining_in_group);
DequantizeGroup(d, packed, current_packed_ofs, current_raw,
num_in_first_group);
current_packed_ofs += num_in_first_group;
current_raw += num_in_first_group;
num_to_decompress -= num_in_first_group;
}
if (num_to_decompress == 0) return;
HWY_DASSERT(current_packed_ofs % kGroupSize == 0);
const size_t num_full_groups = num_to_decompress / kGroupSize;
for (size_t g = 0; g < num_full_groups; ++g) {
DequantizeGroup(d, packed, current_packed_ofs, current_raw, kGroupSize);
current_packed_ofs += kGroupSize;
current_raw += kGroupSize;
}
const size_t remaining = num_to_decompress % kGroupSize;
if (remaining != 0) {
DequantizeGroup(d, packed, current_packed_ofs, current_raw, remaining);
}
}
}; // IntCodec
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_H_

494
compression/int_test.cc Normal file
View File

@ -0,0 +1,494 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// SFP uses ConcatEven/Odd which are not supported; skip SVE for faster tests.
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SVE)
#endif
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include "util/test_util.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/tests/hwy_gtest.h"
#include "hwy/tests/test_util.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "compression/int_test.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "compression/int-inl.h"
#include "hwy/tests/test_util-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
static constexpr size_t kGroupSize = I8Stream::kGroupSize;
static constexpr float kTolerance = 50000.0f;
// Can encode and decode sub-regions.
// Quantizes and de-quantizes a single (potentially partial) group to check
// that the quantizer is working correctly.
struct TestQuantize {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const size_t total = kGroupSize / 2; // already padded
const hn::ScalableTag<float> df;
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
auto dec1 = hwy::AllocateAligned<T>(total);
auto dec2 = hwy::AllocateAligned<T>(total);
auto dec3 = hwy::AllocateAligned<T>(total);
auto i8_stream = hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(total));
HWY_ASSERT(in && dec1 && dec2 && dec3 && i8_stream);
const auto int_span = MakeSpan(i8_stream.get(), total);
hwy::RandomState rng;
for (size_t i = 0; i < total; ++i) {
in[i] = static_cast<float>(RandomGaussian(rng));
}
IntCodec::QuantizeGroup(df, in.get(), total, int_span, 0);
IntCodec::DequantizeGroup(d, MakeConst(int_span), 0, dec1.get(), total);
const float epsilon =
hwy::ConvertScalarTo<float>(hwy::Epsilon<hwy::bfloat16_t>());
const float tolerance = kTolerance * epsilon;
for (size_t i = 0; i < total; ++i) {
const float expected_value = static_cast<float>(in[i]);
const float actual_value = hwy::ConvertScalarTo<float>(dec1[i]);
if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) {
fprintf(stderr,
"in[%zu] = %f, dec1[%zu] = %f, tolerance = %f, epsilon = %f\n",
i, expected_value, i, actual_value, tolerance, epsilon);
}
}
// Check that ::Enc works correctly as well.
IntCodec::Enc(df, in.get(), total, int_span, 0);
IntCodec::DequantizeGroup(d, MakeConst(int_span), 0, dec2.get(), total);
for (size_t i = 0; i < total; ++i) {
const float expected_value = static_cast<float>(in[i]);
const float actual_value = hwy::ConvertScalarTo<float>(dec2[i]);
if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) {
fprintf(stderr,
"in[%zu] = %f, dec2[%zu] = %f, tolerance = %f, epsilon = %f\n",
i, expected_value, i, actual_value, tolerance, epsilon);
}
}
// Check that ::DecompressAndZeroPad works correctly for one group as well.
IntCodec::Enc(df, in.get(), total, int_span, 0);
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec3.get(),
total);
for (size_t i = 0; i < total; ++i) {
const float expected_value = static_cast<float>(in[i]);
const float actual_value = hwy::ConvertScalarTo<float>(dec3[i]);
if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) {
fprintf(stderr,
"in[%zu] = %f, dec3[%zu] = %f, tolerance = %f, epsilon = %f\n",
i, expected_value, i, actual_value, tolerance, epsilon);
HWY_ASSERT(false);
}
}
}
};
void TestQuantizeBF16() { hn::ForGEVectors<128, TestQuantize>()(BF16()); }
void TestQuantizeF32() { hn::ForGEVectors<128, TestQuantize>()(float()); }
// Can encode and decode sub-regions.
// Quantizes and de-quantizes multiple (potentially partial) groups to check
// that DecompressAndZeroPad is working correctly.
struct TestMultiGroup {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::Repartition<float, D> df;
const size_t total = kGroupSize * 2 + kGroupSize / 4; // already padded
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
auto dec1 = hwy::AllocateAligned<T>(total);
auto dec2 = hwy::AllocateAligned<T>(total);
auto i8_stream = hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(total));
HWY_ASSERT(in && dec1 && i8_stream);
const auto int_span = MakeSpan(i8_stream.get(), total);
hwy::RandomState rng;
for (size_t i = 0; i < total; ++i) {
in[i] = static_cast<float>(RandomGaussian(rng));
}
const float epsilon =
hwy::ConvertScalarTo<float>(hwy::Epsilon<hwy::bfloat16_t>());
const float tolerance = kTolerance * epsilon;
// Check that ::DecompressAndZeroPad works correctly for one group as well.
IntCodec::Enc(df, in.get(), total, int_span, 0);
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec2.get(),
total);
for (size_t i = 0; i < total; ++i) {
const float expected_value = static_cast<float>(in[i]);
const float actual_value = hwy::ConvertScalarTo<float>(dec2[i]);
if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) {
fprintf(stderr,
"in[%zu] = %f, dec2[%zu] = %f, tolerance = %f, epsilon = %f\n",
i, expected_value, i, actual_value, tolerance, epsilon);
HWY_ASSERT(false);
}
}
}
};
void TestMultiGroupBF16() { hn::ForGEVectors<128, TestMultiGroup>()(BF16()); }
void TestMultiGroupF32() { hn::ForGEVectors<128, TestMultiGroup>()(float()); }
// Can encode and decode sub-regions.
struct TestOffset {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::Repartition<float, D> df;
const size_t total = 10 * kGroupSize; // already padded
const size_t kMidLen = 2 * kGroupSize; // length of middle piece
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
auto dec1 = hwy::AllocateAligned<T>(total);
auto dec2 = hwy::AllocateAligned<T>(kMidLen);
auto i8_stream = hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(total));
HWY_ASSERT(in && dec1 && dec2 && i8_stream);
const auto int_span = MakeSpan(i8_stream.get(), total);
hwy::RandomState rng;
for (size_t i = 0; i < total; ++i) {
in[i] = static_cast<float>(RandomGaussian(rng));
}
// Encode + decode everything
(void)IntCodec::Enc(df, in.get(), total, int_span, 0);
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec1.get(),
total);
MaybeCheckInitialized(dec1.get(), total * sizeof(T));
// Overwrite middle with first inputs
const size_t offset = 5 * kGroupSize;
(void)IntCodec::Enc(df, in.get(), kMidLen, int_span, offset);
// Decoded middle now matches previously decoded first
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), offset, dec2.get(),
kMidLen);
MaybeCheckInitialized(dec2.get(), kMidLen * sizeof(T));
for (size_t i = 0; i < kMidLen; ++i) {
HWY_ASSERT(dec1[i] == dec2[i]);
}
}
};
void TestOffsetBF16() { hn::ForGEVectors<128, TestOffset>()(BF16()); }
void TestOffsetF32() { hn::ForGEVectors<128, TestOffset>()(float()); }
// Can encode and decode sub-regions.
struct TestUnalignedOffset {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::Repartition<float, D> df;
const size_t total = 10 * kGroupSize; // already padded
const int num_unaligned_offsets = 4;
const std::array<size_t, num_unaligned_offsets> unaligned_offsets = {
4, kGroupSize + 100, 2 * kGroupSize + 100, 3 * kGroupSize + 100};
const std::array<size_t, num_unaligned_offsets> num = {4, 16, 32, 64};
for (int i = 0; i < num_unaligned_offsets; ++i) {
const size_t unaligned_offset = unaligned_offsets[i];
const size_t num_decompressed = num[i];
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
auto dec1 = hwy::AllocateAligned<T>(total);
auto i8_stream =
hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(total));
auto dec2 = hwy::AllocateAligned<T>(num_decompressed);
HWY_ASSERT(in && dec1 && dec2 && i8_stream);
const auto int_span = MakeSpan(i8_stream.get(), total);
hwy::RandomState rng;
for (size_t i = 0; i < total; ++i) {
in[i] = static_cast<float>(RandomGaussian(rng));
}
// // Encode + decode everything
(void)IntCodec::Enc(df, in.get(), total, int_span, 0);
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec1.get(),
total);
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), unaligned_offset,
dec2.get(), num_decompressed);
for (size_t i = 0; i < num_decompressed; ++i) {
T expected = hwy::ConvertScalarTo<T>(dec1[unaligned_offset + i]);
T actual = hwy::ConvertScalarTo<T>(dec2[i]);
HWY_ASSERT_EQ(expected, actual);
}
}
}
};
void TestUnalignedOffsetBF16() {
hn::ForGEVectors<128, TestUnalignedOffset>()(BF16());
}
void TestUnalignedOffsetF32() {
hn::ForGEVectors<128, TestUnalignedOffset>()(float());
}
// Can encode and decode sub-regions.
// Uses Dec2 to decode all elements in the packed buffer, then
// compares against DecompressAndZeroPad.
struct TestDec2 {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::Repartition<float, D> df;
// incl. partial group to test partial group handling
const size_t total = kGroupSize * 10 + kGroupSize / 2;
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
auto dec0 = hwy::AllocateAligned<T>(total);
auto dec1 = hwy::AllocateAligned<T>(total);
auto i8_stream = hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(total));
HWY_ASSERT(in && dec0 && dec1 && i8_stream);
const auto int_span = MakeSpan(i8_stream.get(), total);
hwy::RandomState rng;
for (size_t i = 0; i < total; ++i) {
in[i] = static_cast<float>(RandomGaussian(rng));
}
// Non-interleaved encode + decode for comparison
(void)IntCodec::Enc(df, in.get(), total, int_span, 0);
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec0.get(),
total);
// Encode + decode everything
(void)IntCodec::Enc(df, in.get(), total, int_span, 0);
using V = hn::Vec<decltype(d)>;
const size_t N = Lanes(d);
for (size_t i = 0; i < total; i += 2 * N) {
V f0, f1;
IntCodec::Dec2(d, MakeConst(int_span), i, f0, f1);
hn::StoreU(f0, d, dec1.get() + i + 0 * N);
hn::StoreU(f1, d, dec1.get() + i + 1 * N);
}
for (size_t i = 0; i < total; ++i) {
if (dec0[i] != dec1[i]) {
fprintf(stderr, "dec0[%zu] = %g, dec1[%zu] = %g\n", i,
hwy::ConvertScalarTo<float>(dec0[i]), i,
hwy::ConvertScalarTo<float>(dec1[i]));
}
HWY_ASSERT(dec0[i] == dec1[i]);
}
}
};
void TestDec2BF16() { hn::ForGEVectors<128, TestDec2>()(BF16()); }
void TestDec2F32() { hn::ForGEVectors<128, TestDec2>()(float()); }
// Tests that DecompressAndZeroPad fully populates the output array.
// This is intended to catch uninitialized value errors.
struct TestDequantizeAndZeroPad {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::ScalableTag<float> df;
constexpr size_t kSize = 4096;
auto in = hwy::AllocateAligned<float>(kSize);
auto actual_dec = hwy::AllocateAligned<T>(kSize);
auto i8_stream = hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(kSize));
HWY_ASSERT(in && actual_dec && i8_stream);
const auto int_span = MakeSpan(i8_stream.get(), kSize);
// Fill with a known pattern.
for (size_t i = 0; i < kSize; ++i) {
in[i] = static_cast<float>(i) - 128.0f;
}
IntCodec::Enc(df, in.get(), kSize, int_span, 0);
// Initialize with a sentinel value to detect if it's overwritten.
const T sentinel = hwy::ConvertScalarTo<T>(-999.0f);
std::fill(actual_dec.get(), actual_dec.get() + kSize, sentinel);
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, actual_dec.get(),
kSize);
MaybeCheckInitialized(actual_dec.get(), kSize * sizeof(T));
// Check that all sentinels were overwritten.
for (size_t i = 0; i < kSize; ++i) {
EXPECT_NE(hwy::ConvertScalarTo<float>(actual_dec[i]),
hwy::ConvertScalarTo<float>(sentinel))
<< " at index " << i;
}
}
};
void TestAllDequantizeAndZeroPad() {
hn::ForGEVectors<128, TestDequantizeAndZeroPad>()(BF16());
hn::ForGEVectors<128, TestDequantizeAndZeroPad>()(float());
}
// Tests that DecompressAndZeroPad works correctly for small and unaligned
// inputs. This is intended to catch uninitialized value errors in remainder
// handling.
struct TestSmallDequantize {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::ScalableTag<float> df;
constexpr size_t kGroupSize = I8Stream::kGroupSize;
constexpr size_t kMaxNum = kGroupSize * 3;
auto in = hwy::AllocateAligned<float>(kMaxNum);
auto actual_dec = hwy::AllocateAligned<T>(kMaxNum);
auto i8_stream =
hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(kMaxNum));
HWY_ASSERT(in && actual_dec && i8_stream);
const auto int_span =
MakeSpan(i8_stream.get(), I8Stream::PackedEnd(kMaxNum));
// Fill with a known pattern.
for (size_t i = 0; i < kMaxNum; ++i) {
in[i] = static_cast<float>(i) - 128.0f;
}
IntCodec::Enc(df, in.get(), kMaxNum, int_span, 0);
for (size_t num = 1; num < kGroupSize * 2; ++num) {
for (size_t offset = 0; offset < kGroupSize; offset += 16) {
const T sentinel = hwy::ConvertScalarTo<T>(-999.0f);
std::fill(actual_dec.get(), actual_dec.get() + num, sentinel);
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), offset,
actual_dec.get(), num);
MaybeCheckInitialized(actual_dec.get(), num);
// Check that all sentinels were overwritten.
for (size_t i = 0; i < num; ++i) {
EXPECT_NE(hwy::ConvertScalarTo<float>(actual_dec[i]),
hwy::ConvertScalarTo<float>(sentinel))
<< " at index " << i << " for num=" << num
<< " offset=" << offset;
}
}
}
}
};
void TestAllSmallDequantize() {
hn::ForGEVectors<128, TestSmallDequantize>()(BF16());
hn::ForGEVectors<128, TestSmallDequantize>()(float());
}
// Tests that DecompressAndZeroPad works correctly for a specific failing input.
struct TestSpecificDequantize {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::ScalableTag<float> df;
constexpr size_t kSize = 737280;
auto in = hwy::AllocateAligned<float>(kSize);
auto actual_dec = hwy::AllocateAligned<T>(kSize);
auto i8_stream = hwy::AllocateAligned<I8Stream>(I8Stream::PackedEnd(kSize));
HWY_ASSERT(in && actual_dec && i8_stream);
const auto int_span = MakeSpan(i8_stream.get(), kSize);
// Fill with a known pattern.
for (size_t i = 0; i < kSize; ++i) {
in[i] = static_cast<float>(i) - 128.0f;
}
IntCodec::Enc(df, in.get(), kSize, int_span, 0);
const size_t num = 64;
const size_t offset = 392704;
const T sentinel = hwy::ConvertScalarTo<T>(-999.0f);
std::fill(actual_dec.get(), actual_dec.get() + num, sentinel);
IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), offset,
actual_dec.get(), num);
// Check that all sentinels were overwritten.
for (size_t i = 0; i < num; ++i) {
EXPECT_NE(hwy::ConvertScalarTo<float>(actual_dec[i]),
hwy::ConvertScalarTo<float>(sentinel))
<< " at index " << i << " for num=" << num << " offset=" << offset;
}
}
};
void TestAllSpecificDequantize() {
hn::ForGEVectors<128, TestSpecificDequantize>()(BF16());
hn::ForGEVectors<128, TestSpecificDequantize>()(float());
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_BEFORE_TEST(IntTest);
HWY_EXPORT_AND_TEST_P(IntTest, TestOffsetF32);
HWY_EXPORT_AND_TEST_P(IntTest, TestOffsetBF16);
HWY_EXPORT_AND_TEST_P(IntTest, TestQuantizeF32);
HWY_EXPORT_AND_TEST_P(IntTest, TestQuantizeBF16);
HWY_EXPORT_AND_TEST_P(IntTest, TestDec2BF16);
HWY_EXPORT_AND_TEST_P(IntTest, TestDec2F32);
HWY_EXPORT_AND_TEST_P(IntTest, TestMultiGroupF32);
HWY_EXPORT_AND_TEST_P(IntTest, TestMultiGroupBF16);
HWY_EXPORT_AND_TEST_P(IntTest, TestUnalignedOffsetBF16);
HWY_EXPORT_AND_TEST_P(IntTest, TestUnalignedOffsetF32);
HWY_EXPORT_AND_TEST_P(IntTest, TestAllDequantizeAndZeroPad);
HWY_EXPORT_AND_TEST_P(IntTest, TestAllSmallDequantize);
HWY_EXPORT_AND_TEST_P(IntTest, TestAllSpecificDequantize);
HWY_AFTER_TEST();
} // namespace gcpp
#endif // HWY_ONCE

View File

@ -24,7 +24,6 @@
#include <algorithm> // std::shuffle
#include <array>
#include <random>
#include "compression/distortion.h"
#include "util/test_util.h"
@ -104,8 +103,8 @@ struct TestPlateaus {
HWY_ASSERT(-0.5f <= in[i] && in[i] < 0.5f);
}
std::random_device rd; // NOLINT
std::mt19937 rng(rd());
AesCtrEngine engine(/*deterministic=*/true);
RngStream rng(engine, 0);
std::shuffle(in.get(), in.get() + kGroupSize, rng);
NuqStream::ClusterBuf buf;
@ -151,8 +150,8 @@ struct TestRamp {
HWY_ASSERT(-0.45f <= in[i] && in[i] < 0.55f);
}
std::random_device rd; // NOLINT
std::mt19937 rng(rd());
AesCtrEngine engine(/*deterministic=*/true);
RngStream rng(engine, 0);
std::shuffle(in.get(), in.get() + kGroupSize, rng);
NuqStream::ClusterBuf buf;

View File

@ -87,9 +87,6 @@ class SbsWriterImpl : public ISbsWriter {
return;
}
fprintf(stderr, "Compressing %s (%zu x %zu = %zuM) to %s, please wait\n",
name, mat.Rows(), mat.Cols(), weights.size() / (1000 * 1000),
TypeName(TypeEnum<Packed>()));
HWY_ASSERT(weights.size() == mat.Extents().Area());
Compress(weights.data(), weights.size(), working_set_, mat.Span(),
/*packed_ofs=*/0, pool);
@ -116,6 +113,9 @@ class SbsWriterImpl : public ISbsWriter {
case Type::kF32:
InsertT<float>(name, weights, tensor_info);
break;
case Type::kI8:
InsertT<I8Stream>(name, weights, tensor_info);
break;
default:
HWY_ABORT("Unsupported destination (compressed) type %s",
TypeName(type));

View File

@ -90,8 +90,15 @@ class CompressionTest(absltest.TestCase):
info_256,
)
writer.insert(
"tensor_i8",
np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32),
configs.Type.kI8,
info_256,
)
config = configs.ModelConfig(
configs.Model.GEMMA_TINY,
configs.Model.GEMMA2_2B,
configs.Type.kSFP,
configs.PromptWrapping.GEMMA_IT,
)
@ -101,7 +108,7 @@ class CompressionTest(absltest.TestCase):
print("Ignore next two warnings; test does not enable model deduction.")
reader = compression.SbsReader(temp_file.full_path)
self.assertEqual(reader.config.model, configs.Model.GEMMA_TINY)
self.assertEqual(reader.config.model, configs.Model.GEMMA2_2B)
self.assertEqual(reader.config.weight, configs.Type.kSFP)
mat = reader.find_mat("tensor0")
@ -140,6 +147,11 @@ class CompressionTest(absltest.TestCase):
self.assertEqual(mat.type, configs.Type.kF32)
self.assertAlmostEqual(mat.scale, 1.0)
mat = reader.find_mat("tensor_i8")
self.assertEqual(mat.cols, 256)
self.assertEqual(mat.rows, 1)
self.assertEqual(mat.type, configs.Type.kI8)
self.assertAlmostEqual(mat.scale, 1.0)
if __name__ == "__main__":
absltest.main()

View File

@ -1,8 +0,0 @@
# General Remarks about the "PyTree" Abstraction
The pytree wrangling code in this project does not use any of the existing
"pytree" modules. The deeper reason here is that our approach is based on an
analysis of the notion that emphasizes deeper underlying principles. This is
being discussed internally at the time of this writing.

View File

@ -1,275 +0,0 @@
"""Ad-hoc glue code for building the griffin model-file for the C++ binary.
Usage:
python3 -m venv $HOME/clients/griffin-venv
. $HOME/clients/griffin-venv/bin/activate
python3 -m pip install -r requirements.txt
time python3 build_model_file_for_cpp_binary.py \
$HOME/GRIFFIN/model_data \
cpp_load_log.txt /tmp/G2B.data
real 3m5.821s
user 2m9.205s
sys 2m46.720s
./compress_weights --weights /tmp/G2B.data --model gr2b-it \
--compressed_weights /tmp/G2B.compressed
./gemma --tokenizer tokenizer.spm --weights /tmp/G2B.compressed \
--model gr2b-it
Weights for the recurrent-gemma model that can be converted with this script
can be found at:
https://www.kaggle.com/models/google/recurrentgemma/flax/2b-it
"""
import pprint
import re
import sys
from typing import Any, Mapping
import numpy
import orbax.checkpoint
import ml_model_transforms
import pytree_transforms
def _fn_identity(x): return x
def _fn_transpose(x): return x.T
def _fn_transpose_all_heads(x): return x.transpose(0, 2, 1)
def _fn_scaled_softplus(a):
return -8 * numpy.logaddexp(a, 0)
def _fn_attention_moveaxis(a):
return a.reshape(10, 256, 2560).transpose(0, 2, 1)
def _aspec(pieces=(), transforms=()):
"""Short-hand array-save-specification.
Args:
pieces: Sequence of key-sequences identifying an array.
transforms: Sequence of transformations, indexed in
parallel to `pieces`, to apply to data arrays prior to saving.
Will be padded with identity-transformations to the length of `pieces`.
Returns:
Specification as for use in _LAYETR_NAME_MAPPING.
"""
# `zip` trims to shortest sequence, so this amounts to using
# default-transforms.
# tuple() since we need a Sequence here, not a stateful-iterator zip_object.
return tuple(zip(pieces, list(transforms) + [_fn_identity] * len(pieces)))
_LAYER_NAME_MAPPING = pytree_transforms.deep_freeze({
# Recurrent Layer
'griffin_linear_x_w': _aspec(
[('recurrent_block', 'linear_x', 'kernel')],
[_fn_transpose]),
'griffin_linear_x_biases': _aspec(
[('recurrent_block', 'linear_x', 'bias')]),
'griffin_linear_y_w': _aspec(
[('recurrent_block', 'linear_y', 'kernel')],
[_fn_transpose]),
'griffin_linear_y_biases': _aspec(
[('recurrent_block', 'linear_y', 'bias')]),
'griffin_linear_out_w': _aspec(
[('recurrent_block', 'linear_out', 'kernel')],
[_fn_transpose]),
'griffin_linear_out_biases': _aspec(
[('recurrent_block' ,'linear_out', 'bias')]),
'griffin_conv_w': _aspec(
[('recurrent_block', 'conv_1d', 'w')]),
'griffin_conv_biases': _aspec(
[('recurrent_block', 'conv_1d', 'b')]),
'griffin_gate_w': _aspec(
[('recurrent_block', 'rg_lru', 'input_gate', 'w'),
('recurrent_block', 'rg_lru', 'a_gate', 'w')],
[_fn_transpose_all_heads, _fn_transpose_all_heads]),
'griffin_gate_biases': _aspec(
[('recurrent_block', 'rg_lru', 'input_gate', 'b'),
('recurrent_block', 'rg_lru', 'a_gate', 'b')]),
'griffin_a': _aspec(
[('recurrent_block', 'rg_lru', 'a_param')],
[_fn_scaled_softplus]),
# Attention Layer
'qkv_einsum_w': _aspec(
[('attention_block', 'proj_q', 'kernel'),
('attention_block', 'proj_k', 'kernel'),
('attention_block', 'proj_v', 'kernel'),
],
[_fn_transpose, _fn_transpose, _fn_transpose]),
'attn_vec_einsum_w': _aspec(
[('attention_block', 'proj_final', 'kernel')],
[_fn_attention_moveaxis]),
'attention_output_biases': _aspec(
[('attention_block', 'proj_final', 'bias')]),
# Common
'pre_attention_norm_scale': _aspec(
[('temporal_pre_norm', 'scale')]),
'pre_ffw_norm_scale': _aspec(
[('channel_pre_norm', 'scale')]),
'gating_einsum_w': _aspec(
[('mlp_block', 'ffw_up', 'w')],
[_fn_transpose_all_heads]),
'ffw_gating_biases': _aspec(
[('mlp_block', 'ffw_up', 'b')]),
'linear_w': _aspec(
[('mlp_block', 'ffw_down', 'kernel')],
[_fn_transpose]),
'ffw_output_biases': _aspec(
[('mlp_block', 'ffw_down', 'bias')]),
# Other
'embedder_input_embedding': _aspec(
[('embedder', 'input_embedding')]),
'final_norm_scale': _aspec(
[('final_norm', 'scale')]),
})
def process_param_line(line : str) -> tuple[None | str, int, str]:
"""Processes a "loading parameters" log-line from the griffin binary."""
# This is slightly more permissive than strictly needed, to also handle
# some earlier form of the output.
matched = re.match(
r'(?a)Loading Parameters:? \('
r'(?:layer=(?P<layer>\d+), )?'
r'size (?P<size>\d+)\):? '
r'(?P<tag>\S+)',
line)
if not matched:
return None
layer = matched['layer']
wanted_size = int(matched['size'])
cpp_tag = matched['tag']
return matched['layer'], int(matched['size']), matched['tag']
def collect_pytree_keys(param_lines):
"""Collects all the pytree keys and transforms for model-serialization."""
pytree_keys = []
array_transforms = []
unsatisfied = []
for maybe_spec in map(process_param_line, param_lines):
if not maybe_spec: continue # Skip non-parameter lines.
layer, wanted_size, cpp_tag = maybe_spec
pytree_key_tails_and_transforms = _LAYER_NAME_MAPPING.get(cpp_tag, ())
if not pytree_key_tails_and_transforms:
unsatisfied.append((layer, cpp_tag))
else:
for key_tail, array_transform in pytree_key_tails_and_transforms:
pytree_keys.append(
key_tail if layer is None
else (f'blocks.{layer}',) + key_tail)
array_transforms.append(array_transform)
return pytree_keys, array_transforms, unsatisfied
class UnsatisfiedArrayLoadsError(ValueError):
"""Some array-loads could not be satisfied."""
def flatten_model_for_cpp_binary(tree,
cpp_expectations_logfile_path : str,
out_path : str,
unsatisfied_ok : bool = False
):
"""Produces a model-parameters file readable by the C++ binary.
Args:
tree: The pytree with model-parameters.
cpp_expectations_logfile_path:
Path to a logfile produced by the C++ binary that shows
the expected array-order.
out_path: Path to the model-weights file to be written.
unsatisfied_ok: If true, we ignore the presence of unsatisfied
array-loads and write a model-parameters file that skips these pieces.
This will lead to an unusable model-parameters file which however
still might be useful for other analysis.
Returns:
Tuple `(unknown_keys, missing_keys)`, where `unknown_keys`
is a sequence of `(layer_or_None, name)` descriptions of the keys
in the C++ log that could not be satisfied, and `missing_keys`
is a sequence of linearized pytree key-sequences for keys
not found in the checkpoint.
Raises:
UnsatisfiedArrayLoadsError: If some of the expected arrays
could not be included in the output and `unsatisfied_ok`
is false.
"""
with open(cpp_expectations_logfile_path, 'rt') as h_log:
pytree_keys, array_transforms, unknown_keys = collect_pytree_keys(
list(h_log))
rank_by_pytree_key = {k: n for n, k in enumerate(pytree_keys)}
array_transform_by_pytree_key = dict(zip(pytree_keys, array_transforms))
#
model_contents = ml_model_transforms.model_contents(tree)
missing_keys = set(pytree_keys) - model_contents.keys()
if (unknown_keys or missing_keys) and not unsatisfied_ok:
raise ValueError(
f'Unsatisfied loads: unknown_keys: {unknown_keys!r}, '
f'missing keys: {sorted(missing_keys)!r}')
ml_model_transforms.model_save(
tree,
filepath_stem=out_path,
data_suffix='',
manifest_suffix=None,
array_transform_by_pytree_key=array_transform_by_pytree_key,
key=rank_by_pytree_key.get,
report=lambda line: print(line, file=sys.stderr),
byte_align=1)
return tuple(unknown_keys), tuple(sorted(missing_keys))
def main(args):
"""Creates the model-file.
Args:
sys.argv[] parameters from command line sans the leading one.
Returns:
The pytree with all the de-serialized variables, such as for convenient
`python3 -i` inspection.
"""
try:
model_dir, cpp_load_log, out_path = args
except Exception:
sys.exit(f'Usage: {__file__} [model_dir] [cpp_load_log] [output_filename]')
pattern = ("recurrent", "recurrent", "attention")
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
variables = orbax_checkpointer.restore(model_dir)
if sorted(variables) == ['params']:
print('Warning: Using `variables["params"]` as tree-root.', file=sys.stderr)
variables_to_use = variables['params']
else:
variables_to_use = variables
unknown, missing = flatten_model_for_cpp_binary(variables_to_use,
cpp_load_log,
out_path,
unsatisfied_ok=True)
print('Model file saved.\n'
f'# unknown:\n{pprint.pformat(unknown)}\n'
f'# missing:\n{pprint.pformat(missing)}')
return variables
if __name__ == '__main__':
# Return value assignment is for `python3 -i ...` inspection.
pytree = main(sys.argv[1:])

View File

@ -1,380 +0,0 @@
Loading Parameters (size 2622750720): embedder_input_embedding
Loading Parameters (size 10240): final_norm_scale
Loading Parameters: (layer=0, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=0, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=0, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=0, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=0, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=0, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=0, size 40960) griffin_conv_w
Loading Parameters: (layer=0, size 10240) griffin_conv_biases
Loading Parameters: (layer=0, size 5242880) griffin_gate_w
Loading Parameters: (layer=0, size 20480) griffin_gate_biases
Loading Parameters: (layer=0, size 10240) griffin_a
Loading Parameters: (layer=0, size 157286400) gating_einsum_w
Loading Parameters: (layer=0, size 78643200) linear_w
Loading Parameters: (layer=0, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=0, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=0, size 61440) ffw_gating_biases
Loading Parameters: (layer=0, size 10240) ffw_output_biases
Loading Parameters: (layer=1, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=1, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=1, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=1, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=1, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=1, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=1, size 40960) griffin_conv_w
Loading Parameters: (layer=1, size 10240) griffin_conv_biases
Loading Parameters: (layer=1, size 5242880) griffin_gate_w
Loading Parameters: (layer=1, size 20480) griffin_gate_biases
Loading Parameters: (layer=1, size 10240) griffin_a
Loading Parameters: (layer=1, size 157286400) gating_einsum_w
Loading Parameters: (layer=1, size 78643200) linear_w
Loading Parameters: (layer=1, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=1, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=1, size 61440) ffw_gating_biases
Loading Parameters: (layer=1, size 10240) ffw_output_biases
Loading Parameters: (layer=2, size 26214400) attn_vec_einsum_w
Loading Parameters: (layer=2, size 78643200) qkv_einsum_w
Loading Parameters: (layer=2, size 157286400) gating_einsum_w
Loading Parameters: (layer=2, size 78643200) linear_w
Loading Parameters: (layer=2, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=2, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=2, size 61440) ffw_gating_biases
Loading Parameters: (layer=2, size 10240) ffw_output_biases
Loading Parameters: (layer=2, size 10240) attention_output_biases
Loading Parameters: (layer=3, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=3, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=3, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=3, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=3, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=3, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=3, size 40960) griffin_conv_w
Loading Parameters: (layer=3, size 10240) griffin_conv_biases
Loading Parameters: (layer=3, size 5242880) griffin_gate_w
Loading Parameters: (layer=3, size 20480) griffin_gate_biases
Loading Parameters: (layer=3, size 10240) griffin_a
Loading Parameters: (layer=3, size 157286400) gating_einsum_w
Loading Parameters: (layer=3, size 78643200) linear_w
Loading Parameters: (layer=3, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=3, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=3, size 61440) ffw_gating_biases
Loading Parameters: (layer=3, size 10240) ffw_output_biases
Loading Parameters: (layer=4, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=4, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=4, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=4, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=4, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=4, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=4, size 40960) griffin_conv_w
Loading Parameters: (layer=4, size 10240) griffin_conv_biases
Loading Parameters: (layer=4, size 5242880) griffin_gate_w
Loading Parameters: (layer=4, size 20480) griffin_gate_biases
Loading Parameters: (layer=4, size 10240) griffin_a
Loading Parameters: (layer=4, size 157286400) gating_einsum_w
Loading Parameters: (layer=4, size 78643200) linear_w
Loading Parameters: (layer=4, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=4, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=4, size 61440) ffw_gating_biases
Loading Parameters: (layer=4, size 10240) ffw_output_biases
Loading Parameters: (layer=5, size 26214400) attn_vec_einsum_w
Loading Parameters: (layer=5, size 78643200) qkv_einsum_w
Loading Parameters: (layer=5, size 157286400) gating_einsum_w
Loading Parameters: (layer=5, size 78643200) linear_w
Loading Parameters: (layer=5, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=5, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=5, size 61440) ffw_gating_biases
Loading Parameters: (layer=5, size 10240) ffw_output_biases
Loading Parameters: (layer=5, size 10240) attention_output_biases
Loading Parameters: (layer=6, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=6, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=6, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=6, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=6, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=6, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=6, size 40960) griffin_conv_w
Loading Parameters: (layer=6, size 10240) griffin_conv_biases
Loading Parameters: (layer=6, size 5242880) griffin_gate_w
Loading Parameters: (layer=6, size 20480) griffin_gate_biases
Loading Parameters: (layer=6, size 10240) griffin_a
Loading Parameters: (layer=6, size 157286400) gating_einsum_w
Loading Parameters: (layer=6, size 78643200) linear_w
Loading Parameters: (layer=6, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=6, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=6, size 61440) ffw_gating_biases
Loading Parameters: (layer=6, size 10240) ffw_output_biases
Loading Parameters: (layer=7, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=7, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=7, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=7, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=7, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=7, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=7, size 40960) griffin_conv_w
Loading Parameters: (layer=7, size 10240) griffin_conv_biases
Loading Parameters: (layer=7, size 5242880) griffin_gate_w
Loading Parameters: (layer=7, size 20480) griffin_gate_biases
Loading Parameters: (layer=7, size 10240) griffin_a
Loading Parameters: (layer=7, size 157286400) gating_einsum_w
Loading Parameters: (layer=7, size 78643200) linear_w
Loading Parameters: (layer=7, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=7, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=7, size 61440) ffw_gating_biases
Loading Parameters: (layer=7, size 10240) ffw_output_biases
Loading Parameters: (layer=8, size 26214400) attn_vec_einsum_w
Loading Parameters: (layer=8, size 78643200) qkv_einsum_w
Loading Parameters: (layer=8, size 157286400) gating_einsum_w
Loading Parameters: (layer=8, size 78643200) linear_w
Loading Parameters: (layer=8, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=8, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=8, size 61440) ffw_gating_biases
Loading Parameters: (layer=8, size 10240) ffw_output_biases
Loading Parameters: (layer=8, size 10240) attention_output_biases
Loading Parameters: (layer=9, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=9, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=9, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=9, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=9, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=9, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=9, size 40960) griffin_conv_w
Loading Parameters: (layer=9, size 10240) griffin_conv_biases
Loading Parameters: (layer=9, size 5242880) griffin_gate_w
Loading Parameters: (layer=9, size 20480) griffin_gate_biases
Loading Parameters: (layer=9, size 10240) griffin_a
Loading Parameters: (layer=9, size 157286400) gating_einsum_w
Loading Parameters: (layer=9, size 78643200) linear_w
Loading Parameters: (layer=9, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=9, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=9, size 61440) ffw_gating_biases
Loading Parameters: (layer=9, size 10240) ffw_output_biases
Loading Parameters: (layer=10, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=10, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=10, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=10, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=10, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=10, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=10, size 40960) griffin_conv_w
Loading Parameters: (layer=10, size 10240) griffin_conv_biases
Loading Parameters: (layer=10, size 5242880) griffin_gate_w
Loading Parameters: (layer=10, size 20480) griffin_gate_biases
Loading Parameters: (layer=10, size 10240) griffin_a
Loading Parameters: (layer=10, size 157286400) gating_einsum_w
Loading Parameters: (layer=10, size 78643200) linear_w
Loading Parameters: (layer=10, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=10, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=10, size 61440) ffw_gating_biases
Loading Parameters: (layer=10, size 10240) ffw_output_biases
Loading Parameters: (layer=11, size 26214400) attn_vec_einsum_w
Loading Parameters: (layer=11, size 78643200) qkv_einsum_w
Loading Parameters: (layer=11, size 157286400) gating_einsum_w
Loading Parameters: (layer=11, size 78643200) linear_w
Loading Parameters: (layer=11, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=11, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=11, size 61440) ffw_gating_biases
Loading Parameters: (layer=11, size 10240) ffw_output_biases
Loading Parameters: (layer=11, size 10240) attention_output_biases
Loading Parameters: (layer=12, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=12, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=12, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=12, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=12, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=12, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=12, size 40960) griffin_conv_w
Loading Parameters: (layer=12, size 10240) griffin_conv_biases
Loading Parameters: (layer=12, size 5242880) griffin_gate_w
Loading Parameters: (layer=12, size 20480) griffin_gate_biases
Loading Parameters: (layer=12, size 10240) griffin_a
Loading Parameters: (layer=12, size 157286400) gating_einsum_w
Loading Parameters: (layer=12, size 78643200) linear_w
Loading Parameters: (layer=12, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=12, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=12, size 61440) ffw_gating_biases
Loading Parameters: (layer=12, size 10240) ffw_output_biases
Loading Parameters: (layer=13, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=13, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=13, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=13, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=13, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=13, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=13, size 40960) griffin_conv_w
Loading Parameters: (layer=13, size 10240) griffin_conv_biases
Loading Parameters: (layer=13, size 5242880) griffin_gate_w
Loading Parameters: (layer=13, size 20480) griffin_gate_biases
Loading Parameters: (layer=13, size 10240) griffin_a
Loading Parameters: (layer=13, size 157286400) gating_einsum_w
Loading Parameters: (layer=13, size 78643200) linear_w
Loading Parameters: (layer=13, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=13, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=13, size 61440) ffw_gating_biases
Loading Parameters: (layer=13, size 10240) ffw_output_biases
Loading Parameters: (layer=14, size 26214400) attn_vec_einsum_w
Loading Parameters: (layer=14, size 78643200) qkv_einsum_w
Loading Parameters: (layer=14, size 157286400) gating_einsum_w
Loading Parameters: (layer=14, size 78643200) linear_w
Loading Parameters: (layer=14, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=14, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=14, size 61440) ffw_gating_biases
Loading Parameters: (layer=14, size 10240) ffw_output_biases
Loading Parameters: (layer=14, size 10240) attention_output_biases
Loading Parameters: (layer=15, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=15, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=15, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=15, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=15, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=15, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=15, size 40960) griffin_conv_w
Loading Parameters: (layer=15, size 10240) griffin_conv_biases
Loading Parameters: (layer=15, size 5242880) griffin_gate_w
Loading Parameters: (layer=15, size 20480) griffin_gate_biases
Loading Parameters: (layer=15, size 10240) griffin_a
Loading Parameters: (layer=15, size 157286400) gating_einsum_w
Loading Parameters: (layer=15, size 78643200) linear_w
Loading Parameters: (layer=15, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=15, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=15, size 61440) ffw_gating_biases
Loading Parameters: (layer=15, size 10240) ffw_output_biases
Loading Parameters: (layer=16, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=16, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=16, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=16, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=16, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=16, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=16, size 40960) griffin_conv_w
Loading Parameters: (layer=16, size 10240) griffin_conv_biases
Loading Parameters: (layer=16, size 5242880) griffin_gate_w
Loading Parameters: (layer=16, size 20480) griffin_gate_biases
Loading Parameters: (layer=16, size 10240) griffin_a
Loading Parameters: (layer=16, size 157286400) gating_einsum_w
Loading Parameters: (layer=16, size 78643200) linear_w
Loading Parameters: (layer=16, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=16, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=16, size 61440) ffw_gating_biases
Loading Parameters: (layer=16, size 10240) ffw_output_biases
Loading Parameters: (layer=17, size 26214400) attn_vec_einsum_w
Loading Parameters: (layer=17, size 78643200) qkv_einsum_w
Loading Parameters: (layer=17, size 157286400) gating_einsum_w
Loading Parameters: (layer=17, size 78643200) linear_w
Loading Parameters: (layer=17, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=17, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=17, size 61440) ffw_gating_biases
Loading Parameters: (layer=17, size 10240) ffw_output_biases
Loading Parameters: (layer=17, size 10240) attention_output_biases
Loading Parameters: (layer=18, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=18, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=18, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=18, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=18, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=18, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=18, size 40960) griffin_conv_w
Loading Parameters: (layer=18, size 10240) griffin_conv_biases
Loading Parameters: (layer=18, size 5242880) griffin_gate_w
Loading Parameters: (layer=18, size 20480) griffin_gate_biases
Loading Parameters: (layer=18, size 10240) griffin_a
Loading Parameters: (layer=18, size 157286400) gating_einsum_w
Loading Parameters: (layer=18, size 78643200) linear_w
Loading Parameters: (layer=18, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=18, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=18, size 61440) ffw_gating_biases
Loading Parameters: (layer=18, size 10240) ffw_output_biases
Loading Parameters: (layer=19, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=19, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=19, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=19, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=19, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=19, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=19, size 40960) griffin_conv_w
Loading Parameters: (layer=19, size 10240) griffin_conv_biases
Loading Parameters: (layer=19, size 5242880) griffin_gate_w
Loading Parameters: (layer=19, size 20480) griffin_gate_biases
Loading Parameters: (layer=19, size 10240) griffin_a
Loading Parameters: (layer=19, size 157286400) gating_einsum_w
Loading Parameters: (layer=19, size 78643200) linear_w
Loading Parameters: (layer=19, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=19, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=19, size 61440) ffw_gating_biases
Loading Parameters: (layer=19, size 10240) ffw_output_biases
Loading Parameters: (layer=20, size 26214400) attn_vec_einsum_w
Loading Parameters: (layer=20, size 78643200) qkv_einsum_w
Loading Parameters: (layer=20, size 157286400) gating_einsum_w
Loading Parameters: (layer=20, size 78643200) linear_w
Loading Parameters: (layer=20, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=20, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=20, size 61440) ffw_gating_biases
Loading Parameters: (layer=20, size 10240) ffw_output_biases
Loading Parameters: (layer=20, size 10240) attention_output_biases
Loading Parameters: (layer=21, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=21, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=21, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=21, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=21, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=21, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=21, size 40960) griffin_conv_w
Loading Parameters: (layer=21, size 10240) griffin_conv_biases
Loading Parameters: (layer=21, size 5242880) griffin_gate_w
Loading Parameters: (layer=21, size 20480) griffin_gate_biases
Loading Parameters: (layer=21, size 10240) griffin_a
Loading Parameters: (layer=21, size 157286400) gating_einsum_w
Loading Parameters: (layer=21, size 78643200) linear_w
Loading Parameters: (layer=21, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=21, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=21, size 61440) ffw_gating_biases
Loading Parameters: (layer=21, size 10240) ffw_output_biases
Loading Parameters: (layer=22, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=22, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=22, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=22, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=22, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=22, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=22, size 40960) griffin_conv_w
Loading Parameters: (layer=22, size 10240) griffin_conv_biases
Loading Parameters: (layer=22, size 5242880) griffin_gate_w
Loading Parameters: (layer=22, size 20480) griffin_gate_biases
Loading Parameters: (layer=22, size 10240) griffin_a
Loading Parameters: (layer=22, size 157286400) gating_einsum_w
Loading Parameters: (layer=22, size 78643200) linear_w
Loading Parameters: (layer=22, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=22, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=22, size 61440) ffw_gating_biases
Loading Parameters: (layer=22, size 10240) ffw_output_biases
Loading Parameters: (layer=23, size 26214400) attn_vec_einsum_w
Loading Parameters: (layer=23, size 78643200) qkv_einsum_w
Loading Parameters: (layer=23, size 157286400) gating_einsum_w
Loading Parameters: (layer=23, size 78643200) linear_w
Loading Parameters: (layer=23, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=23, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=23, size 61440) ffw_gating_biases
Loading Parameters: (layer=23, size 10240) ffw_output_biases
Loading Parameters: (layer=23, size 10240) attention_output_biases
Loading Parameters: (layer=24, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=24, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=24, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=24, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=24, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=24, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=24, size 40960) griffin_conv_w
Loading Parameters: (layer=24, size 10240) griffin_conv_biases
Loading Parameters: (layer=24, size 5242880) griffin_gate_w
Loading Parameters: (layer=24, size 20480) griffin_gate_biases
Loading Parameters: (layer=24, size 10240) griffin_a
Loading Parameters: (layer=24, size 157286400) gating_einsum_w
Loading Parameters: (layer=24, size 78643200) linear_w
Loading Parameters: (layer=24, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=24, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=24, size 61440) ffw_gating_biases
Loading Parameters: (layer=24, size 10240) ffw_output_biases
Loading Parameters: (layer=25, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=25, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=25, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=25, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=25, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=25, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=25, size 40960) griffin_conv_w
Loading Parameters: (layer=25, size 10240) griffin_conv_biases
Loading Parameters: (layer=25, size 5242880) griffin_gate_w
Loading Parameters: (layer=25, size 20480) griffin_gate_biases
Loading Parameters: (layer=25, size 10240) griffin_a
Loading Parameters: (layer=25, size 157286400) gating_einsum_w
Loading Parameters: (layer=25, size 78643200) linear_w
Loading Parameters: (layer=25, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=25, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=25, size 61440) ffw_gating_biases
Loading Parameters: (layer=25, size 10240) ffw_output_biases

View File

@ -1,371 +0,0 @@
"""Transformations for python-trees representing the parameters of a ML model.
Important: This module assumes that byte-order is the same on the
machine that serializes data and the machine that deserializes
data. If, for example, numpy-data gets dumped, respectively loaded,
with a dtype-specification of numpy.float32, on-file byte-order
will be host byte order.
"""
import ast
import hashlib
import itertools
import pprint
import sys
import time
from typing import Any, Callable, Iterable, Iterator, Mapping, TypeVar
import numpy
import pytree_transforms
NT = TypeVar('NT')
def ml_model_leaf_summary(path, x, sep=', '):
"""Produces a textual summary of a leaf-node and its path.
Args:
path: The path-to-root, as a reverse-order recursive
pair of path-components, with `()` as root.
x: The leaf-object.
sep: the separator between description-elements.
Default ', ' allows for convenient line-by-line processing
(such as via grep, perl -ne, etc.), but using e.g. sep=',\n '
might be more useful for human consumption.
Returns:
A human-readable string providing information about the node.
"""
# Using `repr` for path-components to get a faithful presentation.
# (...which still however would be somewat painful to correctly
# split into components.)
path_str = ','.join(map(repr,
pytree_transforms.linearize_revtuple_path(path)))
tx = type(x)
mod = tx.__module__ # Either a module or a string like 'builtins'.
modname = mod if isinstance(mod, str) else mod.__name__
type_str = f'{modname}.{tx.__qualname__}'
try:
# `numpy.ndarray` instances have a `.data` property that gives access
# to a buffer via which we can hashlib-fingerprint the numerical
# contents. We here simply try to produce a fingerprint and also look
# up the .dtype of the object. Technically, there is a somewhat-unsound
# assumption here that if these operations succeed, we are indeed looking
# at a ndarray or sufficiently similar object for these operations to
# make sense. As the output is declared "for human consumption", this
# fishiness is not a problem.
fp = hashlib.sha256(x.data).hexdigest()
start = list(itertools.islice(x.flat, 5))
stats_str = (
f'min={numpy.min(x):.6g}, max={numpy.max(x):.6g}, '
f'mean={numpy.mean(x):.6g}, std={numpy.std(x):.6g}')
return (f'{path_str:60s}: <{type_str}{sep}'
f'fp=0x{fp}{sep}{stats_str}{sep}shape={x.shape}, '
f'dtype={x.dtype}{sep}start={start}>')
except (AttributeError, ValueError, TypeError):
# Fallback - trying to include information about the data-content
# of a likely-numerical-array failed.
return f'{path_str:60s}: {type_str}({repr(x)})'
# A specialized node-handler.
# Interface follows node-handler expectations defined in pytree_transforms.
def _ml_model_tree_node_handler(path: tuple, node : NT) -> (
None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT],
Iterator[tuple[Any, NT]]]):
"""Processes a tree-node as required by pytree-iteration and -mapping.
Args:
path: revtuple path to the current node.
node: a tree-node in a ML-model tree that is recursively
built out of `numpy.ndarray` leaf-values and dicts mapping
node-name string-keys to other such nodes representing subtrees -
and nothing else.
Returns:
`None` if the tree-node is to be regarded as a leaf, otherwise
a pair `(rebuilder, iterator)`, where `iterator` iterates
over the data-content of the node, each item represented as a pair
of `(lookup_path_item, value_item)`, and `rebuilder` is a function
which, when applied to `iterator` or any iterable with the same
elements, returns a node that is equivalent to the original.
Raises:
NotAMLModelTreeNodeError: If the tree contains a node that is neither
a `dict` nor a `numpy.ndarray` instance.
"""
# The astute reader will notice that we are doing something fishy
# here - this code could not be translated to Haskell as-is, since
# `NT` cannot actually be a proper type-variable in the sense
# of parametric polymorphism.
del path # Unused.
if isinstance(node, dict):
return dict, iter(node.items())
if isinstance(node, numpy.ndarray):
return None
raise pytree_transforms.NotAMLModelTreeNodeError(
f'Type of bad node: {type(node)}')
def _ml_model_extract_leaf_transform(
path: pytree_transforms.RevTuplePath,
leaf: Any):
"""Maps an array-leaf to a pair `(full_path, lambda: array)`.
The computation that produces the leaf-value is lazified underneath
a `lambda`, since if we e.g. performed a memory-expensive
transformation (such as some dtype-changes) directly at this point,
then going from an iterator over tree-items for one-by-one
consumption to a list of these items would have all the
dtype-transformed values around simultaneously. We want to avoid
situations where we can do nothing about having multiple variants
of the data simultaneously in memory.
"""
# Hack: If we are encountering a `bfloat16` numpy-array,
# we pretend to have the data as a numpy.float32 array,
# since that's about all that contemporary CPUs can process
# efficiently here.
linearized_path = pytree_transforms.linearize_revtuple_path(path)
try:
# We have to use some trickery to detect `bfloat16`.
if leaf.dtype.descr[-1] == ('', '<V2'):
return linearized_path, lambda: leaf.astype(numpy.float32)
else:
return linearized_path, lambda: leaf
except Exception:
return linearized_path, lambda: leaf
# Here, we cannot properly specify the return-type, since this can
# either be a leaf-type or something recursively-defined.
def revtuple_autovifify_from_linear(
keys_and_vals: Iterable[tuple[Any, Any]]) -> Any:
"""Performs perl-style autovivification on a nested-dict tree.
Args:
keys_and_vals: An iterable of pairs `(key_path, value)`, where
`key_path` is a sequence of keys to be used to navigate to
the result via iterative dict-lookup, left-to-right.
Must not have duplicate keys, and must not more than one key if
an empty-sequence key is present. If this iterable is an
iterator, it will be fully exhausted on successful execution.
Returns:
An object representing a nested-dict structure such that
for every `key_path` from `keys_and_vals`, recursive-dict-lookup
on the elements of that path starting from this object will
produce the corresponding value. An empty `keys_and_vals`
set will return `{}`. Every dict in the nested return-value
that has been populated by autovivification is newly allocated.
"""
# Code structure is a bit gnarly here due to f(keys_and_vals=[((), x)])
# having to evaluate to x and not a dict.
# There may be ways to prettify/simplify this.
result = None
empty = {}
for linear_path, val in keys_and_vals:
if linear_path == ():
if result is not None:
raise ValueError('Root-value seen alongside other values.')
result = val
else:
if result is None:
result = {}
elif type(result) is not dict:
# We already did encounter a root-value.
raise ValueError('Root-value seen alongside other values.')
cursor = result
for n in range(len(linear_path) - 1):
cursor = cursor.setdefault(linear_path[n], empty)
if cursor is empty:
# Regenerate `empty` if we just used it up.
empty = {}
cursor[linear_path[-1]] = val
return {} if result is None else result
def model_overview(tree, out=None) -> None:
"""Prints a human-readable overview to `(out or sys.stdout)`."""
actual_out = out or sys.stdout
for line in pytree_transforms.pytree_leaf_iter(
tree, ml_model_leaf_summary,
_ml_model_tree_node_handler):
print(line, file=actual_out)
def model_contents(tree) -> Mapping[tuple[str, ...], Any]:
"""Maps a model to a {pytree_keys: data_array} mapping.
Args:
tree: The ML-model parameter-tree, built recursively out of
dict-instances with numpy.ndarray instances as leaves.
Returns:
A mapping from linearized pytree-key-sequence tuple to the corresponding
leaf-value.
"""
def leaf_transform(revtuple_path, leaf):
return pytree_transforms.linearize_revtuple_path(revtuple_path), leaf
return dict(
pytree_transforms.pytree_leaf_iter(
tree, leaf_transform, _ml_model_tree_node_handler))
def _fn_identity(x): return x
def model_save(tree,
filepath_stem: str,
data_suffix: str = '.data',
manifest_suffix: str | None = '.manifest',
key: Callable[[tuple[str, ...]], Any] | None = None,
array_transform_by_pytree_key: (
Mapping[tuple[str, ...],
Callable[[numpy.ndarray], numpy.ndarray]] |
None) = None,
report: Callable[[str], None] | None = None,
byte_align: int = 8) -> tuple[int, float]:
"""Saves the content of a ML-model parameter-tree to filesystem.
After successful execution, the file f"{filepath_stem}.data"
will hold the combined numerical model-parameters, and
f"{filepath_stem}.manifest" will contain the key for interpreting
(and rebuilding) the data.
Args:
tree: The ML-model parameter-tree, built recursively out of
dict-instances with numpy.ndarray instances as leaves.
filepath_stem: Filesystem location for data.
data_suffix: Suffix to use for the data file.
manifest_suffix: Either `None`, in which case no manifest-file
will get written, or the suffix for the manifest-file.
key: `None` or a key-function that will be applied to the linear model-path
and used for sorting the data arrays by increasing value of the
key-function. If the key-function returns `None` on an item,
then this item is not included.
array_transform_by_pytree_key: Optional mapping from pytree-key
to an array-to-array transformation function to apply to the array
prior to serialization.
report: Optional callable for logging progress-reports.
byte_align: byte-alignment to use for numerical array data.
Numerical arrays whose size in bytes is not a multiple of this
will get padded to the next full multiple.
Returns:
A pair of `(size, time_sec)`, where `size` is the total byte-size
of the `.data` file and `time_sec` is the elapsed time
for saving the model, in seconds.
"""
time0 = time.monotonic()
if array_transform_by_pytree_key is None:
array_transform_by_pytree_key = {}
model_lazy_items = (
pytree_transforms.pytree_leaf_iter(
tree, _ml_model_extract_leaf_transform,
_ml_model_tree_node_handler))
if key is not None:
to_write = [
nkv[1:] for nkv in sorted(
(nkv for nkv in ((key(path), path, v)
for path, v in model_lazy_items)
if nkv[0] is not None), key=lambda nkv: nkv[0])]
else:
to_write = list(model_lazy_items)
#
def lazy_arr_path_shape_dtype_size(path_and_lazy_arr):
path, lazy_arr = path_and_lazy_arr
arr = array_transform_by_pytree_key.get(path, _fn_identity)(lazy_arr())
return path, arr.shape, arr.dtype, arr.data.nbytes
arrs_path_shape_dtype_nbytes = list(
map(lazy_arr_path_shape_dtype_size, to_write))
# We need to know the total size of all the data.
bytesizes = [nbytes for *_, nbytes in arrs_path_shape_dtype_nbytes]
padded_bytesizes = [-(-bytesize // byte_align * byte_align)
for bytesize in bytesizes]
offsets = numpy.cumsum([0] + padded_bytesizes)
membuf = numpy.memmap(filepath_stem + data_suffix,
mode='w+', shape=offsets[-1])
try:
for (path, shape, dtype, nbytes), offset, (_, lazy_arr) in zip(
arrs_path_shape_dtype_nbytes, offsets, to_write):
# Note that if getting the array from the lazy lambda involved some
# computation, such as a copying dtype-change, that computation would
# end up being done multiple times here - including once above, to compute
# byte-sizes, and once more here.
transformed_arr = array_transform_by_pytree_key.get(
path,
_fn_identity)(lazy_arr())
membuf[offset : offset + nbytes] = numpy.frombuffer(
transformed_arr.ravel().data, 'u1')
if report is not None:
samples = ', '.join(map(str, transformed_arr.ravel()[:5]))
report(f'# Adding: {path!r}\n bytes: {nbytes:10d}, '
f'shape: {shape!r:30},\n start: [{samples}, ...]')
transformed_arr = None # Drop memory references to numerical arrays ASAP.
finally:
if membuf is not None:
membuf.flush()
# NumPy wart: the memory-buffer is a resource that conceptually
# should be .close()able - since mmap()ing holds on to a
# file descriptor. However, it looks as if that clean-up were done
# in the "finalizer", despite that having meanwhile been widely
# understood as dubious practice. So, the best we can do here is
# to explicitly and clearly remove our reference to the instance.
del membuf
if manifest_suffix is not None:
# We still have to serialize the data that allows us to reconstruct
# a tree that is equivalent to the original.
manifest_data = [
dict(path=path,
dtype=dtype.descr[-1][-1],
shape=shape,
nbytes=nbytes,
offset=offset)
for (path, shape, dtype, nbytes), offset in zip(
arrs_path_shape_dtype_nbytes, offsets)]
with open(filepath_stem + '.manifest', 'wt') as h_manifest:
pprint.pprint(manifest_data, stream=h_manifest)
time_taken = time.monotonic() - time0
return offsets[-1], time_taken
def model_load(filepath_stem, mmapped=True):
"""Loads a model saved by `model_save`.
Tries to load the model from f"{filepath_stem}.data"
and f"{filepath_stem}.manifest".
Args:
filepath_stem: The model location on the filesystem.
mmapped: Whether data-arrays will be slices of a
`numpy.memmap` mapped buffer, to be paged in
on demand only, or in-memory copies of the data.
Returns:
A dict/numpy.ndarray tree representation of the model,
equivalent to the original model.
"""
with open(filepath_stem + '.manifest', 'rt') as h_manifest:
manifest = ast.literal_eval(h_manifest.read())
membuf = numpy.memmap(filepath_stem + '.data', mode='r+')
paths_and_arrays = []
for item in manifest:
path = item['path']
dtype = numpy.dtype(item['dtype'])
shape = item['shape']
nbytes = item['nbytes']
offset = item['offset']
data_array = numpy.frombuffer(membuf[offset : offset + nbytes].data,
dtype=dtype).reshape(shape)
paths_and_arrays.append(
(path,
data_array if mmapped else data_array.copy()))
# At this point, the memory-buffer is no longer needed. Still, if
# data-arrays retain references to the underlying data
# (i.e. when mmapped=False), this should keep the mapping
# - and hence file descriptor - open. We then are in a somewhat
# undesirable situation of clean-up of a resource that happens in a
# hard-to-predict way releasing a file descriptor.
del membuf
return revtuple_autovifify_from_linear(paths_and_arrays)

View File

@ -1,92 +0,0 @@
"""Basic tests for 'algebraic data type based pytree' transformations."""
import io
import os
import tempfile
import unittest
import numpy
import ml_model_transforms
def _get_model(prefix):
return {
prefix + 'a1': numpy.arange(1000, 1024).reshape(6, 4).astype(numpy.float32),
prefix + 'a2': numpy.arange(2000, 2048).reshape(6, 8).astype(numpy.float32),
prefix + 'b1': {
prefix + 'c1': numpy.arange(100, 127).reshape(3, 3, 3).astype(numpy.int8),
prefix + 'c2': numpy.arange(100, 128).reshape(7, 4).astype(numpy.float64)
}}
class MLModeltransformsTest(unittest.TestCase):
"""Basic correctness validation tests for ML-model transformations."""
def test_ml_model_leaf_summary(self):
"""Tests guarantees given by `ml_model_leaf_summary`."""
summary = ml_model_transforms.ml_model_leaf_summary(
('a', ()),
numpy.arange(1000, 1024).reshape(6, 4).astype(numpy.int16),
sep='##')
self.assertIn('##', summary) # Separator is respected.
self.assertIn('(6, 4)', summary) # Shape is mentioned somewhere.
self.assertIn('int16', summary) # dtype is mentioned somewhere.
def test_revtuple_autovivify_from_linear(self):
"""Tests guarantees given by `revtuple_autovifify_from_linear`."""
with self.subTest(guarantee='empty'):
self.assertEqual(
ml_model_transforms.revtuple_autovifify_from_linear([]),
{})
with self.subTest(guarantee='generic'):
keys_vals = [(('a', 'b1', 'c1'), 1001),
(('a', 'b2'), 1002),
(('a2',), 1003),
]
self.assertEqual(
ml_model_transforms.revtuple_autovifify_from_linear(keys_vals),
{'a': {'b1': {'c1': 1001}, 'b2': 1002}, 'a2': 1003})
def test_model_overview(self):
"""Tests guarantees given by `model_overview`."""
model = _get_model('xyz')
out_io = io.StringIO()
ml_model_transforms.model_overview(model, out=out_io)
overview = out_io.getvalue()
self.assertIn('xyz', overview)
def test_model_contents(self):
"""Tests guarantees given by `model_contents`."""
model = _get_model('pq_')
contents = ml_model_transforms.model_contents(model)
fingerprints = {k: (a.shape, a.ravel()[:3].tolist())
for k, a in contents.items()}
self.assertEqual(fingerprints,
{('pq_a1',): ((6, 4), [1000.0, 1001.0, 1002.0]),
('pq_a2',): ((6, 8), [2000.0, 2001.0, 2002.0]),
('pq_b1', 'pq_c1'): ((3, 3, 3), [100, 101, 102]),
('pq_b1', 'pq_c2'): ((7, 4), [100.0, 101.0, 102.0])})
def test_model_save_load_basic(self):
"""Tests basic guarantees given by `model_save` and `model_load`."""
# What we care about here is that the round trip works - so
# it makes more sense to test saving and loading as one unit.
model_orig = _get_model('model_')
with tempfile.TemporaryDirectory() as tempdir:
filepath_stem = os.path.join(tempdir, 'the_model')
total_size, total_time = ml_model_transforms.model_save(model_orig,
filepath_stem)
self.assertGreater(total_size, 0)
self.assertGreater(total_time, 0)
model_reloaded = ml_model_transforms.model_load(filepath_stem)
contents_orig = ml_model_transforms.model_contents(model_orig)
contents_reloaded = ml_model_transforms.model_contents(model_reloaded)
self.assertEqual(
{k: v.tolist() for k, v in contents_orig.items()},
{k: v.tolist() for k, v in contents_reloaded.items()})
if __name__ == '__main__':
unittest.main()

View File

@ -1,508 +0,0 @@
"""Tools for transforming "nested python object" tree data structures.
# Context
The motivation for this module came from ML applications that ought to
be based on a principled handling of nested Python data structures.
Having such principled pytree-transforming code available solves
some other problems, such as doing away with a need to abuse
tree-mapping for-side-effect-only and having to use a hope-and-pray
approach to processing very deeply nested values which with a recursive
approach might trigger a RecursionError.
We specifically want to cover the use case of having ML model
parameters that are available in a nested Python data structure for
which there "almost" is a unique-up-to-unique-isomorphism mapping from
and to this Algebraic Data Type:
`data ModelParams a = Array a | Node [(String, ModelParams a)]`
In this correspondence, `a` is some array-type (perhaps
`numpy.ndarray`, `jax.numpy.ndarray`, `tf.tensor`, etc.), but the
data-processing code is effectively entirely agnostic to this, and a
`Node` is "almost" an associative-list of (key, value) pairs
representing a Python dict. (Note: The "almost" here is mostly about
the conceptual wart that assoc-lists can in principle have key
duplicates, but Python dicts can not. This is however not a problem
since all we need is the transformation in one direction,
i.e. whatever data-processing `f` we want to express on the
model-parameters-pytree, we can express by specifying a "faithful"
mapping `m` into the above algebraic data type through which every
such pytree data transform factorizes, i.e. for every `f` we can find
a `g` such that `f(p) = g(m(p))`.)
## Components
The main workhorse in this module is the `pytree_iter` function that
maps a "PyTree (such as representing `ModelParams`)" to an iterator
over values obtained by applying a mapping-function to the "key-path"
and leaf-value for every leaf, where the "key-path" contains a
linked-list representation of the reversed sequence of keys from the
tree-root, with list-nodes being represented by pairs
`(latest_dict_key, rest_path)`, and the empty path being represented
by `()`.
For the sake of genericity, `pytree_iter` is built in such a way that
it actually can handle any kind of traversal of PyTree-trees that do
represent algebraic data types (note however that some some do not) -
but for this to make sense, the user must have a way to define how to
interpret tree-nodes, in particular identify leaves. This requires
providing a function `node_handler` with the same signature and
behavior as described below for "node handlers".
Additionally, this module provides mapping-over-pytrees via
`pytree_map`, which is also built in such a way that it makes the
correspondence between an algebraic data type and its Python
nested-tree representation explicit. Despite being powerful and
flexible, this, however, may in general require a bit more effort to
wire up, since node-rebuilding can be fairly nontrivial.
Furthermore, as a prominent application, this module provides a simple
deep-freezing function that translates a nested Python data structure
to deeply-immutable form.
## Concepts and Conventions
"revtuple representation":
As we iterate over a tree, we will have to keep track of the
path-to-tree-root. Naturally, two sibling nodes `n1` and `n2`
will share the same parent-path (being siblings), so it makes
sense to use a linked-list-with-shared-tail representation.
Python does not have a natural notion for that, so we use
recursively-constructed tuples `(node_tag, parent_path)`
that represent the path-from-root in-reverse-order, i.e.
for a non-empty path `p`, `p[0]` is the node-tag at the
deepest nesting level. We call this a "revtuple representation"
of the path.
"node handler":
A node-handler classifies a tree-node as "leaf or other node", and
for non-leaf nodes provides information about both its children and
how to rebuild it. The behavior of a node-handler function must be
in alignment with this docstring:
'''Processes a tree-node as required by pytree-iteration and -mapping.
Args:
revtuple_path: Revtuple-representation of the path-from-root
to the current node.
node: a tree-node in a ML-model tree that is recursively
built out of leaf-values and other nodes.
Returns:
`None` if the tree-node is to be regarded as a leaf, otherwise
a pair `(rebuilder, iterator)`, where `iterator` iterates
over the data-content of the node, each item represented as a pair
of `(lookup_path_item, value_item)`, and `rebuilder` is a function
which, when applied to an iterable of the aforementioned value-items
(or some transformation thereof) returns a node that is equivalent
to the original (or up to a transformation of the contents).
Raises:
InvalidTreeNodeError: If the tree contains a node of a kind
that is not expected to show up.
'''
Examples:
(The behavior of a node-handler is somewhat nontrivial, so covering
two very common cases via examples is in order.)
This node-handler would allow descending into (nested)
instances of `list` (but not subclass instances thereof):
```def list_node_handler(revtuple_path, obj):
''' ... '''
if type(obj) is list:
return list, enumerate(obj)
else:
return None
```
This node-handler would allow descending into (nested) mappings,
which upon rebuilding would get turned into `dict` instances:
```def mapping_node_handler(revtuple_path, obj):
''' ... '''
if isinstance(obj, collections.abc.Mapping):
# For generic mappings, we cannot rely on key- and item-iteration
# being guaranteed to use identical iteration-order.
items = list(obj.items())
keys = [kv[0] for kv in items]
return (lambda values: dict(zip(keys, values))), items
else:
return None
```
A dict/mapping node-handler can of course rename keys, add or remove
entries, make decisions based on the item-path, or map a dict to
an associative list, etc.
## Further Design Notes
The `pytree_map` function requests the leaf-transform and node-handler
to be side-effect-free functions. This is both required to leave
implementation-side flexibility, and also follows the general LISP
recommendation to not abuse mapping (which should be a pure
data-transformation) for imperative data processing. Overall, if
a need for more general "nested datastructures" processing becomes
pressing, it is for the better if this leads to a proper articulation
of the specific needs, to be addressed with appropriate design, rather
than abuse of functional data-transforms becoming "a bad idiom
that turned into established practice".
"""
import collections.abc
import immutabledict
import numpy
from typing import Any, Callable, Iterable, Iterator, TypeVar
T = TypeVar('T')
U = TypeVar('U')
KT = TypeVar('KT')
NT = TypeVar('NT')
## Type of the reverse-order-keys-to-root path.
# (This code actually illustrates why https://xkcd.com/2483/ is very misguided.)
RevTuplePath = tuple
## Type of the `leaf_transform` function-argument used for tree-iteration.
#
# This would be the correct type we would have to specify here but cannot,
# since the design of Python's static typing at the time of this writing
# is too broken for that:
#
# type LeafTransformFunc[L, R] = Callable[[RevTuplePath, L], R]
#
# Instead, we have to settle for...:
LeafTransformFunc = Callable[[RevTuplePath, Any], Any]
## Type of the `tree_node_handler` function-argument used for
## tree-iteration and tree-mapping.
#
# Again, this is the correct type we would have to put here but cannot:
#
# type NodeHandlerFunc[KT] = (
# Callable[[NT],
# None | tuple[Callable[[Iterable[tuple[KT, NT]]], NT],
# Iterator[tuple[KT, NT]]]])
#
# ...so, we have to instead settle for:
NodeHandlerFunc = (
Callable[[RevTuplePath, NT],
None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT],
Iterator[tuple[Any, NT]]]])
Predicate = Callable[[object], bool]
class InvalidTreeNodeError(ValueError):
"""Encountered a tree-node of invalid type."""
def linearize_revtuple_path(
revtuple_path: RevTuplePath,
present_as: Callable[[Iterator[T]], U] = tuple) -> U:
"""Translates a revtuple path to (typically) linear form.
With default `present_as`, this will map a path of the form
`(key_{N}, (key_{N-1}, ..., (root, ())))` into a tuple
(root, ..., key_{N-1}, key_{N}).
Args:
revtuple_path: A linked-list-as-recursive-pairs
reverse-order tuple-representation of the path.
Path-root is `()`, and node-key `x` relative to
earlier path `p` is represented as `(x, p)`.
present_as: Callable that consumes an iterator over
path-pieces - with the deepest-nesting level coming last -
turning it into a linearized path. Defaults to `tuple`.
Returns:
Linearized presentation of all the node-keys in the
recursive-path in order, deepest-down path component coming last.
"""
pieces = []
todo = revtuple_path
while todo:
node, todo = todo
pieces.append(node)
return present_as(reversed(pieces))
# This function itself has type `NodeHandlerFunc`, but Python does not
# allow us to here simply type-annotate it like this. We cannot even
# introduce an abbreviation for the complicated output-type,
# since that would have to be parametric in node-type `NT` (and `KT`).
def everything_is_a_leaf_node_handler(
revtuple_path: tuple,
node : NT) -> (
None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT],
Iterator[tuple[Any, NT]]]):
"""Processes a tree-node as required by pytree-iteration and -mapping.
Interface and signature are in alignment with the requirements for a
"node handler" function explained in the module-docstring.
Args:
revtuple_path: the path-to-root for this node.
node: a tree-node.
Returns:
`None`, i.e. classifying any kind of node as a leaf-node.
"""
del revtuple_path, node # Unused.
return None
def leaf_summary(path: RevTuplePath, x: object):
"""Produces a human-readable summary-string for a leaf-node.
Args:
path: revtuple representation of the path-to-root.
x: The leaf-value.
"""
del path # Ignored here.
tx = type(x)
mod = tx.__module__
modname = mod if isinstance(mod, str) else mod.__name__
type_str = f'{modname}.{tx.__qualname__}'
repr_str = repr(x)
repr_abbrev = repr_str if len(repr_str) < 40 else repr_str[:40] + ' ...'
# On str, int, float, etc. `{type_str}(repr(x))` would actually still be
# a (non-literal) Python-expression that would evaluate to the original value.
# However, we make no promises beyond "human-readable".
return f'{type_str}({repr_abbrev})'
# With respect to static type annotations, the limitations of Python's
# approach to static typing really become prominently visible here.
#
# Different arguments have type-parameters, but since there is no way
# to have parametric abbreviations such as `LeafTransformFunc[L, R]`,
# the only way we would have available to express relations between
# type-parameters would be to substitute in the not-abbreviated form of
# `NodeHandlerFunc` and `LeafTransformFunc`, giving us something monstrous.
# We instead here settle for "we cannot express that `tree` must
# have the same type as the input-type to `tree_node_handler` and use `Any`,
# and likewise for leaf_transform and the output.
def pytree_leaf_iter(
tree: Any,
leaf_transform: LeafTransformFunc,
node_handler: NodeHandlerFunc = everything_is_a_leaf_node_handler,
) -> Iterator[Any]:
# ...actual return type would be `Iterator[{what leaf_transform returns}]`.
"""Iterates over the leaves of a tree.
Args:
tree: The tree to iterate over.
leaf_transform: A callable `f` that will get applied
as `f(revtuple_path, leaf)`, where `revtuple_path`
is the revtuple representation of the path to the
leaf from the root.
node_handler: A "node handler" (see module docstring)
that processes nodes encountered during iterative traversal.
Yields:
Value of `leaf_transform(p, x)`, where `x` is the current leaf
and `p` is its revtuple-path to the root.
"""
# Note: Exit points for the code below are in non-obvious places
# and hence marked with " # ***EXIT***".
#
# Doing iteration properly is slightly nontrivial.
# One may be tempted to go for a very simple recursive implementation
# (with an extra pre-final `path` argument to `pytree_iter`):
#
# maybe_substructure = node_handler(path, tree)
# if maybe_substructure is None:
# # We are looking at a leaf-node.
# yield leaf_transform(path, tree)
# else:
# _, contents_iter = maybe_substructure
# for k, v in contents_iter:
# yield from pytree_iter(v, leaf_transform, (k, path), node_handler)
#
# That, however, would be flawed, since there is no a priori reason
# why a pytree may not be a very deeply nested structure - such as a
# long linked list. That would then risk raising `RecursionError`,
# and since Python by design(!) does not perform tail call elimination
# or any other kind of advanced CPS transforms, there is no recursive
# solution here. So, to do this properly, we have to do this iteratively.
#
# We are facing an annoying situation here: If `tree` itself is a leaf,
# we have two options: (a) wrapping it up in a one-node tree
# and processing that, or (b) special-casing "root is a leaf".
# Option (b) leads to some mild node-processing code-duplication
# for a single node (the root).
# Option (a) requires having special cases for node-processing that
# get looked at for every tree node. We go with option (b) here.
maybe_substructure = node_handler((), tree)
if maybe_substructure is None:
# The tree itself is a leaf.
yield leaf_transform((), tree)
return # ***EXIT***
# Otherwise, we are looking at a tree.
_, contents_iter = maybe_substructure
current_revtuple_path = ()
work_to_do = [contents_iter]
# Otherwise-unreachable sentinel for reliably identifying
# iterator-exhaustion without using exceptions:
sentinel = object()
while True:
current_iter = work_to_do[-1]
maybe_next_item = next(current_iter, sentinel)
if maybe_next_item is sentinel:
# We are done at this level.
work_to_do.pop()
if not work_to_do: return # ***EXIT***
current_revtuple_path = current_revtuple_path[1]
else:
path_piece, subtree = maybe_next_item
extended_revtuple_path = (path_piece, current_revtuple_path)
maybe_subtree_substructure = node_handler(extended_revtuple_path, subtree)
if maybe_subtree_substructure is None: # Case: subtree is a leaf.
yield leaf_transform(extended_revtuple_path, subtree)
else: # Case: subtree is a tree.
current_revtuple_path = (path_piece, current_revtuple_path)
_, items_iter = maybe_subtree_substructure
work_to_do.append(items_iter)
# The current design approach here would be appropriate for
# applying leaf-transforms while retaining the structure of the tree -
# which closely corresponds to e.g. a (a -> b) -> (Tree a -> Tree b) functor.
#
# It is not entirely clear whether this is the abstraction that we should
# consider as being appropriately generic to flesh out explicitly - rather
# than starting from a more general approach of which this then is a special
# case. Some background: https://ncatlab.org/nlab/show/recursion+scheme
#
# On the other hand, there is a lot of flexibility via whatever
# node-rebuilder a node-handler produces - this can do quite some reshaping
# of a tree, including dropping or duplicating nodes.
def pytree_map(
tree: Any,
leaf_transform,
node_handler: NodeHandlerFunc = everything_is_a_leaf_node_handler,
):
"""Maps a (potentially nested) Python value to another such value.
Args:
tree: The Python-object to be mapped.
leaf_transform: A callable `f` that will get applied
as `f(revtuple_path, leaf)`, where `revtuple_path`
is the revtuple representation of the path to the
leaf from the root. Must be side effect free.
node_handler: A "node handler" (see module docstring)
that processes nodes encountered during iterative traversal.
Must be side effect free.
Returns:
The outcome of translating `tree`.
"""
# Note: Exit points for the code below are in non-obvious places
# and hence marked with " # ***EXIT***".
#
# Otherwise-inaccessible sentinel object, for reliably identifying
# missing-values via identity-check against sentinel lookup-default.
sentinel = object()
# Code structure mostly follows pytree_leaf_iter.
maybe_substructure = node_handler((), tree)
if maybe_substructure is None:
return leaf_transform((), tree) # ***EXIT***
rebuilder, items_iter = maybe_substructure
current_revtuple_path = ()
# Per-level, we have a triplet of:
# (rebuilder, remaining_items_to_iterate_over, processed).
parts_for_assembly = [(rebuilder, items_iter, [])]
while True:
this_rebuilder, this_items_iter, this_done_pieces = parts_for_assembly[-1]
maybe_next_item = next(this_items_iter, sentinel)
if maybe_next_item is sentinel:
# We are done with all the items for this level.
parts_for_assembly.pop()
built_iter = this_rebuilder(this_done_pieces)
if not parts_for_assembly: # No outer structure, so at-top-level.
return built_iter # ***EXIT***
else: # We have outer structure.
parts_for_assembly[-1][-1].append(built_iter)
current_revtuple_path = current_revtuple_path[1]
continue # ...with next is-the-final-item-complete-check.
else:
# More constituents of the current item.
path_piece, subtree_item = maybe_next_item
extended_revtuple_path = (path_piece, current_revtuple_path)
maybe_subtree_substructure = node_handler(
extended_revtuple_path,
subtree_item)
if maybe_subtree_substructure is None:
this_done_pieces.append(
leaf_transform(extended_revtuple_path, subtree_item))
else:
# We have a subtree.
subtree_rebuilder, subtree_items_iter = maybe_subtree_substructure
current_revtuple_path = (path_piece,
current_revtuple_path)
parts_for_assembly.append(
(subtree_rebuilder, subtree_items_iter, []))
def deep_freeze(
tree,
*,
is_mapping : Predicate = lambda x: isinstance(x, collections.abc.Mapping),
is_set : Predicate = lambda x: isinstance(x, collections.abc.Set),
is_sequence : Predicate = lambda x: isinstance(x, (list, tuple)),
leaf_fn: Callable[[Any], Any] = lambda x: x,
):
"""Recursively freezes Set/Mapping/List/Tuple structures.
Args:
tree: The potentially deeply-nested object to deep-freeze.
is_mapping: Callable that decides whether a sub-object is a mapping.
Defaults to an `isinstance()` check for `collections.abc.Mapping`.
is_set: Callable that decides whether a sub-object is a set.
Defaults to an `isinstance()` check for `collections.abc.Set`.
is_sequence: Callable that decides whether a sub-object is a sequence.
Defaults to a check for being a `tuple` or `list` instance.
leaf_fn: Function to use for translating non-mapping/set/sequence
instances.
Returns:
Translated-to-deeply-immutable form of `tree`.
"""
idict = immutabledict.immutabledict
def freeze_node_handler(path, x):
if is_set(x):
return frozenset, ((None, y) for y in x)
if is_mapping(x):
# Mappings already have hashable, so
# (should-be-)deeply-immutable keys.
# Hence, we only need to deep-freeze the values.
#
# Note that non-`dict` mappings might not guarantee
# to respect iteration-order, so we have to be careful here:
items = list(x.items())
keys = [kv[0] for kv in items]
values = [kv[1] for kv in items]
return ((lambda ys: idict(zip(keys, ys))),
iter(items))
if is_sequence(x):
return tuple, enumerate(iter(x))
# Otherwise, this should not be traversed.
return None
def leaf_transform(revtuple_path, value):
del revtuple_path # Unused.
return leaf_fn(value)
return pytree_map(tree, leaf_transform, freeze_node_handler)

View File

@ -1,168 +0,0 @@
"""Basic tests for 'algebraic data type based pytree' transformations."""
import collections.abc
import sys
import unittest
import pytree_transforms
def _get_deep_pytree(packaging_fn, bottom, depth):
current = bottom
for n in reversed(range(depth)):
current = packaging_fn(n, current)
return current
def _dict_node_handler(p, d):
del p # Unused.
if isinstance(d, dict):
keys = d.keys()
newdict = lambda vals: dict(zip(keys, vals))
return (newdict, iter(d.items()))
else:
return None
class PyTreeTest(unittest.TestCase):
"""Basic correctness validation tests for PyTree transformations."""
def test_linearize_revtuple_path(self):
"""Tests guarantees given by `linearize_revtuple_path`."""
linearize_revtuple_path = pytree_transforms.linearize_revtuple_path
with self.subTest(guarantee='empty'):
self.assertEqual(linearize_revtuple_path(()), ())
with self.subTest(guarantee='typical'):
self.assertEqual(linearize_revtuple_path((30, (20, (10, ())))),
(10, 20, 30))
with self.subTest(guarantee='present_as'):
self.assertEqual(
linearize_revtuple_path(
(30, (20, (10, ()))), present_as=list),
[10, 20, 30])
def test_everything_is_a_leaf_node_handler(self):
"""Tests guarantees given by `everything_is_a_leaf_node_handler`."""
everything_is_a_leaf_node_handler = (
pytree_transforms.everything_is_a_leaf_node_handler)
self.assertEqual(everything_is_a_leaf_node_handler((), 'abc'),
None)
self.assertEqual(everything_is_a_leaf_node_handler(('b', ()),
dict(a=3)),
None)
def test_leaf_summary(self):
"""Tests guarantees given by `leaf_summary`."""
# Since the docstring only guarantees "a human-readable presentation",
# we can and should only do loose checks.
thing = (5678, 9531)
summary = pytree_transforms.leaf_summary(('key', ()), thing)
self.assertIsInstance(summary, str)
self.assertIn(str(thing[0]), summary)
self.assertIn(str(thing[1]), summary)
def test_pytree_leaf_iter(self):
"""Tests guarantees given by `pytree_leaf_iter`."""
pytree_leaf_iter = pytree_transforms.pytree_leaf_iter
def leaf_transform(path, leaf):
return repr(leaf) if path and path[0].startswith('R') else leaf
with self.subTest(guarantee='returns_iterator'):
result = pytree_leaf_iter(7, leaf_transform, _dict_node_handler)
self.assertIsInstance(result, collections.abc.Iterator)
with self.subTest(guarantee='totally_empty'):
result = list(pytree_leaf_iter({}, leaf_transform, _dict_node_handler))
self.assertEqual(result, [])
with self.subTest(guarantee='no_leaves'):
result = list(pytree_leaf_iter(dict(a={}),
leaf_transform, _dict_node_handler))
self.assertEqual(result, [])
with self.subTest(guarantee='is_leaf'):
result = list(pytree_leaf_iter(777, leaf_transform, _dict_node_handler))
self.assertEqual(result, [777])
with self.subTest(guarantee='generic'):
result = list(pytree_leaf_iter(
dict(n0=dict(n01=dict(n012=1002,
n013=1003,
Rn014=1004,
),
n02=1005),
n5=1006),
leaf_transform, _dict_node_handler))
self.assertEqual(result, [1002, 1003, '1004', 1005, 1006])
with self.subTest(guarantee='with_keys'):
result = list(pytree_leaf_iter(
dict(n0=dict(n01=dict(n012=1002,
n013=1003)),
n1=1004),
lambda p, s: (pytree_transforms.linearize_revtuple_path(p), s),
_dict_node_handler))
self.assertEqual(result,
[(('n0', 'n01', 'n012'), 1002),
(('n0', 'n01', 'n013'), 1003),
(('n1',), 1004)])
def test_pytree_map(self):
"""Tests guarantees given by `pytree_map`."""
pytree_map = pytree_transforms.pytree_map
leaf_transform = lambda p, s: repr(s)
tree1 = dict(t0=dict(t10=1001,
t11=dict(t110=1002,
t111=1003),
t12=dict(t120=1004,
t121=1005,
t122=1006)),
t1=1007)
with self.subTest(guarantee='no_leaves'):
result = pytree_map(dict(a={}),
leaf_transform,
_dict_node_handler)
self.assertEqual(result, dict(a={}))
with self.subTest(guarantee='is_leaf'):
result = pytree_map(777, leaf_transform, _dict_node_handler)
self.assertEqual(result, '777')
with self.subTest(guarantee='generic'):
result = pytree_map(tree1, leaf_transform, _dict_node_handler)
self.assertEqual(result['t0']['t10'], '1001')
def test_deeply_nested(self):
"""Tests correct behavior on deeply-nested data structures."""
pytree_leaf_iter = pytree_transforms.pytree_leaf_iter
pytree_map = pytree_transforms.pytree_map
#
depth = max(10**5, sys.getrecursionlimit() + 100)
deep_tree = _get_deep_pytree(lambda n, t: {n: t},
'leaf', depth)
with self.subTest(function='pytree_leaf_iter'):
leaves = list(pytree_leaf_iter(deep_tree,
lambda p, s: s.upper(),
_dict_node_handler))
self.assertEqual(leaves, ['LEAF'])
with self.subTest(function='pytree_map'):
mapped_deep_tree = pytree_map(deep_tree,
lambda p, s: s,
_dict_node_handler)
self.assertIsInstance(mapped_deep_tree, dict)
with self.subTest(function='combined'):
leaves = list(
pytree_leaf_iter(
pytree_map(deep_tree,
lambda p, s: s.capitalize(),
_dict_node_handler),
lambda p, s: s + s,
_dict_node_handler))
self.assertEqual(leaves, ['LeafLeaf'])
def test_deep_freeze(self):
"""Tests guarantees given by `deep_freeze`."""
frozen = pytree_transforms.deep_freeze(
dict(a=[1001, 1002, dict(b=(1003, [1004, {1005, 1006}]))]))
self.assertIsInstance(frozen, collections.abc.Mapping)
self.assertNotIsInstance(frozen, collections.abc.MutableMapping)
self.assertIsInstance(frozen['a'], tuple)
# `frozen` is hashable, and hashes to an integer.
self.assertIsInstance(hash(frozen), int)
if __name__ == '__main__':
unittest.main()

View File

@ -1,4 +0,0 @@
immutabledict>=4.2.0
numpy>=1.26.4
orbax-checkpoint>=0.0.0

View File

@ -67,6 +67,35 @@ void ForeachPackedAndRawType() {
}
}
template <class Test, class D>
void ForeachActivationType1(D d) {
Test test;
test(float(), d);
test(BF16(), d);
}
template <class Test, class D>
void ForeachActivationType2(D d) {
Test test;
test(float(), float(), d);
test(float(), BF16(), d);
test(BF16(), float(), d);
test(BF16(), BF16(), d);
}
template <class Test, class D>
void ForeachActivationType3(D d) {
Test test;
test(float(), float(), float(), d);
test(float(), float(), BF16(), d);
test(float(), BF16(), float(), d);
test(float(), BF16(), BF16(), d);
test(BF16(), float(), float(), d);
test(BF16(), float(), BF16(), d);
test(BF16(), BF16(), float(), d);
test(BF16(), BF16(), BF16(), d);
}
// Generates inputs: deterministic, within max SfpStream range.
template <typename MatT>
MatStorageT<MatT> GenerateMat(const Extents2D& extents,
@ -85,7 +114,7 @@ MatStorageT<MatT> GenerateMat(const Extents2D& extents,
row[c] = f;
}
Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
MakeSpan(compressed.Row(r), compressed.Cols()),
MakeSpan(compressed.Row(r), extents.cols),
/*packed_ofs=*/0);
});
@ -93,7 +122,8 @@ MatStorageT<MatT> GenerateMat(const Extents2D& extents,
return compressed;
}
// Same, but `extents` describes the transposed matrix.
// Same, but `extents` describes the transposed matrix and the computation of
// `f` swaps `r` and `c`.
template <typename MatT>
MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
const Allocator& allocator,
@ -112,7 +142,7 @@ MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
row[c] = f;
}
Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
MakeSpan(compressed.Row(r), compressed.Cols()),
MakeSpan(compressed.Row(r), extents.cols),
/*packed_ofs=*/0);
});

View File

@ -45,10 +45,11 @@ namespace gcpp {
// as NEON_WITHOUT_AES. Also skip SVE because SVE2_128 and SVE_256 cover most.
#define GEMMA_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON | HWY_SVE)
#elif HWY_ARCH_X86
// Skip anything older than Haswell (2013); also use Zen4 for recent CPUs,
// because we do not use anything added by SPR (e.g. FP16) nor AVX 10.2.
// Skip anything older than Haswell (2013); use Zen4/SPR for recent CPUs.
// Although we do not use SPR's F16, Zen4 is only enabled for AMD. We do not
// yet use any AVX 10.2 features.
#define GEMMA_DISABLED_TARGETS \
(HWY_SCALAR | HWY_SSE2 | HWY_SSSE3 | HWY_SSE4 | HWY_AVX3_SPR | HWY_AVX10_2)
(HWY_SCALAR | HWY_SSE2 | HWY_SSSE3 | HWY_SSE4 | HWY_AVX10_2)
#endif // HWY_ARCH_*
#endif // GEMMA_DISABLED_TARGETS
@ -88,6 +89,26 @@ struct SfpStream {
};
#pragma pack(pop)
#pragma pack(push, 1)
struct I8Stream {
static constexpr size_t kGroupSize = 128;
using ScaleT = hwy::bfloat16_t;
// Returns number of I8Stream to allocate for the stream, which matches its
// size in bytes.
// TODO: should support other types beyond hwy::float32_t for scale and
// zero-point.
static constexpr size_t PackedEnd(size_t capacity) {
const size_t num_groups = hwy::DivCeil(capacity, kGroupSize);
return (sizeof(ScaleT) * num_groups) + // scale
(sizeof(ScaleT) * num_groups) + // zero-point
capacity; // 1 value per byte
}
int8_t i;
};
#pragma pack(pop)
// Non-uniform quantization: a compressed representation of f32 inputs that
// supports seeking at a granularity of 1 (for `DecompressAndZeroPad`) or
// two vectors (for `Decompress2`), and decoding to bf16/f32.
@ -186,12 +207,23 @@ constexpr bool IsNuqStream() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>();
}
// Tensor types for loading weights.
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64 };
template <typename Packed>
constexpr bool IsI8Stream() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, I8Stream>();
}
template <typename Packed>
constexpr bool SupportsPointerArithmetic() {
return !IsNuqStream<Packed>() && !IsI8Stream<Packed>();
}
// Tensor types for loading weights. Not all of these are supported weight
// types, some are only used for `Activations`.
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kU32, kU64, kI8 };
// These are used in `ModelConfig.Specifier`, hence the strings will not
// change, though new ones may be added.
static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16",
"sfp", "nuq", "f64"};
static constexpr const char* kTypeStrings[] = {
"unknown", "f32", "bf16", "sfp", "nuq", "f64", "u32", "u64", "i8"};
static constexpr size_t kNumTypes =
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
static constexpr size_t kTypeBits[] = {
@ -201,6 +233,9 @@ static constexpr size_t kTypeBits[] = {
8 * sizeof(SfpStream),
4 /* NuqStream, actually 4.5 */,
8 * sizeof(double),
8 * sizeof(uint32_t),
8 * sizeof(uint64_t),
8 * sizeof(I8Stream),
};
static inline bool EnumValid(Type type) {
@ -221,6 +256,12 @@ Type TypeEnum() {
return Type::kNUQ;
} else if constexpr (hwy::IsSame<Packed, double>()) {
return Type::kF64;
} else if constexpr (hwy::IsSame<Packed, uint32_t>()) {
return Type::kU32;
} else if constexpr (hwy::IsSame<Packed, uint64_t>()) {
return Type::kU64;
} else if constexpr (hwy::IsSame<Packed, I8Stream>()) {
return Type::kI8;
} else {
HWY_DASSERT(false);
return Type::kUnknown;
@ -241,7 +282,9 @@ const char* TypeName() {
template <typename Packed>
constexpr bool IsCompressed() {
return hwy::IsSameEither<hwy::RemoveCvRef<Packed>, SfpStream, NuqStream>();
return hwy::IsSame<hwy::RemoveCvRef<Packed>, SfpStream>() ||
hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>() ||
hwy::IsSame<hwy::RemoveCvRef<Packed>, I8Stream>();
}
// Returns the number of `MatT` elements required to store `capacity` values,
@ -252,6 +295,8 @@ template <typename Packed>
constexpr size_t CompressedArrayElements(size_t capacity) {
if constexpr (hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>()) {
return NuqStream::PackedEnd(capacity);
} else if constexpr (hwy::IsSame<hwy::RemoveCvRef<Packed>, I8Stream>()) {
return I8Stream::PackedEnd(capacity);
} else {
return capacity;
}

View File

@ -20,7 +20,6 @@
#include <iostream>
#include <ostream>
#include <random>
#include <string>
#include <vector>
@ -37,17 +36,6 @@
namespace gcpp {
void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
if (inference.deterministic) {
// Nothing up my sleeve number, at least some upper bits set.
gen.seed(0x12345678);
} else {
// Depending on the library implementation, this may still be deterministic.
std::random_device rd; // NOLINT
gen.seed(rd());
}
}
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference)
: ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) {
@ -60,12 +48,9 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
ctx_);
}
InitGenerator(inference, gen_);
runtime_config_ = {
.max_generated_tokens = inference.max_generated_tokens,
.temperature = inference.temperature,
.gen = &gen_,
.verbosity = inference.verbosity,
};
inference.CopyTo(runtime_config_);
@ -93,16 +78,16 @@ QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
<< runtime_config_.max_generated_tokens
<< "\ttemperature: " << runtime_config_.temperature << "\n";
}
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
gcpp::TimingInfo timing_info{.verbosity = runtime_config_.verbosity};
runtime_config_.batch_stream_token = batch_stream_token;
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_,
timing_info);
return result;
}
void GemmaEnv::QueryModel(
const std::vector<int>& tokens, const StreamFunc& stream_token) {
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
void GemmaEnv::QueryModel(const std::vector<int>& tokens,
const StreamFunc& stream_token) {
gcpp::TimingInfo timing_info{.verbosity = runtime_config_.verbosity};
const StreamFunc previous_stream_token = runtime_config_.stream_token;
runtime_config_.stream_token = stream_token;
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_,
@ -110,7 +95,7 @@ void GemmaEnv::QueryModel(
runtime_config_.stream_token = previous_stream_token;
}
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
QueryResultAndMetrics GemmaEnv::BatchQueryModelWithMetrics(
const QueriesPromptTokens& queries_prompt,
const hwy::Span<const size_t>& prefix_end) {
const size_t num_queries = queries_prompt.size();
@ -120,8 +105,14 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
const size_t pos,
const int token, float) {
HWY_ASSERT(query_index < num_queries);
if (token >= gemma_.Config().vocab_size) {
HWY_ABORT("Token %d >= vocab size %d", token, gemma_.Config().vocab_size);
}
std::string token_text;
HWY_ASSERT(gemma_.Tokenizer().Decode(std::vector<int>{token}, &token_text));
if (!gemma_.Tokenizer().Decode(std::vector<int>{token}, &token_text)) {
HWY_ABORT("Failed to decode token %d, tokenizer bytes %s\n", token,
gemma_.Tokenizer().Serialize().substr(0, 10).c_str());
}
res[query_index].response.append(token_text);
HWY_ASSERT(pos == res[query_index].tokens_generated);
res[query_index].tokens_generated += 1;
@ -149,29 +140,39 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
gcpp::AllQueries all_queries(queries_prompt, kv_caches, prefix_end);
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
gemma_.GenerateBatch(runtime_config_, all_queries, env_, timing_info);
return res;
return {res, timing_info};
}
QueryResult GemmaEnv::QueryModel(std::string& input) {
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
const QueriesPromptTokens& queries_prompt,
const hwy::Span<const size_t>& prefix_end) {
return BatchQueryModelWithMetrics(queries_prompt, prefix_end).query_results;
}
QueryResult GemmaEnv::QueryModel(const std::string& input) {
const std::vector<int> prompt = WrapAndTokenize(input);
return QueryModel(prompt);
}
QueryResultAndMetrics GemmaEnv::BatchQueryModelWithMetrics(
const std::vector<std::string>& prompt_strings) {
std::vector<PromptTokens> views;
views.reserve(prompt_strings.size());
std::vector<std::vector<int>> storage;
storage.reserve(prompt_strings.size());
for (auto& input : prompt_strings) {
storage.push_back(WrapAndTokenize(input));
views.push_back(PromptTokens(storage.back().data(), storage.back().size()));
}
QueriesPromptTokens span_of_views(views.data(), views.size());
return BatchQueryModelWithMetrics(span_of_views);
}
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
const std::vector<std::string>& inputs) {
std::vector<std::vector<int>> prompts;
prompts.reserve(inputs.size());
for (auto& input : inputs) {
std::string mutable_prompt = input;
prompts.push_back(WrapAndTokenize(mutable_prompt));
}
std::vector<PromptTokens> prompt_vector;
prompt_vector.reserve(prompts.size());
for (auto& prompt : prompts) {
prompt_vector.push_back(PromptTokens(prompt.data(), prompt.size()));
}
QueriesPromptTokens prompt_span(prompt_vector.data(), prompt_vector.size());
return BatchQueryModel(prompt_span);
return BatchQueryModelWithMetrics(inputs).query_results;
}
float GemmaEnv::CrossEntropy(const std::string& input) {
@ -256,8 +257,8 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
dt, cpu100, static_cast<int>(threading.bind),
ctx.topology.TopologyString(), ctx.pools.PinString(),
CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()),
ctx.allocator.VectorBytes() * 8, CompiledConfig(), PROFILER_ENABLED,
ctx.allocator.TotalMiB());
ctx.cache_info.VectorBytes() * 8, CompiledConfig(),
PROFILER_ENABLED, ctx.allocator.TotalMiB());
}
}

View File

@ -18,7 +18,6 @@
#include <stddef.h>
#include <random>
#include <string>
#include <vector>
@ -32,8 +31,6 @@
namespace gcpp {
void InitGenerator(const InferenceArgs& inference, std::mt19937& gen);
// Return type for query model calls.
struct QueryResult {
std::string response;
@ -42,6 +39,14 @@ struct QueryResult {
size_t response_start_pos = 0;
};
// Return type for batch query model calls with metrics.
struct QueryResultAndMetrics {
// The query results for each query in the batch.
std::vector<QueryResult> query_results;
// The timing information for the batch query.
TimingInfo timing_info;
};
// Convenience class to load a model and run inference.
class GemmaEnv {
public:
@ -71,7 +76,7 @@ class GemmaEnv {
return tokens;
}
std::vector<int> WrapAndTokenize(std::string& input) const {
std::vector<int> WrapAndTokenize(const std::string& input) const {
return gcpp::WrapAndTokenize(gemma_.Tokenizer(), gemma_.ChatTemplate(),
gemma_.Config().wrapping, 0, input);
}
@ -82,21 +87,30 @@ class GemmaEnv {
return string;
}
// Adds turn structure to input, tokenizes and calls the below overload.
QueryResult QueryModel(const std::string& input);
// Runs inference on the given input and returns the top-1 result string and
// the number of tokens that were generated.
QueryResult QueryModel(const std::vector<int>& tokens);
// Runs inference on the given input and calls the callback for each token.
void QueryModel(const std::vector<int>& tokens,
const StreamFunc& stream_token);
// Similar to the above, but runs inference on a batch of inputs.
std::vector<QueryResult> BatchQueryModel(
const std::vector<std::string>& inputs);
// The default prefix_end means "causal attention".
std::vector<QueryResult> BatchQueryModel(
const QueriesPromptTokens& queries_prompt,
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>());
// Adds turn structure to input, tokenizes and calls the above overload.
QueryResult QueryModel(std::string& input);
std::vector<QueryResult> BatchQueryModel(
const std::vector<std::string>& inputs);
// Runs inference on the given input and calls the callback for each token.
void QueryModel(const std::vector<int>& tokens,
const StreamFunc& stream_token);
// Similar to the above, but returns timing information in addition to the
// query results.
QueryResultAndMetrics BatchQueryModelWithMetrics(
const std::vector<std::string>& prompt_strings);
QueryResultAndMetrics BatchQueryModelWithMetrics(
const QueriesPromptTokens& queries_prompt,
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>());
// Runs inference on the given input and returns the cross entropy, a measure
// of how well the model predicts the correct output. It is the average
@ -107,7 +121,6 @@ class GemmaEnv {
int Verbosity() const { return runtime_config_.verbosity; }
RuntimeConfig& MutableConfig() { return runtime_config_; }
std::mt19937& MutableGen() { return gen_; }
KVCache& MutableKVCache() { return kv_caches_[0]; }
MatMulEnv& MutableEnv() { return env_; }
@ -115,7 +128,6 @@ class GemmaEnv {
ThreadingContext ctx_;
MatMulEnv env_;
Gemma gemma_;
std::mt19937 gen_; // Random number generator.
std::vector<KVCache> kv_caches_; // Same number as query batch.
RuntimeConfig runtime_config_;
};

View File

@ -56,11 +56,10 @@ static std::string TokenString(const GemmaTokenizer& tokenizer, int token) {
return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'";
}
void LogTopK(const GemmaTokenizer& tokenizer, const float* dist, size_t len,
size_t k) {
std::vector<std::pair<float, int>> sorted(len);
for (size_t i = 0; i < len; ++i) {
sorted[i] = std::make_pair(dist[i], static_cast<int>(i));
void LogTopK(const GemmaTokenizer& tokenizer, Logits logits, size_t k) {
std::vector<std::pair<float, int>> sorted(logits.size());
for (size_t i = 0; i < logits.size(); ++i) {
sorted[i] = std::make_pair(logits[i], static_cast<int>(i));
}
std::sort(sorted.begin(), sorted.end(),
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
@ -84,9 +83,8 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size,
hwy::Profiler& p) {
Softmax(logits, vocab_size, p, hwy::Profiler::Thread());
void CallSoftmax(Logits logits, hwy::Profiler& p) {
Softmax(logits, p, hwy::Profiler::GlobalIdx());
}
} // namespace HWY_NAMESPACE
@ -101,25 +99,26 @@ HWY_EXPORT(CallSoftmax);
float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
MatMulEnv& env, int verbosity) {
const StreamFunc stream_token = [](int, float) { return true; };
const BatchStreamFunc stream_token = [](size_t, size_t, int, float) {
return true;
};
const int vocab_size = gemma.Config().vocab_size;
float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s)
size_t pos = 1;
const SampleFunc sample_token = [&](float* probs,
size_t vocab_size) -> TokenAndProb {
const SampleFunc sample_token = [&](size_t qi, size_t pos, Logits logits,
size_t /*worker*/) -> TokenAndProb {
// input is logits, not yet probabilities
HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size, env.ctx.profiler);
HWY_DYNAMIC_DISPATCH(CallSoftmax)(logits, env.ctx.profiler);
// We are called for each token, but pos starts at 1. Clamping
// max_generated_tokens to prompt.size() should prevent overrun.
HWY_ASSERT(pos < prompt.size());
const int token = prompt[pos];
const float prob = probs[token];
const float prob = logits[token];
cross_entropy -= std::max(std::log(prob), -64.0f);
if (verbosity >= 4) {
LogTopK(gemma.Tokenizer(), probs, vocab_size, 10);
LogTopK(gemma.Tokenizer(), logits, 10);
}
if (verbosity >= 3) {
printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos, token,
@ -130,7 +129,6 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
cross_entropy / std::log(2.0) / (pos + 1));
}
++pos;
return TokenAndProb{.token = token, .prob = prob};
};
@ -139,9 +137,8 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
RuntimeConfig runtime = {
.max_generated_tokens = max_generated_tokens - 1,
.temperature = 0.0f,
.gen = nullptr,
.verbosity = verbosity,
.stream_token = stream_token,
.batch_stream_token = stream_token,
.sample_func = sample_token,
};
TimingInfo timing_info;

View File

@ -15,6 +15,7 @@
#include <stdio.h>
#include <algorithm>
#include <string>
#include <vector>
@ -37,7 +38,6 @@ class GemmaBatchBench : public ::testing::Test {
protected:
std::vector<std::string> BatchGemmaReply(
const std::vector<std::string>& inputs) {
s_env->SetMaxGeneratedTokens(24);
s_env->MutableConfig().temperature = 0.0f; // deterministic
s_env->MutableConfig().verbosity = 2;
std::vector<std::string> replies;
@ -49,55 +49,94 @@ class GemmaBatchBench : public ::testing::Test {
};
TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
const std::vector<std::string> questions = {
{"Write me a poem about Australia?"},
{"What's the history of Denmark?"},
{"Write me a comedy story about the USA."},
{"Teach me about GPU programming."},
{"Write me a story about the moon."},
{"Write me a story about the universe."},
{"Write a poem about planet earth."},
{"Tell me more about olympic sports."},
{"How would you describe Washington State?"},
{"Write me a story about Silicon Valley."},
{"Write me about your best friend."},
{"How would you describe a unicorn?"},
{"Tell me about world war history."},
{"Tell me about Google."},
std::vector<std::string> prompts = {
{"Describe dynamic programming."},
{"Explain how electric cars work."},
{"Explain to me how to use Google Maps."},
{"Explain to me how AI works."},
{"Write me a poem about France."},
{"What's the history of Great Britain?"},
{"Write me a comedy story about Florida."},
{"Teach me about dynamic programming."},
{"Write me a story about Jupiter."},
{"Write me a story about space ships."},
{"Write a poem about some random planet."},
{"Tell me more about team sports."},
{"How would you describe Michigan State?"},
{"Write me a story about Europe."},
{"Write me about your best colleague."},
{"How would you describe a horse?"},
{"Tell me about World War 2."},
{"How does AI work?"},
{"How would you describe a unicorn?"},
{"Please share some good cooking tips."},
{"Tell me about space travel."},
{"Explain to me how electric cars work."},
{"Teach me about GPU programming."},
{"Tell me a fact about World War 2."},
{"Tell me about Google."},
{"Tell me more about olympic sports."},
{"Tell me something about space travel."},
{"What is a horse?"},
{"What is Michigan State?"},
{"What's the history of Denmark?"},
{"Write a poem about planet earth."},
{"Write a story about Jupiter."},
{"Write about the moon."},
{"Write me a comedy story about Florida."},
{"Write me a poem about France."},
};
const std::vector<std::string> start = {
{"What is"}, {"When did"}, {"Where did"}, {"How did"}, {"Why did"}};
const std::vector<std::string> concepts = {"Socrates",
"Einstein",
"Leonardo",
"Cleopatra",
"Adele",
"Mars",
"Turing",
"Mozart",
"democracy",
"gravity",
"AI",
"evolution",
"physics",
"the internet",
"steam engine",
"inflation",
"electricity",
"the Sahara",
"NASA",
"Rome",
"the UN",
"Google",
"the Renaissance",
"Hamlet",
"poetry",
"Stoicism",
"geometry",
"DNA",
"Star Wars",
"1984"};
const std::vector<std::string> end = {"exist?", "work?", "happen?",
"lead to?", "believe?", "result in?"};
for (const std::string& s : start) {
for (const std::string& c : concepts) {
for (const std::string& e : end) {
prompts.push_back(s + " " + c + " " + e);
}
}
}
AesCtrEngine engine(true);
std::shuffle(prompts.begin(), prompts.end(), RngStream(engine, 123));
// Fills prompts round robin from `questions` until the desired batch size.
// Fills `inputs` by repeating from `prompts` until the desired batch size.
std::vector<std::string> inputs;
inputs.reserve(s_env->MutableConfig().decode_qbatch_size);
size_t qpos = 0;
for (size_t i = 0; i < inputs.capacity(); ++i) {
inputs.push_back(questions[qpos++]);
if (qpos == questions.size()) qpos = 0;
inputs.push_back(prompts[qpos++]);
if (qpos == prompts.size()) qpos = 0;
}
s_env->SetMaxGeneratedTokens(24);
std::vector<std::string> responses = BatchGemmaReply(inputs);
for (size_t i = 0; i < hwy::Unpredictable1() * 3; ++i) {
for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size());
++i) {
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
}
PROFILER_PRINT_RESULTS();
// Run again: prefill will be faster due to autotuning. Fewer decode steps
// because those are already fast.
s_env->SetMaxGeneratedTokens(2);
responses = BatchGemmaReply(inputs);
PROFILER_PRINT_RESULTS();
}
} // namespace
} // namespace gcpp

View File

@ -115,7 +115,6 @@ TEST_F(GemmaTest, Multiturn) {
RuntimeConfig runtime_config{
.max_generated_tokens = 64,
.temperature = 0.0f,
.gen = &s_env->MutableGen(),
.verbosity = 2,
.batch_stream_token = stream_token,
};
@ -138,6 +137,10 @@ TEST_F(GemmaTest, Multiturn) {
// Reset the `response` string here, then check that the model actually has
// access to the previous turn by asking to reproduce.
response.clear();
// -1 because our prefill does not generate KVs for the last token. Do not
// just pass abs_pos - 1 because our callback checks pos == abs_pos.
HWY_ASSERT(abs_pos > 0);
--abs_pos;
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
s_env->MutableEnv(), timing_info);
fprintf(stderr, "decoded: '%s'\n", response.c_str());
@ -155,9 +158,6 @@ TEST_F(GemmaTest, CrossEntropySmall) {
float entropy = s_env->CrossEntropy(kSmall);
fprintf(stderr, "per-token entropy: %f\n", entropy);
switch (config.model) {
case gcpp::Model::GRIFFIN_2B:
EXPECT_NEAR(entropy, 2.61f, 0.02f);
break;
case gcpp::Model::GEMMA2_2B:
EXPECT_NEAR(entropy, 1.14f, 0.02f);
break;

View File

@ -126,7 +126,6 @@ void Run(GemmaEnv& env, JsonArgs& json) {
gcpp::RuntimeConfig runtime_config = {
.max_generated_tokens = 30,
.temperature = 0.0f,
.gen = &env.MutableGen(),
.verbosity = env.Verbosity(),
.stream_token = stream_token,
};

View File

@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
include(FetchContent)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 1d16731233de45a365b43867f27d0a5f73925300)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee)
FetchContent_MakeAvailable(highway)
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9)
FetchContent_MakeAvailable(sentencepiece)

View File

@ -17,8 +17,8 @@
#include <cstdlib>
#include <cstring>
#include <functional>
#include <iostream>
#include <random>
#include <set>
#include <string>
#include <vector>
@ -44,7 +44,7 @@ int main(int argc, char** argv) {
for (int arg = 0; arg < argc; ++arg) {
// Find a --reject flag and consume everything after it.
if (strcmp(argv[arg], "--reject") == 0) {
while (++arg < argc) reject_tokens.insert(atoi(argv[arg]));
while (++arg < argc) reject_tokens.insert(atoi(argv[arg])); // NOLINT
}
}
@ -55,11 +55,6 @@ int main(int argc, char** argv) {
gcpp::KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
size_t generated = 0;
// Initialize random number generator
std::mt19937 gen;
std::random_device rd; // NOLINT
gen.seed(rd());
// Tokenize instructions.
std::string prompt = "Write a greeting to the world.";
const std::vector<int> tokens =
@ -84,7 +79,6 @@ int main(int argc, char** argv) {
gcpp::RuntimeConfig runtime_config = {
.max_generated_tokens = 1024,
.temperature = 1.0,
.gen = &gen,
.verbosity = 0,
.stream_token = stream_token,
.accept_token =

View File

@ -15,7 +15,7 @@ cc_library(
deps = [
"//:gemma_args",
"//:gemma_lib",
"//:matmul",
"//:matmul_env",
"//:threading_context",
"//:tokenizer",
"@highway//:hwy",

View File

@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
include(FetchContent)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 1d16731233de45a365b43867f27d0a5f73925300)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee)
FetchContent_MakeAvailable(highway)
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
FetchContent_MakeAvailable(sentencepiece)

View File

@ -18,7 +18,6 @@
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <random>
#include <set>
#include <string>
#include <vector>
@ -38,11 +37,7 @@ class SimplifiedGemma {
: ctx_(threading),
env_(ctx_),
gemma_(loader, inference, ctx_),
kv_cache_(gemma_.Config(), inference, ctx_.allocator) {
// Initialize random number generator
std::random_device rd;
gen_.seed(rd());
}
kv_cache_(gemma_.Config(), inference, ctx_.allocator) {}
SimplifiedGemma(int argc, char** argv)
: SimplifiedGemma(gcpp::LoaderArgs(argc, argv),
@ -76,7 +71,6 @@ class SimplifiedGemma {
gcpp::RuntimeConfig runtime_config = {
.max_generated_tokens = max_generated_tokens,
.temperature = temperature,
.gen = &gen_,
.verbosity = 0,
.stream_token = stream_token,
.accept_token =
@ -93,6 +87,5 @@ class SimplifiedGemma {
gcpp::MatMulEnv env_;
gcpp::Gemma gemma_;
gcpp::KVCache kv_cache_;
std::mt19937 gen_;
std::string validation_error_;
};

View File

@ -23,50 +23,24 @@
#include <atomic>
#include <vector>
#include "gemma/configs.h" // ModelConfig
#include "ops/matmul.h" // MatMulEnv
#include "ops/ops.h" // CreateInvTimescale
#include "util/allocator.h" // Allocator
#include "util/basics.h" // BF16
#include "util/mat.h" // MatStorageT
#include "gemma/configs.h" // ModelConfig
#include "ops/ops.h" // CreateInvTimescale
#include "util/basics.h" // BF16
#include "util/mat.h" // MatStorageT
#include "util/threading_context.h"
namespace gcpp {
struct GriffinActivations {
GriffinActivations(const ModelConfig& config, size_t batch_size,
const Allocator& allocator)
: griffin_x(
MatFactory("griffin_x", batch_size, config.model_dim, allocator)),
griffin_y(
MatFactory("griffin_y", batch_size, config.model_dim, allocator)),
griffin_gate_x(MatFactory("griffin_gate_x", batch_size,
config.model_dim, allocator)),
griffin_multiplier(MatFactory("griffin_mul", batch_size,
config.model_dim, allocator)) {}
void SetBatchSize(size_t batch_size) {
if (griffin_x.Rows() == 0) return;
griffin_x.OverrideRows(batch_size);
griffin_y.OverrideRows(batch_size);
griffin_gate_x.OverrideRows(batch_size);
griffin_multiplier.OverrideRows(batch_size);
}
MatStorageT<float> griffin_x;
MatStorageT<float> griffin_y;
MatStorageT<float> griffin_gate_x;
MatStorageT<float> griffin_multiplier;
};
struct AttentionActivations {
// Returns the scale value to use for the query in the attention computation.
// Also called by ops_test.
static inline float ChooseQueryScale(const ModelConfig& config) {
const LayerConfig& layer_config = config.layer_configs[0];
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
return 1.0f / sqrtf(static_cast<float>(config.model_dim /
config.layer_configs[0].heads));
return 1.0f /
sqrtf(static_cast<float>(config.model_dim / layer_config.heads));
// QueryScaleType::SqrtKeySize
return 1.0f / sqrtf(static_cast<float>(config.layer_configs[0].qkv_dim));
return 1.0f / sqrtf(static_cast<float>(layer_config.qkv_dim));
}
AttentionActivations(
@ -82,7 +56,11 @@ struct AttentionActivations {
? layer_config.heads * 3 * layer_config.qkv_dim
: layer_config.heads * layer_config.qkv_dim,
allocator)),
q_T(MatFactory("q_T", layer_config.qkv_dim,
config.vocab_size == 0
? batch_size * layer_config.heads * 3
: batch_size * layer_config.heads,
allocator)),
pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size,
config.model_dim, allocator)),
att(MatFactory("att", batch_size, layer_config.heads * seq_len,
@ -116,11 +94,13 @@ struct AttentionActivations {
// If we forget any MatMul outputs here, debug builds print a warning but
// fill them in each MatMul call.
q.AllocateAndAttachRowPtrs(row_ptrs);
q_T.AllocateAndAttachRowPtrs(row_ptrs);
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
}
void SetBatchSize(size_t batch_size) {
q.OverrideRows(batch_size);
// q_T rows are always qkv_dim!
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
@ -131,6 +111,7 @@ struct AttentionActivations {
const ModelConfig& config;
MatStorageT<float> q; // query
MatStorageT<float> q_T; // Transposed to maximize attention speed.
MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector
@ -143,36 +124,41 @@ struct AttentionActivations {
MatStorageT<float> inv_timescale_global;
hwy::Divisor div_seq_len;
// Unfortunately, some models (Griffin) have non-power-of-two heads.
// Unfortunately, some models have had non-power-of-two heads.
hwy::Divisor div_heads;
float query_scale;
};
struct Activations {
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len,
const Allocator& allocator,
ThreadingContext& ctx,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: layer_config(config.layer_configs[0]),
x(MatFactory("x", batch_size, config.model_dim, allocator)),
logits(MatFactory("logits", batch_size, config.vocab_size, allocator)),
x(MatFactory("x", batch_size, config.model_dim, ctx.allocator)),
x_bf(MatFactory("x_bf", batch_size, config.model_dim, ctx.allocator)),
logits(
MatFactory("logits", batch_size, config.vocab_size, ctx.allocator)),
sampled(MatFactory("sampled", batch_size, 3, ctx.allocator)),
pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size,
config.model_dim, allocator)),
C1(MatFactory("C1", batch_size, layer_config.ff_hidden_dim, allocator)),
C2(MatFactory("C2", batch_size, layer_config.ff_hidden_dim, allocator)),
ffw_out(MatFactory("ffw_out", batch_size, config.model_dim, allocator)),
config.model_dim, ctx.allocator)),
C1(MatFactory("C1", batch_size, layer_config.ff_hidden_dim,
ctx.allocator)),
C2(MatFactory("C2", batch_size, layer_config.ff_hidden_dim,
ctx.allocator)),
ffw_out(
MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)),
attention(config, layer_config, batch_size, seq_len, allocator,
row_ptrs),
griffin(config, config.model == Model::GRIFFIN_2B ? batch_size : 0,
allocator) {
attention(config, layer_config, batch_size, seq_len, ctx.allocator,
row_ptrs) {
HWY_ASSERT(batch_size != 0);
// For MatMul outputs, precompute their row pointers.
// If we forget any MatMul outputs here, debug builds print a warning but
// fill them in each MatMul call.
x.AllocateAndAttachRowPtrs(row_ptrs);
x_bf.AllocateAndAttachRowPtrs(row_ptrs);
logits.AllocateAndAttachRowPtrs(row_ptrs);
C1.AllocateAndAttachRowPtrs(row_ptrs);
C2.AllocateAndAttachRowPtrs(row_ptrs);
@ -184,7 +170,9 @@ struct Activations {
// Negligible CPU time.
void SetBatchSize(size_t batch_size) {
x.OverrideRows(batch_size);
x_bf.OverrideRows(batch_size);
logits.OverrideRows(batch_size);
sampled.OverrideRows(batch_size);
pre_ffw_rms_out.OverrideRows(batch_size);
C1.OverrideRows(batch_size);
@ -192,23 +180,22 @@ struct Activations {
ffw_out.OverrideRows(batch_size);
attention.SetBatchSize(batch_size);
griffin.SetBatchSize(batch_size);
}
const LayerConfig& layer_config;
MatStorageT<float> x; // input
MatStorageT<float> logits;
MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul
MatStorageT<float> logits; // TODO: BF16 after Softmax supports that.
MatStorageT<uint32_t> sampled; // batch_size x 3 (padded)
// Gated FFW
MatStorageT<BF16> pre_ffw_rms_out;
// Norm may be large, so prefer to keep as f32.
MatStorageT<float> C1;
MatStorageT<float> C2;
MatStorageT<BF16> ffw_out;
MatStorageT<BF16> C1;
MatStorageT<BF16> C2;
MatStorageT<float> ffw_out;
AttentionActivations attention;
GriffinActivations griffin;
};
} // namespace gcpp

359
gemma/api_client.cc Normal file
View File

@ -0,0 +1,359 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Test client for API server
#include <iostream>
#include <string>
#include <sstream>
#include <cstdlib>
#include <cstring>
#include "httplib.h"
#include "nlohmann/json.hpp"
#include "gemma/gemma_args.h"
using json = nlohmann::json;
// ANSI color codes
const std::string RESET = "\033[0m";
const std::string BOLD = "\033[1m";
const std::string GREEN = "\033[32m";
const std::string BLUE = "\033[34m";
const std::string CYAN = "\033[36m";
const std::string YELLOW = "\033[33m";
const std::string RED = "\033[31m";
class APIClient {
public:
APIClient(const std::string& host, int port, const std::string& api_key = "", const std::string& model = "gemma3-4b")
: host_(host), port_(port), api_key_(api_key), model_(model), use_https_(port == 443), interactive_mode_(false) {
if (use_https_) {
ssl_client_ = std::make_unique<httplib::SSLClient>(host, port);
ssl_client_->set_read_timeout(60, 0);
ssl_client_->set_write_timeout(60, 0);
ssl_client_->enable_server_certificate_verification(false);
} else {
client_ = std::make_unique<httplib::Client>(host, port);
client_->set_read_timeout(60, 0);
client_->set_write_timeout(60, 0);
}
}
// Unified request processing for both public and local APIs
json ProcessRequest(const json& request, bool stream = true) {
bool is_public_api = !api_key_.empty();
std::string endpoint;
if (is_public_api) {
endpoint = stream ? "/v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse"
: "/v1beta/models/gemini-2.0-flash:generateContent";
} else {
endpoint = stream ? "/v1beta/models/" + model_ + ":streamGenerateContent"
: "/v1beta/models/" + model_ + ":generateContent";
}
// Only show verbose output in non-interactive mode
if (!interactive_mode_) {
std::cout << "\n" << BOLD << BLUE << "📤 POST " << endpoint << RESET << std::endl;
std::cout << "Request: " << request.dump(2) << std::endl;
}
if (stream) {
return ProcessStreamingRequest(request, endpoint);
} else {
return ProcessNonStreamingRequest(request, endpoint);
}
}
void TestGenerateContent(const std::string& prompt, bool stream = true) {
json request = CreateAPIRequest(prompt);
json response = ProcessRequest(request, stream);
if (response.contains("error")) {
std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET << std::endl;
}
}
void TestListModels() {
std::cout << "\n" << BOLD << BLUE << "📤 GET /v1beta/models" << RESET << std::endl;
httplib::Headers headers;
if (!api_key_.empty()) {
headers.emplace("X-goog-api-key", api_key_);
}
auto res = use_https_ ? ssl_client_->Get("/v1beta/models", headers) : client_->Get("/v1beta/models", headers);
if (res && res->status == 200) {
json response = json::parse(res->body);
std::cout << GREEN << "✅ Available models:" << RESET << std::endl;
std::cout << response.dump(2) << std::endl;
} else {
std::cerr << RED << "❌ Request failed" << RESET << std::endl;
}
}
void InteractiveChat() {
std::cout << "\n" << BOLD << CYAN << "💬 Interactive Chat Mode (with session)" << RESET << std::endl;
std::cout << "Type ':gemma %q' to end.\n" << std::endl;
interactive_mode_ = true;
json messages;
while (true) {
std::cout << BOLD << BLUE << "You: " << RESET;
std::string input;
std::getline(std::cin, input);
if (input == ":gemma %q") {
std::cout << BOLD << YELLOW << "👋 Goodbye!" << RESET << std::endl;
break;
}
if (input.empty()) continue;
// Add user message with proper role
json user_message = {{"parts", {{{"text", input}}}}};
if (!api_key_.empty()) {
user_message["role"] = "user";
}
messages.push_back(user_message);
// Create request using unified logic
json request = CreateAPIRequest("", messages);
std::cout << BOLD << GREEN << "Assistant: " << RESET;
// Use unified processing - streaming for real-time output
json response = ProcessRequest(request, true);
if (response.contains("candidates") && !response["candidates"].empty()) {
auto& candidate = response["candidates"][0];
if (candidate.contains("content") && candidate["content"].contains("parts")) {
for (const auto& part : candidate["content"]["parts"]) {
if (part.contains("text")) {
std::string assistant_response = part["text"].get<std::string>();
// For streaming, the response is already displayed in real-time
// Just add to message history for context
json assistant_message = {{"parts", {{{"text", assistant_response}}}}};
if (!api_key_.empty()) {
assistant_message["role"] = "model";
}
messages.push_back(assistant_message);
}
}
}
} else if (response.contains("error")) {
std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET << std::endl;
}
std::cout << std::endl;
}
}
private:
json CreateAPIRequest(const std::string& prompt, const json& messages = json::array()) {
json request = {
{"generationConfig", {
{"temperature", 0.9},
{"topK", 1},
{"maxOutputTokens", 1024}
}}
};
if (messages.empty()) {
// Single prompt
json user_message = {{"parts", {{{"text", prompt}}}}};
if (!api_key_.empty()) {
user_message["role"] = "user";
}
request["contents"] = json::array({user_message});
} else {
// Use provided message history
request["contents"] = messages;
}
return request;
}
json ProcessNonStreamingRequest(const json& request, const std::string& endpoint) {
httplib::Headers headers = {{"Content-Type", "application/json"}};
if (!api_key_.empty()) {
headers.emplace("X-goog-api-key", api_key_);
}
auto res = use_https_ ? ssl_client_->Post(endpoint, headers, request.dump(), "application/json")
: client_->Post(endpoint, headers, request.dump(), "application/json");
if (res && res->status == 200) {
json response = json::parse(res->body);
if (!interactive_mode_) {
std::cout << "\n" << BOLD << GREEN << "📥 Response:" << RESET << std::endl;
std::cout << response.dump(2) << std::endl;
}
return response;
} else {
json error_response = {
{"error", {
{"message", "Request failed"},
{"status", res ? res->status : -1}
}}
};
if (res && !res->body.empty()) {
error_response["error"]["details"] = res->body;
}
std::cerr << RED << "❌ Request failed. Status: " << (res ? res->status : -1) << RESET << std::endl;
return error_response;
}
}
json ProcessStreamingRequest(const json& request, const std::string& endpoint) {
std::string accumulated_response;
// Use same SSE logic for both public and local APIs
httplib::Request req;
req.method = "POST";
req.path = endpoint;
req.set_header("Content-Type", "application/json");
if (!api_key_.empty()) {
req.set_header("X-goog-api-key", api_key_);
}
req.body = request.dump();
req.content_receiver = [&accumulated_response, this](const char* data, size_t data_length, uint64_t offset, uint64_t total_length) -> bool {
std::string chunk(data, data_length);
std::istringstream stream(chunk);
std::string line;
while (std::getline(stream, line)) {
if (line.substr(0, 6) == "data: ") {
std::string event_data = line.substr(6);
if (event_data == "[DONE]") {
if (!interactive_mode_) {
std::cout << "\n\n" << GREEN << "✅ Generation complete!" << RESET << std::endl;
}
} else {
try {
json event = json::parse(event_data);
if (event.contains("candidates") && !event["candidates"].empty()) {
auto& candidate = event["candidates"][0];
if (candidate.contains("content") && candidate["content"].contains("parts")) {
for (const auto& part : candidate["content"]["parts"]) {
if (part.contains("text")) {
std::string text = part["text"].get<std::string>();
std::cout << text << std::flush;
accumulated_response += text;
}
}
}
}
} catch (const json::exception& e) {
// Skip parse errors
}
}
}
}
return true;
};
httplib::Response res;
httplib::Error error;
bool success = use_https_ ? ssl_client_->send(req, res, error) : client_->send(req, res, error);
if (res.status == 200 && !accumulated_response.empty()) {
return json{
{"candidates", {{
{"content", {
{"parts", {{{"text", accumulated_response}}}}
}}
}}}
};
} else {
json error_response = {
{"error", {
{"message", "Streaming request failed"},
{"status", res.status}
}}
};
if (!res.body.empty()) {
error_response["error"]["details"] = res.body;
}
std::cerr << RED << "❌ Streaming request failed. Status: " << res.status << RESET << std::endl;
return error_response;
}
}
private:
std::unique_ptr<httplib::Client> client_;
std::unique_ptr<httplib::SSLClient> ssl_client_;
std::string host_;
int port_;
std::string api_key_;
std::string model_;
bool use_https_;
bool interactive_mode_;
};
int main(int argc, char* argv[]) {
gcpp::ClientArgs client_args(argc, argv);
if (gcpp::HasHelp(argc, argv)) {
std::cout << "\nAPI Client for gemma.cpp\n";
std::cout << "========================\n\n";
client_args.Help();
std::cout << std::endl;
std::cout << "Environment Variables:" << std::endl;
std::cout << " GOOGLE_API_KEY : Automatically use public Google API if set" << std::endl;
return 0;
}
// Check for GOOGLE_API_KEY environment variable
const char* env_api_key = std::getenv("GOOGLE_API_KEY");
if (env_api_key != nullptr && strlen(env_api_key) > 0) {
client_args.api_key = env_api_key;
client_args.host = "generativelanguage.googleapis.com";
client_args.port = 443;
}
// Handle API key override
if (!client_args.api_key.empty()) {
client_args.host = "generativelanguage.googleapis.com";
client_args.port = 443;
}
std::cout << BOLD << YELLOW << "🚀 Testing API Server at "
<< client_args.host << ":" << client_args.port << RESET << std::endl;
try {
APIClient client(client_args.host, client_args.port, client_args.api_key, client_args.model);
if (client_args.interactive) {
client.InteractiveChat();
} else {
client.TestListModels();
client.TestGenerateContent(client_args.prompt, true);
}
} catch (const std::exception& e) {
std::cerr << RED << "❌ Error: " << e.what() << RESET << std::endl;
std::cerr << "Make sure the API server is running:" << std::endl;
std::cerr << " ./build/gemma_api_server --tokenizer <path> --weights <path>" << std::endl;
return 1;
}
return 0;
}

516
gemma/api_server.cc Normal file
View File

@ -0,0 +1,516 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// HTTP API server for gemma.cpp with SSE support
#include <stdio.h>
#include <signal.h>
#include <iostream>
#include <memory>
#include <random>
#include <string>
#include <string_view>
#include <vector>
#include <thread>
#include <atomic>
#include <chrono>
#include <sstream>
#include <iomanip>
#include <mutex>
#include <unordered_map>
// HTTP server library
#undef CPPHTTPLIB_OPENSSL_SUPPORT
#undef CPPHTTPLIB_ZLIB_SUPPORT
#include "httplib.h"
// JSON library
#include "nlohmann/json.hpp"
#include "compression/types.h"
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/tokenizer.h"
#include "ops/matmul.h"
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/profiler.h"
using json = nlohmann::json;
namespace gcpp {
static std::atomic<bool> server_running{true};
// Server state holding model and KV caches
struct ServerState {
std::unique_ptr<Gemma> gemma;
MatMulEnv* env;
ThreadingContext* ctx;
// Session-based KV cache storage
struct Session {
std::unique_ptr<KVCache> kv_cache;
size_t abs_pos = 0;
std::chrono::steady_clock::time_point last_access;
};
std::unordered_map<std::string, Session> sessions;
std::mutex sessions_mutex;
std::mutex inference_mutex;
// Cleanup old sessions after 30 minutes of inactivity
void CleanupOldSessions() {
std::lock_guard<std::mutex> lock(sessions_mutex);
auto now = std::chrono::steady_clock::now();
for (auto it = sessions.begin(); it != sessions.end();) {
if (now - it->second.last_access > std::chrono::minutes(30)) {
it = sessions.erase(it);
} else {
++it;
}
}
}
// Get or create session with KV cache
Session& GetOrCreateSession(const std::string& session_id) {
std::lock_guard<std::mutex> lock(sessions_mutex);
auto& session = sessions[session_id];
if (!session.kv_cache) {
session.kv_cache = std::make_unique<KVCache>(gemma->Config(), InferenceArgs(), env->ctx.allocator);
}
session.last_access = std::chrono::steady_clock::now();
return session;
}
};
// Generate a unique session ID
std::string GenerateSessionId() {
static std::atomic<uint64_t> counter{0};
std::stringstream ss;
ss << "session_" << std::hex
<< std::chrono::steady_clock::now().time_since_epoch().count() << "_"
<< counter.fetch_add(1);
return ss.str();
}
// Wraps messages with start_of_turn markers - handles both with and without roles
std::string WrapMessagesWithTurnMarkers(const json& contents) {
std::string prompt;
for (const auto& content : contents) {
if (content.contains("parts")) {
// Check if role is specified (public API format) or not (local format)
std::string role = content.value("role", "");
for (const auto& part : content["parts"]) {
if (part.contains("text")) {
std::string text = part["text"];
if (role == "user") {
prompt += "<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
} else if (role == "model") {
prompt += text + "\n";
} else if (role.empty()) {
// Local format without roles - for now, treat as user input
prompt += "<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
}
}
}
}
}
return prompt;
}
// Parse generation config
RuntimeConfig ParseGenerationConfig(const json& request) {
RuntimeConfig config;
config.verbosity = 0;
// Set defaults matching public API
config.temperature = 1.0f;
config.top_k = 1;
config.max_generated_tokens = 8192;
if (request.contains("generationConfig")) {
auto& gen_config = request["generationConfig"];
if (gen_config.contains("temperature")) {
config.temperature = gen_config["temperature"].get<float>();
}
if (gen_config.contains("topK")) {
config.top_k = gen_config["topK"].get<int>();
}
if (gen_config.contains("maxOutputTokens")) {
config.max_generated_tokens = gen_config["maxOutputTokens"].get<size_t>();
}
}
return config;
}
// Unified response formatter - creates consistent format regardless of request type
json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false) {
json response = {
{"candidates", {{
{"content", {
{"parts", {{{"text", text}}}},
{"role", "model"}
}},
{"index", 0}
}}},
{"promptFeedback", {{"safetyRatings", json::array()}}}
};
// Only add finishReason for non-streaming chunks
if (!is_streaming_chunk) {
response["candidates"][0]["finishReason"] = "STOP";
}
return response;
}
// Handle generateContent endpoint (non-streaming)
void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) {
try {
json request = json::parse(req.body);
// Get or create session
std::string session_id = request.value("sessionId", GenerateSessionId());
auto& session = state.GetOrCreateSession(session_id);
// Extract prompt from API format
std::string prompt;
if (request.contains("contents")) {
prompt = WrapMessagesWithTurnMarkers(request["contents"]);
} else {
res.status = 400;
res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json");
return;
}
// Lock for inference
std::lock_guard<std::mutex> lock(state.inference_mutex);
// Set up runtime config
RuntimeConfig runtime_config = ParseGenerationConfig(request);
// Collect full response
std::string full_response;
runtime_config.stream_token = [&full_response](int token, float) {
// Skip EOS token
return true;
};
// Tokenize prompt
std::vector<int> tokens = WrapAndTokenize(
state.gemma->Tokenizer(), state.gemma->ChatTemplate(),
state.gemma->Config().wrapping, session.abs_pos, prompt);
// Run inference with KV cache
TimingInfo timing_info = {.verbosity = 0};
size_t prefix_end = 0;
// Temporarily redirect output to capture response
std::stringstream output;
runtime_config.stream_token = [&output, &state, &session, &tokens](int token, float) {
// Skip prompt tokens
if (session.abs_pos < tokens.size()) {
session.abs_pos++;
return true;
}
session.abs_pos++;
// Check for EOS
if (state.gemma->Config().IsEOS(token)) {
return true;
}
// Decode token
std::string token_text;
state.gemma->Tokenizer().Decode(std::vector<int>{token}, &token_text);
output << token_text;
return true;
};
state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end,
*session.kv_cache, *state.env, timing_info);
// Create response
json response = CreateAPIResponse(output.str(), false);
response["usageMetadata"] = {
{"promptTokenCount", tokens.size()},
{"candidatesTokenCount", session.abs_pos - tokens.size()},
{"totalTokenCount", session.abs_pos}
};
res.set_content(response.dump(), "application/json");
} catch (const json::exception& e) {
res.status = 400;
res.set_content(
json{{"error",
{{"message", std::string("JSON parsing error: ") + e.what()}}}}
.dump(),
"application/json");
} catch (const std::exception& e) {
res.status = 500;
res.set_content(
json{{"error", {{"message", std::string("Server error: ") + e.what()}}}}
.dump(),
"application/json");
}
}
// Handle streamGenerateContent endpoint with SSE)
void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) {
try {
json request = json::parse(req.body);
// Get or create session
std::string session_id = request.value("sessionId", GenerateSessionId());
auto& session = state.GetOrCreateSession(session_id);
// Extract prompt from API format
std::string prompt;
if (request.contains("contents")) {
prompt = WrapMessagesWithTurnMarkers(request["contents"]);
} else {
res.status = 400;
res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json");
return;
}
// Set up SSE headers
res.set_header("Content-Type", "text/event-stream");
res.set_header("Cache-Control", "no-cache");
res.set_header("Connection", "keep-alive");
res.set_header("X-Session-Id", session_id);
// Set up chunked content provider for SSE
res.set_chunked_content_provider(
"text/event-stream",
[&state, request, prompt, session_id](size_t offset, httplib::DataSink& sink) {
try {
// Lock for inference
std::lock_guard<std::mutex> lock(state.inference_mutex);
auto& session = state.GetOrCreateSession(session_id);
// Set up runtime config
RuntimeConfig runtime_config = ParseGenerationConfig(request);
// Tokenize prompt
std::vector<int> tokens = WrapAndTokenize(
state.gemma->Tokenizer(), state.gemma->ChatTemplate(),
state.gemma->Config().wrapping, session.abs_pos, prompt);
// Stream token callback
std::string accumulated_text;
auto stream_token = [&](int token, float) {
// Skip prompt tokens
if (session.abs_pos < tokens.size()) {
session.abs_pos++;
return true;
}
session.abs_pos++;
// Check for EOS
if (state.gemma->Config().IsEOS(token)) {
return true;
}
// Decode token
std::string token_text;
state.gemma->Tokenizer().Decode(std::vector<int>{token}, &token_text);
accumulated_text += token_text;
// Send SSE event using unified formatter
json event = CreateAPIResponse(token_text, true);
std::string sse_data = "data: " + event.dump() + "\n\n";
sink.write(sse_data.data(), sse_data.size());
return true;
};
runtime_config.stream_token = stream_token;
// Run inference with KV cache
TimingInfo timing_info = {.verbosity = 0};
size_t prefix_end = 0;
state.gemma->Generate(runtime_config, tokens, session.abs_pos,
prefix_end, *session.kv_cache, *state.env,
timing_info);
// Send final event using unified formatter
json final_event = CreateAPIResponse("", false);
final_event["usageMetadata"] = {
{"promptTokenCount", tokens.size()},
{"candidatesTokenCount", session.abs_pos - tokens.size()},
{"totalTokenCount", session.abs_pos}
};
std::string final_sse = "data: " + final_event.dump() + "\n\n";
sink.write(final_sse.data(), final_sse.size());
// Send done event
sink.write("data: [DONE]\n\n", 15);
// Ensure all data is sent
sink.done();
return false; // End streaming
} catch (const std::exception& e) {
json error_event = {{"error", {{"message", e.what()}}}};
std::string error_sse = "data: " + error_event.dump() + "\n\n";
sink.write(error_sse.data(), error_sse.size());
return false;
}
}
);
} catch (const json::exception& e) {
res.status = 400;
res.set_content(
json{{"error",
{{"message", std::string("JSON parsing error: ") + e.what()}}}}
.dump(),
"application/json");
}
}
// Handle models list endpoint
void HandleListModels(ServerState& state, const InferenceArgs& inference, const httplib::Request& req, httplib::Response& res) {
json response = {
{"models", {{
{"name", "models/" + inference.model},
{"version", "001"},
{"displayName", inference.model},
{"description", inference.model + " model running locally"},
{"inputTokenLimit", 8192},
{"outputTokenLimit", 8192},
{"supportedGenerationMethods", json::array({"generateContent", "streamGenerateContent"})},
{"temperature", 1.0},
{"topK", 1}
}}}
};
res.set_content(response.dump(), "application/json");
}
// void HandleShutdown(int signal) {
// std::cerr << "\nShutting down server..." << std::endl;
// server_running = false;
// }
void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference) {
std::cerr << "Loading model..." << std::endl;
// Initialize model
ThreadingContext ctx(threading);
MatMulEnv env(ctx);
ServerState state;
state.gemma = std::make_unique<Gemma>(loader, inference, ctx);
state.env = &env;
state.ctx = &ctx;
httplib::Server server;
// Set up routes
server.Get("/", [&inference](const httplib::Request&, httplib::Response& res) {
res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" + inference.model + ":generateContent", "text/plain");
});
// API endpoints
server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req, httplib::Response& res) {
HandleListModels(state, inference, req, res);
});
std::string model_endpoint = "/v1beta/models/" + inference.model;
server.Post(model_endpoint + ":generateContent", [&state](const httplib::Request& req, httplib::Response& res) {
HandleGenerateContentNonStreaming(state, req, res);
});
server.Post(model_endpoint + ":streamGenerateContent", [&state](const httplib::Request& req, httplib::Response& res) {
HandleGenerateContentStreaming(state, req, res);
});
// Periodic cleanup of old sessions
std::thread cleanup_thread([&state]() {
while (server_running) {
std::this_thread::sleep_for(std::chrono::minutes(5));
state.CleanupOldSessions();
}
});
std::cerr << "Starting API server on port " << inference.port << std::endl;
std::cerr << "Model loaded successfully" << std::endl;
std::cerr << "Endpoints:" << std::endl;
std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent" << std::endl;
std::cerr << " POST /v1beta/models/" << inference.model << ":streamGenerateContent (SSE)" << std::endl;
std::cerr << " GET /v1beta/models" << std::endl;
if (!server.listen("0.0.0.0", inference.port)) {
std::cerr << "Failed to start server on port " << inference.port << std::endl;
}
cleanup_thread.join();
}
} // namespace gcpp
int main(int argc, char** argv) {
gcpp::InternalInit();
gcpp::LoaderArgs loader(argc, argv);
gcpp::ThreadingArgs threading(argc, argv);
gcpp::InferenceArgs inference(argc, argv);
if (gcpp::HasHelp(argc, argv)) {
std::cerr << "\n\nAPI server for gemma.cpp\n";
std::cout << "========================\n\n";
std::cerr << "Usage: " << argv[0] << " --weights <path> --tokenizer <path> [options]\n";
std::cerr << "\nOptions:\n";
std::cerr << " --port PORT Server port (default: 8080)\n";
std::cerr << " --model MODEL Model name for endpoints (default: gemma3-4b)\n";
std::cerr << "\n";
std::cerr << "\n*Model Loading Arguments*\n\n";
loader.Help();
std::cerr << "\n*Threading Arguments*\n\n";
threading.Help();
std::cerr << "\n*Inference Arguments*\n\n";
inference.Help();
std::cerr << "\n";
return 0;
}
// Arguments are now handled by InferenceArgs
// // Set up signal handler
// signal(SIGINT, gcpp::HandleShutdown);
// signal(SIGTERM, gcpp::HandleShutdown);
gcpp::RunServer(loader, threading, inference);
return 0;
}

View File

@ -19,7 +19,7 @@
#include <vector>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "util/threading_context.h"
#include "util/zones.h"
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
@ -29,7 +29,7 @@
#include "gemma/gemma.h"
#include "gemma/weights.h"
#include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "util/threading_context.h"
#include "hwy/profiler.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
@ -42,6 +42,7 @@
#include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
#include "gemma/flash_attention.h"
#include "ops/ops-inl.h"
HWY_BEFORE_NAMESPACE();
@ -55,8 +56,7 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
const float* HWY_RESTRICT q,
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
hwy::Profiler& p, const size_t worker) {
static const auto zone = p.AddZone("Gen.Attention.QDotK");
PROFILER_ZONE3(p, worker, zone);
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenAttentionQDotK));
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
@ -72,11 +72,11 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
}
}
static void PositionalEncodingQK(float* qk, const size_t layer_idx,
const LayerWeightsPtrs& layer,
const AttentionActivations& activations,
hwy::Profiler& p, const size_t worker,
const size_t pos, const float mul = 1.0f) {
void PositionalEncodingQK(float* qk, const size_t layer_idx,
const LayerWeightsPtrs& layer,
const AttentionActivations& activations,
hwy::Profiler& p, const size_t worker,
const size_t pos, const float mul) {
const size_t qkv_dim = layer.layer_config.qkv_dim;
const PostQKType& post_qk = layer.layer_config.post_qk;
// qk is either q or k, so qkv_dim is the length we operate on.
@ -89,7 +89,7 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx,
// PostQKType::Rope
if (post_qk == PostQKType::HalfRope) {
Rope(qk, qkv_dim / 2, inv_timescale, pos, p, worker);
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, p, worker);
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim);
} else {
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, p, worker);
}
@ -113,7 +113,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), p,
worker);
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), p, worker);
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols());
}
} else {
{
@ -122,8 +122,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
}
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = div_seq_len.Remainder(pos);
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), p,
worker);
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols());
}
}
}
@ -144,7 +143,7 @@ void SingleDotSoftmaxWeightedSum(
// Apply rope and scaling to Q.
if (layer.query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), 0, q,
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q,
layer.layer_config.qkv_dim, p, worker);
});
}
@ -156,8 +155,9 @@ void SingleDotSoftmaxWeightedSum(
// SoftMax with optional SoftCap yields "probabilities" in att.
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
MaybeLogitsSoftCap(att_cap, att, att_len, p, worker);
Softmax(att, att_len, p, worker, /*temperature=*/1.0f);
const Logits logits(att, att_len);
MaybeLogitsSoftCap(att_cap, logits, p, worker);
Softmax(logits, p, worker, /*temperature=*/1.0f);
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, p,
worker);
@ -165,8 +165,7 @@ void SingleDotSoftmaxWeightedSum(
// The attention window usually starts at 0 unless `pos` is larger than
// the attention window size, then it is `pos` - window_size + 1.
static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config,
size_t layer_idx) {
size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) {
const size_t att_window_size = config.attention_window_sizes[layer_idx];
return pos - HWY_MIN(att_window_size - 1, pos);
}
@ -175,7 +174,12 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch,
ThreadingContext& ctx) {
static const auto zone = ctx.profiler.AddZone("Gen.Attention.DotSoftmax.par");
static const auto root_zone =
ctx.profiler.AddZone("Gen.Attention.DotSoftmaxWeightedSumInclusive",
hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(ctx.profiler, 0, root_zone);
const auto zone =
GetProfilerZone(Zones::kGenAttentionDotSoftmaxWeightedSumPar);
const hwy::Divisor div_qbatch(qbatch.Size());
const LayerConfig& layer_config = layer.layer_config;
@ -233,9 +237,10 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
{
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
// Full parallelism is helpful, SmallParallelFor is insufficient.
ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads,
ctx.pools, func);
// Full parallelism is helpful, kAcrossClusters is insufficient.
HierarchicalParallelFor(
num_tokens * div_qbatch.GetDivisor() * layer_config.heads, ctx.pools,
func);
}
}
@ -251,7 +256,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
AttentionActivations& activations,
const QBatch& qbatch, const int flags,
MatMulEnv& env) {
PROFILER_ZONE("Gen.Attention.QKV");
static const auto zone = env.ctx.profiler.AddZone(
"Gen.Attention.ComputeQKV", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone);
const hwy::Divisor div_qbatch(qbatch.Size());
const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor();
const LayerConfig& layer_config = layer.layer_config;
@ -275,19 +283,19 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t cache_pos =
activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx);
env.row_ptrs[2][interleaved_idx] = reinterpret_cast<uint8_t*>(
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size);
}
kv_rows.AttachRowPtrs(env.row_ptrs[2].get());
kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
/*add=*/nullptr, env, kv_rows);
// Apply positional encodings for K.
// Note that 2D parallelism is not worth the fork/join overhead because the
// tasks are very lightweight.
env.ctx.pools.Pool(0).Run(
0, kv_heads * num_interleaved,
[&](uint64_t task, size_t thread) HWY_ATTR {
ParallelFor(
ParallelismStrategy::kFlat, kv_heads * num_interleaved, env.ctx,
/*cluster_idx=*/0, [&](size_t task, size_t worker) HWY_ATTR {
const size_t head = task % kv_heads;
const size_t interleaved_idx = task / kv_heads;
const size_t qi = div_qbatch.Remainder(interleaved_idx);
@ -307,13 +315,13 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// Apply further processing to K.
if (layer.key_norm_scale.HasPtr()) {
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim,
env.ctx.profiler, thread);
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, kv_f32,
qkv_dim, env.ctx.profiler, worker);
});
}
PositionalEncodingQK(kv_f32, layer_idx, layer, activations,
env.ctx.profiler, thread, pos);
env.ctx.profiler, worker, pos, /*mul=*/1.0f);
CompressPerThread tls;
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
});
@ -324,8 +332,11 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
AttentionActivations& activations,
MatMulEnv& env) {
PROFILER_ZONE("Gen.Attention.SumHeads");
static const auto zone = env.ctx.profiler.AddZone(
"Gen.Attention.SumHeads", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone);
const LayerConfig& layer_config = layer.layer_config;
(void)layer_config; // For HWY_DASSERT
// att_weights and att_out are concatenated heads, each of length
// layer_config.qkv_dim. Thus the [num_interleaved,
// layer_config.model_dim] matmul output is the sum over heads. Compare
@ -333,10 +344,7 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
// encoded)
HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 &&
layer_config.qkv_dim != 0);
const float* add = layer_config.softmax_attn_output_biases
? layer.attention_output_biases.PackedScale1()
: nullptr;
CallMatMul(activations.att_out, layer.att_weights, add, env,
CallMatMul(activations.att_out, layer.att_weights, /*add=*/nullptr, env,
activations.att_sums);
}
@ -346,7 +354,7 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
MatMulEnv& env, int flags) {
static const auto zone =
env.ctx.profiler.AddZone("Gen.Attention", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone);
const LayerConfig& layer_config = layer.layer_config;
HWY_DASSERT(!layer_config.IsMHA()); // No longer supported.
@ -355,8 +363,15 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
(void)layer_config; // only used in HWY_DASSERT
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
env.ctx);
if (flags & kAttentionUseOld) {
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
env.ctx);
} else {
// * 2 does not help on Turin.
FlashAttention(num_tokens,
/*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1,
layer_idx, layer, activations, qbatch, env.ctx);
}
SumHeads(layer, activations, env);
}

View File

@ -28,6 +28,14 @@ namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void PositionalEncodingQK(float* qk, size_t layer_idx, \
const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, \
hwy::Profiler& p, size_t worker, size_t pos, \
float mul); \
\
size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx); \
\
void SingleDotSoftmaxWeightedSum( \
const size_t pos, const size_t start_pos, const size_t last_pos, \
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \

View File

@ -23,7 +23,6 @@
#include <string>
#include <vector>
#include "evals/benchmark_helper.h" // InitGenerator
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/tokenizer.h" // WrapAndTokenize
@ -135,8 +134,6 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
std::stringstream ss;
result_buffer.clear();
InitGenerator(inference_args, gen);
// Ensure we have an active conversation
if (!active_conversation || !active_conversation->kv_cache) {
LogDebug("Generate called with null active_conversation or kv_cache");
@ -174,8 +171,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
// set up runtime config
TimingInfo timing_info = {};
RuntimeConfig runtime_config = {.gen = &gen,
.stream_token = stream_token,
RuntimeConfig runtime_config = {.stream_token = stream_token,
.use_spinning = threading_args.spin};
inference_args.CopyTo(runtime_config);
size_t prefix_end = 0;
@ -256,7 +252,6 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
// If not multiturn, or Paligemma (which handles turns differently),
// reset the *active* conversation's position.
active_conversation->abs_pos = 0;
InitGenerator(inference_args, gen);
} else {
// Multi-turn Gemma: Rewind position in the active conversation
// The last token was either EOS, then it should be ignored because it is

View File

@ -17,7 +17,6 @@
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_
#include <memory> // For std::shared_ptr, std::make_shared
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
@ -107,10 +106,6 @@ class GemmaContext {
// Set deterministic flag
void SetDeterministic(bool value) {
inference_args.deterministic = value;
// Reset the random number generator for deterministic generation
if (value) {
gen.seed(0x87654321);
}
LogDebug("Setting deterministic flag to configured value");
}
@ -289,9 +284,6 @@ class GemmaContext {
// Model itself (don't move this, needs to be below the args above)
Gemma model;
// Random generator (remains global for the context)
std::mt19937 gen;
// Static members for logging
static GemmaLogCallback s_log_callback;
static void* s_log_user_data;

View File

@ -133,78 +133,6 @@ static ModelConfig ConfigGemma2_2B() {
return config;
}
static LayerConfig LayerConfigGemmaTiny(size_t model_dim) {
LayerConfig config;
config.model_dim = model_dim;
config.ff_hidden_dim = 256;
config.heads = 4;
config.kv_heads = 1;
config.qkv_dim = 16;
return config;
}
static ModelConfig ConfigGemmaTiny() {
ModelConfig config = ConfigNoSSM();
config.display_name = "GemmaTiny";
config.model = Model::GEMMA_TINY;
config.wrapping = PromptWrapping::GEMMA_IT;
config.model_dim = 32;
config.vocab_size = 32; // at least two f32 vectors
config.max_seq_len = 32;
LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim);
config.num_layers = 2;
config.layer_configs = {config.num_layers, layer_config};
config.query_scale = QueryScaleType::SqrtKeySize;
config.attention_window_sizes = FixedAttentionWindowSizes<2>(32);
config.att_cap = 50.0f;
config.final_cap = 30.0f;
config.eos_id = 11;
config.secondary_eos_id = 11;
return config;
}
static LayerConfig LayerConfigGriffin2B(size_t model_dim) {
LayerConfig config;
config.model_dim = model_dim;
config.griffin_dim = model_dim;
config.ff_hidden_dim = 7680;
config.heads = 10;
config.kv_heads = 1;
config.qkv_dim = 256;
config.conv1d_width = 4;
HWY_DASSERT(config.conv1d_width <= kMaxConv1DWidth);
config.ff_biases = true;
config.softmax_attn_output_biases = true;
config.optimized_gating = false;
config.type = LayerAttentionType::kGriffinRecurrentBlock;
config.activation = ActivationType::Gelu;
config.post_qk = PostQKType::HalfRope;
return config;
}
static ModelConfig ConfigGriffin2B() {
ModelConfig config = ConfigNoSSM();
config.display_name = "Griffin2B";
config.model = Model::GRIFFIN_2B;
// Griffin uses local attention, so max_seq_len is actually the local
// attention window.
config.model_dim = 2560;
config.vocab_size = kVocabSize;
config.max_seq_len = 2048;
LayerConfig layer_config = LayerConfigGriffin2B(config.model_dim);
config.num_layers = 26;
config.layer_configs = {config.num_layers, layer_config};
for (size_t i = 2; i < config.num_layers; i += 3) {
config.layer_configs[i].type = LayerAttentionType::kGemma;
config.layer_configs[i].griffin_dim = 0;
}
config.attention_window_sizes =
FixedAttentionWindowSizes<26>(config.max_seq_len);
config.use_local_attention = true;
config.final_cap = 0.0f;
return config;
}
static LayerConfig LayerConfigVit(size_t model_dim) {
LayerConfig config;
config.model_dim = model_dim;
@ -510,10 +438,6 @@ static ModelConfig ConfigFromModel(Model model) {
return ConfigGemma2_9B();
case Model::GEMMA2_27B:
return ConfigGemma2_27B();
case Model::GRIFFIN_2B:
return ConfigGriffin2B();
case Model::GEMMA_TINY:
return ConfigGemmaTiny();
case Model::PALIGEMMA2_3B_224:
return ConfigPaliGemma2_3B_224();
case Model::PALIGEMMA2_3B_448:
@ -547,10 +471,6 @@ const char* ModelPrefix(Model model) {
return "9b";
case Model::GEMMA2_27B:
return "27b";
case Model::GRIFFIN_2B:
return "gr2b";
case Model::GEMMA_TINY:
return "tiny";
case Model::PALIGEMMA2_3B_224:
return "paligemma2-3b-224";
case Model::PALIGEMMA2_3B_448:
@ -710,8 +630,10 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool print) const {
ModelConfig a = *this;
ModelConfig b = other;
// Called by `OverwriteWithCanonical`, so ignore the fields it will set.
a.display_name = b.display_name;
a.model = b.model;
// Order matters: overwrite `b` with `a` because that is the known-good config
// when called by `OverwriteWithCanonical`.
b.display_name = a.display_name;
b.model = a.model;
// The following are not yet set by config_converter.py, so we here ignore
// them for purposes of comparison, and there overwrite the converter's config
@ -719,12 +641,12 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool print) const {
// these fields will be set.
// `vit_config` is also not yet set, but we must not ignore it because
// otherwise PaliGemma models will be indistinguishable for `configs_test`.
a.pool_dim = b.pool_dim; // ViT
a.eos_id = b.eos_id;
a.secondary_eos_id = b.secondary_eos_id;
a.scale_base_names = b.scale_base_names;
for (size_t i = 0; i < a.layer_configs.size(); ++i) {
a.layer_configs[i].optimized_gating = b.layer_configs[i].optimized_gating;
b.pool_dim = a.pool_dim; // ViT
b.eos_id = a.eos_id;
b.secondary_eos_id = a.secondary_eos_id;
b.scale_base_names = a.scale_base_names;
for (size_t i = 0; i < b.layer_configs.size(); ++i) {
b.layer_configs[i].optimized_gating = a.layer_configs[i].optimized_gating;
}
return AllEqual(a, b, print);
@ -748,13 +670,10 @@ bool ModelConfig::OverwriteWithCanonical() {
Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) {
switch (layers) {
case 2:
return Model::GEMMA_TINY;
case 18:
return Model::GEMMA3_270M;
case 26:
if (layer_types & kDeducedGriffin) return Model::GRIFFIN_2B;
if (layer_types & kDeducedViT) return Model::GEMMA3_1B;
return Model::GEMMA2_2B;
case 27:

View File

@ -26,14 +26,19 @@
#include <vector>
#include "compression/types.h" // Type
#include "io/fields.h" // IFieldsVisitor
#include "io/io.h" // Path
#include "io/fields.h" // IFieldsVisitor
#include "io/io.h" // Path
#include "util/basics.h"
namespace gcpp {
static constexpr size_t kMaxConv1DWidth = 4;
static constexpr size_t kMaxQKVDim = 1024;
HWY_INLINE_VAR constexpr int kAttentionUseOld = 2;
HWY_INLINE_VAR constexpr size_t kMaxQKVDim = 1024;
#ifndef GEMMA_FUSED_FFN
#define GEMMA_FUSED_FFN 1
#endif // !GEMMA_FUSED_FFN
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
enum class PromptWrapping {
@ -68,14 +73,11 @@ static inline bool EnumValid(PromptWrapping wrapping) {
enum class LayerAttentionType {
kGemma,
kGriffinRecurrentBlock,
kVit,
};
static inline bool EnumValid(LayerAttentionType type) {
return type == LayerAttentionType::kGemma ||
type == LayerAttentionType::kGriffinRecurrentBlock ||
type == LayerAttentionType::kVit;
return type == LayerAttentionType::kGemma || type == LayerAttentionType::kVit;
}
// Post attention and ffw normalization type.
@ -163,9 +165,8 @@ enum class Model {
// 1 and 2 are obsolete.
GEMMA2_9B = 3,
GEMMA2_27B,
GRIFFIN_2B,
GEMMA_TINY, // for testing only
GEMMA2_2B,
// 5 and 6 are obsolete.
GEMMA2_2B = 7,
// 8 and 9 are obsolete.
PALIGEMMA2_3B_224 = 10,
PALIGEMMA2_3B_448,
@ -199,13 +200,19 @@ static inline bool IsPaliGemma(Model model) {
return false;
}
static inline bool IsObsolete(Model model) {
const size_t i = static_cast<size_t>(model);
if (i == 5 || i == 6 || i == 8 || i == 9) return true;
return false;
}
// Visits every valid model enum, skipping `UNKNOWN` and `kSentinel`.
template <class Func>
void ForEachModel(const Func& func) {
for (size_t i = static_cast<size_t>(Model::GEMMA2_9B);
i < static_cast<size_t>(Model::kSentinel); ++i) {
if (i == 8 || i == 9) continue;
func(static_cast<Model>(i));
const Model model = static_cast<Model>(i);
if (!IsObsolete(model)) func(model);
}
}
@ -214,7 +221,7 @@ static inline bool EnumValid(Model model) {
if (model == Model::UNKNOWN) return true;
const size_t i = static_cast<size_t>(model);
if (i >= static_cast<size_t>(Model::GEMMA2_9B) &&
i < static_cast<size_t>(Model::kSentinel) && i != 8 && i != 9) {
i < static_cast<size_t>(Model::kSentinel) && !IsObsolete(model)) {
return true;
}
return false;
@ -235,15 +242,20 @@ struct LayerConfig : public IFields {
// Source of truth for field ordering.
void VisitFields(IFieldsVisitor& visitor) override {
// Formerly used for Griffin.
uint32_t unused_griffin_dim = 0;
uint32_t unused_conv1d_width = 0;
bool unused_softmax_attn_output_biases = false;
visitor(model_dim);
visitor(griffin_dim);
visitor(unused_griffin_dim);
visitor(ff_hidden_dim);
visitor(heads);
visitor(kv_heads);
visitor(qkv_dim);
visitor(conv1d_width);
visitor(unused_conv1d_width);
visitor(ff_biases);
visitor(softmax_attn_output_biases);
visitor(unused_softmax_attn_output_biases);
visitor(optimized_gating);
visitor(post_norm);
visitor(type);
@ -263,14 +275,11 @@ struct LayerConfig : public IFields {
bool IsMHA() const { return heads == kv_heads; }
uint32_t model_dim = 0;
uint32_t griffin_dim = 0;
uint32_t ff_hidden_dim = 0;
uint32_t heads = 0;
uint32_t kv_heads = 0;
uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous).
uint32_t conv1d_width = 0; // Griffin only
uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous).
bool ff_biases = false;
bool softmax_attn_output_biases = false; // for Griffin
bool optimized_gating = true; // for Gemma3
PostNormType post_norm = PostNormType::None;
LayerAttentionType type = LayerAttentionType::kGemma;
@ -358,7 +367,8 @@ struct ModelConfig : public IFields {
visitor(final_cap);
visitor(absolute_pe);
visitor(use_local_attention);
bool unused_use_local_attention = false; // formerly used for Griffin
visitor(unused_use_local_attention);
visitor(query_scale);
visitor(layer_configs);
visitor(attention_window_sizes);
@ -421,7 +431,7 @@ struct ModelConfig : public IFields {
}
size_t KVCacheCols() const {
size_t num_layers = layer_configs.size();
const size_t num_layers = layer_configs.size();
return num_layers * layer_configs[0].CacheLayerSize();
}
@ -454,7 +464,6 @@ struct ModelConfig : public IFields {
float final_cap = 0.0f;
bool absolute_pe = false;
bool use_local_attention = false; // Griffin only
QueryScaleType query_scale = QueryScaleType::SqrtKeySize;
std::vector<LayerConfig> layer_configs;
std::vector<uint32_t> attention_window_sizes;
@ -478,7 +487,6 @@ struct ModelConfig : public IFields {
ModelConfig GetVitConfig(const ModelConfig& config);
enum DeducedLayerTypes {
kDeducedGriffin = 1,
kDeducedViT = 2,
kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224.
};

775
gemma/flash_attention.cc Normal file
View File

@ -0,0 +1,775 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stddef.h>
#include <stdint.h>
#include <algorithm>
#include <cmath>
#include <limits>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "util/threading_context.h"
#include "util/zones.h"
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
#include "gemma/activations.h"
#include "gemma/configs.h" // kMaxQKVDim
#include "gemma/gemma.h"
#include "gemma/weights.h"
#include "util/threading.h"
#include "hwy/profiler.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/flash_attention.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
#include "gemma/attention.h"
#include "ops/matmul-inl.h"
#include "ops/ops-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
static constexpr size_t kNFx8HTileSize = 8;
// Transposes q into q_t.
// Both are 4D tensors stuffed into a 2-D MatPtrT.
// q has shape [batch, qbatch][head, qkv_dim].
// q_t has shape [qkv_dim][qbatch, head, batch] in order to make the maximum
// possible consecutive elements have the same KV.
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
const size_t qbatch_size, ThreadingContext& ctx) {
const auto zone = GetProfilerZone(Zones::kFlashAttentionTransposeQ);
// Group floats by the number of floats in a cache line.
const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float);
const size_t num_heads = q.Cols() / q_t.Rows();
const size_t batch_size = q.Rows() / qbatch_size;
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
PROFILER_ZONE3(ctx.profiler, worker, zone);
for (size_t lane = 0; lane < kNF; ++lane) {
size_t q_row = task * kNF + lane;
if (q_row >= q_t.Rows()) break;
float* HWY_RESTRICT qt_row = q_t.Row(q_row);
for (size_t qi = 0; qi < qbatch_size; ++qi) {
for (size_t h = 0; h < num_heads; ++h) {
for (size_t b = 0; b < batch_size; ++b) {
qt_row[(qi * num_heads + h) * batch_size + b] =
q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row];
}
}
}
}
};
{
// Better than kFlat.
size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF);
ParallelFor(ParallelismStrategy::kHierarchical, num_tasks, ctx,
/*cluster_idx=*/0, func);
}
}
// Updates q in place for RMSNorm and positional encoding.
void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
MatPtrT<KV_t>& q, const size_t layer_idx,
const LayerWeightsPtrs& layer,
const AttentionActivations& activations,
ThreadingContext& ctx) {
const auto zone =
GetProfilerZone(Zones::kFlashAttentionRmsNormAndPositionalEncoding);
const float query_scale = activations.query_scale;
const hwy::Divisor div_qbatch(qbatch.Size());
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
PROFILER_ZONE3(ctx.profiler, worker, zone);
size_t qi = div_qbatch.Remainder(task);
size_t batch_idx = div_qbatch.Divide(task);
for (size_t h = 0; h < layer.layer_config.heads; ++h) {
const size_t tq_idx = qbatch.Size() * batch_idx + qi;
// Find the token position in the query and calculate
// the range of cache positions to attend to.
const size_t pos = qbatch.Pos(qi) + batch_idx;
float* HWY_RESTRICT q_row =
q.Row(tq_idx) + h * layer.layer_config.qkv_dim;
// Apply rope and scaling to Q.
if (layer.query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q_row,
layer.layer_config.qkv_dim, ctx.profiler, worker);
});
}
PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler,
worker, pos, query_scale);
}
};
{
// kHierarchical is not worth the extra sync overhead because the tasks are
// very lightweight.
ParallelFor(ParallelismStrategy::kFlat, num_tokens * qbatch.Size(), ctx,
/*cluster_idx=*/0, func);
}
}
// Handles a single v row of flash attention for a single q.k dot product.
void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max,
float& old_d,
const float* HWY_RESTRICT v,
const size_t v_cols,
float* HWY_RESTRICT att_out) {
if (cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
x = cap * std::tanh(x / cap);
}
float m = std::max(x, old_max);
x = std::exp(x - m);
float scale = old_d * std::exp(old_max - m);
old_d = x + scale;
old_max = m;
float one_over_d = 1.0f / old_d;
scale *= one_over_d;
x *= one_over_d;
MulByConst(scale, att_out, v_cols);
MulByConstAndAdd(x, v, att_out, v_cols);
}
// Calculates the complete attention outputs for a single row of q.
void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
const float* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const LayerWeightsPtrs& layer,
const AttentionActivations& activations,
float* HWY_RESTRICT att_out, hwy::Profiler& p,
const size_t worker) {
PROFILER_ZONE3(p, worker,
GetProfilerZone(Zones::kFlashAttentionSingleFlashAttention));
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
float m = Dot(q, k.Row(pos_mod), k.Cols());
if (float cap = activations.config.att_cap; cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
m = cap * std::tanh(m / cap);
}
float d = 1.0f;
// This is just a copy of the first token.
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), p, worker);
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
float x = Dot(q, k.Row(pos_mod), k.Cols());
SingleFlashAttentionStep(x, activations.config.att_cap, m, d,
v.Row(pos_mod), v.Cols(), att_out);
}
}
// Computes and returns a single vector of NF Q.K dot products, which represents
// the dot products of NF rows of Q for a single K timestep.
template <class DF, class VF = hn::Vec<DF>>
VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
const size_t k_pos, const MatPtrT<KV_t>& q,
const MatPtrT<KV_t>& k) {
hn::TFromD<DF> results[hn::MaxLanes(df)];
for (size_t i = 0; i < hn::Lanes(df); ++i) {
results[i] = Dot(q.Row(0) + q_offsets[i], k.Row(k_pos), k.Cols());
}
return hn::LoadU(df, results);
}
// Returns an NF Q rows by 8 K rows tile of Q.K dot products, in single
// precision.
// This is the result of NF rows of Q against 8 K timesteps, with positions
// given by k_pos[0..7]. Q has been transposed so that the NF rows are read in
// consecutive elements, and other columns by adding q_stride.
template <class DF, class VF = hn::Vec<DF>>
void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
const MatPtrT<KV_t>& k, const size_t* k_pos, VF& sum0,
VF& sum1, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6,
VF& sum7) {
constexpr size_t kHTileSize = kNFx8HTileSize;
sum0 = hn::Zero(df);
sum1 = hn::Zero(df);
sum2 = hn::Zero(df);
sum3 = hn::Zero(df);
sum4 = hn::Zero(df);
sum5 = hn::Zero(df);
sum6 = hn::Zero(df);
sum7 = hn::Zero(df);
const float* HWY_RESTRICT k_row[kHTileSize];
for (int i = 0; i < kHTileSize; ++i) {
k_row[i] = k.Row(k_pos[i]);
}
for (size_t i = 0; i < k.Cols(); ++i) {
VF q_vec = hn::Load(df, q);
VF k_0 = hn::Set(df, k_row[0][i]);
sum0 = hn::MulAdd(q_vec, k_0, sum0);
VF k_1 = hn::Set(df, k_row[1][i]);
sum1 = hn::MulAdd(q_vec, k_1, sum1);
VF k_2 = hn::Set(df, k_row[2][i]);
sum2 = hn::MulAdd(q_vec, k_2, sum2);
VF k_3 = hn::Set(df, k_row[3][i]);
sum3 = hn::MulAdd(q_vec, k_3, sum3);
VF k_4 = hn::Set(df, k_row[4][i]);
sum4 = hn::MulAdd(q_vec, k_4, sum4);
VF k_5 = hn::Set(df, k_row[5][i]);
sum5 = hn::MulAdd(q_vec, k_5, sum5);
VF k_6 = hn::Set(df, k_row[6][i]);
sum6 = hn::MulAdd(q_vec, k_6, sum6);
VF k_7 = hn::Set(df, k_row[7][i]);
sum7 = hn::MulAdd(q_vec, k_7, sum7);
q += q_stride;
}
}
// Returns the element-wise maximum of 8 vectors, in a single vector.
template <class DF, class VF = hn::Vec<DF>>
VF HWY_INLINE ElementwiseMaxOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
const VF& x3, const VF& x4, const VF& x5,
const VF& x6, const VF& x7) {
VF m0 = hn::Max(x0, x1);
VF m1 = hn::Max(x2, x3);
VF m2 = hn::Max(x4, x5);
VF m3 = hn::Max(x6, x7);
m0 = hn::Max(m0, m1);
m2 = hn::Max(m2, m3);
return hn::Max(m0, m2);
}
// Returns the element-wise sum of 8 vectors, in a single vector.
template <class DF, class VF = hn::Vec<DF>>
VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
const VF& x3, const VF& x4, const VF& x5,
const VF& x6, const VF& x7) {
VF sum0 = hn::Add(x0, x1);
VF sum1 = hn::Add(x2, x3);
VF sum2 = hn::Add(x4, x5);
VF sum3 = hn::Add(x6, x7);
sum0 = hn::Add(sum0, sum1);
sum2 = hn::Add(sum2, sum3);
return hn::Add(sum0, sum2);
}
// Sweeps a tile of NF Q rows by 8 K timesteps accumulators from start_pos to
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
// max_last_pos].
void TileFlashAttention(
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
const StridedView<float>& qT, const MatPtrT<KV_t>& k,
const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos,
const size_t min_last_pos, const size_t max_last_pos,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
hwy::Profiler& p, const size_t worker) {
PROFILER_ZONE3(p, worker,
GetProfilerZone(Zones::kFlashAttentionTileFlashAttention));
constexpr int kHTileSize = kNFx8HTileSize;
using DF = hn::ScalableTag<float>;
const DF df;
using VF = hn::Vec<DF>;
using DI = hn::ScalableTag<uint32_t>;
const DI di;
using VI = hn::Vec<DI>;
const int kVTileSize = hn::Lanes(df);
for (int i = 0; i < kVTileSize; ++i) {
hwy::ZeroBytes(att_out.Row(0) + out_offsets[i],
v.Cols() * sizeof(att_out.Row(0)[0]));
}
VI lasts = hn::LoadU(di, last_pos);
VF old_m = hn::Set(df, -std::numeric_limits<float>::max() / 2.0f);
VF old_d = hn::Zero(df);
const float* HWY_RESTRICT qT_row = qT.Row(0);
const size_t qT_stride = qT.Stride();
size_t position = start_pos;
while (position + kHTileSize - 1 <= min_last_pos) {
size_t k_pos[kHTileSize];
for (size_t i = 0; i < kHTileSize; ++i) {
k_pos[i] = activations.div_seq_len.Remainder(position + i);
}
VF x0, x1, x2, x3, x4, x5, x6, x7;
QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6,
x7);
if (activations.config.att_cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
VF cap = hn::Set(df, activations.config.att_cap);
VF one_over_cap = hn::Div(hn::Set(df, 1.0f), cap);
x0 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x0, one_over_cap)));
x1 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x1, one_over_cap)));
x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap)));
x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap)));
x4 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x4, one_over_cap)));
x5 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x5, one_over_cap)));
x6 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x6, one_over_cap)));
x7 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x7, one_over_cap)));
}
VF m = ElementwiseMaxOf8(df, x0, x1, x2, x3, x4, x5, x6, x7);
m = hn::Max(old_m, m);
x0 = hn::Exp(df, x0 - m);
x1 = hn::Exp(df, x1 - m);
x2 = hn::Exp(df, x2 - m);
x3 = hn::Exp(df, x3 - m);
x4 = hn::Exp(df, x4 - m);
x5 = hn::Exp(df, x5 - m);
x6 = hn::Exp(df, x6 - m);
x7 = hn::Exp(df, x7 - m);
VF scale = hn::Mul(old_d, hn::Exp(df, old_m - m));
old_d = ElementwiseSumOf8(df, x0, x1, x2, x3, x4, x5, x6, x7);
old_d = hn::Add(scale, old_d);
old_m = m;
VF one_over_d = hn::Div(hn::Set(df, 1.0f), old_d);
scale = hn::Mul(scale, one_over_d);
x0 = hn::Mul(x0, one_over_d);
x1 = hn::Mul(x1, one_over_d);
x2 = hn::Mul(x2, one_over_d);
x3 = hn::Mul(x3, one_over_d);
x4 = hn::Mul(x4, one_over_d);
x5 = hn::Mul(x5, one_over_d);
x6 = hn::Mul(x6, one_over_d);
x7 = hn::Mul(x7, one_over_d);
MulByConstAndAddTile(df, scale, x0, x1, x2, x3, x4, x5, x6, x7, v, k_pos,
att_out.Row(0), out_offsets, v.Cols());
position += kHTileSize;
}
while (position <= max_last_pos) {
size_t k_pos = activations.div_seq_len.Remainder(position);
VF x0 = QDotKVector(df, q_offsets, k_pos, q, k);
if (activations.config.att_cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the vector.
VF cap = hn::Set(df, activations.config.att_cap);
VF one_over_cap = hn::Div(hn::Set(df, 1.0f), cap);
x0 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x0, one_over_cap)));
}
// Past the last position, x0 doesn't count.
auto mask = hn::Gt(hn::Set(di, position), lasts);
VF causal_offset = hn::MaskedSet(df, RebindMask(df, mask),
std::numeric_limits<float>::max() / 2.0f);
x0 = hn::Sub(x0, causal_offset);
VF m = hn::Max(old_m, x0);
x0 = hn::Exp(df, x0 - m);
VF scale = hn::Mul(old_d, hn::Exp(df, old_m - m));
old_m = m;
old_d = hn::Add(scale, x0);
VF one_over_d = hn::Div(hn::Set(df, 1.0f), old_d);
x0 = hn::Mul(x0, one_over_d);
scale = hn::Mul(scale, one_over_d);
MulByConstAndAddVector(df, scale, x0, v, k_pos, att_out.Row(0), out_offsets,
v.Cols());
++position;
}
}
// Returns an 4 Q rows by NF K tile of Q.K dot products, in single precision.
// This is the result of 4 rows of Q against NF K timesteps, with positions
// given by k_offsets[0..NF].
template <class DF, class VF = hn::Vec<DF>>
void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<KV_t>& k,
const int32_t* HWY_RESTRICT k_offsets, VF& sum0, VF& sum1,
VF& sum2, VF& sum3) {
sum0 = hn::Zero(df);
sum1 = hn::Zero(df);
sum2 = hn::Zero(df);
sum3 = hn::Zero(df);
const float* HWY_RESTRICT k_base = k.Row(0);
using DI = hn::ScalableTag<int32_t>;
const DI di;
using VI = hn::Vec<DI>;
VI k_offsets_vec = hn::LoadU(di, k_offsets);
for (size_t i = 0; i < k.Cols(); ++i) {
VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec);
VF q_0 = hn::Set(df, q[q_offsets[0] + i]);
sum0 = hn::MulAdd(q_0, k_vec, sum0);
VF q_1 = hn::Set(df, q[q_offsets[1] + i]);
sum1 = hn::MulAdd(q_1, k_vec, sum1);
VF q_2 = hn::Set(df, q[q_offsets[2] + i]);
sum2 = hn::MulAdd(q_2, k_vec, sum2);
VF q_3 = hn::Set(df, q[q_offsets[3] + i]);
sum3 = hn::MulAdd(q_3, k_vec, sum3);
}
}
// Handles NF v rows of flash attention for NF q.k dot products from one q row.
template <class DF, class VF = hn::Vec<DF>>
float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
float& old_d) {
float m = hn::ReduceMax(df, x);
m = std::max(m, old_max);
x = hn::Exp(df, x - hn::Set(df, m));
float scale = old_d * std::exp(old_max - m);
old_d = hn::ReduceSum(df, x) + scale;
old_max = m;
float one_over_d = 1.0f / old_d;
scale *= one_over_d;
x = hn::Mul(x, hn::Set(df, one_over_d));
return scale;
}
// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
// max_last_pos].
void TileFlashAttention4(
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
const MatPtrT<KV_t>& k, const size_t start_pos,
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
hwy::Profiler& p, const size_t worker) {
PROFILER_ZONE3(p, worker,
GetProfilerZone(Zones::kFlashAttentionTileFlashAttention4));
using DF = hn::ScalableTag<float>;
const DF df;
using VF = hn::Vec<DF>;
constexpr size_t kMaxNF = hn::MaxLanes(df);
const size_t kHTileSize = hn::Lanes(df);
HWY_DASSERT(kHTileSize <= kMaxNF);
constexpr size_t kVTileSize = 4;
float scales[kVTileSize];
for (size_t i = 0; i < kVTileSize; ++i) {
hwy::ZeroBytes(att_out.Row(0) + out_offsets[i],
v.Cols() * sizeof(att_out.Row(0)[0]));
}
float old_m0 = -std::numeric_limits<float>::max() / 2.0f;
float old_m1 = -std::numeric_limits<float>::max() / 2.0f;
float old_m2 = -std::numeric_limits<float>::max() / 2.0f;
float old_m3 = -std::numeric_limits<float>::max() / 2.0f;
float old_d0 = 0.0f;
float old_d1 = 0.0f;
float old_d2 = 0.0f;
float old_d3 = 0.0f;
size_t position = start_pos;
while (position + kHTileSize - 1 <= min_last_pos) {
int32_t k_offsets[kMaxNF];
size_t v_pos[kMaxNF];
for (size_t i = 0; i < kHTileSize; ++i) {
v_pos[i] = activations.div_seq_len.Remainder(position + i);
k_offsets[i] = k.Row(v_pos[i]) - k.Row(0);
}
VF x0, x1, x2, x3;
QDotKTilex4(df, q.Row(0), q_offsets, k, k_offsets, x0, x1, x2, x3);
if (activations.config.att_cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
VF cap = hn::Set(df, activations.config.att_cap);
VF one_over_cap = hn::Div(hn::Set(df, 1.0f), cap);
x0 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x0, one_over_cap)));
x1 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x1, one_over_cap)));
x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap)));
x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap)));
}
scales[0] = SingleFlashAttentionRowVector(df, x0, old_m0, old_d0);
scales[1] = SingleFlashAttentionRowVector(df, x1, old_m1, old_d1);
scales[2] = SingleFlashAttentionRowVector(df, x2, old_m2, old_d2);
scales[3] = SingleFlashAttentionRowVector(df, x3, old_m3, old_d3);
MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0),
out_offsets, v.Cols());
position += kHTileSize;
}
while (position <= max_last_pos) {
size_t k_pos = activations.div_seq_len.Remainder(position);
if (position <= last_pos[0]) {
// Past the last position, x0 doesn't count.
float x0 = Dot(q.Row(0) + q_offsets[0], k.Row(k_pos), k.Cols());
SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[0]);
}
if (position <= last_pos[1]) {
// Past the last position, x1 doesn't count.
float x1 = Dot(q.Row(0) + q_offsets[1], k.Row(k_pos), k.Cols());
SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[1]);
}
if (position <= last_pos[2]) {
// Past the last position, x2 doesn't count.
float x2 = Dot(q.Row(0) + q_offsets[2], k.Row(k_pos), k.Cols());
SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[2]);
}
if (position <= last_pos[3]) {
// Past the last position, x3 doesn't count.
float x3 = Dot(q.Row(0) + q_offsets[3], k.Row(k_pos), k.Cols());
SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[3]);
}
++position;
}
}
// Rounds n to a number that can be used as the number of Q rows in a tile
// of flash attention.
static size_t RoundToSuitablePowerOf2(size_t n) {
if (n < 4) return 1;
if (n < 8) return 4;
if (n < 16) return 8;
if (n < 32) return 16;
return 32;
}
// The vertical tile size is determined by the ability to use tiling and the
// target_parallelism. In practice the possible tile sizes in order of
// preference for efficiency are kNF, 4, 1, where kNF is likely to be 4 8 or
// 16. The final tile size is chosen to be the largest possible that allows
// for target_parallelism parallel tasks.
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens,
size_t total_tasks, size_t target_parallelism) {
const size_t kMaxEqualK =
RoundToSuitablePowerOf2(num_head_groups * num_tokens);
const size_t kMinTileSize = (total_tasks / 4 >= target_parallelism) ? 4 : 1;
return (kNF <= kMaxEqualK && total_tasks / kNF >= target_parallelism)
? kNF
: std::min(kMinTileSize, kMaxEqualK);
}
// The nominal aim of attention is to combine 3 inputs Q[L,D], K[L,D], V[L,D]
// into a single output O[L,D].
// Conventional attention first computes A[L,L] = Q . KT
// followed by A = softmax(A) (over invididual rows).
// Then A is multiplied by V to get O[L,D].
// For each row of O, this takes a read of one row of Q L times, all of K,
// 3 write/reads of one row of A, read all of V, and read/write the one row of O
// L times. Ignoring the computation for now, and focusing just on memory,
// the one row of O takes L(4D+3) reads and L(D+3) writes.
// For the whole of Q, this is L^2(4D+3) reads and L^2(D+3) writes.
//
// Flash attention fuses these operations together, and has 3 operating modes:
// 1. NF rows of the result computed using tiles of registers of shape NFx8.
// 2. 4 rows of the result computed using tiles of registers of shape 4xNF.
// 3. One row (of Q and the result) at a time.
// In all cases the intermediate result (Q.KT) is never stored to memory.
// NF is the number of float lanes in a register, being 16 for AVX3. The softmax
// is converted to streaming form using the algorithm from:
// https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf.
// Q is transposed to Q_T[D,L] to make the dot product computation efficient.
//
// In mode 1:
// QDotKTileFloat computes NF Q rows x 8 K timesteps of Q.K dot products in one
// go, reducing reads of Q by 8 and reads of K by NF. The streaming softmax is
// computed entirely in registers, and a further NF registers to accumulate the
// results of the product of the softmax and V, reduce the number of reads of V
// by NF, and the reads/writes of O by 8.
// The reads are thus reduced to 2DL^2(1/8+1/NF) and writes reduced to DL^2/8,
// which on AVX3 is an overall reduction by about a factor of 10.
// Mode 1 can only be accessed if there is a large Qbatch size, or in multi-turn
// prefill, since in other cases, there is either a single K timestep (prefill)
// or a single num_heads set of Q rows (decode).
//
// In mode 2, the 4 rows of Q are computed against NF K timesteps in a tile,
// reducing the reads of Q by NF, and the reads of K by 4. The softmax and
// accumulation of the result is done in registers, cutting the reads of V by 4.
// The reads/writes of O are reduced by a factor of NF.
// The overall reduction is limited by the need to use gather to load K.
// Transposing K would be possible, but is complicated by the wraparound.
// Mode 2 can be used in all cases when there are at least 4 attention heads,
// but it may be prefereable to use mode 3 when the batch size is small to
// maximise parallelism.
//
// In mode 3, a single row of Q is computed against a single K timestep at a
// time, using SingleFlashAttention. In this case there is no reduction in the
// reads of Q or K, or V, or O, but the reads/writes of the intermediate A are
// still eliminated.
//
// A further complication is that real attention is not as simple as documented
// in the paper and above. There are multiple query heads, differing KV, and
// different sequence lengths, so a lot of the work in FlashAttention is making
// sure that a collection of q rows with the same KV and sequence length are
// grouped together so that mode 1 or 2 can be used, and choosing which of the
// 3 modes to use for best efficiency.
void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
const size_t layer_idx, const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch,
ThreadingContext& ctx) {
static const auto root_zone = ctx.profiler.AddZone(
"FlashAttention.Inclusive", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(ctx.profiler, 0, root_zone);
const auto zone = GetProfilerZone(Zones::kFlashAttentionFlashAttention);
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx,
layer, activations, ctx);
const hwy::Divisor div_qbatch(qbatch.Size());
const LayerConfig& layer_config = layer.layer_config;
const size_t qkv_dim = layer_config.qkv_dim;
// A "head group" in the context of GQA refers to a collection of query
// heads that share the same key and value heads.
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
const size_t cache_layer_size = layer_config.CacheLayerSize();
const size_t seq_len =
static_cast<size_t>(activations.div_seq_len.GetDivisor());
const size_t token_batch = num_tokens * div_qbatch.GetDivisor();
const size_t total_tasks = token_batch * layer_config.heads;
using DF = hn::ScalableTag<float>;
const DF df;
const size_t kNF = hn::Lanes(df);
constexpr size_t kMaxNF = hn::MaxLanes(df);
HWY_DASSERT(kNF <= kMaxNF);
const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, num_tokens,
total_tasks, target_parallelism);
// Only transpose Q if we are using tiling.
if (kVTileSize == kNF) {
size_t max_last = 0, min_start = std::numeric_limits<size_t>::max();
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
size_t pos = qbatch.Pos(qi);
const size_t start = StartPos(pos, activations.config, layer_idx);
pos += num_tokens - 1;
const size_t end = qbatch.PrefixEnd(qi);
if (end > 0 && end - 1 > pos) {
pos = end - 1;
}
max_last = std::max(max_last, pos);
min_start = std::min(min_start, start);
}
if (max_last - min_start + 1 >= kNFx8HTileSize) {
// q has shape [batch, qbatch][head, qkv_dim].
// We transpose it to [qkv_dim][qbatch, head, batch] in order to make the
// maximum possible number of consecutive columns have the same KV
// matrices. Each thread will process a tile of NF columns of QT so the
// starting column index of QT is just the task index * kVTileSize.
TransposeQ(activations.q, activations.q_T, qbatch.Size(), ctx);
}
}
const size_t num_thread_tasks = hwy::DivCeil(total_tasks, kVTileSize);
const hwy::Divisor div_tokens(num_tokens);
// All layers should have the same number of heads.
HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads);
// For each head/token/query, compute fused flash Q.K, softmax and weighted V.
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
PROFILER_ZONE3(ctx.profiler, worker, zone);
// Offsets into original Q for each row in the tile.
uint32_t q_offsets[kMaxNF];
// Offsets into att_out for each row in the tile.
uint32_t out_offsets[kMaxNF];
// Start positions for each row in the tile.
size_t start_positions[kMaxNF];
// Last positions for each row in the tile. Inclusive.
uint32_t last_pos[kMaxNF];
// min and max last positions across all rows in the tile determines when
// TileFlashAttention switches to single vector mode to handle the
// ragged sequence lengths.
size_t min_last_pos = std::numeric_limits<size_t>::max();
size_t max_last_pos = 0;
// Indices into the qbatch.KV for each row in the tile.
size_t qi_indices[kMaxNF];
// Indices into the kv_cache for each row in the tile.
size_t kv_offsets[kMaxNF];
// first_task is [qbatch, head, token].
const size_t first_task = task * kVTileSize;
const size_t last_task = first_task + kVTileSize - 1;
bool use_tile_attention = kVTileSize > 1 && last_task < total_tasks;
for (size_t offset = 0;
offset < kVTileSize && first_task + offset < total_tasks; ++offset) {
const size_t batch_idx = div_tokens.Remainder(first_task + offset);
const size_t qh = div_tokens.Divide(first_task + offset);
const size_t head = activations.div_heads.Remainder(qh);
const size_t qi = activations.div_heads.Divide(qh);
const size_t tq_idx = div_qbatch.GetDivisor() * batch_idx + qi;
qi_indices[offset] = qi;
// Find the token position in the query and calculate
// the range of cache positions to attend to.
const size_t pos = qbatch.Pos(qi) + batch_idx;
const size_t start_pos = StartPos(pos, activations.config, layer_idx);
start_positions[offset] = start_pos;
size_t last = pos;
const size_t prefix_end = qbatch.PrefixEnd(qi);
if (prefix_end > 0 && prefix_end - 1 > last) {
// last_pos in QDotK and WeightedSumV is inclusive.
last = prefix_end - 1;
}
last_pos[offset] = last;
min_last_pos = HWY_MIN(min_last_pos, last);
max_last_pos = HWY_MAX(max_last_pos, last);
q_offsets[offset] =
activations.q.Row(tq_idx) + head * qkv_dim - activations.q.Row(0);
out_offsets[offset] = activations.att_out.Row(tq_idx) + head * qkv_dim -
activations.att_out.Row(0);
const size_t kv_index = head / kHeadGroups;
const size_t head_offset = kv_index * qkv_dim * 2;
kv_offsets[offset] = layer_idx * cache_layer_size + head_offset;
// If any of the parameters in this if statement differ within this task,
// then we can't use TileFlashAttention. TileFlashAttention requires that
// all rows in the tile have the same K and V matrices, and Q starts at
// the same position. The end positions do not have to be the equal.
if (start_positions[offset] != start_positions[0] ||
qi_indices[offset] != qi_indices[0] ||
kv_offsets[offset] != kv_offsets[0]) {
use_tile_attention = false;
}
}
for (size_t offset = 0;
offset < kVTileSize && first_task + offset < total_tasks; ++offset) {
auto& kv_cache = qbatch.KV(qi_indices[offset]).kv_cache;
MatPtrT<KV_t> k("k_view", Extents2D(seq_len, qkv_dim));
k.SetPtr(kv_cache.Row(0) + kv_offsets[offset], kv_cache.Stride());
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
v.SetPtr(kv_cache.Row(0) + kv_offsets[offset] + qkv_dim,
kv_cache.Stride());
if (use_tile_attention) {
// To avoid duplicating the code to setup K and V, the call to
// TileFlashAttention is inside the loop over tasks, even though it
// handles all rows in the task at once.
StridedView<float> qT =
StridedView<float>(activations.q_T.Row(0) + first_task, kVTileSize,
activations.q_T.Stride());
if (kVTileSize == kNF) {
// We can still use TileFlashAttention even if we didn't transpose Q
// above. The condition used for transposing Q above is more general
// and easier to compute than the condition used within
// TileFlashAttention that min_last_pos - start_positions[offset] <
// kNFx8HTileSize. In this case, qT is never used. Some tasks might
// use qT and some might not, which is why the more general condition
// is used above to catch all cases where qT will be used.
TileFlashAttention(activations.q, q_offsets, qT, k,
start_positions[offset], last_pos, min_last_pos,
max_last_pos, v, layer_idx, layer, activations,
activations.att_out, out_offsets, ctx.profiler,
worker);
} else if (kVTileSize == 4) {
TileFlashAttention4(
activations.q, q_offsets, k, start_positions[offset], last_pos,
min_last_pos, max_last_pos, v, layer_idx, layer, activations,
activations.att_out, out_offsets, ctx.profiler, worker);
} else {
HWY_UNREACHABLE;
}
break;
} else {
SingleFlashAttention(start_positions[offset], last_pos[offset],
activations.q.Row(0) + q_offsets[offset], k, v,
layer_idx, layer, activations,
activations.att_out.Row(0) + out_offsets[offset],
ctx.profiler, worker);
}
}
};
{
PROFILER_ZONE("Gen.FlashAttention.ForkJoin");
// Full parallelism is helpful, SmallParallelFor is insufficient.
HierarchicalParallelFor(num_thread_tasks, ctx.pools, func);
}
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();

64
gemma/flash_attention.h Normal file
View File

@ -0,0 +1,64 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_H_
// Declares FlashAttention for all SIMD targets.
#include <stddef.h>
#include "gemma/gemma.h"
#include "hwy/highway.h"
namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void RMSNormAndPositionalEncoding(size_t num_tokens, const QBatch& qbatch, \
MatPtrT<KV_t>& q, size_t layer_idx, \
const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, \
ThreadingContext& ctx); \
\
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
const float* HWY_RESTRICT q, \
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, \
float* HWY_RESTRICT att_out, hwy::Profiler& p, \
size_t worker); \
\
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
size_t total_tasks, size_t target_parallelism); \
\
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
AttentionActivations& activations, QBatch& qbatch, \
ThreadingContext& ctx); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE
// Function declarations for each SIMD target. Allows direct call from the
// per-target namespace. We may later replace this with dynamic dispatch if
// the overhead is acceptable.
HWY_VISIT_TARGETS(GEMMA_DECL_FLASH_ATTENTION)
#undef GEMMA_DECL_FLASH_ATTENTION
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_H_

View File

@ -0,0 +1,187 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstring>
#include <numeric>
#include <vector>
#include "compression/types.h"
#include "gemma/activations.h"
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/kv_cache.h"
#include "gemma/weights.h"
#include "ops/matmul.h"
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
#include <stddef.h>
#include <stdio.h>
#include <algorithm> // std::max
#include <cmath> // std::abs
#include <memory>
#include "util/mat.h"
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/flash_attention_test.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
#include "gemma/attention.h"
#include "gemma/configs.h"
#include "gemma/flash_attention.h"
#include "hwy/tests/test_util-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
void SetMat(const size_t offset, MatPtrT<float>& mat) {
const size_t kOuter = mat.Extents().rows;
const size_t kInner = mat.Extents().cols;
const float i_scale = 1.0f / kInner;
const float j_scale = 1.0f / kOuter;
for (size_t i = 0; i < kOuter; ++i) {
float* row = mat.Row(i);
for (size_t j = 0; j < kInner; ++j) {
row[j] =
static_cast<float>((i * kInner * i_scale + (j + offset) * j_scale));
}
}
}
std::unique_ptr<MatStorageT<float>> MakeCopyOfMat(const MatPtrT<float>& mat,
const Allocator& allocator) {
auto copy = std::make_unique<MatStorageT<float>>("TestMat", mat.Extents(),
allocator, MatPadding::kOdd);
CopyMat(mat, *copy);
return copy;
}
void AssertClose(const MatPtrT<float>& a, const MatPtrT<float>& b) {
// Avoid comparing the padding bytes, which are uninitialized.
for (size_t r = 0; r < a.Rows(); ++r) {
const float* HWY_RESTRICT a_row = a.Row(r);
const float* HWY_RESTRICT b_row = b.Row(r);
for (size_t c = 0; c < a.Cols(); ++c) {
float rel_abs_delta = std::abs(a_row[c] - b_row[c]);
if (rel_abs_delta > 0.0f) {
rel_abs_delta /= std::max(std::abs(a_row[c]), std::abs(b_row[c]));
}
EXPECT_LT(rel_abs_delta, 1e-5)
<< "a[" << r << "," << c << "]=" << a_row[c] << ", b[" << r << ","
<< c << "]=" << b_row[c];
}
}
}
void TestFlashAttention(size_t target_parallelism) {
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
constexpr size_t kOuter = 1024;
constexpr size_t kInner = 256;
ModelConfig config(Model::GEMMA2_2B, Type::kF32, PromptWrapping::GEMMA_PT);
config.att_cap = 1024.0f;
TensorInfoRegistry tensor_info_registry(config);
const LayerConfig& layer_config = config.layer_configs[0];
const LayerWeightsPtrs layers(0, layer_config, tensor_info_registry);
InferenceArgs inference_args;
RuntimeConfig runtime_config;
KVCache kv_cache(config, inference_args, ctx.allocator);
MatMulEnv env(ctx);
Activations activations(config, runtime_config.prefill_tbatch_size,
kv_cache.SeqLen(), env.ctx, env.row_ptrs);
std::vector<int> tokens(kOuter);
std::iota(tokens.begin(), tokens.end(), 1);
PromptTokens prompt(tokens);
AllQueries all_queries(hwy::Span<const PromptTokens>(&prompt, 1),
hwy::Span<KVCache>(&kv_cache, 1));
QBatch qbatch(/*start=*/0, /*max_size=*/kOuter, all_queries);
const size_t batch_size = kOuter;
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
AttentionActivations attention(config, layer_config, batch_size, kOuter,
ctx.allocator, row_ptrs);
const size_t qkv_dim = layer_config.qkv_dim;
ASSERT_EQ(qkv_dim, kInner);
const hwy::Divisor div_qbatch(qbatch.Size());
// A "head group" in the context of GQA refers to a collection of query
// heads that share the same key and value heads.
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
const size_t seq_len =
static_cast<size_t>(attention.div_seq_len.GetDivisor());
auto& kvc = qbatch.KV(0).kv_cache;
for (size_t h = 0; h < layer_config.heads; ++h) {
// Make strided views into the kv cache for
// this query and head.
const size_t head_offset = (h / kHeadGroups) * qkv_dim * 2;
MatPtrT<KV_t> k("k_view", Extents2D(seq_len, qkv_dim));
k.SetPtr(kvc.Row(0) + head_offset, kvc.Stride());
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
v.SetPtr(kvc.Row(0) + head_offset + qkv_dim, kvc.Stride());
SetMat(h + layer_config.heads, k);
SetMat(h + layer_config.heads * 2, v);
}
SetMat(1, attention.q);
DotSoftmaxWeightedSum(tokens.size(), 0, layers, attention, qbatch, ctx);
// Copy the output to saved_att to allow for comparison.
auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator);
SetMat(1, attention.q);
using DF = hn::ScalableTag<float>;
const DF df;
const size_t kNF = hn::Lanes(df);
const size_t total_tasks =
tokens.size() * div_qbatch.GetDivisor() * layer_config.heads;
const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, tokens.size(),
total_tasks, target_parallelism);
printf("FlashAttention: target_parallelism=%zu, kNF=%zu, kVTileSize=%zu\n",
target_parallelism, kNF, kVTileSize);
FlashAttention(tokens.size(), target_parallelism, 0, layers, attention,
qbatch, ctx);
AssertClose(attention.att_out, *saved_att);
ctx.profiler.PrintResults();
}
void TestAttention() {
TestFlashAttention(8192);
TestFlashAttention(2048);
TestFlashAttention(256);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_BEFORE_TEST(FlashAttentionTest);
HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestAttention);
HWY_AFTER_TEST();
} // namespace gcpp
#endif

View File

@ -24,6 +24,7 @@
#include "ops/matmul.h"
#include "util/mat.h"
#include "util/threading.h"
#include "util/zones.h"
#include "hwy/profiler.h"
// Include guard (still compiled once per target)
@ -43,14 +44,14 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
template <typename T>
void Activation(ActivationType activation, T* HWY_RESTRICT c1,
const T* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p,
// For use by Vit even if !GEMMA_FUSED_FFN.
template <typename T1, typename T2>
void Activation(ActivationType activation, T1* HWY_RESTRICT c1,
const T2* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p,
const size_t worker) {
static const auto zone = p.AddZone("Gen.Activation");
PROFILER_ZONE3(p, worker, zone);
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenActivation));
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<T>;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
// ActivationType::Gelu
if (c2 == nullptr) { // No multiplier, just Gelu.
@ -58,38 +59,77 @@ void Activation(ActivationType activation, T* HWY_RESTRICT c1,
return;
};
// Has multiplier, Gelu(c1) * c2.
hn::Transform1(DF(), c1, count, c2, [](DF df, VF v, VF mul) HWY_ATTR {
return hn::Mul(mul, Gelu(df, v));
});
Decompress1AndCompressInplace(DF(), c1, count, c2, /*p1_ofs=*/0,
[](DF df, VF v1, VF v2) HWY_ATTR -> VF {
return hn::Mul(v2, Gelu(df, v1));
});
}
// No C2 multiplier.
// No C2 multiplier - used by Vit.
template <class Mat>
void ActivationBatched(ActivationType activation, Mat& c1,
ThreadingContext& ctx) {
void ActivationBatched(
ActivationType activation, Mat& c1, ThreadingContext& ctx,
size_t cluster_idx = 0,
ParallelismStrategy parallelism = ParallelismStrategy::kFlat) {
using T = typename Mat::T;
SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
// Cast to correct type so type deduction works.
Activation(activation, c1.Row(task), static_cast<const T*>(nullptr),
c1.Cols(), ctx.profiler, worker);
});
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
[&](uint64_t task, size_t worker) {
// Cast to correct type so type deduction works.
Activation(activation, c1.Row(task),
static_cast<const T*>(nullptr), c1.Cols(),
ctx.profiler, worker);
});
}
template <class Mat>
HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat& c1,
const Mat* c2, ThreadingContext& ctx) {
using T = typename Mat::T;
#if GEMMA_FUSED_FFN
// Called during `TwoMatMul`.
static inline void Activation(ActivationType activation, const RowPtrsBF C1,
const IndexRange range_r,
const IndexRange range_c, const StridedViewBF C2,
hwy::Profiler& p, const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenActivationFused));
const size_t cols = range_c.Num();
HWY_DASSERT(C2.Cols() == cols);
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
// ActivationType::Gelu
// Gated: Gelu(c1) * c2.
for (size_t ir = 0; ir < range_r.Num(); ++ir) {
Decompress1AndCompressInplace(
DF(), C1.Row(range_r.begin() + ir) + range_c.begin(), cols, C2.Row(ir),
/*p1_ofs*/ 0, [](DF df, VF v1, VF v2) HWY_ATTR -> VF {
return hn::Mul(v2, Gelu(df, v1));
});
}
}
#endif // GEMMA_FUSED_FFN
// Only used if !GEMMA_FUSED_FFN, but define anyway so that we can check
// using if constexpr rather than #if, which interferes with code folding.
template <class Mat1, class Mat2>
HWY_NOINLINE void ActivationBatched(
ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx,
size_t cluster_idx = 0,
ParallelismStrategy parallelism = ParallelismStrategy::kFlat) {
HWY_DASSERT(c1.SameShape(*c2));
if (c2 && c2->HasPtr()) {
SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(),
ctx.profiler, worker);
});
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
[&](uint64_t task, size_t worker) {
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(),
ctx.profiler, worker);
});
} else { // No multiplier
SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
Activation(activation, c1.Row(task), static_cast<const T*>(nullptr),
c1.Cols(), ctx.profiler, worker);
});
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
[&](uint64_t task, size_t worker) {
Activation(activation, c1.Row(task),
static_cast<const typename Mat2::T*>(nullptr),
c1.Cols(), ctx.profiler, worker);
});
}
}
@ -115,30 +155,34 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
Activations& activations, MatMulEnv& env) {
static const auto zone =
env.ctx.profiler.AddZone("Gen.FFW", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone);
const LayerConfig& layer_config = layer.layer_config;
const size_t ffh_hidden_dim = layer_config.ff_hidden_dim;
const bool add_bias = layer_config.ff_biases;
const float* bias1 =
add_bias ? layer.ffw_gating_biases.PackedScale1() : nullptr;
const float* bias2 = add_bias ? bias1 + ffh_hidden_dim : nullptr;
const float* output_bias =
add_bias ? layer.ffw_output_biases.PackedScale1() : nullptr;
HWY_DASSERT(!layer_config.ff_biases); // Only used in Vit.
#if GEMMA_FUSED_FFN
const auto fused = [&](RowPtrsBF C1, IndexRange range_r, IndexRange range_c,
StridedViewBF C2, size_t worker) {
Activation(layer_config.activation, C1, range_r, range_c, C2,
env.ctx.profiler, worker);
};
MMOptions options;
options.SetFunc(fused);
CallTwoMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1,
layer.gating_einsum_w2, env, activations.C1, options);
#else
// Compute the hidden layer activations.
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, bias1, env,
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, nullptr, env,
activations.C1);
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w2, bias2, env,
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w2, nullptr, env,
activations.C2);
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
ActivationBatched(layer_config.activation, activations.C1, &activations.C2,
env.ctx);
#endif
// Hidden layer -> output layer.
CallMatMul(activations.C1, layer.linear_w, output_bias, env,
activations.ffw_out);
CallMatMul(activations.C1, layer.linear_w, nullptr, env, activations.ffw_out);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -19,6 +19,7 @@
#include "gemma/gemma.h"
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "util/zones.h"
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
@ -34,7 +35,6 @@
// After highway.h
#include "gemma/attention.h" // includes highway.h
#include "gemma/gemma-inl.h"
#include "gemma/griffin.h" // includes highway.h
#include "gemma/vit.h" // includes highway.h
#ifndef GEMMA_CC_ONCE
@ -61,6 +61,9 @@
#include "hwy/base.h"
#include "hwy/timer.h"
// Require opt-in to debug/introspection functions to eliminate their overhead.
HWY_INLINE_VAR constexpr bool kObserver = false;
#endif // GEMMA_CC_ONCE
HWY_BEFORE_NAMESPACE();
@ -71,17 +74,9 @@ void Attention(LayerAttentionType type, const size_t num_tokens,
const size_t layer_idx, const LayerWeightsPtrs& layer,
Activations& activations, QBatch& qbatch, MatMulEnv& env) {
if (type == LayerAttentionType::kGemma) {
// TODO: remove flag to enable FlashAttention.
GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch,
env,
/*flags=*/0);
} else {
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
// KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer,
// so map `layer` to the Griffin layer index.
const size_t griffin_layer =
activations.attention.config.NumLayersOfTypeBefore(type, layer_idx);
GriffinRecurrent(num_tokens, griffin_layer, &layer, activations, qbatch,
env);
env, HWY_NATIVE_DOT_BF16 ? kAttentionUseOld : 0);
}
}
@ -127,9 +122,10 @@ static float EmbeddingScaling(size_t model_dim) {
hwy::ConvertScalarTo<BF16>(sqrtf(static_cast<float>(model_dim))));
}
// `batch_idx` indicates which row of `x` to write to.
// `pos` is the *token*'s position, not the start of the batch, because this is
// called for batches of tokens in prefill, but batches of queries in decode.
// `x_row` indicates which row of `x` to write to.
// `pos` is the *token*'s position for `AddAbsolutePositionalEmbeddings`, not
// the start of the batch, because this is called for batches of tokens in
// prefill, but batches of queries in decode.
//
// For GEMMA_VLM, image tokens are copied into -2 locations (per the Gemma 3
// spec) until we run out of image tokens. This allows for a multi-image prompt
@ -137,30 +133,33 @@ static float EmbeddingScaling(size_t model_dim) {
// calling application.
// Returns new image_token_position.
static HWY_NOINLINE size_t
EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt,
const ModelConfig& model_config, const WeightsPtrs& weights,
MatStorageT<float>& x, ThreadingContext& ctx,
const ImageTokens* image_tokens = nullptr,
size_t image_token_position = 0) {
static const auto zone =
ctx.profiler.AddZone("Gen.Embed", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(ctx.profiler, hwy::Profiler::GlobalIdx(), zone);
// Image tokens just need to be copied.
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
image_tokens != nullptr && token == -2 &&
image_token_position < image_tokens->Rows()) {
hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(qi),
hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(x_row),
x.Cols() * x.ElementBytes());
return image_token_position + 1;
}
if (model_config.wrapping == PromptWrapping::PALIGEMMA &&
image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) {
hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(qi),
hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(x_row),
x.Cols() * x.ElementBytes());
return image_token_position;
}
const size_t model_dim = model_config.model_dim;
const float emb_scaling = EmbeddingScaling(model_dim);
const size_t worker = 0; // Not yet parallelized.
HWY_DASSERT(token >= 0);
HWY_DASSERT(token < static_cast<int>(model_config.vocab_size));
@ -174,14 +173,13 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
const auto embedding_span =
MakeSpan(weights_t->Row(0), embedding_ofs + model_dim);
const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi),
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(x_row),
model_dim);
MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim,
ctx.profiler, worker);
MulByConst(emb_scaling * weights_t->Scale(), x.Row(x_row), model_dim);
});
if (model_config.absolute_pe) {
AddAbsolutePositionalEmbeddings(x.Row(qi), model_dim, pos);
AddAbsolutePositionalEmbeddings(x.Row(x_row), model_dim, pos);
}
return image_token_position;
}
@ -249,24 +247,12 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
for (size_t ti = 0; ti < tbatch_size; ++ti) {
const size_t pos = qbatch_1.Pos(0) + ti;
const size_t pos_in_prompt = tbatch_start + ti;
HWY_DASSERT(pos_in_prompt < prompt_size);
const int token = qbatch_1.Prompt(0)[pos_in_prompt];
image_token_position = EmbedMMToken(
token, ti, pos, pos_in_prompt, config, weights, activations.x,
env.ctx, runtime_config.image_tokens, image_token_position);
}
// Transformer with one batch of tokens from a single query.
for (size_t layer_idx = 0; layer_idx < config.layer_configs.size();
++layer_idx) {
TransformerLayer(tbatch_size, layer_idx, *weights.GetLayer(layer_idx),
activations, qbatch_1, env);
}
// NOTE: we unconditionally call StreamToken, even if EOS.
for (size_t ti = 0; ti < tbatch_size; ++ti) {
const size_t pos = qbatch_1.Pos(0) + ti;
const size_t pos_in_prompt = tbatch_start + ti;
const int token = qbatch_1.Prompt(0)[pos_in_prompt];
// NOTE: we unconditionally call StreamToken, even if EOS.
if (pos_in_prompt < prompt_size - 1) {
runtime_config.StreamToken(qbatch_1.QueryIdx(0), pos, token, 0.0f);
} else {
@ -276,6 +262,14 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
}
}
// Transformer with one batch of tokens from a single query. No need to
// set `PrevToken` because we already did the embedding above.
for (size_t layer_idx = 0; layer_idx < config.layer_configs.size();
++layer_idx) {
TransformerLayer(tbatch_size, layer_idx, *weights.GetLayer(layer_idx),
activations, qbatch_1, env);
}
qbatch_1.MutablePos(0) += tbatch_size;
} // for tbatch_start
if (attend_to_last_token) {
@ -290,19 +284,33 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
}
}
static void MaybeObserve(const RuntimeConfig& runtime_config,
Activations& activations, QBatch& qbatch,
int layer_idx) {
if constexpr (kObserver) {
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
runtime_config.activations_observer(
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx,
activations);
}
}
}
// Embeds PrevToken (one from each query) and calls each TransformerLayer.
// Called by query-batched `PrefillQBatch` and `DecodeStepT`, but not the
// token-batched `PrefillTBatch`.
// Called by query-batched `PrefillQBatch` and `GenerateT`, but not the
// token-batched `PrefillTBatch`, which supports image embedding.
static HWY_NOINLINE void Transformer(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const WeightsPtrs& weights,
Activations& activations, QBatch& qbatch,
MatMulEnv& env) {
if (HWY_UNLIKELY(runtime_config.layers_output)) {
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
const float token_f = qbatch.PrevToken(qi);
runtime_config.layers_output(qbatch.QueryIdx(qi), qbatch.Pos(qi),
"tokens", -1, &token_f, 1);
if constexpr (kObserver) {
if (HWY_UNLIKELY(runtime_config.layers_output)) {
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
const float token_f = qbatch.PrevToken(qi);
runtime_config.layers_output(qbatch.QueryIdx(qi), qbatch.Pos(qi),
"tokens", -1, &token_f, 1);
}
}
}
@ -316,16 +324,11 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx),
activations, qbatch, env);
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
runtime_config.activations_observer(
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx,
activations);
}
MaybeObserve(runtime_config, activations, qbatch, layer_idx);
}
}
// Populates KV cache for the batch queries, one token at a time. Only called
// for autoregressive (non-prefix-LM) prefill, so `queries_prefix_end` == 0.
// Populates KV cache for the batch queries, one token at a time.
static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
const ModelConfig& config,
const RuntimeConfig& runtime_config,
@ -337,6 +340,8 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
non_eos.Set(qi);
// Should only be called for autoregressive (non-prefix-LM) prefill.
HWY_DASSERT(qbatch.PrefixEnd(qi) == 0);
}
@ -358,7 +363,7 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
}
// The input (PrevToken) is one token from each query in the batch.
// Do not call DecodeStepT because it computes logits for token
// Do not call `SampleAndStream` because it computes logits for token
// probabilities, which are not required for the prompt tokens.
Transformer(config, runtime_config, weights, activations, qbatch, env);
}
@ -369,122 +374,143 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
}
// Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent
// `DecodeStepT`, and increments `MutablePos`. Also updates `non_eos` if the
// `Transformer`, and increments `MutablePos`. Also updates `non_eos` if the
// query is at the end of its sequence.
static void StreamAndUpdateEOS(const size_t qi, int token, const float prob,
const ModelConfig& config,
static void StreamAndUpdateEOS(const size_t qi, size_t pos, int token,
const float prob, const ModelConfig& config,
const RuntimeConfig& runtime_config,
QBatch& qbatch, hwy::BitSet4096<>& non_eos) {
QBatch& qbatch, bool update_pos,
hwy::BitSet4096<>& non_eos) {
HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called.
if (HWY_UNLIKELY(!runtime_config.StreamToken(qbatch.QueryIdx(qi),
qbatch.Pos(qi), token, prob))) {
if (HWY_UNLIKELY(
!runtime_config.StreamToken(qbatch.QueryIdx(qi), pos, token, prob))) {
// User decided to stop: set token to primary EOS to trigger IsEOS below.
token = config.eos_id;
HWY_DASSERT(config.IsEOS(token));
}
qbatch.PrevToken(qi) = token;
qbatch.MutablePos(qi) += 1;
qbatch.MutablePos(qi) += update_pos ? 1 : 0;
// Primary or secondary EOS: mark query as EOS, but still increment (for
// multi-turn, we should still keep the prior EOS).
if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi);
}
// For a batch of queries, runs Transformer, computes logits, samples and
// streams the token.
static void DecodeStepT(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const WeightsPtrs& weights,
const SampleFunc& sample_token,
Activations& activations, QBatch& qbatch,
MatMulEnv& env, hwy::BitSet4096<>& non_eos,
TimingInfo& timing_info) {
// Must be called after Transformer: either after prefill, or during decode.
// Computes logits, samples and streams the token.
static void SampleAndStream(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const WeightsPtrs& weights,
const SampleFunc& sample_token,
Activations& activations, QBatch& qbatch,
MatMulEnv& env, hwy::BitSet4096<>& non_eos,
TimingInfo& timing_info) {
HWY_DASSERT(qbatch.Size() == activations.x.Rows());
Transformer(config, runtime_config, weights, activations, qbatch, env);
RMSNormBatched(activations.x, weights.final_norm_scale, activations.x_bf,
env.ctx);
RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx);
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
runtime_config.activations_observer(
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations);
}
MaybeObserve(runtime_config, activations, qbatch, -1);
{
PROFILER_ZONE("Gen.EmbeddingMatmul");
static const auto zone = env.ctx.profiler.AddZone(
"Gen.EmbeddingMatmul", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, /*worker=*/0, zone);
// Compute logits from last layer activations.
CallMatMul(activations.x, weights.embedder_input_embedding,
CallMatMul(activations.x_bf, weights.embedder_input_embedding,
/*add=*/nullptr, env, activations.logits);
}
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
const size_t worker = 0; // TODO: parallelize
non_eos.Foreach([&](size_t qi) {
float* HWY_RESTRICT logits = activations.logits.Row(qi);
MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size,
env.ctx.profiler, worker);
const TokenAndProb tp = sample_token(logits, config.vocab_size);
timing_info.NotifyGenerated();
StreamAndUpdateEOS(qi, tp.token, tp.prob, config, runtime_config, qbatch,
non_eos);
MaybeLogitsSoftCapBatched(config.final_cap, activations.logits, non_eos,
env.ctx);
timing_info.NotifyGenerated(non_eos.Count());
ParallelFor(
ParallelismStrategy::kFlat, qbatch.Size(), env.ctx,
/*cluster_idx=*/0, [&](size_t qi, size_t worker) {
if (!non_eos.Get(qi)) return;
// We streamed all prefill tokens, but pos is still one behind
// because we started generation at pos = prompt.size() - 1.
// We want the pos argument to match the number of calls to
// `StreamToken`, as expected by the caller.
const size_t pos = qbatch.Pos(qi) + 1;
const TokenAndProb tp =
sample_token(qi, pos, activations.logits.RowSpan(qi), worker);
// `sampled` is padded, which prevents false sharing.
activations.sampled.Row(qi)[0] = static_cast<uint32_t>(pos);
activations.sampled.Row(qi)[1] = static_cast<uint32_t>(tp.token);
activations.sampled.Row(qi)[2] = hwy::BitCastScalar<uint32_t>(tp.prob);
});
// Sequentially, because `StreamToken` is not yet thread-safe.
non_eos.Foreach([&](size_t qi) {
const size_t pos = activations.sampled.Row(qi)[0];
const int token = static_cast<int>(activations.sampled.Row(qi)[1]);
const float prob =
hwy::BitCastScalar<float>(activations.sampled.Row(qi)[2]);
StreamAndUpdateEOS(qi, pos, token, prob, config, runtime_config, qbatch,
/*update_pos=*/true, non_eos);
});
}
static HWY_INLINE SampleFunc
ChooseSampleFunc(const RuntimeConfig& runtime_config, ThreadingContext& ctx) {
ChooseSampleFunc(const RuntimeConfig& runtime_config,
const AesCtrEngine& engine, ThreadingContext& ctx) {
// If user provided a sample_func, use it.
if (runtime_config.sample_func) return runtime_config.sample_func;
static const auto zone = ctx.profiler.AddZone("Gen.Sample Top1");
const size_t worker = 0; // TODO: parallelize
// Fast path for top-1 with no accept_token.
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
return [&](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE3(ctx.profiler, worker, zone);
return Top1OfSoftmax(logits, vocab_size);
};
return [&](size_t /*qi*/, size_t /*pos*/, Logits logits, size_t worker)
HWY_ATTR -> TokenAndProb {
PROFILER_ZONE3(ctx.profiler, worker,
GetProfilerZone(Zones::kGenSampleTop1));
return Top1OfSoftmax(logits);
};
}
// General case: Softmax with top-k sampling.
return [&](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE("Gen.Sample general");
return [&](size_t qi, size_t pos, Logits logits,
size_t worker) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE3(ctx.profiler, worker,
GetProfilerZone(Zones::kGenSampleTopK));
// We want a different sequence for each batch element and position.
const uint64_t stream = (static_cast<uint64_t>(qi) << 32) | pos;
RngStream gen(engine, stream);
return FusedSoftmaxAndSampleTopK(
logits, runtime_config.top_k, vocab_size, *runtime_config.gen,
runtime_config.temperature, runtime_config.accept_token, ctx.profiler,
worker);
logits, runtime_config.top_k, gen, runtime_config.temperature,
runtime_config.accept_token, ctx.profiler, worker);
};
}
// Decode: generates one continuation token for each query in `qbatch`.
static void GenerateT(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const WeightsPtrs& weights, Activations& activations,
QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) {
// Griffin assumes that the recurrent block cache is zero-initialized.
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
if (qbatch.MutablePos(qi) == 0) {
qbatch.KV(qi).ZeroGriffinCache(); // No-op for non-Griffin models.
}
}
const AesCtrEngine& engine, const WeightsPtrs& weights,
Activations& activations, QBatch& qbatch, MatMulEnv& env,
TimingInfo& timing_info) {
size_t max_prompt_size = 0;
bool all_prefix_end_are_zero = true;
size_t total_prefill_tokens = 0; // only for throughput stats.
const size_t seq_len = qbatch.KV(0).SeqLen();
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
const PromptTokens& prompt = qbatch.Prompt(qi);
// Sanity check: prompts should not be empty. Note that multi-turn prompts
// start with <end_of_turn>.
HWY_ASSERT(prompt.size() != 0);
max_prompt_size = HWY_MAX(max_prompt_size, prompt.size());
// Prefill stops before size - 1 because the last prompt token is the
// first input token for generation.
total_prefill_tokens += prompt.size() - 1;
// Sanity check: prompts should not be empty, nor start with EOS.
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
all_prefix_end_are_zero &= qbatch.PrefixEnd(qi) == 0;
// We use a single divisor, so all sequence lengths must be the same.
@ -518,8 +544,13 @@ static void GenerateT(const ModelConfig& config,
// Stream the last prompt token from each query, fill activations.gen_tokens.
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config,
runtime_config, qbatch, non_eos);
const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct.
// In autoregressive mode, we have not prefilled the last token, so do
// not advance.
const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi));
StreamAndUpdateEOS(qi, pos, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f,
config, runtime_config, qbatch, update_pos, non_eos);
}
size_t max_gen_steps = runtime_config.max_generated_tokens;
@ -529,30 +560,31 @@ static void GenerateT(const ModelConfig& config,
max_gen_steps = seq_len - max_prompt_size;
}
const SampleFunc sample_token = ChooseSampleFunc(runtime_config, env.ctx);
const SampleFunc sample_token =
ChooseSampleFunc(runtime_config, engine, env.ctx);
{
timing_info.generate_start = hwy::platform::Now();
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
DecodeStepT(config, runtime_config, weights, sample_token, activations,
qbatch, env, non_eos, timing_info);
}
timing_info.NotifyGenerateDone();
timing_info.generate_start = hwy::platform::Now();
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
Transformer(config, runtime_config, weights, activations, qbatch, env);
SampleAndStream(config, runtime_config, weights, sample_token, activations,
qbatch, env, non_eos, timing_info);
}
timing_info.NotifyGenerateDone();
}
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const ModelConfig& config,
const RuntimeConfig& runtime_config,
const WeightsPtrs& weights, KVCache& kv_cache,
MatMulEnv& env, TimingInfo& timing_info) {
const AesCtrEngine& engine, const WeightsPtrs& weights,
KVCache& kv_cache, MatMulEnv& env,
TimingInfo& timing_info) {
Activations activations(config, runtime_config.prefill_tbatch_size,
kv_cache.SeqLen(), env.ctx.allocator, env.row_ptrs);
kv_cache.SeqLen(), env.ctx, env.row_ptrs);
AllQueries all_queries(prompt, pos, prefix_end,
hwy::Span<KVCache>(&kv_cache, 1));
QBatch qbatch(/*start=*/0, /*max_size=*/1, all_queries);
GenerateT(config, runtime_config, weights, activations, qbatch, env,
GenerateT(config, runtime_config, engine, weights, activations, qbatch, env,
timing_info);
}
@ -560,19 +592,20 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
// queries, and calls `GenerateT` on each batch.
void GenerateBatchT(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const WeightsPtrs& weights, AllQueries& all_queries,
MatMulEnv& env, TimingInfo& timing_info) {
const AesCtrEngine& engine, const WeightsPtrs& weights,
AllQueries& all_queries, MatMulEnv& env,
TimingInfo& timing_info) {
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
runtime_config.prefill_tbatch_size);
Activations activations(config, max_batch_size,
all_queries[0].kv_cache.SeqLen(), env.ctx.allocator,
all_queries[0].kv_cache.SeqLen(), env.ctx,
env.row_ptrs);
for (size_t start = 0; start < all_queries.NumQueries();
start += runtime_config.decode_qbatch_size) {
QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries);
// Generate a batch of one token for each of `qbatch.Size()` queries.
GenerateT(config, runtime_config, weights, activations, qbatch, env,
GenerateT(config, runtime_config, engine, weights, activations, qbatch, env,
timing_info);
}
}
@ -589,8 +622,8 @@ void GenerateImageTokensT(const ModelConfig& config,
const size_t num_tokens = vit_config.max_seq_len;
prefill_runtime_config.prefill_tbatch_size =
num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
Activations prefill_activations(vit_config, num_tokens, num_tokens,
env.ctx.allocator, env.row_ptrs);
Activations prefill_activations(vit_config, num_tokens, num_tokens, env.ctx,
env.row_ptrs);
// Weights are for the full PaliGemma model, not just the ViT part.
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
prefill_activations, env);
@ -613,7 +646,8 @@ Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
model_(reader_, loader.tokenizer, loader.wrapping),
weights_(model_.Config()),
chat_template_(model_.Tokenizer(), model_.Config().model),
inference_(inference) {
inference_(inference),
aes_ctr_engine_(inference.deterministic) {
// Negligible CPU time in the ctor body (except ReadFromBlobs).
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference,
mat_owners_, ctx);
@ -637,9 +671,9 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
TimingInfo& timing_info) const {
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
HWY_DYNAMIC_DISPATCH(GenerateSingleT)(prompt, pos, prefix_end,
model_.Config(), runtime_config,
weights_, kv_cache, env, timing_info);
HWY_DYNAMIC_DISPATCH(GenerateSingleT)(
prompt, pos, prefix_end, model_.Config(), runtime_config, aes_ctr_engine_,
weights_, kv_cache, env, timing_info);
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}
@ -650,7 +684,8 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
HWY_DYNAMIC_DISPATCH(GenerateBatchT)(model_.Config(), runtime_config,
weights_, all_queries, env, timing_info);
aes_ctr_engine_, weights_, all_queries,
env, timing_info);
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}

View File

@ -127,7 +127,7 @@ class QBatch {
max_size_(max_size),
queries_(queries),
size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) {
HWY_ASSERT(max_size_ <= 4096); // non_eos uses `BitSet4096`.
HWY_ASSERT(max_size_ <= kMaxBatchSize);
HWY_DASSERT(size_ != 0);
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
}
@ -177,9 +177,11 @@ struct TimingInfo {
// be sure to populate prefill_start and generate_start before calling
// NotifyGenerated.
void NotifyGenerated() {
++tokens_generated;
if (HWY_UNLIKELY(tokens_generated == 1)) {
void NotifyGenerated(size_t batch_size) {
generation_steps += 1;
const bool is_first = (tokens_generated == 0);
tokens_generated += batch_size;
if (HWY_UNLIKELY(is_first)) {
time_to_first_token = hwy::platform::Now() - prefill_start;
if (verbosity >= 1) {
double prefill_tok_sec =
@ -191,7 +193,7 @@ struct TimingInfo {
prefill_tok_sec, static_cast<int>(time_to_first_token * 1000));
}
}
if (verbosity >= 2 && tokens_generated % 128 == 0) {
if (HWY_UNLIKELY(verbosity >= 2 && tokens_generated % 1024 == 0)) {
double gen_tok_sec = static_cast<double>(tokens_generated) /
(hwy::platform::Now() - generate_start);
fprintf(stderr,
@ -223,15 +225,16 @@ struct TimingInfo {
double time_to_first_token = 0;
double generate_duration = 0;
size_t tokens_generated = 0;
size_t generation_steps = 0;
};
// After construction, all methods are const and thread-compatible if using
// separate ThreadingContext for each thread.
// separate `ThreadingContext` and `MatMulEnv` for each concurrent `Generate`.
class Gemma {
public:
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
// `ctx` is only used to read tensors, but it is typically also referenced
// by the `MatMulEnv` passed to the Generate* methods.
// `ctx` is only used to read tensors and not stored. Calls to `Generate*`
// may reference the same, or other `ThreadingContext` via `MatMulEnv`.
Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
ThreadingContext& ctx);
~Gemma();
@ -247,6 +250,8 @@ class Gemma {
// `pos` is the position in the KV cache. Users are responsible for
// incrementing it in the `*StreamFunc`, or setting to zero for single-turn.
// All `Generate*` may be called concurrently if `env` and the
// `ThreadingContext` it references are both distinct.
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
size_t pos, KVCache& kv_cache, MatMulEnv& env,
TimingInfo& timing_info) const {
@ -275,6 +280,7 @@ class Gemma {
WeightsPtrs::Mode weight_read_mode_;
GemmaChatTemplate chat_template_;
InferenceArgs inference_;
AesCtrEngine aes_ctr_engine_;
};
} // namespace gcpp

View File

@ -22,11 +22,9 @@
#include <stdio.h>
#include <functional>
#include <random>
#include <string>
#include "io/io.h" // Path
#include "ops/matmul.h" // MMStorage::kMax*
#include "io/io.h" // Path
#include "util/args.h"
#include "util/basics.h" // Tristate
#include "util/mat.h"
@ -90,10 +88,10 @@ using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
// If not empty, AcceptFunc is called with token. It should return false for
// tokens you don't want to generate and true for tokens you want to generate.
using AcceptFunc = std::function<bool(int, float)>;
// If not empty, SampleFunc is called with the logits for the next token, which
// it may modify/overwrite, and its return value is the next generated token
// together with its probability.
using SampleFunc = std::function<TokenAndProb(float*, size_t)>;
// If not empty, SampleFunc is called concurrently from worker thread(s) with
// query_idx, pos, logits for the next token (which it may modify/overwrite),
// and worker. It returns the next generated token and its probability.
using SampleFunc = std::function<TokenAndProb(size_t, size_t, Logits, size_t)>;
// If not empty, LayersOutputFunc is called for layer outputs, specified with:
// - index of query within containing batch (if any); zero otherwise.
// - position in the tokens sequence
@ -116,6 +114,7 @@ using ActivationsObserverFunc =
struct RuntimeConfig {
// If non-null, `batch_stream_token` is called for each token in the batch,
// otherwise `stream_token`. `query_idx` is absolute, not batch-relative.
// This is called sequentially from the main thread.
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
PROFILER_ZONE("Gen.StreamToken");
if (batch_stream_token) {
@ -136,8 +135,7 @@ struct RuntimeConfig {
// Sampling-related parameters.
float temperature; // Temperature for sampling.
size_t top_k = 1; // Top-k for sampling.
std::mt19937* gen; // Random number generator used for sampling.
size_t top_k = 1; // Top-k for sampling.
int verbosity; // Controls verbosity of printed messages.
@ -183,6 +181,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
bool multiturn;
Path image_file;
int port; // Server port
std::string model; // Model name for API endpoints
std::string prompt; // Bypasses std::getline
// For prompts longer than the Linux terminal's 4K line edit buffer.
Path prompt_file;
@ -218,6 +218,12 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
"resets every turn)");
visitor(image_file, "image_file", Path(), "Image file to load.");
// Since it is not used in the CLI version, the print_verbosity is set
// higher than others.
visitor(port, "port", 8080, "Server port (default: 8080)", 3);
visitor(model, "model", std::string("gemma3-4b"),
"Model name for API endpoints (default: gemma3-4b)", 3);
visitor(prompt, "prompt", std::string(""),
"Initial prompt for non-interactive mode. When specified, "
"generates a response and exits.",
@ -240,17 +246,17 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
runtime_config.max_generated_tokens = max_generated_tokens;
runtime_config.prefill_tbatch_size = prefill_tbatch_size;
runtime_config.decode_qbatch_size = decode_qbatch_size;
if (prefill_tbatch_size > MMStorage::kMaxM) {
if (prefill_tbatch_size > kMaxBatchSize) {
HWY_ABORT(
"prefill_tbatch_size %zu > kMaxM %zu: specify a smaller value, "
"or increase the constant in MMStorage.\n",
prefill_tbatch_size, MMStorage::kMaxM);
"prefill_tbatch_size %zu > kMaxBatchSize %zu: specify a "
"smaller value, or increase kMaxBatchSize.\n",
prefill_tbatch_size, kMaxBatchSize);
}
if (decode_qbatch_size > MMStorage::kMaxM) {
if (decode_qbatch_size > kMaxBatchSize) {
HWY_ABORT(
"decode_qbatch_size %zu > kMaxM %zu: specify a smaller value, "
"or increase the constant in MMStorage.\n",
decode_qbatch_size, MMStorage::kMaxM);
"decode_qbatch_size %zu > kMaxBatchSize %zu: specify a "
"smaller value, or increase kMaxBatchSize.\n",
decode_qbatch_size, kMaxBatchSize);
}
runtime_config.temperature = temperature;
@ -258,6 +264,35 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
}
};
struct ClientArgs : public ArgsBase<ClientArgs> {
ClientArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
ClientArgs() { Init(); };
std::string host;
int port;
std::string api_key;
std::string model;
std::string prompt;
bool interactive;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(host, "host", std::string("localhost"),
"Server host (default: localhost)");
visitor(port, "port", 8080,
"Server port (default: 8080)");
visitor(api_key, "api_key", std::string(""),
"Use public API with key (changes host to "
"generativelanguage.googleapis.com:443)");
visitor(model, "model", std::string("gemma3-4b"),
"Model name to use (default: gemma3-4b)");
visitor(prompt, "prompt", std::string("Hello! How are you?"),
"Prompt for generation (default: 'Hello! How are you?')");
visitor(interactive, "interactive", false,
"Start interactive chat mode (0 = no, 1 = yes)");
}
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_

View File

@ -1,192 +0,0 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stddef.h>
#include <stdint.h>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
#include "gemma/activations.h"
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/weights.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/griffin.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "ops/matvec-inl.h"
#include "ops/ops-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
void GriffinRecurrent(size_t num_tokens, size_t griffin_layer,
const LayerWeightsPtrs* layer_weights,
Activations& activations, QBatch& qbatch,
MatMulEnv& env) {
PROFILER_ZONE("Gen.Griffin");
hwy::ThreadPool& pool = env.ctx.pools.Pool(0);
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D df;
const size_t model_dim = layer_weights->layer_config.model_dim;
HWY_DASSERT(model_dim % hn::Lanes(df) == 0);
const size_t heads = layer_weights->layer_config.heads;
const size_t conv_1d_width = layer_weights->layer_config.conv1d_width;
HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even");
const size_t kHeadDim = model_dim / heads;
const size_t kMatrixSize = kHeadDim * kHeadDim;
const size_t num_interleaved = num_tokens * qbatch.Size();
const hwy::Divisor div_qbatch(static_cast<uint32_t>(qbatch.Size()));
GriffinActivations& griffin = activations.griffin;
// X / Y linear layers.
// TODO: MatMul
HWY_DASSERT(griffin.griffin_y.Rows() == griffin.griffin_x.Rows());
HWY_DASSERT(num_interleaved == griffin.griffin_y.Rows());
CallUpcastedSame(
&layer_weights->griffin.linear_x_w, &layer_weights->griffin.linear_y_w,
[&](const auto* wx, const auto* wy) {
for (size_t r = 0; r < num_interleaved; ++r) {
float* HWY_RESTRICT y = griffin.griffin_y.Row(r);
float* HWY_RESTRICT x = griffin.griffin_x.Row(r);
TwoMatVecAdd(
*wx, *wy, 0, model_dim, model_dim,
activations.attention.pre_att_rms_out.Row(r),
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
/*out0=*/x, /*out1=*/y, pool);
Gelu(y, model_dim);
}
});
// Conv1D.
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t pos = qbatch.Pos(qi) + batch_idx;
float* HWY_RESTRICT x = griffin.griffin_x.Row(qi);
// cache[i] = input at time t-i.
float* HWY_RESTRICT cache[kMaxConv1DWidth];
cache[0] = x;
for (size_t i = 1; i < conv_1d_width; i++) {
cache[i] =
qbatch.KV(qi).conv1d_cache.Row(griffin_layer) +
((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim;
}
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
auto xv = hn::Load(df, x + i);
auto accum0 =
hn::Load(df, layer_weights->griffin.conv_biases.PackedScale1() + i);
auto accum1 = hn::Zero(df);
for (size_t l = 0; 2 * l < conv_1d_width; l++) {
auto wv0 =
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
(conv_1d_width - 1 - 2 * l) * model_dim + i);
auto wv1 =
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
(conv_1d_width - 2 - 2 * l) * model_dim + i);
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
}
hn::Store(hn::Add(accum0, accum1), df, x + i);
hn::Store(xv, df, cache[HWY_MAX(conv_1d_width, 1) - 1] + i);
}
}
// RGLRU
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t pos = qbatch.Pos(qi) + batch_idx;
float* HWY_RESTRICT x = griffin.griffin_x.Row(qi);
float* HWY_RESTRICT y = griffin.griffin_y.Row(qi);
float* HWY_RESTRICT gate_x = griffin.griffin_gate_x.Row(qi);
float* HWY_RESTRICT a = griffin.griffin_multiplier.Row(qi);
float* HWY_RESTRICT rnn_state =
qbatch.KV(qi).rglru_cache.Row(griffin_layer);
pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
size_t head_offset = head * kHeadDim;
CallUpcasted(&layer_weights->griffin.gate_w, [&](const auto* gate_w) {
TwoOfsMatVecAddLoop(
*gate_w, kMatrixSize * head, kMatrixSize * (heads + head), kHeadDim,
kHeadDim, x + head_offset,
/*add0=*/layer_weights->griffin.gate_biases.PackedScale1() +
head_offset,
/*add1=*/layer_weights->griffin.gate_biases.PackedScale1() +
model_dim + head_offset,
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
});
Sigmoid(gate_x + head_offset, kHeadDim);
Sigmoid(a + head_offset, kHeadDim);
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
HWY_ATTR { return hn::Mul(x, gate_x); };
hn::Transform1(D(), a + head_offset, kHeadDim,
layer_weights->griffin.a.PackedScale1() + head_offset,
fn_mul);
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
fn_mul);
// RNN scan
HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0);
for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) {
auto log_a = hn::Load(df, a + head_offset + i);
auto gated_x = hn::Load(df, x + head_offset + i);
auto rnn = hn::Load(df, rnn_state + head_offset + i);
auto a = hn::Exp(df, log_a);
auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0f)));
if (pos == 0) {
x_multiplier = hn::Set(df, 1.0f);
}
auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn));
hn::Store(new_x, df, rnn_state + head_offset + i);
// Join branches.
auto yv = hn::Load(df, y + head_offset + i);
auto pre_out = hn::Mul(yv, new_x);
hn::Store(pre_out, df, x + head_offset + i);
}
});
} // interleaved_idx
// Final linear layer.
CallMatMul(griffin.griffin_x, layer_weights->griffin.linear_out_w,
layer_weights->griffin.linear_out_biases.PackedScale1(), env,
activations.attention.att_sums);
} // GriffinRecurrent
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();

View File

@ -1,47 +0,0 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_
// Declares GriffinRecurrent for all SIMD targets.
#include <stddef.h>
#include "gemma/gemma.h"
#include "hwy/highway.h"
namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_GRIFFIN(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void GriffinRecurrent(size_t num_tokens, size_t griffin_layer, \
const LayerWeightsPtrs* layer_weights, \
Activations& activations, QBatch& qbatch, \
MatMulEnv& env); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE
// Function declarations for each SIMD target. Allows direct call from the
// per-target namespace. We may later replace this with dynamic dispatch if
// the overhead is acceptable.
HWY_VISIT_TARGETS(GEMMA_DECL_GRIFFIN)
#undef GEMMA_DECL_GRIFFIN
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_

View File

@ -24,26 +24,6 @@
namespace gcpp {
void KVCache::ZeroGriffinCache() {
if (conv1d_cache.Rows() == 0) return;
ZeroInit(conv1d_cache);
ZeroInit(rglru_cache);
}
static size_t GriffinLayers(const ModelConfig& config) {
return config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock);
}
static size_t GriffinConv1dCols(const ModelConfig& config) {
size_t conv1d_width = 0;
for (const auto& layer_config : config.layer_configs) {
conv1d_width = HWY_MAX(conv1d_width, layer_config.conv1d_width);
}
// The row offset, in blocks of model_dim is computed mod (conv1d_width - 1),
// hence allocate conv1d_width * model_dim total columns.
return conv1d_width * config.model_dim;
}
// Number of rows for KV cache. Note that both rows and cols are u32, and
// the total number of elements can exceed 2^32.
static size_t CappedSeqLen(const ModelConfig& config,
@ -56,30 +36,18 @@ static size_t CappedSeqLen(const ModelConfig& config,
return inference_args.seq_len;
}
KVCache::KVCache(const Extents2D& conv1d_extents,
const Extents2D& rglru_extents, const Extents2D& kv_extents,
const Allocator& allocator)
: conv1d_cache("conv1d_cache", conv1d_extents, allocator, MatPadding::kOdd),
rglru_cache("rglru_cache", rglru_extents, allocator, MatPadding::kOdd),
kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
KVCache::KVCache(const Extents2D& kv_extents, const Allocator& allocator)
: kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
allocator_(allocator) {}
KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
const Allocator& allocator)
: KVCache(
Extents2D(GriffinLayers(config), GriffinConv1dCols(config)),
Extents2D(GriffinLayers(config), config.model_dim),
Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()),
allocator) {}
KVCache KVCache::Copy() {
KVCache copy(conv1d_cache.Extents(), rglru_cache.Extents(),
kv_cache.Extents(), allocator_);
if (conv1d_cache.Rows() != 0) {
CopyMat(conv1d_cache, copy.conv1d_cache);
CopyMat(rglru_cache, copy.rglru_cache);
}
KVCache copy(kv_cache.Extents(), allocator_);
CopyMat(kv_cache, copy.kv_cache);

View File

@ -35,24 +35,15 @@ struct KVCache {
// copy ctor to make the cost explicit.
KVCache Copy();
// Zero-initialize the Griffin recurrent block cache, i.e. the conv1d_cache
// and rglru_cache.
void ZeroGriffinCache();
size_t SeqLen() const { return kv_cache.Rows(); }
// [griffin_layers, griffin_conv1d_cols * model_dim]
MatStorageT<float> conv1d_cache;
MatStorageT<float> rglru_cache; // [griffin_layers, model_dim]
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
private:
const Allocator& allocator_;
// For use by other ctor and Copy()
KVCache(const Extents2D& conv1d_extents, const Extents2D& rglru_extents,
const Extents2D& kv_extents, const Allocator& allocator);
KVCache(const Extents2D& kv_extents, const Allocator& allocator);
};
} // namespace gcpp

View File

@ -112,6 +112,8 @@ class TypePrefix {
return Type::kSFP;
case '2':
return Type::kNUQ;
case 'I':
return Type::kI8;
default:
// The other types were not written to pre-2025 files, hence no need to
// encode and check for them here.
@ -221,9 +223,6 @@ static int DeduceLayerTypes(const BlobReader& reader) {
int layer_types = 0;
for (size_t key_idx = 0; key_idx < reader.Keys().size(); ++key_idx) {
const std::string& key = reader.Keys()[key_idx];
if (key.find("gr_conv_w") != std::string::npos) { // NOLINT
return kDeducedGriffin;
}
if (key.find("qkv_ein_w") != std::string::npos) { // NOLINT
layer_types |= kDeducedViT;
}
@ -293,7 +292,7 @@ static std::vector<float> ReadScales(BlobReader& reader,
const ModelConfig& config) {
std::vector<float> scales;
// Check first to prevent `CallWithSpan` from printing a warning. This blob is
// optional even in pre-2025 format; Griffin was the first to include it.
// optional even in pre-2025 format.
if (reader.Find(kDecoratedScalesName)) {
HWY_ASSERT(reader.CallWithSpan<float>(
kDecoratedScalesName,

View File

@ -18,7 +18,6 @@
#include <stdio.h>
#include <iostream>
#include <random>
#include <string>
#include <string_view>
#include <vector>
@ -98,9 +97,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
size_t prompt_size = 0;
const ModelConfig& config = gemma.Config();
std::mt19937 gen;
InitGenerator(inference, gen);
const bool have_image = !inference.image_file.path.empty();
Image image;
const size_t pool_dim = config.vit_config.pool_dim;
@ -117,8 +113,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
HWY_ASSERT(image.ReadPPM(inference.image_file.path));
const size_t image_size = config.vit_config.image_size;
image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {.gen = &gen,
.verbosity = inference.verbosity,
RuntimeConfig runtime_config = {.verbosity = inference.verbosity,
.use_spinning = threading.spin};
double image_tokens_start = hwy::platform::Now();
gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image,
@ -132,7 +127,12 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
}
// callback function invoked for each generated token.
auto stream_token = [&](int token, float) {
auto batch_stream_token = [&](size_t query_idx, size_t pos, int token,
float) {
std::string token_text;
HWY_ASSERT(gemma.Tokenizer().Decode(std::vector<int>{token}, &token_text));
HWY_ASSERT(pos == abs_pos);
++abs_pos;
const bool in_prompt = tokens_generated_this_turn < prompt_size;
const bool first_response_token = tokens_generated_this_turn == prompt_size;
@ -148,8 +148,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
}
return true;
}
std::string token_text;
HWY_ASSERT(gemma.Tokenizer().Decode(std::vector<int>{token}, &token_text));
if (first_response_token) {
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
if (inference.verbosity >= 1) {
@ -185,9 +183,8 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
// Set up runtime config.
TimingInfo timing_info = {.verbosity = inference.verbosity};
RuntimeConfig runtime_config = {.gen = &gen,
.verbosity = inference.verbosity,
.stream_token = stream_token,
RuntimeConfig runtime_config = {.verbosity = inference.verbosity,
.batch_stream_token = batch_stream_token,
.use_spinning = threading.spin};
inference.CopyTo(runtime_config);
std::vector<int> prompt;
@ -223,6 +220,9 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
if (inference.verbosity >= 1) {
std::cerr << "\n[ Reading prompt ] " << std::flush;
}
// -1 because our prefill does not generate KVs for the last token. Do not
// just pass abs_pos - 1 because our callback checks pos == abs_pos.
if (abs_pos > 0) --abs_pos;
gemma.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache, env,
timing_info);
std::cout << "\n\n";
@ -233,7 +233,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
// Prepare for the next turn. Works only for PaliGemma.
if (!inference.multiturn || config.wrapping == PromptWrapping::PALIGEMMA) {
abs_pos = 0; // Start a new turn at position 0.
InitGenerator(inference, gen);
} else {
// The last token was either EOS, then it should be ignored because it is
// never part of the dialog, see Table 5 in the Gemma-2 paper:
@ -255,7 +254,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
ThreadingContext ctx(threading);
MatMulEnv env(ctx);
if (inference.verbosity >= 2) env.print_best = true;
if (inference.verbosity >= 3) env.print_best = true;
const Gemma gemma(loader, inference, ctx);
KVCache kv_cache(gemma.Config(), inference, ctx.allocator);

View File

@ -277,122 +277,6 @@ void TensorInfoRegistry::AddImageLayerTensors(const ModelConfig& config,
});
}
void TensorInfoRegistry::AddGriffinLayerTensors(const LayerConfig& layer_config,
const size_t layer_idx) {
const std::string suffix = LayerSuffix(layer_idx);
Add(suffix, {
.base_name = "gr_lin_x_w",
.source_names = {"recurrent_block/linear_x/kernel"},
.axes = {1, 0},
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
});
Add(suffix, {
.base_name = "gr_lin_x_b",
.source_names = {"recurrent_block/linear_x/bias"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "gr_lin_y_w",
.source_names = {"recurrent_block/linear_y/kernel"},
.axes = {1, 0},
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
});
Add(suffix, {
.base_name = "gr_lin_y_b",
.source_names = {"recurrent_block/linear_y/bias"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "gr_lin_out_w",
.source_names = {"recurrent_block/linear_out/kernel"},
.axes = {1, 0},
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
});
Add(suffix, {
.base_name = "gr_lin_out_b",
.source_names = {"recurrent_block/linear_out/bias"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
});
Add(suffix,
{
.base_name = "gr_conv_w",
.source_names = {"recurrent_block/conv_1d/w"},
.axes = {0, 1},
.shape = {layer_config.conv1d_width, layer_config.griffin_dim},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "gr_conv_b",
.source_names = {"recurrent_block/conv_1d/b"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "gr1_gate_w",
.source_names = {"recurrent_block/rg_lru/input_gate/w"},
.axes = {0, 2, 1},
.shape = {layer_config.heads,
layer_config.griffin_dim / layer_config.heads,
layer_config.griffin_dim / layer_config.heads},
.concat_names = {"gr_gate_w", "gr2_gate_w"},
});
Add(suffix, {
.base_name = "gr2_gate_w",
.source_names = {"recurrent_block/rg_lru/a_gate/w"},
.axes = {0, 2, 1},
.shape = {layer_config.heads,
layer_config.griffin_dim / layer_config.heads,
layer_config.griffin_dim / layer_config.heads},
.concat_names = {""},
});
Add(suffix, {
.base_name = "gr_gate_w",
.source_names = {"recurrent_block/rg_lru/gate/w"},
.axes = {0, 2, 1},
.shape = {2 * layer_config.heads,
layer_config.griffin_dim / layer_config.heads,
layer_config.griffin_dim / layer_config.heads},
});
Add(suffix, {
.base_name = "gr1_gate_b",
.source_names = {"recurrent_block/rg_lru/input_gate/b"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.concat_names = {"gr_gate_b", "gr2_gate_b"},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "gr2_gate_b",
.source_names = {"recurrent_block/rg_lru/a_gate/b"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.concat_names = {""},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "gr_gate_b",
.source_names = {"recurrent_block/rg_lru/input_gate/b"},
.axes = {0, 1},
.shape = {2 * layer_config.griffin_dim},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "gr_a",
.source_names = {"recurrent_block/rg_lru/a_param"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
.scaled_softplus = true,
});
}
void TensorInfoRegistry::AddLayerTensors(const ModelConfig& config,
const LayerConfig& layer_config,
const size_t layer_idx) {
@ -553,10 +437,6 @@ void TensorInfoRegistry::AddLayerTensors(const ModelConfig& config,
.shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim},
.cols_take_extra_dims = true,
});
if (config.model == Model::GRIFFIN_2B) {
AddGriffinLayerTensors(layer_config, layer_idx);
}
}
TensorInfoRegistry::TensorInfoRegistry(const ModelConfig& config) {

View File

@ -46,7 +46,7 @@ struct TensorInfo {
// The highest permissible compression for this tensor. The default is
// kNUQ, which provides maximum compression. Other values such as kBF16
// or kF32 can be used to limit the compression to a specific type.
Type min_size = Type::kNUQ;
Type min_size = Type::kI8;
// Whether to apply scaled softplus to the data.
bool scaled_softplus = false;
// Whether the columns or the rows take any extra dimensions.
@ -124,8 +124,6 @@ class TensorInfoRegistry {
void AddModelTensors(const ModelConfig& config);
void AddLayerTensors(const ModelConfig& config,
const LayerConfig& layer_config, size_t layer_idx);
void AddGriffinLayerTensors(const LayerConfig& layer_config,
size_t layer_idx);
void AddImageLayerTensors(const ModelConfig& config,
const LayerConfig& layer_config,

View File

@ -95,7 +95,7 @@ class VitAttention {
float* HWY_RESTRICT q =
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
// TODO: shift to MatMul with A.scale once MatMul is confirmed working
MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker);
MulByConst(query_scale, q, qkv_dim);
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
});
@ -110,8 +110,7 @@ class VitAttention {
CallMatMul(Q, K, nullptr, env_, C);
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
float* HWY_RESTRICT c = C.Row(task);
Softmax(c, C.Cols(), env_.ctx.profiler, worker);
Softmax(C.RowSpan(task), env_.ctx.profiler, worker);
});
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
@ -121,8 +120,7 @@ class VitAttention {
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim,
env_.ctx.profiler, worker);
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim);
}
});
}
@ -145,7 +143,7 @@ class VitAttention {
// Compute Q.K scores, which are "logits" stored in head_att.
float* HWY_RESTRICT q =
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker);
MulByConst(query_scale, q, qkv_dim);
float* HWY_RESTRICT head_att =
activations_.attention.att.Row(token) + head * seq_len;
for (size_t i = 0; i < seq_len; ++i) {
@ -154,7 +152,7 @@ class VitAttention {
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
}
// SoftMax yields "probabilities" in head_att.
Softmax(head_att, seq_len, env_.ctx.profiler, worker);
Softmax(Logits(head_att, seq_len), env_.ctx.profiler, worker);
// Compute weighted sum of v into att_out.
float* HWY_RESTRICT att_out =
activations_.attention.att_out.Row(token) + head * qkv_dim;
@ -162,8 +160,7 @@ class VitAttention {
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim,
env_.ctx.profiler, worker);
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
}
});
}
@ -335,8 +332,9 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights,
// Apply soft embedding norm before input projection.
CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), 0, activations.x.Row(0),
vit_model_dim, env.ctx.profiler, hwy::Profiler::Thread());
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0,
activations.x.Row(0), vit_model_dim, env.ctx.profiler,
hwy::Profiler::GlobalIdx());
});
}

View File

@ -30,9 +30,9 @@
#include "gemma/gemma_args.h"
#include "gemma/model_store.h"
#include "io/blob_store.h"
#include "ops/matmul.h" // MMParallel
#include "util/mat.h"
#include "util/threading_context.h"
#include "util/zones.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
@ -88,7 +88,7 @@ void LayerWeightsPtrs::InitAttWeights(std::vector<MatOwner>& mat_owners,
// For FFN. Fast, only updates pointers.
void LayerWeightsPtrs::SplitW1() {
// Used for Gemma and Griffin layers; FFWVit uses different tensors.
// Used for Gemma layers; FFWVit uses different tensors.
if (layer_config.type == LayerAttentionType::kVit) return;
// Files have both or neither of w1 and w2.
@ -147,15 +147,222 @@ void LayerWeightsPtrs::SplitAttW1() {
qkv_einsum_w.SetPtr(nullptr, qkv_einsum_w.Cols());
}
static void HWY_MAYBE_UNUSED InitAttWeightsI8(
const LayerConfig& layer_config, MatPtrT<I8Stream>& attn_vec_einsum_w,
MatPtrT<I8Stream>& att_weights, std::vector<MatOwner>& mat_owners,
const Allocator& allocator) {
if (!attn_vec_einsum_w.HasPtr()) return;
HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kI8);
att_weights.SetType(Type::kI8);
{
static std::mutex m;
std::lock_guard<std::mutex> lock(m);
mat_owners.emplace_back();
mat_owners.back().AllocateFor(att_weights, allocator, MatPadding::kPacked);
}
const size_t model_dim = layer_config.model_dim;
const size_t heads = layer_config.heads;
const size_t qkv_dim = layer_config.qkv_dim;
// Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim].
hwy::AlignedFreeUniquePtr<float[]> attn_vec_einsum_w_tmp =
hwy::AllocateAligned<float>(model_dim * heads * qkv_dim);
hwy::AlignedFreeUniquePtr<float[]> att_weights_tmp =
hwy::AllocateAligned<float>(model_dim * heads * qkv_dim);
const hwy::HWY_NAMESPACE::ScalableTag<float> df;
HWY_NAMESPACE::DecompressAndZeroPad(df, attn_vec_einsum_w.Span(), 0,
attn_vec_einsum_w_tmp.get(),
model_dim * heads * qkv_dim);
for (size_t m = 0; m < model_dim; ++m) {
float* HWY_RESTRICT out_row = att_weights_tmp.get() + m * heads * qkv_dim;
for (size_t h = 0; h < heads; ++h) {
hwy::CopyBytes(
attn_vec_einsum_w_tmp.get() + h * model_dim * qkv_dim + m * qkv_dim,
out_row + h * qkv_dim, qkv_dim * sizeof(float));
}
}
CompressWorkingSet work;
hwy::ThreadPool pool(0);
HWY_NAMESPACE::Compress(att_weights_tmp.get(), model_dim * heads * qkv_dim,
work, att_weights.Span(),
/*packed_ofs=*/0, pool);
att_weights.SetScale(attn_vec_einsum_w.Scale());
}
static void HWY_MAYBE_UNUSED SplitW1I8(const LayerConfig& layer_config,
MatPtrT<I8Stream>& gating_einsum_w,
MatPtrT<I8Stream>& gating_einsum_w1,
MatPtrT<I8Stream>& gating_einsum_w2,
std::vector<MatOwner>& mat_owners,
const Allocator& allocator) {
// Files have both or neither of w1 and w2.
HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr());
// w is mutually exclusive with w1 and w2 in the file.
HWY_ASSERT(gating_einsum_w.HasPtr() ^ gating_einsum_w1.HasPtr());
// Done if we already read split tensors.
if (gating_einsum_w1.HasPtr() && !gating_einsum_w.HasPtr()) return;
// Nothing to do if w is not present.
if (!gating_einsum_w.HasPtr()) return;
HWY_ASSERT(gating_einsum_w.GetType() == Type::kI8);
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
const size_t model_dim = gating_einsum_w.Cols();
HWY_ASSERT(gating_einsum_w.Rows() == 2 * ff_hidden_dim);
HWY_ASSERT(gating_einsum_w1.Rows() == ff_hidden_dim);
HWY_ASSERT(gating_einsum_w2.Rows() == ff_hidden_dim);
HWY_ASSERT(gating_einsum_w1.Cols() == model_dim);
HWY_ASSERT(gating_einsum_w2.Cols() == model_dim);
gating_einsum_w1.SetType(Type::kI8);
gating_einsum_w2.SetType(Type::kI8);
{
static std::mutex m;
std::lock_guard<std::mutex> lock(m);
mat_owners.emplace_back();
mat_owners.back().AllocateFor(gating_einsum_w1, allocator,
MatPadding::kPacked);
mat_owners.emplace_back();
mat_owners.back().AllocateFor(gating_einsum_w2, allocator,
MatPadding::kPacked);
}
const size_t total_size = gating_einsum_w.Rows() * gating_einsum_w.Cols();
hwy::AlignedFreeUniquePtr<float[]> w_tmp =
hwy::AllocateAligned<float>(total_size);
const hwy::HWY_NAMESPACE::ScalableTag<float> df;
HWY_NAMESPACE::DecompressAndZeroPad(df, gating_einsum_w.Span(), 0,
w_tmp.get(), total_size);
const size_t split_size = ff_hidden_dim * model_dim;
float* w1_tmp = w_tmp.get();
float* w2_tmp = w_tmp.get() + split_size;
CompressWorkingSet work;
hwy::ThreadPool pool(0);
HWY_NAMESPACE::Compress(w1_tmp, split_size, work, gating_einsum_w1.Span(), 0,
pool);
HWY_NAMESPACE::Compress(w2_tmp, split_size, work, gating_einsum_w2.Span(), 0,
pool);
gating_einsum_w1.SetScale(1.0f);
gating_einsum_w2.SetScale(1.0f);
gating_einsum_w.SetPtr(nullptr, gating_einsum_w.Cols());
}
static void HWY_MAYBE_UNUSED SplitAttW1I8(const LayerConfig& layer_config,
MatPtrT<I8Stream>& qkv_einsum_w,
MatPtrT<I8Stream>& qkv_einsum_w1,
MatPtrT<I8Stream>& qkv_einsum_w2,
std::vector<MatOwner>& mat_owners,
const Allocator& allocator) {
// w is mutually exclusive with w1 in the file.
HWY_ASSERT(qkv_einsum_w.HasPtr() ^ qkv_einsum_w1.HasPtr());
// Done if we already read split tensors.
if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return;
// Nothing to do if w is not present.
if (!qkv_einsum_w.HasPtr()) return;
HWY_ASSERT(qkv_einsum_w.GetType() == Type::kI8);
const size_t model_dim = qkv_einsum_w.Cols();
const size_t w1_rows = layer_config.heads * layer_config.qkv_dim;
const size_t w2_rows = layer_config.kv_heads * 2 * layer_config.qkv_dim;
HWY_ASSERT(qkv_einsum_w.Rows() == w1_rows + w2_rows);
HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows);
HWY_ASSERT(qkv_einsum_w2.Rows() == w2_rows);
HWY_ASSERT(qkv_einsum_w1.Cols() == model_dim);
HWY_ASSERT(qkv_einsum_w2.Cols() == model_dim);
qkv_einsum_w1.SetType(Type::kI8);
qkv_einsum_w2.SetType(Type::kI8);
{
static std::mutex m;
std::lock_guard<std::mutex> lock(m);
mat_owners.emplace_back();
mat_owners.back().AllocateFor(qkv_einsum_w1, allocator,
MatPadding::kPacked);
mat_owners.emplace_back();
mat_owners.back().AllocateFor(qkv_einsum_w2, allocator,
MatPadding::kPacked);
}
const size_t total_size = qkv_einsum_w.Rows() * qkv_einsum_w.Cols();
hwy::AlignedFreeUniquePtr<float[]> w_tmp =
hwy::AllocateAligned<float>(total_size);
const hwy::HWY_NAMESPACE::ScalableTag<float> df;
HWY_NAMESPACE::DecompressAndZeroPad(df, qkv_einsum_w.Span(), 0, w_tmp.get(),
total_size);
const size_t w1_size = w1_rows * model_dim;
const size_t w2_size = w2_rows * model_dim;
float* w1_tmp = w_tmp.get();
float* w2_tmp = w_tmp.get() + w1_size;
CompressWorkingSet work;
hwy::ThreadPool pool(0);
HWY_NAMESPACE::Compress(w1_tmp, w1_size, work, qkv_einsum_w1.Span(), 0, pool);
HWY_NAMESPACE::Compress(w2_tmp, w2_size, work, qkv_einsum_w2.Span(), 0, pool);
qkv_einsum_w1.SetScale(1.0f);
qkv_einsum_w2.SetScale(1.0f);
qkv_einsum_w.SetPtr(nullptr, qkv_einsum_w.Cols());
}
// Must be called after reading weights via `ForEachTensor`.
// TODO: exporters should bake this into the weights already.
// WARNING: called from multiple threads; `mat_owners` requires a lock.
void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
const Allocator& allocator) {
// TODO(janwas): handle NUQ
InitAttWeights(mat_owners, allocator);
SplitW1();
SplitAttW1();
if (attn_vec_einsum_w.GetType() == Type::kI8) {
MatPtrT<I8Stream> attn_vec_einsum_w_i8(attn_vec_einsum_w);
MatPtrT<I8Stream> att_weights_i8(att_weights);
InitAttWeightsI8(layer_config, attn_vec_einsum_w_i8, att_weights_i8,
mat_owners, allocator);
attn_vec_einsum_w = attn_vec_einsum_w_i8;
att_weights = att_weights_i8;
} else {
InitAttWeights(mat_owners, allocator);
}
if (gating_einsum_w.GetType() == Type::kI8) {
MatPtrT<I8Stream> gating_einsum_w_i8(gating_einsum_w);
MatPtrT<I8Stream> gating_einsum_w1_i8(gating_einsum_w1);
MatPtrT<I8Stream> gating_einsum_w2_i8(gating_einsum_w2);
SplitW1I8(layer_config, gating_einsum_w_i8, gating_einsum_w1_i8,
gating_einsum_w2_i8, mat_owners, allocator);
gating_einsum_w = gating_einsum_w_i8;
gating_einsum_w1 = gating_einsum_w1_i8;
gating_einsum_w2 = gating_einsum_w2_i8;
} else {
SplitW1();
}
if (qkv_einsum_w.GetType() == Type::kI8) {
MatPtrT<I8Stream> qkv_einsum_w_i8(qkv_einsum_w);
MatPtrT<I8Stream> qkv_einsum_w1_i8(qkv_einsum_w1);
MatPtrT<I8Stream> qkv_einsum_w2_i8(qkv_einsum_w2);
SplitAttW1I8(layer_config, qkv_einsum_w_i8, qkv_einsum_w1_i8,
qkv_einsum_w2_i8, mat_owners, allocator);
qkv_einsum_w = qkv_einsum_w_i8;
qkv_einsum_w1 = qkv_einsum_w1_i8;
qkv_einsum_w2 = qkv_einsum_w2_i8;
} else {
SplitAttW1();
}
}
static void HWY_MAYBE_UNUSED InitAttWeightsNUQ(
@ -226,15 +433,16 @@ void WeightsPtrs::CopyFrom(const WeightsPtrs& other) {
// ideally already happen in the importer. Called by `ReadFromBlobs`.
void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
ThreadingContext& ctx) {
// TODO: use 1D parallel-for helper function
hwy::ThreadPool& pool = ctx.pools.Pool();
pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
GetLayer(layer)->Fixup(mat_owners, ctx.allocator);
});
const size_t cluster_idx = 0;
ParallelFor(ParallelismStrategy::kFlat, c_layers.size(), ctx, cluster_idx,
[&](uint64_t layer, size_t /*worker*/) {
GetLayer(layer)->Fixup(mat_owners, ctx.allocator);
});
pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
VitLayer(layer)->Fixup(mat_owners, ctx.allocator);
});
ParallelFor(ParallelismStrategy::kFlat, vit_layers.size(), ctx, cluster_idx,
[&](uint64_t layer, size_t /*worker*/) {
VitLayer(layer)->Fixup(mat_owners, ctx.allocator);
});
}
std::vector<uint32_t> WeightsPtrs::AddTensorDataToWriter(
@ -320,8 +528,6 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
const size_t start = owners.size();
owners.resize(start + tensors.size());
MMParallel parallel(ctx);
// Allocate in parallel because faulting in large tensors is slow.
ctx.pools.Pool().Run(
0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
@ -339,7 +545,6 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
owners[start + task].AllocateFor(*tensor.mat, ctx.allocator,
tensor.padding);
BindB(*tensor.mat, tensor.mat->ElementBytes(), parallel);
});
}
@ -382,41 +587,46 @@ static void DecompressToBF16(MatPtr& mat,
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
const BlobReader& reader, ThreadingContext& ctx) {
static const auto zone =
ctx.profiler.AddZone("Startup.Weights.ReadAllToBF16");
ctx.pools.Pool().Run(0, tensors.size(), [&](uint64_t task, size_t thread) {
PROFILER_ZONE3(ctx.profiler, thread, zone);
const TensorToRead& tensor = tensors[task];
MatPtr& mat = *tensor.mat;
const auto zone = GetProfilerZone(Zones::kStartupWeightsReadAllToBF16);
// Especially TSAN is slow enough to warrant hierarchical parallelism.
const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD
? ParallelismStrategy::kHierarchical
: ParallelismStrategy::kFlat;
ParallelFor(strategy, tensors.size(), ctx, /*cluster_idx=*/0,
[&](uint64_t task, size_t thread) {
PROFILER_ZONE3(ctx.profiler, thread, zone);
const TensorToRead& tensor = tensors[task];
MatPtr& mat = *tensor.mat;
if (tensor.keep_type) {
HWY_ASSERT(reader.file().Read(tensor.range.offset, tensor.range.bytes,
mat.Packed()));
return;
}
if (tensor.keep_type) {
HWY_ASSERT(reader.file().Read(
tensor.range.offset, tensor.range.bytes, mat.Packed()));
return;
}
// Read to a temporary buffer.
const hwy::AlignedFreeUniquePtr<uint8_t[]> buf =
hwy::AllocateAligned<uint8_t>(tensor.range.bytes);
HWY_ASSERT(
reader.file().Read(tensor.range.offset, tensor.range.bytes, buf.get()));
// Read to a temporary buffer.
const hwy::AlignedFreeUniquePtr<uint8_t[]> buf =
hwy::AllocateAligned<uint8_t>(tensor.range.bytes);
HWY_ASSERT(reader.file().Read(tensor.range.offset,
tensor.range.bytes, buf.get()));
if constexpr (GEMMA_ENABLE_NUQ) {
if (tensor.prev_type == Type::kNUQ) {
return DecompressToBF16<NuqStream>(*tensor.mat, buf);
}
}
switch (tensor.prev_type) {
case Type::kF32:
return DecompressToBF16<float>(*tensor.mat, buf);
case Type::kBF16:
return DecompressToBF16<BF16>(*tensor.mat, buf);
case Type::kSFP:
return DecompressToBF16<SfpStream>(*tensor.mat, buf);
default:
HWY_ABORT("Unsupported type %s", TypeName(tensor.prev_type));
}
});
if constexpr (GEMMA_ENABLE_NUQ) {
if (tensor.prev_type == Type::kNUQ) {
return DecompressToBF16<NuqStream>(*tensor.mat, buf);
}
}
switch (tensor.prev_type) {
case Type::kF32:
return DecompressToBF16<float>(*tensor.mat, buf);
case Type::kBF16:
return DecompressToBF16<BF16>(*tensor.mat, buf);
case Type::kSFP:
return DecompressToBF16<SfpStream>(*tensor.mat, buf);
default:
HWY_ABORT("Unsupported type %s",
TypeName(tensor.prev_type));
}
});
}
// Mode == kRead:
@ -424,8 +634,6 @@ static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
static std::vector<IOBatch> MakeBatches(
const std::vector<TensorToRead>& tensors, const uint64_t file_bytes) {
PROFILER_ZONE("Startup.Weights.MakeBatches");
// Batches must be contiguous but blobs are padded, hence at least one
// batch per tensor, and more when tensor rows exceed the batch size.
std::vector<IOBatch> batches;
batches.reserve(tensors.size());
@ -436,20 +644,28 @@ static std::vector<IOBatch> MakeBatches(
HWY_ASSERT(range.End() <= file_bytes);
batches.emplace_back(offset, range.key_idx);
const size_t file_bytes_per_row = mat.Cols() * mat.ElementBytes();
const size_t mem_stride_bytes = mat.Stride() * mat.ElementBytes();
uint8_t* row_bytes = mat.RowBytes(0);
for (size_t r = 0; r < mat.Rows(); ++r) {
if (!batches.back().Add(row_bytes, file_bytes_per_row)) { // Full batch.
batches.emplace_back(offset, range.key_idx);
// Adding to an empty batch is always successful.
HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row));
if (mat.IsPacked()) {
HWY_ASSERT(range.bytes == mat.PackedBytes());
if (!batches.back().Add(mat.Packed(), range.bytes)) {
// This should not happen if tensors are < 2GB.
// If it does, we need to chunk. For now, let's assume it doesn't.
HWY_ABORT("Packed tensor too large for a single IO batch.");
}
offset += range.bytes;
} else {
const size_t file_bytes_per_row = mat.Cols() * mat.ElementBytes();
const size_t mem_stride_bytes = mat.Stride() * mat.ElementBytes();
uint8_t* row_bytes = mat.RowBytes(0);
for (size_t r = 0; r < mat.Rows(); ++r) {
if (!batches.back().Add(row_bytes,
file_bytes_per_row)) { // Full batch.
batches.emplace_back(offset, range.key_idx);
// Adding to an empty batch is always successful.
HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row));
}
offset += file_bytes_per_row;
row_bytes += mem_stride_bytes;
}
offset += file_bytes_per_row;
// Must zero-initialize the in-memory row padding, see MatMul.
hwy::ZeroBytes(row_bytes + file_bytes_per_row,
mem_stride_bytes - file_bytes_per_row);
row_bytes += mem_stride_bytes;
}
HWY_ASSERT(offset == range.End());
}
@ -463,20 +679,22 @@ static std::vector<IOBatch> MakeBatches(
static void ReadBatches(const BlobReader& reader,
const std::vector<IOBatch>& batches,
ThreadingContext& ctx) {
static const auto zone = ctx.profiler.AddZone("Startup.Weights.ReadBatches");
const auto zone = GetProfilerZone(Zones::kStartupWeightsReadBatches);
// >5x speedup from parallel reads when cached.
ctx.pools.Pool().Run(0, batches.size(), [&](uint64_t i, size_t thread) {
PROFILER_ZONE3(ctx.profiler, thread, zone);
const IOBatch& batch = batches[i];
const std::string& key = reader.Keys()[batch.KeyIdx()];
const uint64_t bytes_read = batch.Read(reader.file());
if (bytes_read != batch.TotalBytes()) {
HWY_ABORT("Read failed for %s from %zu, %zu bytes; got %zu.", key.c_str(),
static_cast<size_t>(batch.Offset()),
static_cast<size_t>(batch.TotalBytes()),
static_cast<size_t>(bytes_read));
}
});
ParallelFor(ParallelismStrategy::kHierarchical,
batches.size(), ctx, /*cluster_idx=*/0,
[&](uint64_t task, size_t thread) {
PROFILER_ZONE3(ctx.profiler, thread, zone);
const IOBatch& batch = batches[task];
const std::string& key = reader.Keys()[batch.KeyIdx()];
const uint64_t bytes_read = batch.Read(reader.file());
if (bytes_read != batch.TotalBytes()) {
HWY_ABORT("Read failed for %s from %zu, %zu bytes; got %zu.",
key.c_str(), static_cast<size_t>(batch.Offset()),
static_cast<size_t>(batch.TotalBytes()),
static_cast<size_t>(bytes_read));
}
});
}
// Aborts on error. Updates `mode` to the actual mode used. Returns mapped

View File

@ -57,8 +57,7 @@ struct TensorArgs {
// the _w1/_w2 tensors are not always present.
kMaybeRead = 1,
// Avoid padding tensor rows when reading. Used for some Griffin tensors
// whose index computations do not use Row() accessors.
// Avoid padding tensor rows when reading.
kPacked = 2,
};
const int flags;
@ -102,17 +101,6 @@ struct LayerWeightsPtrs {
qkv_einsum_w1(finder_("qkv1_w")),
qkv_einsum_w2(finder_("qkv2_w")),
attention_output_biases(finder_("attn_ob")),
griffin({.linear_x_w = finder_("gr_lin_x_w"),
.linear_x_biases = finder_("gr_lin_x_b"),
.linear_y_w = finder_("gr_lin_y_w"),
.linear_y_biases = finder_("gr_lin_y_b"),
.linear_out_w = finder_("gr_lin_out_w"),
.linear_out_biases = finder_("gr_lin_out_b"),
.conv_w = finder_("gr_conv_w"),
.conv_biases = finder_("gr_conv_b"),
.gate_w = finder_("gr_gate_w"),
.gate_biases = finder_("gr_gate_b"),
.a = finder_("gr_a")}),
// MultiHeadDotProductAttention.
vit({.attn_out_w = finder_("attn_out_w"),
.attn_out_b = finder_("attn_out_b"),
@ -156,20 +144,6 @@ struct LayerWeightsPtrs {
MatPtr qkv_einsum_w2;
MatPtrT<float> attention_output_biases;
struct {
MatPtr linear_x_w;
MatPtrT<float> linear_x_biases;
MatPtr linear_y_w;
MatPtrT<float> linear_y_biases;
MatPtr linear_out_w;
MatPtrT<float> linear_out_biases;
MatPtrT<float> conv_w;
MatPtrT<float> conv_biases;
MatPtr gate_w;
MatPtrT<float> gate_biases;
MatPtrT<float> a;
} griffin;
struct {
// MultiHeadDotProductAttention.
MatPtr attn_out_w; // at least BF16.
@ -244,20 +218,6 @@ struct LayerWeightsPtrs {
func(TENSOR_ARGS(qkv_einsum_w, kMaybeRead));
func(TENSOR_ARGS(qkv_einsum_w1, kMaybeRead));
func(TENSOR_ARGS(qkv_einsum_w2, kMaybeRead));
} else {
func(TENSOR_ARGS(griffin.linear_x_w, kMustRead));
func(TENSOR_ARGS(griffin.linear_x_biases, kMustRead));
func(TENSOR_ARGS(griffin.linear_y_w, kMustRead));
func(TENSOR_ARGS(griffin.linear_y_biases, kMustRead));
func(TENSOR_ARGS(griffin.linear_out_w, kMustRead));
func(TENSOR_ARGS(griffin.linear_out_biases, kMustRead));
// conv_w and gate_w are not accessed via Row(), hence must not be padded.
// Note that *biases are 1D, hence packing/padding does not matter.
func(TENSOR_ARGS(griffin.conv_w, kMustRead | TensorArgs::kPacked));
func(TENSOR_ARGS(griffin.conv_biases, kMustRead));
func(TENSOR_ARGS(griffin.gate_w, kMustRead | TensorArgs::kPacked));
func(TENSOR_ARGS(griffin.gate_biases, kMustRead));
func(TENSOR_ARGS(griffin.a, kMustRead));
}
{
func(TENSOR_ARGS(gating_einsum_w, kMaybeRead));
@ -281,11 +241,6 @@ struct LayerWeightsPtrs {
func(TENSOR_ARGS(ffw_gating_biases, kMustRead));
func(TENSOR_ARGS(ffw_output_biases, kMustRead));
}
if (layer_config.softmax_attn_output_biases &&
layer_config.type == LayerAttentionType::kGemma) {
func(TENSOR_ARGS(attention_output_biases, kMustRead));
}
} // `ForEachTensor`
// Zero-initializes all allocated tensors in the layer.

View File

@ -28,7 +28,6 @@
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/timer.h"
namespace gcpp {
@ -104,27 +103,31 @@ BlobVec ReserveMemory(const RangeVec& ranges, BytePtr& all_blobs, size_t& pos) {
// Reads one set of blobs in parallel (helpful if in disk cache).
// Aborts on error.
void ReadBlobs(BlobReader& reader, const RangeVec& ranges, BlobVec& blobs,
hwy::ThreadPool& pool) {
ThreadingContext& ctx, size_t cluster_idx) {
HWY_ASSERT(reader.Keys().size() == blobs.size());
HWY_ASSERT(ranges.size() == blobs.size());
pool.Run(0, blobs.size(), [&](size_t i, size_t /*thread*/) {
HWY_ASSERT(ranges[i].bytes == blobs[i].size());
reader.file().Read(ranges[i].offset, ranges[i].bytes, blobs[i].data());
});
ParallelFor(ParallelismStrategy::kWithinCluster, blobs.size(), ctx,
cluster_idx, [&](size_t i, size_t /*thread*/) {
HWY_ASSERT(ranges[i].bytes == blobs[i].size());
reader.file().Read(ranges[i].offset, ranges[i].bytes,
blobs[i].data());
});
}
// Parallelizes ReadBlobs across (two) packages, if available.
void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2,
const RangeVec& ranges1, const RangeVec& ranges2,
size_t total_bytes, BlobVec& blobs1, BlobVec& blobs2,
NestedPools& pools) {
ThreadingContext& ctx) {
const double t0 = hwy::platform::Now();
HWY_WARN("Reading %zu GiB, %zux%zu cores: ", total_bytes >> 30,
pools.AllPackages().NumWorkers(), pools.Pool().NumWorkers());
pools.AllPackages().Run(0, 2, [&](size_t task, size_t pkg_idx) {
ReadBlobs(task ? reader2 : reader1, task ? ranges2 : ranges1,
task ? blobs2 : blobs1, pools.Pool(pkg_idx));
});
HWY_WARN("Reading %zu GiB, %zu clusters: ", total_bytes >> 30,
ctx.pools.NumClusters());
ParallelFor(ParallelismStrategy::kAcrossClusters, 2, ctx, 0,
[&](const size_t task, size_t cluster_idx) {
ReadBlobs(task ? reader1 : reader2, task ? ranges1 : ranges2,
task ? blobs1 : blobs2, ctx, cluster_idx);
});
const double t1 = hwy::platform::Now();
HWY_WARN("%.1f GB/s\n", total_bytes / (t1 - t0) * 1E-9);
}
@ -181,29 +184,23 @@ size_t BlobDifferences(const ByteSpan data1, const ByteSpan data2,
}
void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2,
size_t total_bytes, NestedPools& pools) {
size_t total_bytes, ThreadingContext& ctx) {
HWY_WARN("Comparing %zu blobs in parallel: ", keys.size());
const double t0 = hwy::platform::Now();
std::atomic<size_t> blobs_equal{};
std::atomic<size_t> blobs_diff{};
const IndexRangePartition ranges = StaticPartition(
IndexRange(0, keys.size()), pools.AllPackages().NumWorkers(), 1);
ParallelizeOneRange(
ranges, pools.AllPackages(),
[&](const IndexRange& range, size_t pkg_idx) {
pools.Pool(pkg_idx).Run(
range.begin(), range.end(), [&](size_t i, size_t /*thread*/) {
const size_t mismatches =
BlobDifferences(blobs1[i], blobs2[i], keys[i]);
if (mismatches != 0) {
HWY_WARN("key %s has %zu mismatches in %zu bytes!\n",
keys[i].c_str(), mismatches, blobs1[i].size());
blobs_diff.fetch_add(1);
} else {
blobs_equal.fetch_add(1);
}
});
});
ParallelFor(ParallelismStrategy::kHierarchical, keys.size(), ctx, 0,
[&](size_t i, size_t /*thread*/) {
const size_t mismatches =
BlobDifferences(blobs1[i], blobs2[i], keys[i]);
if (mismatches != 0) {
HWY_WARN("key %s has %zu mismatches in %zu bytes!\n",
keys[i].c_str(), mismatches, blobs1[i].size());
blobs_diff.fetch_add(1);
} else {
blobs_equal.fetch_add(1);
}
});
const double t1 = hwy::platform::Now();
HWY_WARN("%.1f GB/s; total blob matches=%zu, mismatches=%zu\n",
total_bytes / (t1 - t0) * 1E-9, blobs_equal.load(),
@ -230,9 +227,9 @@ void ReadAndCompareBlobs(const Path& path1, const Path& path2) {
ThreadingArgs args;
ThreadingContext ctx(args);
ReadBothBlobs(reader1, reader2, ranges1, ranges2, total_bytes, blobs1, blobs2,
ctx.pools);
ctx);
CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, ctx.pools);
CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, ctx);
}
} // namespace gcpp

View File

@ -106,7 +106,13 @@ class FilePosix : public File {
for (;;) {
// pread seems to be faster than lseek + read when parallelized.
const auto bytes_read = pread(fd_, bytes + pos, size - pos, offset + pos);
if (bytes_read <= 0) break;
if (bytes_read <= 0) {
HWY_WARN(
"Read failure at pos %zu within size %zu with offset %zu and "
"errno %d\n",
pos, size, offset, errno);
break;
}
pos += bytes_read;
HWY_ASSERT(pos <= size);
if (pos == size) break;
@ -120,7 +126,13 @@ class FilePosix : public File {
for (;;) {
const auto bytes_written =
pwrite(fd_, bytes + pos, size - pos, offset + pos);
if (bytes_written <= 0) break;
if (bytes_written <= 0) {
HWY_WARN(
"Write failure at pos %zu within size %zu with offset %zu and "
"errno %d\n",
pos, size, offset, errno);
break;
}
pos += bytes_written;
HWY_ASSERT(pos <= size);
if (pos == size) break;
@ -226,21 +238,26 @@ void InternalInit() {
}
uint64_t IOBatch::Read(const File& file) const {
#if GEMMA_IO_PREADV
HWY_ASSERT(!spans_.empty());
ssize_t bytes_read;
for (;;) {
bytes_read =
preadv(file.Handle(), reinterpret_cast<const iovec*>(spans_.data()),
static_cast<int>(spans_.size()), offset_);
if (bytes_read >= 0) break;
if (errno == EINTR) continue; // signal: retry
HWY_WARN("preadv failed, errno %d.", errno);
return 0;
#if GEMMA_IO_PREADV
if (file.Handle() != -1) {
ssize_t bytes_read;
for (;;) {
bytes_read =
preadv(file.Handle(), reinterpret_cast<const iovec*>(spans_.data()),
static_cast<int>(spans_.size()), offset_);
if (bytes_read >= 0) break;
if (errno == EINTR) continue; // signal: retry
HWY_WARN("preadv(%d) for %4zu spans from offset %12zu failed, errno %d.",
file.Handle(), spans_.size(), offset_, errno);
return 0;
}
return static_cast<uint64_t>(bytes_read);
}
return static_cast<uint64_t>(bytes_read);
#else
#endif // GEMMA_IO_PREADV
// preadv disabled or no handle: use normal reads (higher kernel overhead).
uint64_t total = 0;
uint64_t offset = offset_;
for (const IOSpan& span : spans_) {
@ -249,7 +266,6 @@ uint64_t IOBatch::Read(const File& file) const {
offset += span.bytes;
}
return total;
#endif
}
} // namespace gcpp

View File

@ -68,7 +68,7 @@ class File {
// modify internal state. This is only expected to be called once per file.
virtual MapPtr Map() = 0;
// For use by `IOBatch::Read`.
// Returns handle for use by `IOBatch::Read`, or -1 if not supported.
virtual int Handle() const { return -1; }
};

View File

@ -111,8 +111,8 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
// Ensure usage conditions are set before autotuning. Both binding and
// spinning may materially affect the choice of config. No harm in calling
// BindB/C if there is a single package: they will be a no-op.
BindB(b_trans, sizeof(TC), env.parallel);
BindC(C, env.parallel);
BindB(env.ctx, b_trans, sizeof(TC));
BindC(env.ctx, C);
C.AllocateAndAttachRowPtrs(env.row_ptrs);
Tristate use_spinning = Tristate::kDefault;
@ -160,10 +160,10 @@ void BenchAllMatMul() {
ctx.pools.PinString());
MatMulEnv env(ctx);
for (size_t batch_size : {1, 4, 128, 512}) {
for (size_t batch_size : {128, 512}) {
constexpr bool kAdd = false;
BenchMatMul<BF16, SFP, BF16>(batch_size, 24576, 3072, kAdd, env);
BenchMatMul<BF16, SFP, BF16>(batch_size, 3072, 24576, kAdd, env);
BenchMatMul<BF16, BF16, BF16>(batch_size, 24576, 3072, kAdd, env);
BenchMatMul<BF16, BF16, BF16>(batch_size, 3072, 24576, kAdd, env);
}
PROFILER_PRINT_RESULTS();

View File

@ -157,15 +157,16 @@ HWY_MAYBE_UNUSED double ConditionNumber(const VT* HWY_RESTRICT v, size_t num) {
// promoted or even DEMOTED to bf16. Runs at about half the speed of f32 FMA.
struct DotKernelDouble {
// Only `CompressTraits<float>` can `Decompress2` to `double`, so both have
// to be `float` in order to have `Raw = double`. Note that if either type is
// smaller than `float`, we may demote the other type from `float` to `BF16`.
// to be `float` in order to have `Raw = double`. To avoid loss of accuracy,
// if either is float, we decompress both to float, otherwise `BF16`.
template <typename VT, typename WT>
using Raw = hwy::If<IsF32<VT>() && IsF32<WT>(), double, BF16>;
using Raw = hwy::If<IsF32<VT>() && IsF32<WT>(), double,
hwy::If<IsF32<VT>() || IsF32<WT>(), float, BF16>>;
using State = double;
// Raw = double
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
HWY_INLINE void Update4(DRaw dd, const VR w0, const VR w1, const VR w2,
HWY_INLINE void Update4(DRaw dr, const VR w0, const VR w1, const VR w2,
const VR w3, const VR v0, const VR v1, const VR v2,
const VR v3, VR& sum0, VR& sum1, VR& sum2, VR& sum3,
VR&, VR&, VR&, VR&) const {
@ -175,6 +176,41 @@ struct DotKernelDouble {
sum3 = hn::MulAdd(w3, v3, sum3);
}
// Raw = float
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F32_D(DRaw),
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
HWY_INLINE void Update4(DRaw dr, const VR w0, const VR w1, const VR w2,
const VR w3, const VR v0, const VR v1, const VR v2,
const VR v3, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
VS&, VS&, VS&, VS&) const {
const hn::Repartition<double, DRaw> dd;
using VD = hn::Vec<decltype(dd)>;
VD w0d = hn::PromoteLowerTo(dd, w0);
VD w1d = hn::PromoteLowerTo(dd, w1);
VD w2d = hn::PromoteLowerTo(dd, w2);
VD w3d = hn::PromoteLowerTo(dd, w3);
VD v0d = hn::PromoteLowerTo(dd, v0);
VD v1d = hn::PromoteLowerTo(dd, v1);
VD v2d = hn::PromoteLowerTo(dd, v2);
VD v3d = hn::PromoteLowerTo(dd, v3);
sum0 = hn::MulAdd(w0d, v0d, sum0);
sum1 = hn::MulAdd(w1d, v1d, sum1);
sum2 = hn::MulAdd(w2d, v2d, sum2);
sum3 = hn::MulAdd(w3d, v3d, sum3);
w0d = hn::PromoteUpperTo(dd, w0);
w1d = hn::PromoteUpperTo(dd, w1);
w2d = hn::PromoteUpperTo(dd, w2);
w3d = hn::PromoteUpperTo(dd, w3);
v0d = hn::PromoteUpperTo(dd, v0);
v1d = hn::PromoteUpperTo(dd, v1);
v2d = hn::PromoteUpperTo(dd, v2);
v3d = hn::PromoteUpperTo(dd, v3);
sum0 = hn::MulAdd(w0d, v0d, sum0);
sum1 = hn::MulAdd(w1d, v1d, sum1);
sum2 = hn::MulAdd(w2d, v2d, sum2);
sum3 = hn::MulAdd(w3d, v3d, sum3);
}
// Raw = BF16
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
@ -217,11 +253,26 @@ struct DotKernelDouble {
// Raw = double
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
HWY_INLINE void Update1(DRaw dd, const VR w0, const VR v0, VR& sum0,
HWY_INLINE void Update1(DRaw dr, const VR w0, const VR v0, VR& sum0,
VR&) const {
sum0 = hn::MulAdd(w0, v0, sum0);
}
// Raw = float
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F32_D(DRaw),
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
HWY_INLINE void Update1(DRaw dr, const VR w0, const VR v0, VS& sum0,
VS&) const {
const hn::Repartition<double, DRaw> dd;
using VD = hn::Vec<decltype(dd)>;
VD w0d = hn::PromoteLowerTo(dd, w0);
VD v0d = hn::PromoteLowerTo(dd, v0);
sum0 = hn::MulAdd(w0d, v0d, sum0);
w0d = hn::PromoteUpperTo(dd, w0);
v0d = hn::PromoteUpperTo(dd, v0);
sum0 = hn::MulAdd(w0d, v0d, sum0);
}
// Raw = BF16
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>

View File

@ -750,7 +750,7 @@ class DotStats {
void CheckMuls() const {
// Comp2 is between Compensated and Kahan.
ASSERT_INSIDE(kComp2, 1.001, s_muls[kComp2].Mean(), 1.4);
ASSERT_INSIDE(kComp2, 1.001f, s_muls[kComp2].Max(), 2.4f);
ASSERT_INSIDE(kComp2, 1.001f, s_muls[kComp2].Max(), 6.8f);
ASSERT_INSIDE(kComp2, 1.0, s_muls[kComp2].GeometricMean(), 1.2);
// Compensated and Double are very accurate.
@ -812,7 +812,7 @@ class DotStats {
// Forward relative error, lower is better.
void CheckRel() const {
ASSERT_INSIDE(kComp2, 2E-4, s_rels[kComp2].GeometricMean(), 4E-3);
ASSERT_INSIDE(kComp2, 2E-4, s_rels[kComp2].GeometricMean(), 7E-3);
ASSERT_INSIDE(kComp2, 1E-5f, s_rels[kComp2].Max(), 1.23f);
// Compensated and Double are very accurate.
@ -822,22 +822,22 @@ class DotStats {
ASSERT_LESS(kDouble, s_rels[kDouble].Max(), 8E-6f);
// Naive and OnlyTwoProd are considerably higher, but not huge.
ASSERT_INSIDE(kNaive, 1E-3, s_rels[kNaive].GeometricMean(), 8E-2);
ASSERT_INSIDE(kNaive, 1E-3, s_rels[kNaive].GeometricMean(), 3.5E-1);
ASSERT_INSIDE(kOnlyTwoProd, 1E-3, s_rels[kOnlyTwoProd].GeometricMean(),
0.072);
7.5E-2);
// Kahan (FastTwoSum) is decent:
ASSERT_INSIDE(kKahan, 3E-4, s_rels[kKahan].GeometricMean(), 3.5E-3);
ASSERT_INSIDE(kKahan, 3E-4, s_rels[kKahan].GeometricMean(), 1E-2);
ASSERT_INSIDE(kKahan, 6E-4f, s_rels[kKahan].Max(), 0.7f);
// TwoProducts and TwoSums are a bit better.
ASSERT_INSIDE(kAddTwoProd, 2.2E-4, s_rels[kAddTwoProd].GeometricMean(),
3E-3);
ASSERT_INSIDE(kAddTwoProd, 4E-4f, s_rels[kAddTwoProd].Max(), 0.19f);
1.1E-2);
ASSERT_INSIDE(kAddTwoProd, 4E-4f, s_rels[kAddTwoProd].Max(), 1.0f);
ASSERT_INSIDE(kAddTwoSum, 1.5E-4, s_rels[kAddTwoSum].GeometricMean(),
2.6E-3);
1.1E-2);
ASSERT_INSIDE(kPairwise, 4.5E-4, s_rels[kPairwise].GeometricMean(), 1.5E-2);
ASSERT_INSIDE(kPairwise, 4.5E-4, s_rels[kPairwise].GeometricMean(), 5.2E-2);
// Extremely high error on aarch64.
ASSERT_INSIDE(kPairwise, 1.1E-3f, s_rels[kPairwise].Max(), 2E3f);
}
@ -851,13 +851,13 @@ class DotStats {
ASSERT_LESS(kDouble, s_rels[kDouble].Max(), 8E-6f);
// Naive and OnlyTwoProd are considerably higher than others
ASSERT_INSIDE(kNaive, 1.5E-8f, s_rels[kNaive].Max(), 3080.f);
ASSERT_INSIDE(kOnlyTwoProd, 1.5E-8f, s_rels[kNaive].Max(), 3080.f);
ASSERT_INSIDE(kNaive, 1.5E-8f, s_rels[kNaive].Max(), 1.4E4f);
ASSERT_INSIDE(kOnlyTwoProd, 1.5E-8f, s_rels[kNaive].Max(), 1.4E4f);
// Kahan (FastTwoSum) is not much better here!
ASSERT_INSIDE(kKahan, 6E-10f, s_rels[kKahan].Max(), 0.7f);
// But TwoProducts/TwoSums help a bit.
ASSERT_INSIDE(kAddTwoProd, 9E-10f, s_rels[kAddTwoProd].Max(), 0.19f);
ASSERT_INSIDE(kAddTwoProd, 9E-10f, s_rels[kAddTwoProd].Max(), 1.0f);
ASSERT_INSIDE(kAddTwoSum, 5E-10f, s_rels[kAddTwoSum].Max(), 0.34f);
// Extremely high error on aarch64.
@ -893,7 +893,7 @@ class DotStats {
};
// Returns normalized value in [-1, 1).
float RandomFloat(std::mt19937& rng) {
float RandomFloat(RngStream& rng) {
const uint32_t exp = hwy::BitCastScalar<uint32_t>(1.0f);
const uint32_t mantissa_mask = hwy::MantissaMask<float>();
const uint32_t representation = exp | (rng() & mantissa_mask);
@ -908,7 +908,7 @@ float RandomFloat(std::mt19937& rng) {
// error from the Dot algorithms, not the compression.
template <typename Packed>
void GenerateWellConditionedInputs(const size_t num, float* HWY_RESTRICT raw,
std::mt19937& rng,
RngStream& rng,
const PackedSpan<Packed>& packed,
CompressWorkingSet& work) {
std::uniform_int_distribution<int> e_dist(0, 6);
@ -934,7 +934,7 @@ void GenerateWellConditionedInputs(const size_t num, float* HWY_RESTRICT raw,
// Sum and Dot Product". `num` is the (arbitrary) size of w, v, and buf.
template <typename WT, typename VT>
double GenerateIllConditionedInputs(const size_t num, WT* w, VT* HWY_RESTRICT v,
std::mt19937& rng) {
RngStream& rng) {
PROFILER_FUNC;
const size_t half = HWY_MAX(1, num / 2); // generate at least one random
HWY_DASSERT(half != 0);
@ -1002,8 +1002,8 @@ struct TestShortDotsT {
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
CompressWorkingSet work;
std::mt19937 rng;
rng.seed(12345);
AesCtrEngine engine(/*deterministic=*/true);
RngStream rng(engine, 0);
hwy::Stats s_l1[kVariants];
@ -1101,7 +1101,6 @@ void TestAllDot() {
// Limit workers because we only support `kMaxWorkers`.
ThreadingArgs threading_args;
threading_args.max_packages = 1;
threading_args.max_clusters = 1;
threading_args.max_lps = kMaxWorkers - 1;
ThreadingContext ctx(threading_args);
@ -1109,9 +1108,10 @@ void TestAllDot() {
{ // ensure no profiler zones are active
const hn::ScalableTag<float> df;
std::mt19937 rngs[kMaxWorkers];
AesCtrEngine engine(/*deterministic=*/true);
RngStream rngs[kMaxWorkers];
for (size_t i = 0; i < kMaxWorkers; ++i) {
rngs[i].seed(12345 + 65537 * i);
rngs[i] = RngStream(engine, i);
}
constexpr size_t kReps = hn::AdjustedReps(40);
@ -1124,8 +1124,9 @@ void TestAllDot() {
MatPadding::kOdd);
std::array<DotStats, kMaxWorkers> all_stats;
ctx.pools.Cluster(0, 0).Run(
0, kReps, [&](const uint32_t rep, size_t thread) {
ParallelFor(
ParallelismStrategy::kWithinCluster, kReps, ctx, 0,
[&](size_t rep, size_t thread) {
float* HWY_RESTRICT pa = a.Row(thread);
float* HWY_RESTRICT pb = b.Row(thread);
double* HWY_RESTRICT buf = bufs.Row(thread);

View File

@ -1,192 +0,0 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "compression/types.h"
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
#include <stddef.h>
#include <stdio.h>
#include <algorithm> // std::max
#include <cmath> // std::abs
#include <memory>
#include "util/mat.h"
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "ops/gemma_matvec_test.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
#include "ops/matvec-inl.h"
#include "hwy/tests/test_util-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
FloatPtr SimpleMatVecAdd(const MatStorageT<float>& mat, const FloatPtr& vec,
const FloatPtr& add) {
const size_t num = mat.Rows() * mat.Cols();
FloatPtr raw_mat = hwy::AllocateAligned<float>(num);
FloatPtr out = hwy::AllocateAligned<float>(mat.Rows());
HWY_ASSERT(raw_mat && out);
const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, mat.Span(), 0, raw_mat.get(), num);
for (size_t idx_row = 0; idx_row < mat.Rows(); idx_row++) {
out[idx_row] = 0.0f;
for (size_t idx_col = 0; idx_col < mat.Cols(); idx_col++) {
out[idx_row] += raw_mat[mat.Cols() * idx_row + idx_col] * vec[idx_col];
}
out[idx_row] *= mat.Scale();
out[idx_row] += add[idx_row];
}
return out;
}
template <typename MatT, size_t kOuter, size_t kInner>
std::unique_ptr<MatStorageT<float>> GenerateMat(size_t offset,
const Allocator& allocator,
hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
const Extents2D extents(kOuter, kInner);
auto mat = std::make_unique<MatStorageT<float>>("TestMat", extents, allocator,
MatPadding::kPacked);
FloatPtr raw_mat = hwy::AllocateAligned<float>(extents.Area());
HWY_ASSERT(raw_mat);
const float scale = 1.0f / kInner;
pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) {
for (size_t j = 0; j < kInner; j++) {
raw_mat[i * kInner + j] =
static_cast<float>((i * kInner + j + offset) * scale);
}
});
Compress(raw_mat.get(), extents.Area(), ws, mat->Span(), 0, pool);
mat->SetScale(1.9f); // Arbitrary value, different from 1.
return mat;
}
template <size_t length>
FloatPtr GenerateVec(size_t offset) {
FloatPtr vec = hwy::AllocateAligned<float>(length);
HWY_ASSERT(vec);
for (size_t idx = 0; idx < length; idx++) {
vec[idx] = static_cast<float>(idx + offset);
}
return vec;
}
template <size_t length>
void AssertClose(const FloatPtr& a, const FloatPtr& b) {
for (size_t idx = 0; idx < length; idx++) {
const float rel_abs_delta = std::abs(a[idx] - b[idx]) /
std::max(std::abs(a[idx]), std::abs(b[idx]));
EXPECT_LT(rel_abs_delta, 2e-6)
<< "a[" << idx << "]=" << a[idx] << ", b[" << idx << "]=" << b[idx];
}
}
void TestMatVecAdd() {
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
hwy::ThreadPool& pool = ctx.pools.Pool();
constexpr size_t kOuter = 128 * 3;
constexpr size_t kInner = 128 * 5;
auto mat = GenerateMat<float, kOuter, kInner>(0, ctx.allocator, pool);
FloatPtr vec = GenerateVec<kInner>(0);
FloatPtr add = GenerateVec<kOuter>(0);
FloatPtr expected_out = SimpleMatVecAdd(*mat, vec, add);
FloatPtr actual_out = hwy::AllocateAligned<float>(kOuter);
HWY_ASSERT(vec && add && expected_out && actual_out);
MatVecAdd(*mat, 0, kOuter, kInner, vec.get(), add.get(), actual_out.get(),
pool);
AssertClose<kOuter>(actual_out, expected_out);
}
void TestTwoMatVecAdd() {
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
hwy::ThreadPool& pool = ctx.pools.Pool();
constexpr size_t kOuter = 128 * 3;
constexpr size_t kInner = 128 * 5;
auto mat0 = GenerateMat<float, kOuter, kInner>(0, ctx.allocator, pool);
auto mat1 = GenerateMat<float, kOuter, kInner>(1, ctx.allocator, pool);
FloatPtr vec = GenerateVec<kInner>(0);
FloatPtr add0 = GenerateVec<kOuter>(0);
FloatPtr add1 = GenerateVec<kOuter>(1);
FloatPtr expected_out0 = SimpleMatVecAdd(*mat0, vec, add0);
FloatPtr expected_out1 = SimpleMatVecAdd(*mat1, vec, add1);
FloatPtr actual_out0 = hwy::AllocateAligned<float>(kOuter);
FloatPtr actual_out1 = hwy::AllocateAligned<float>(kOuter);
HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 &&
expected_out1 && actual_out1);
TwoMatVecAdd(*mat0, *mat1, 0, kOuter, kInner, vec.get(), add0.get(),
add1.get(), actual_out0.get(), actual_out1.get(), pool);
AssertClose<kOuter>(actual_out0, expected_out0);
AssertClose<kOuter>(actual_out1, expected_out1);
}
void TestTwoOfsMatVecAddLoop() {
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
hwy::ThreadPool& pool = ctx.pools.Pool();
constexpr size_t kOuter = 128 * 3;
constexpr size_t kInner = 128 * 5;
auto mat = GenerateMat<float, kOuter, kInner>(0, ctx.allocator, pool);
FloatPtr vec = GenerateVec<kInner>(0);
FloatPtr add0 = GenerateVec<kOuter>(0);
FloatPtr add1 = GenerateVec<kOuter>(1);
FloatPtr expected_out0 = SimpleMatVecAdd(*mat, vec, add0);
FloatPtr expected_out1 = SimpleMatVecAdd(*mat, vec, add1);
FloatPtr actual_out0 = hwy::AllocateAligned<float>(kOuter);
FloatPtr actual_out1 = hwy::AllocateAligned<float>(kOuter);
HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 &&
expected_out1 && actual_out1);
TwoOfsMatVecAddLoop(*mat, 0, 0, kOuter, kInner, vec.get(), add0.get(),
add1.get(), actual_out0.get(), actual_out1.get());
AssertClose<kOuter>(actual_out0, expected_out0);
AssertClose<kOuter>(actual_out1, expected_out1);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_BEFORE_TEST(MatVecTest);
HWY_EXPORT_AND_TEST_P(MatVecTest, TestMatVecAdd);
HWY_EXPORT_AND_TEST_P(MatVecTest, TestTwoMatVecAdd);
HWY_EXPORT_AND_TEST_P(MatVecTest, TestTwoOfsMatVecAddLoop);
HWY_AFTER_TEST();
} // namespace gcpp
#endif

File diff suppressed because it is too large Load Diff

View File

@ -21,7 +21,6 @@
#include <stdint.h>
#include <stdio.h>
#include <atomic>
#include <vector>
#include "util/allocator.h"
@ -63,23 +62,20 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
// and holds most of their arguments in member variables.
class GenerateCandidates {
public:
GenerateCandidates(const Allocator& allocator, size_t M, size_t K, size_t N,
size_t sizeof_TC, size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np, bool print_config)
: allocator_(allocator),
GenerateCandidates(const CacheInfo& cache, size_t M, size_t K, size_t N,
size_t num_B, size_t sizeof_TC, bool print_config)
: cache_(cache),
M_(M),
K_(K),
N_(N),
num_B_(num_B),
sizeof_TC_(sizeof_TC),
max_mr_(max_mr),
nr_(nr),
// These influence kc/nc, but are also stored in `MMConfig` for
// `RangesOf*`. Must be a vector multiple. The previous/next cache line
// is likely still in L1, but we expect K > 1000 and might as well round
// up to the line size.
kc_multiple_(HWY_MIN(K, allocator.LineBytes() / sizeof(BF16))),
nc_multiple_(allocator.StepBytes() / sizeof_TC),
ranges_np_(ranges_np),
// up to the line size. Both A and B are BF16.
kc_multiple_(HWY_MIN(K, cache.LineBytes() / sizeof(BF16))),
nc_multiple_(cache.StepBytes() / sizeof_TC),
print_config_(print_config) {}
std::vector<MMConfig> operator()() const {
@ -89,24 +85,21 @@ class GenerateCandidates {
for (size_t mr : MR()) {
for (MMOrder order : Orders(mr)) {
const std::vector<int>& all_inner_tasks = InnerTasks(order);
const std::vector<MMOut>& all_outs = Outs(order);
for (size_t kc : KC(mr, order)) {
for (size_t mc : MC(mr, kc, order)) {
for (size_t nc : NC(mr, mc, kc, order)) {
for (int inner_tasks : all_inner_tasks) {
for (MMOut out : all_outs) {
const MMConfig config(K_, N_, mr, mc, kc, nc, kc_multiple_,
nc_multiple_, order, out, inner_tasks);
const size_t M_tasks = config.RangesOfMC(M_).NumTasks();
const size_t K_tasks = config.RangesOfKC(K_).NumTasks();
const MMConfig config(K_, N_, mr, mc, kc, nc, kc_multiple_,
nc_multiple_, order, inner_tasks);
const size_t M_tasks = config.RangesOfMC(M_).NumTasks();
const size_t K_tasks = config.RangesOfKC(K_).NumTasks();
// Blocks only make sense when there are multiple M tasks.
if (IsBlock(order) != (M_tasks > 1)) continue;
// Single KC only makes sense when there is a single K task.
if (IsOneKC(order) != (K_tasks == 1)) continue;
// Blocks only make sense when there are multiple M tasks.
if (IsBlock(order) != (M_tasks > 1)) continue;
// Single KC only makes sense when there is a single K task.
if (IsOneKC(order) != (K_tasks == 1)) continue;
candidates.push_back(config);
}
candidates.push_back(config);
}
}
}
@ -132,10 +125,10 @@ class GenerateCandidates {
SizeVec all_mr;
all_mr.reserve(3);
// AVX2's 16 registers are not enough for four rows, but SSE4 may benefit.
if (M_ >= max_mr_ && !is_avx2) all_mr.push_back(max_mr_);
if (M_ >= kMaxMR && !is_avx2) all_mr.push_back(kMaxMR);
// Allow for AVX-512 but not SSE4 (for which 4 are usually better). Also
// enable if not enough rows for 4.
if (M_ >= 2 && (M_ < max_mr_ || (!is_sse && !is_wasm))) {
if (M_ >= 2 && (M_ < kMaxMR || (!is_sse && !is_wasm))) {
all_mr.push_back(size_t{2});
}
// Even SSE4 usually prefers 2 rows; only enable for single rows.
@ -158,7 +151,7 @@ class GenerateCandidates {
}
}
// The number of A and B columns to read between updating `partial`.
// The number of A and B columns to read between updating `C`.
SizeVec KC(size_t mr, MMOrder order) const {
// `LoopKC` handles up to `mr` rows of A.
const size_t rows_a = HWY_MIN(M_, mr);
@ -172,22 +165,22 @@ class GenerateCandidates {
// TB=NUQ due to less amortization of the table loads. Due to the low L1
// latency, the packing is still effectively fused into `LoopKC`. It may
// be better to round up and accept a few L2 accesses in exchange for
// fewer loops over K, and thus fewer writes to `partial`. Hence we do not
// fewer loops over K, and thus fewer writes to `C`. Hence we do not
// subtract the output and buf, and allow using more than the actual L1
// size. This results in an overestimate, and the loop below will propose
// the next few smaller values for the autotuner to evaluate.
const size_t bytes_ab = allocator_.L1Bytes() * 3;
const size_t col_bytes = rows_a * sizeof(BF16) + nr_ * sizeof(BF16);
const size_t bytes_ab =
cache_.L1Bytes() * (sizeof(BF16) + sizeof(SfpStream));
const size_t col_bytes = rows_a * sizeof(BF16) + kNR * sizeof(BF16);
size_t kc_max = hwy::DivCeil(bytes_ab, col_bytes);
kc_max =
RoundDownWithFloor(HWY_MIN(kc_max, MMStorage::kMaxKC), kc_multiple_);
kc_max = RoundDownWithFloor(HWY_MIN(kc_max, kMaxKC), kc_multiple_);
kc_max = HWY_MIN(kc_max, K_);
SizeVec all_kc(1, kc_max);
// Avoid proposing kc > K.
if (K_ > kc_multiple_) {
// Generally it is best to use the full `kc` (fewer writes to `partial`),
// Generally it is best to use the full `kc` (fewer writes to `C`),
// but a bit less can be better if it evenly divides `K`, or enables an
// `mc` that evenly divides `M`. Try several smaller values.
@ -204,7 +197,7 @@ class GenerateCandidates {
}
if (print_config_ && all_kc.size() > 1) {
fprintf(stderr, "KC: ");
fprintf(stderr, "num_B %zu: KC: ", num_B_);
for (size_t kc : all_kc) {
fprintf(stderr, "%zu ", kc);
}
@ -218,22 +211,22 @@ class GenerateCandidates {
SizeVec MC(size_t mr, size_t kc, MMOrder order) const {
// Typically 12-24K. The B rows are pinned in L1, but also occupy L2 because
// it is typically inclusive.
const size_t bytes_b = nr_ * kc * (sizeof(SfpStream) + sizeof(BF16));
const size_t bytes_b = kNR * kc * (sizeof(SfpStream) + sizeof(BF16));
// Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the
// packed B. We want `mc * kc` elements of A to fit in L2, alongside
// `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of
// partial.
const size_t bytes_per_mc = kc * sizeof(BF16) + allocator_.LineBytes();
size_t mc_max = hwy::DivCeil(allocator_.L2Bytes() - bytes_b, bytes_per_mc);
mc_max = HWY_MIN(mc_max, MMStorage::kMaxM);
// `bytes_b` plus `mc` cache lines because resident-A updates `mc` C rows.
const size_t bytes_per_mc = kc * sizeof(BF16) + cache_.LineBytes();
size_t mc_max = hwy::DivCeil(cache_.L2Bytes() - bytes_b, bytes_per_mc);
mc_max = HWY_MIN(mc_max, HWY_MIN(kMaxBatchSize, kMaxMC));
HWY_DASSERT(mc_max != 0);
mc_max = HWY_MIN(mc_max, M_);
mc_max = hwy::RoundDownTo(mc_max, mr);
SizeVec all_mc(1, mc_max);
// Larger MC is better for non-blocks, otherwise we want more small options.
const size_t reps = !IsBlock(order) ? 2 : 3;
// Larger MC is better for non-blocks, otherwise we want more small options,
// especially for two B.
const size_t reps = !IsBlock(order) ? 2 : (2 + num_B_);
size_t prev = mc_max;
for (size_t rep = 0; rep < reps; ++rep) {
@ -248,7 +241,7 @@ class GenerateCandidates {
}
if (print_config_ && all_mc.size() > 1) {
fprintf(stderr, "MC: ");
fprintf(stderr, "num_B %zu: MC: ", num_B_);
for (size_t mc : all_mc) {
fprintf(stderr, "%zu ", mc);
}
@ -260,43 +253,40 @@ class GenerateCandidates {
// The number of (possibly L3 resident) B rows per `NT_MT` task.
SizeVec NC(size_t mr, size_t mc, size_t kc, MMOrder order) const {
const size_t np_max = ranges_np_.TaskSize();
size_t nc_max = np_max;
const size_t out_bytes = IsOneKC(order) ? sizeof_TC_ : sizeof(double);
size_t nc_max = kMaxNC;
// Only if there will be reuse of B: choose the largest `nc_max` (C cols)
// such that `nc x kc` of B and `mc x nc` of `partial` or `C` fit in L3.
// Otherwise, leave it unbounded.
// such that `nc x kc` of B and `mc x nc` of `C` fit in L3. Otherwise,
// leave it unbounded.
if (M_ > mr) {
const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * out_bytes);
nc_max = hwy::DivCeil(allocator_.L3Bytes(), bytes_per_nc);
nc_max = HWY_MIN(HWY_MIN(nc_max, MMStorage::kMaxN), np_max);
const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * sizeof_TC_);
nc_max = HWY_MIN(hwy::DivCeil(cache_.L3Bytes(), bytes_per_nc), kMaxNC);
}
nc_max = HWY_MIN(nc_max, N_);
HWY_DASSERT(nc_max != 0);
nc_max = RoundDownWithFloor(nc_max, nc_multiple_);
// If there are going to be multiple ranges, anything more than half would
// be imbalanced and suboptimal.
if (nc_max < np_max && nc_max >= np_max / 2) {
nc_max = RoundDownWithFloor(np_max / 2, nc_multiple_);
if (nc_max < N_ && nc_max >= N_ / 2) {
nc_max = RoundDownWithFloor(N_ / 2, nc_multiple_);
}
// Non-block calls ForNP, which ignores `range_nc` and uses `range_np`.
if (!IsBlock(order)) return SizeVec(1, np_max);
if (!IsBlock(order)) return SizeVec(1, N_);
SizeVec all_nc(1, nc_max);
// Avoid proposing nc > N.
if (np_max > nc_multiple_) {
if (N_ > nc_multiple_) {
// Large L3, but its behavior and characteristics varies across platforms,
// hence autotune a wider range of nc than the other dimensions.
size_t reps = 10;
size_t reps = 9 + num_B_;
// For small M, we can afford larger NC, hence allow fewer small options.
if (M_ <= 2 * mr) reps -= 1;
size_t prev = nc_max;
for (size_t rep = 0; rep < reps; ++rep) {
const size_t div =
PrevDivisor(nc_multiple_, prev, np_max, nc_multiple_);
const size_t div = PrevDivisor(nc_multiple_, prev, N_, nc_multiple_);
prev = div ? div : RoundDownWithFloor(prev / 2, nc_multiple_);
all_nc.push_back(prev);
if (prev == nc_multiple_) break;
@ -313,7 +303,7 @@ class GenerateCandidates {
}
if (print_config_ && all_nc.size() > 1) {
fprintf(stderr, "NC: ");
fprintf(stderr, "num_B %zu: NC: ", num_B_);
for (size_t nc : all_nc) {
fprintf(stderr, "%zu ", nc);
}
@ -337,152 +327,80 @@ class GenerateCandidates {
return inner_tasks;
}
// Whether to parallelize FillC or enable direct writes to C.
std::vector<MMOut> Outs(MMOrder order) const {
std::vector<MMOut> outs;
for (size_t out_idx = 0;; ++out_idx) {
const MMOut out = static_cast<MMOut>(out_idx);
if (StringFromOut(out) == nullptr) return outs; // done
// kParM only makes sense if we have more than one row of A.
if (out == MMOut::kParM && M_ == 1) continue;
// Blocks are already parallelized.
if (out == MMOut::kParM && IsBlock(order)) continue;
// Direct only works for a single kc range.
if ((out == MMOut::kDirect) != IsOneKC(order)) continue;
// For non-block, kCopy does not beat kDirect.
if (out == MMOut::kCopy && IsOneKC(order) && !IsBlock(order)) continue;
outs.push_back(out);
}
}
const Allocator& allocator_;
const CacheInfo& cache_;
const size_t M_;
const size_t K_;
const size_t N_;
const size_t num_B_;
const size_t sizeof_TC_;
const size_t max_mr_;
const size_t nr_;
const size_t kc_multiple_;
const size_t nc_multiple_;
IndexRangePartition ranges_np_;
const bool print_config_;
};
} // namespace
// Facade to avoid exposing `GenerateCandidates` in the header.
std::vector<MMConfig> MMCandidates(const Allocator& allocator, size_t M,
size_t K, size_t N, size_t sizeof_TC,
size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np,
std::vector<MMConfig> MMCandidates(const CacheInfo& cache, size_t M, size_t K,
size_t N, size_t num_B, size_t sizeof_TC,
bool print_config) {
return GenerateCandidates(allocator, M, K, N, sizeof_TC, max_mr, nr,
ranges_np, print_config)();
}
// Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote
// memory accesses or false sharing, unless there are insufficient per-package
// rows for that.
static size_t NPMultiple(const Allocator& allocator, size_t N,
size_t sizeof_TC, size_t nr, size_t num_packages) {
size_t np_multiple = allocator.BasePageBytes() / sizeof_TC;
// If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For
// `N` < 4096, this can cause significant load imbalance. If split unevenly,
// choose a smaller multiple.
if (N % (np_multiple * num_packages)) {
const size_t min_multiple = allocator.LineBytes() / sizeof_TC;
np_multiple =
PrevDivisor(min_multiple, np_multiple, N / num_packages, min_multiple);
if (HWY_UNLIKELY(np_multiple == 0)) {
np_multiple = min_multiple;
}
// This happens in tests with small N, hence do not assert.
if (N % (np_multiple * num_packages) && N >= 128) {
static std::atomic_flag warned = ATOMIC_FLAG_INIT;
if (!warned.test_and_set()) {
HWY_WARN(
"NPMultiple: N=%zu still not divisible by np_multiple=%zu * "
"num_packages=%zu\n",
N, np_multiple, num_packages);
}
np_multiple = nr;
}
}
return np_multiple;
}
IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N,
size_t sizeof_TC, size_t nr) const {
const size_t num_packages = HWY_MIN(max_packages, ctx_.pools.NumPackages());
return StaticPartition(
IndexRange(0, N), num_packages,
NPMultiple(ctx_.allocator, N, sizeof_TC, nr, num_packages));
return GenerateCandidates(cache, M, K, N, num_B, sizeof_TC, print_config)();
}
MatMulEnv::MatMulEnv(ThreadingContext& ctx)
: ctx(ctx), parallel(ctx), storage(ctx.allocator, parallel) {
: ctx(ctx), A_BF(ctx.allocator), C_tiles(ctx) {
const size_t num_clusters = ctx.pools.NumClusters();
per_cluster.resize(num_clusters);
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(kMaxBatchSize)); // C
}
char cpu100[100];
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxM)); // A
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxN)); // B
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxM)); // C
}
void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
Allocator& allocator = parallel.allocator();
void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) {
Allocator& allocator = ctx.allocator;
if (!allocator.ShouldBind()) return;
if (B.Rows() == 1) return;
PROFILER_ZONE("Startup.BindB");
const IndexRangePartition ranges_np =
parallel.RangesOfNP(kMaxPackages, B.Rows(), sizeof_TC, kNR);
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
const size_t node = parallel.Node(pkg_idx);
uintptr_t begin = reinterpret_cast<uintptr_t>(B.RowBytes(rows_b.begin()));
uintptr_t end = begin + rows_b.Num() * B.Stride() * B.ElementBytes();
// B row padding is less than the page size, so only bind the subset that
// is page-aligned.
begin = hwy::RoundUpTo(begin, allocator.BasePageBytes());
end = hwy::RoundDownTo(end, allocator.BasePageBytes());
if (HWY_LIKELY(begin != end)) {
allocator.BindMemory(reinterpret_cast<void*>(begin), end - begin, node);
}
const size_t node = ctx.topology.GetCluster(0).Node();
uintptr_t begin = reinterpret_cast<uintptr_t>(B.RowBytes(0));
uintptr_t end = begin + B.Rows() * B.Stride() * B.ElementBytes();
// B row padding is less than the page size, so only bind the subset that
// is page-aligned.
begin = hwy::RoundUpTo(begin, allocator.BasePageBytes());
end = hwy::RoundDownTo(end, allocator.BasePageBytes());
if (HWY_LIKELY(begin != end)) {
allocator.BindMemory(reinterpret_cast<void*>(begin), end - begin, node);
}
}
// C is BF16/float, or double for partial
void BindC(MatPtr& C, MMParallel& parallel) {
Allocator& allocator = parallel.allocator();
// C is BF16/float
void BindC(ThreadingContext& ctx, MatPtr& C) {
Allocator& allocator = ctx.allocator;
if (!allocator.ShouldBind()) return;
PROFILER_ZONE("Startup.BindC");
const IndexRangePartition ranges_np =
parallel.RangesOfNP(kMaxPackages, C.Cols(), C.ElementBytes(), kNR);
bool ok = true;
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
const IndexRange& cols_c = ranges_np.Range(pkg_idx);
// `BindMemory` requires page alignment. These are in bytes.
const size_t begin = hwy::RoundUpTo(cols_c.begin() * C.ElementBytes(),
allocator.BasePageBytes());
const size_t end = hwy::RoundDownTo(cols_c.end() * C.ElementBytes(),
allocator.BasePageBytes());
const IndexRange cols_c(0, C.Cols());
// `BindMemory` requires page alignment. These are in bytes.
const size_t begin = hwy::RoundUpTo(cols_c.begin() * C.ElementBytes(),
allocator.BasePageBytes());
const size_t end = hwy::RoundDownTo(cols_c.end() * C.ElementBytes(),
allocator.BasePageBytes());
const size_t node = parallel.Node(pkg_idx);
for (size_t im = 0; im < C.Rows(); ++im) {
ok &= allocator.BindMemory(C.RowBytes(im) + begin, end - begin, node);
}
const size_t node = ctx.topology.GetCluster(0).Node();
bool ok = true;
for (size_t im = 0; im < C.Rows(); ++im) {
ok &= allocator.BindMemory(C.RowBytes(im) + begin, end - begin, node);
}
if (HWY_UNLIKELY(!ok)) {
HWY_WARN("Failed to bind C (%zux%zu), %zu packages.", C.Rows(), C.Cols(),
ranges_np.NumTasks());
HWY_WARN("Failed to bind C (%zux%zu).", C.Rows(), C.Cols());
}
}

View File

@ -21,12 +21,12 @@
#include <stddef.h>
#include <stdint.h>
#include <memory> // std::unique_ptr
#include <vector>
// IWYU pragma: begin_exports
#include "util/basics.h"
#include "util/mat.h"
#include "util/threading.h"
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h"
@ -43,93 +43,153 @@ namespace gcpp {
// at least the product of the FMA latency (3..5) times the throughput (2).
// This and `mr` are limited by the number of registers, which is generally
// 32 but 16 for AVX2. `kNR` == 4 enables the `StoreInterleaved4` transpose in
// `MMAddHorizontalSumsIntoPartial`. We ensure `C.Cols() % kNR == 0`.
constexpr size_t kNR = 4;
// `MMStoreHorizontalSumsIntoC`. We ensure `C.Cols() % kNR == 0`.
HWY_INLINE_VAR constexpr size_t kNR = 4;
// Choosing `kMaxMR == kNR` minimizes the ratio of loads to FMA, because
// we load `kNR + kMaxMR` vectors per `kMaxMR * kNR` element tile.
// In general, `M` (batch size) is not a multiple of `kMaxMR`. Thus functions
// that load or store a tile are parameterized on `kRowsAC`: usually `kMaxMR`,
// or less on ISAs with fewer registers, or for the last few rows of A.
static constexpr size_t kMaxMR = 4;
HWY_INLINE_VAR constexpr size_t kMaxMR = 4;
// Mostly stateless, can be constructed on the fly by weights.cc. Captures the
// the ThreadingContext to shorten call sites.
class MMParallel {
public:
// `ctx` must outlive this object.
MMParallel(ThreadingContext& ctx) : ctx_(ctx) {
if (ctx_.pools.NumPackages() > kMaxPackages) {
HWY_WARN("CPU and --max_packages allow %zu > matmul.h kMaxPackages %zu.",
ctx_.pools.NumPackages(), kMaxPackages);
}
}
// For `MMTilesC`.
HWY_INLINE_VAR constexpr size_t kMaxMC = 512;
HWY_INLINE_VAR constexpr size_t kMaxNC = 16384;
Allocator& allocator() const { return ctx_.allocator; }
// Upper bound for per-worker B storage on the stack. Chosen such that one row
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024;
// Initial static partitioning of B rows across packages.
IndexRangePartition RangesOfNP(size_t max_packages, size_t N,
size_t sizeof_TC, size_t nr) const;
// Policy classes for parallelism, implementing some of `ParallelismStrategy`.
// For `BindB` and `BindC`.
size_t Node(size_t pkg_idx) const {
return ctx_.topology.GetCluster(pkg_idx, 0).Node();
}
// Calls `func(pkg_idx)` for each package in parallel.
struct MMParallelNone {
template <class Func>
void ForPkg(const size_t max_packages, const Func& func) {
if constexpr (kMaxPackages > 1) {
ctx_.pools.AllPackages().Run(
0, HWY_MIN(max_packages, ctx_.pools.NumPackages()),
[&](uint64_t task, size_t pkg_idx) {
HWY_DASSERT(task == pkg_idx);
(void)task;
func(pkg_idx);
});
} else {
func(/*pkg_idx=*/0);
void ForN(ThreadingContext& ctx, const IndexRange& range_n,
size_t /*n_multiple*/, size_t inner_tasks, size_t cluster_idx,
const Func& func) const {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
const size_t worker = ctx.Worker(cluster_idx);
func(range_n, worker);
}
template <class Func>
void ForRangesMC_NC(ThreadingContext& ctx,
const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc, size_t cluster_idx,
const Func& func) const {
const size_t worker = ctx.Worker(cluster_idx);
for (size_t i = 0; i < ranges_mc.NumTasks(); ++i) {
const IndexRange range_mc = ranges_mc.Range(i);
for (size_t j = 0; j < ranges_nc.NumTasks(); ++j) {
const IndexRange range_nc = ranges_nc.Range(j);
func(range_mc, range_nc, worker);
}
}
}
// Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is
template <class Func>
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
size_t cluster_idx, const Func& func) const {
const size_t worker = ctx.Worker(cluster_idx);
for (uint64_t row_a = range_mc.begin(); row_a < range_mc.end(); ++row_a) {
func(row_a, worker);
}
}
};
struct MMParallelWithinCluster {
template <class Func>
void ForN(ThreadingContext& ctx, const IndexRange& range_n, size_t n_multiple,
size_t inner_tasks, size_t cluster_idx, const Func& func) const {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
const size_t base = ctx.Worker(cluster_idx);
const IndexRangePartition ranges_n = StaticPartition(
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
ParallelizeOneRange(ranges_n, cluster,
[&](const IndexRange& worker_range, size_t worker) {
func(worker_range, base + worker);
});
}
template <class Func>
void ForRangesMC_NC(ThreadingContext& ctx,
const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc, size_t cluster_idx,
const Func& func) const {
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
const size_t base = ctx.Worker(cluster_idx);
// Low-batch: avoid Divide/Remainder.
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
ParallelizeOneRange(ranges_nc, cluster,
[&](const IndexRange& range_nc, size_t worker) {
func(ranges_mc.Range(0), range_nc, base + worker);
});
} else {
ParallelizeTwoRanges(
ranges_mc, ranges_nc, cluster,
[&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) { func(range_mc, range_nc, base + worker); });
}
}
template <class Func>
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
size_t cluster_idx, const Func& func) const {
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
const size_t base = ctx.Worker(cluster_idx);
cluster.Run(
range_mc.begin(), range_mc.end(),
[&](uint64_t row_a, size_t worker) { func(row_a, base + worker); });
}
};
struct MMParallelHierarchical {
// Cluster/CCX-aware parallel-for over B rows in `range_n`. `n_multiple` is
// the granularity of per-cluster tasks. Calls `func(worker_range, worker)`.
template <class Func>
void ForNP(const IndexRange& range_np, size_t nx_multiple, size_t inner_tasks,
size_t pkg_idx, const Func& func) {
void ForN(ThreadingContext& ctx, const IndexRange& range_n, size_t n_multiple,
size_t inner_tasks, HWY_MAYBE_UNUSED size_t caller_cluster_idx,
const Func& func) const {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage();
HWY_DASSERT(caller_cluster_idx == 0);
// Single cluster: parallel-for over static partition of `range_np`.
hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx);
// Single cluster: parallel-for over static partition of `range_n`.
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters();
const size_t num_clusters = all_clusters.NumWorkers();
if (num_clusters == 1) {
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, 0);
const IndexRangePartition worker_ranges = StaticPartition(
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
const size_t cluster_idx = 0;
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
const IndexRangePartition ranges_n = StaticPartition(
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
return ParallelizeOneRange(
worker_ranges, cluster,
[&](const IndexRange& worker_range, size_t thread) {
func(worker_range, pkg_base + thread);
ranges_n, cluster,
[&](const IndexRange& worker_range, size_t worker) {
func(worker_range, worker);
});
}
// Assign each cluster a sub-range of `range_np` (typically hundreds).
const IndexRangePartition nx_ranges =
StaticPartition(range_np, num_clusters, nx_multiple);
// Assign each cluster a sub-range of `range_n` (typically hundreds).
const IndexRangePartition ranges_n =
StaticPartition(range_n, num_clusters, n_multiple);
ParallelizeOneRange(
nx_ranges, all_clusters,
[&](const IndexRange& nx_range, const size_t cluster_idx) {
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
const size_t cluster_base =
pkg_base + cluster_idx * ctx_.pools.MaxWorkersPerCluster();
ranges_n, all_clusters,
[&](const IndexRange& n_range, const size_t cluster_idx) {
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
const size_t cluster_base = ctx.Worker(cluster_idx);
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
const IndexRangePartition worker_ranges = StaticPartition(
nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple);
n_range, cluster.NumWorkers() * inner_tasks, n_multiple);
ParallelizeOneRange(
worker_ranges, cluster,
[&](const IndexRange& worker_range, size_t thread) {
func(worker_range, cluster_base + thread);
[&](const IndexRange& worker_range, size_t worker) {
func(worker_range, cluster_base + worker);
});
});
}
@ -137,31 +197,32 @@ class MMParallel {
// Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B
// rows). Calls `func(range_mc, range_nc, worker)`.
template <class Func>
void ForRangesMC_NC(const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc, size_t pkg_idx,
const Func& func) {
const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage();
hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx);
void ForRangesMC_NC(ThreadingContext& ctx,
const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc,
HWY_MAYBE_UNUSED size_t caller_cluster_idx,
const Func& func) const {
HWY_DASSERT(caller_cluster_idx == 0);
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters();
// `all_clusters` is a pool with one worker per cluster in a package.
const size_t num_clusters = all_clusters.NumWorkers();
// Single (big) cluster: collapse two range indices into one parallel-for
// to reduce the number of fork-joins.
if (num_clusters == 1) {
const size_t cluster_idx = 0;
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
// Low-batch: avoid Divide/Remainder.
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
return ParallelizeOneRange(
ranges_nc, cluster, [&](const IndexRange& range_nc, size_t thread) {
func(ranges_mc.Range(0), range_nc, pkg_base + thread);
ranges_nc, cluster, [&](const IndexRange& range_nc, size_t worker) {
func(ranges_mc.Range(0), range_nc, worker);
});
} else {
return ParallelizeTwoRanges(
ranges_mc, ranges_nc, cluster,
[&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t thread) {
func(range_mc, range_nc, pkg_base + thread);
});
size_t worker) { func(range_mc, range_nc, worker); });
}
}
@ -170,139 +231,95 @@ class MMParallel {
ParallelizeOneRange(
ranges_nc, all_clusters,
[&](const IndexRange range_nc, size_t cluster_idx) {
const size_t cluster_base =
pkg_base + cluster_idx * ctx_.pools.MaxWorkersPerCluster();
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
const size_t cluster_base = ctx.Worker(cluster_idx);
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
ParallelizeOneRange(ranges_mc, cluster,
[&](const IndexRange& range_mc, size_t thread) {
func(range_mc, range_nc, cluster_base + thread);
[&](const IndexRange& range_mc, size_t worker) {
func(range_mc, range_nc, cluster_base + worker);
});
});
}
// Calls `func(row_a, worker)` in parallel.
template <class Func>
void ForRangeMC(const IndexRange& range_mc, size_t pkg_idx,
const Func& func) {
const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage();
ctx_.pools.Pool(pkg_idx).Run(
range_mc.begin(), range_mc.end(),
[&](uint64_t row_a, size_t thread) { func(row_a, pkg_base + thread); });
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
size_t caller_cluster_idx, const Func& func) const {
HierarchicalParallelFor(range_mc.Num(), ctx.pools,
[&](size_t task, size_t worker) {
func(range_mc.begin() + task, worker);
});
}
private:
ThreadingContext& ctx_;
};
void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel);
// C is BF16/float, or double for partial.
void BindC(MatPtr& C, MMParallel& parallel);
template <class Func, typename... Args>
void DispatchParallelism(ParallelismStrategy parallelism, const Func& func,
Args&&... args) {
switch (parallelism) {
case ParallelismStrategy::kNone:
return func(MMParallelNone(), std::forward<Args>(args)...);
case ParallelismStrategy::kWithinCluster:
return func(MMParallelWithinCluster(), std::forward<Args>(args)...);
case ParallelismStrategy::kHierarchical:
return func(MMParallelHierarchical(), std::forward<Args>(args)...);
default:
HWY_UNREACHABLE;
}
}
// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows.
#pragma pack(push, 1) // power of two size
template <typename T>
class StridedView {
void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC);
// C is BF16/float.
void BindC(ThreadingContext& ctx, MatPtr& C);
// Space for converting A=F32 to BF16 before the matmul. This is faster than
// on-the-fly when native BF16 is available: it only happens once, not per B
// tile row, and the cache footprint is smaller.
class MMEntireA {
public:
StridedView(T* HWY_RESTRICT row0, size_t cols, size_t stride)
: row0_(row0),
cols_(static_cast<uint32_t>(cols)),
stride_(static_cast<uint32_t>(stride)) {
HWY_DASSERT(stride >= cols);
}
// Compile-time bounds on matrix columns to enable pre-allocating storage
// and reusing it across `MatMul` calls. Sufficient for Gemma 2 27B.
static constexpr size_t kMaxK = 36 * 1024;
T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
size_t Cols() const { return static_cast<size_t>(cols_); }
explicit MMEntireA(const Allocator& allocator)
// 288 MiB. Must be padded, see `DoDecompressA`.
: A_("A_bf", Extents2D(kMaxBatchSize, kMaxK), allocator,
MatPadding::kOdd) {}
size_t Stride() const { return static_cast<size_t>(stride_); }
void SetStride(size_t stride) {
HWY_DASSERT(stride >= Cols());
stride_ = stride;
}
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
StridedView<T> View(size_t r, size_t c, size_t cols) const {
HWY_DASSERT(c < Cols());
HWY_DASSERT(cols <= Cols() - c);
return StridedView<T>(Row(r) + c, cols, stride_);
StridedViewBF A(const Extents2D& extents) const {
HWY_DASSERT(extents.rows <= kMaxBatchSize);
return StridedViewBF(A_, 0, 0, extents.cols);
}
private:
T* HWY_RESTRICT row0_;
uint32_t cols_;
uint32_t stride_;
MatStorageT<BF16> A_;
};
#pragma pack(pop)
using StridedViewBF = StridedView<BF16>;
using StridedViewD = StridedView<double>;
// Per-package storage for packed A, and one global C-shaped `partial` for
// accumulating partial dot products (sections of K).
class MMStorage {
// One tile of C per *worker* (required for `kNT_MT*`).
class MMTilesC {
public:
// Compile-time bounds on matrix dimensions to enable pre-allocating storage
// and reusing it across `MatMul` calls. The resulting allocations are 256 MiB
// per package and 512 MiB, respectively.
static constexpr size_t kMaxM = 4096;
static constexpr size_t kMaxK = 64 * 1024;
static constexpr size_t kMaxN = 256 * 1024;
// Upper bound for per-worker B storage on the stack. Chosen such that one row
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
static constexpr size_t kMaxKC = 8 * 1024;
// Internally threaded; must not be called concurrently with the same
// `ThreadingContext` (used via `parallel`).
MMStorage(const Allocator& allocator, MMParallel& parallel)
: // Per-worker copies of `partial` would be wasteful. We instead
// allocate one instance of the maximum matrix extents because threads
// write at false-sharing-free granularity.
partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN), allocator,
MatPadding::kOdd),
// Same stride independent of the actual C.Cols() so we can pre-bind.
partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) {
// Per-package allocation so each can decompress A into its own copy.
// Must be padded, see `DoDecompressA`.
parallel.ForPkg(kMaxPackages, [&](size_t pkg_idx) {
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
"pkg_A", Extents2D(kMaxM, kMaxK), allocator, MatPadding::kOdd));
if (allocator.ShouldBind()) {
const size_t node = parallel.Node(pkg_idx);
size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() *
pkg_A_[pkg_idx]->ElementBytes();
bytes = hwy::RoundDownTo(bytes, allocator.BasePageBytes());
if (!allocator.BindMemory(pkg_A_[pkg_idx]->Row(0), bytes, node)) {
HWY_WARN("Failed to bind memory for package %zu", pkg_idx);
}
}
});
// Avoid cross-package accesses.
BindC(partial_storage_, parallel);
explicit MMTilesC(const ThreadingContext& ctx) {
const size_t max_workers = ctx.pools.MaxWorkers();
C_.reserve(max_workers);
for (size_t worker = 0; worker < max_workers; ++worker) {
C_.push_back(MatStorageT<BF16>("Ctile", Extents2D(kMaxBatchSize, kMaxNC),
ctx.allocator, MatPadding::kOdd));
}
}
// Returns per-package matrix view.
StridedViewBF A(size_t pkg_idx, const Extents2D& extents) const {
HWY_DASSERT(extents.rows <= kMaxM);
HWY_DASSERT(extents.cols <= kMaxK);
return StridedViewBF(const_cast<BF16*>(pkg_A_[pkg_idx]->Row(0)),
extents.cols, pkg_A_[pkg_idx]->Stride());
StridedViewBF C(const Extents2D& extents, size_t worker) const {
HWY_DASSERT(extents.rows <= kMaxBatchSize);
HWY_DASSERT(worker < C_.size());
return StridedViewBF(C_[worker], 0, 0, extents.cols);
}
StridedViewD Partial() const { return partial_; }
private:
std::unique_ptr<MatStorageT<BF16>> pkg_A_[kMaxPackages];
MatStorageT<double> partial_storage_;
StridedViewD partial_;
std::vector<MatStorageT<BF16>> C_;
};
//------------------------------------------------------------------------------
// Autotuning
// Naming convention: outer loop first, T suffix means threaded. This refers to
// the loops *around* `A2C0`, which contains loops over mc/kc. The outermost
// `ranges_np` loop across packages is implicit and applies to all of these.
// the loops *around* `A2C0`, which contains loops over mc/kc.
//
// Parallelizing across K (A/B columns) is undesirable because the resulting
// partial dot products require synchronization or reduction across threads.
@ -310,20 +327,42 @@ enum class MMOrder : uint8_t {
// Single M, parallel N, sequential K (inside the parallel section to
// reduce fork-joins). Similar to GotoBLAS, good for large N vs. M and K.
kNT_K,
// Specialization of `kNT_K` for a single K task with `kDirect`.
// Specialization of `kNT_K` for a single K task with `MMSetC`.
kNT,
// Parallelize over blocks of M and N: good when both are large. We no longer
// support `kMT_NT_K`: no advantage on Skylake, and `kNT_MT_K` is 1.5x as
// fast on Zen4.
kNT_MT_K,
kNT_MT, // Specialization of `kNT_MT_K` for a single K task with `kDirect`.
kNT_MT, // Specialization of `kNT_MT_K` for a single K task with `MMSetC`.
// Resident C (`kK_M_NT`) should be good for large K relative to M and N.
// However, it does not (much) outperform `kNT_K` on SKX and Zen4. There are
// no kN* because we expect M (batch size) to be small relative to K and N.
// no kM* because we expect M (batch size) to be small relative to K and N.
};
// Tag types for `DispatchOrder`.
struct MMOrderNT_K {};
struct MMOrderNT {};
struct MMOrderNT_MT_K {};
struct MMOrderNT_MT {};
template <class Func, typename... Args>
void DispatchOrder(MMOrder order, const Func& func, Args&&... args) {
switch (order) {
case MMOrder::kNT_K:
return func(MMOrderNT_K(), std::forward<Args>(args)...);
case MMOrder::kNT:
return func(MMOrderNT(), std::forward<Args>(args)...);
case MMOrder::kNT_MT_K:
return func(MMOrderNT_MT_K(), std::forward<Args>(args)...);
case MMOrder::kNT_MT:
return func(MMOrderNT_MT(), std::forward<Args>(args)...);
default:
HWY_UNREACHABLE;
}
}
static inline bool IsBlock(MMOrder order) {
return order == MMOrder::kNT_MT_K || order == MMOrder::kNT_MT;
}
@ -347,29 +386,6 @@ static inline const char* StringFromOrder(MMOrder order) {
}
}
// How/where to write the A2C0 result. This determines the `tag` argument to
// that function, which governs whether we call `MMStoreHorizontalSumsIntoC` or
// `MMAddHorizontalSumsIntoPartial`.
enum class MMOut : uint8_t {
kCopy, // accumulate into partial, scale/add to C
kDirect, // single kc task, write directly to C
kParM // kCopy but parallel over M
// kParN is not better on SKX/Zen4.
};
static inline const char* StringFromOut(MMOut out) {
switch (out) {
case MMOut::kDirect:
return "Direct";
case MMOut::kCopy:
return "Copy";
case MMOut::kParM:
return "ParM";
default:
return nullptr;
}
}
// How to parallelize the per-package `DecompressA`. To reduce combinatorial
// explosion, we tune this separately from `MMConfig`.
enum class MMParA : uint8_t { kNone, kK1 = 1, kK2 = 2, kK4 = 4, kM };
@ -403,10 +419,9 @@ class MMConfig {
MMConfig() = default; // for std::vector
// `mr` is the number of A rows per call to `MMKernel::LoopKC`.
// `MMOrder` is how to parallelize the outer loops.
// `MMOut` is how/whether to parallelize filling the C result.
// `inner_tasks` chooses the within-cluster task granularity in `ForNP`.
// `inner_tasks` chooses the within-cluster task granularity in `ForN`.
MMConfig(size_t K, size_t N, size_t mr, size_t mc, size_t kc, size_t nc,
size_t kc_multiple, size_t nc_multiple, MMOrder order, MMOut out,
size_t kc_multiple, size_t nc_multiple, MMOrder order,
int inner_tasks)
: mr_(static_cast<uint32_t>(mr)),
mc_(static_cast<uint32_t>(mc)),
@ -415,7 +430,6 @@ class MMConfig {
nc_multiple_(static_cast<uint32_t>(nc_multiple)),
kc_multiple_(static_cast<uint32_t>(kc_multiple)),
order_(order),
out_(out),
inner_tasks_(static_cast<uint8_t>(inner_tasks)),
reserved_{} {
HWY_DASSERT(mr == 1 || mr == 2 || mr == 4);
@ -431,7 +445,6 @@ class MMConfig {
HWY_WARN("nc %zu not a multiple of nc_multiple %zu", nc, nc_multiple);
}
HWY_DASSERT(StringFromOrder(order_) != nullptr);
HWY_DASSERT(StringFromOut(out_) != nullptr);
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
}
@ -443,12 +456,11 @@ class MMConfig {
IndexRangePartition RangesOfKC(size_t K) const {
return MaxSizePartition(IndexRange(0, K), kc_, kc_multiple_);
}
IndexRangePartition RangesOfNC(IndexRange range_np) const {
return MaxSizePartition(range_np, nc_, nc_multiple_);
IndexRangePartition RangesOfNC(size_t N) const {
return MaxSizePartition(IndexRange(0, N), nc_, nc_multiple_);
}
MMOrder Order() const { return order_; }
MMOut Out() const { return out_; }
// No `OuterTasks` because static partitioning across clusters is sufficient.
size_t InnerTasks() const { return static_cast<size_t>(inner_tasks_); }
@ -467,17 +479,14 @@ class MMConfig {
uint32_t nc_multiple_;
uint32_t kc_multiple_;
MMOrder order_;
MMOut out_;
uint8_t inner_tasks_;
HWY_MAYBE_UNUSED uint8_t reserved_[5];
HWY_MAYBE_UNUSED uint8_t reserved_[6];
};
static_assert(sizeof(MMConfig) == 32); // for faster indexing
#pragma pack(pop)
std::vector<MMConfig> MMCandidates(const Allocator& allocator, size_t M,
size_t K, size_t N, size_t sizeof_TC,
size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np,
std::vector<MMConfig> MMCandidates(const CacheInfo& cache, size_t M, size_t K,
size_t N, size_t num_B, size_t sizeof_TC,
bool print_config);
// State machine for choosing the best `TConfig`, which is `MMConfig` for the
@ -584,7 +593,7 @@ class MMAutoTune {
// `MMOrder::kNT[_K]` are no longer allowed. They require a single MC range,
// but choosing the same config for a larger M can result in multiple MC ranges.
// Thus M less than this must have unique keys/configs.
static constexpr size_t kMaxTilesM = 8;
HWY_INLINE_VAR constexpr size_t kMaxTilesM = 8;
// Map of previously seen dimensions to index via linear search.
class MMKeys {
@ -601,28 +610,30 @@ class MMKeys {
static constexpr Key kPadding = 0;
// Compresses the dimensions into a single Key for faster comparison.
static Key KeyFromDims(size_t M, size_t K, size_t N) {
static Key KeyFromDims(size_t M, size_t K, size_t N, size_t num_B) {
HWY_DASSERT(M < (Key{1} << 16)); // batch sizes are smaller
HWY_DASSERT(K < (Key{1} << 24));
HWY_DASSERT(N < (Key{1} << 24));
HWY_DASSERT(K < (Key{1} << 20));
HWY_DASSERT(N < (Key{1} << 20));
HWY_DASSERT(num_B == 1 || num_B == 2);
const Key key = static_cast<Key>(BucketM(M)) | (static_cast<Key>(K) << 16) |
(static_cast<Key>(N) << 40);
(static_cast<Key>(N) << 40) |
(static_cast<Key>(num_B) << 60);
HWY_DASSERT(key != kPadding);
return key;
}
// We leave the search to callers so they can use dynamic-dispatched SIMD,
// which is not possible in this header.
// We leave the search to callers so they can use per-target SIMD, which is
// not possible in this header.
hwy::Span<const Key> Keys() const {
return hwy::Span<const Key>(keys_.get(), num_unique_);
}
// Must only be called if not already present in `Keys()`.
void Append(Key key, const Allocator& allocator) {
void Append(Key key, size_t vector_bytes) {
// Dynamic allocation because the test checks many more dimensions than
// would be reasonable to pre-allocate. DIY for alignment and padding.
if (HWY_UNLIKELY(num_unique_ >= capacity_)) {
const size_t NU64 = allocator.VectorBytes() / sizeof(Key);
const size_t NU64 = vector_bytes / sizeof(Key);
// Start at one vector so the size is always a multiple of N.
if (HWY_UNLIKELY(capacity_ == 0)) {
capacity_ = hwy::DivCeil(NU64, 2); // will be doubled below
@ -649,26 +660,13 @@ class MMKeys {
// Per-MatMul-shape state.
struct MMPerKey {
MMPerKey(size_t max_packages, size_t N, size_t sizeof_TC, size_t nr,
MMParallel& parallel)
: ranges_np(parallel.RangesOfNP(max_packages, N, sizeof_TC, nr)) {
HWY_DASSERT(ranges_np.NumTasks() <= max_packages);
}
// Only profile if enabled and the main autotuner finished (the par_a
// autotuner is per-package and we want to avoid synchronization).
bool WantProfile() const { return PROFILER_ENABLED != 0 && autotune.Best(); }
const IndexRangePartition ranges_np;
MMAutoTune<MMConfig> autotune;
MMAutoTune<MMParA> autotune_par_a[kMaxPackages];
MMAutoTune<MMParA> autotune_par_a;
};
// Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive
// `MatMulEnv`.
struct MatMulEnv {
// Internally threaded; must not be called concurrently with the same
// `ThreadingContext`.
explicit MatMulEnv(ThreadingContext& ctx);
ThreadingContext& ctx;
@ -681,40 +679,110 @@ struct MatMulEnv {
// Whether to print the best config immediately after autotuning finished.
bool print_best = false;
MMParallel parallel;
MMStorage storage;
MMKeys keys;
std::vector<MMPerKey> per_key;
MMEntireA A_BF;
MMTilesC C_tiles;
struct PerCluster {
MMKeys keys;
std::vector<MMPerKey> per_key;
// Prevents false sharing.
HWY_MAYBE_UNUSED uint8_t
padding[HWY_ALIGNMENT - sizeof(MMKeys) - sizeof(per_key)];
};
std::vector<PerCluster> per_cluster;
// Storage for arbitrary output rows, see `MatPtr::AllocateAndAttachRowPtrs`.
// Most MatMul callers use strided MatPtr, but GemmaAttention::ComputeQKV
// writes to differing KV positions per query / output row.
// The first three allocations are sufficient for any A, B, C, respectively,
// but also potentially overwritten by each MatMul. Subsequent entries are
// precomputed for tensors and not overwritten. Per-tensor allocations make
// it likelier that asan detects bugs such as use after free, overrun, and
// The first `num_clusters` entries are sufficient for any C argument, and
// must be indexed by `options.cluster_idx`. Note that they are potentially
// overwritten by each `MatMul`. Subsequent entries are for specific tensors
// and only written once by their allocator. A per-tensor allocation makes it
// likelier that asan detects bugs such as use after free, overrun, and
// dangling references.
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
};
// Arguments to MatMul() that are independent of the A/B/C types.
// Reduces register pressure compared to individual values/references.
// Called via `CallClosure`, which consumes the first (opaque) argument. User
// functions are called with the entire C matrix, the sub-ranges of M (rows)
// and N (cols) that this thread has just filled, a view into a second tile
// (only for `TwoMatmul`), and the worker thread index (see `ParallelFor`).
typedef void (*MMFunc)(const void* opaque, RowPtrsBF, IndexRange, IndexRange,
StridedViewBF, size_t);
class MMOptions {
// Same technique as in `hwy::ThreadPool` and C++23 `std::function_ref`:
// type-erasure without allocation.
template <class Closure>
static void CallClosure(const void* opaque, RowPtrsBF C1, IndexRange range_r,
IndexRange range_c, StridedViewBF C2, size_t worker) {
(*reinterpret_cast<const Closure*>(opaque))(C1, range_r, range_c, C2,
worker);
}
public:
// `closure` must remain alive until the end of (Two)MatMul.
template <class Closure>
void SetFunc(const Closure& closure) {
func = static_cast<MMFunc>(&CallClosure<Closure>);
opaque = &closure;
}
void MaybeCallFunc(RowPtrsBF C1, IndexRange range_r, IndexRange range_c,
StridedViewBF C2, size_t worker) const {
if (func != nullptr) {
func(opaque, C1, range_r, range_c, C2, worker);
}
}
MMFunc func = nullptr; // called if non-null and `TC` is BF16.
const void* opaque = nullptr;
uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`.
ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical;
};
// Arguments to MatMul() that are independent of the A/B/C types. Reduces
// register pressure compared to individual values/references. Also used for
// passing through `DispatchOrder`.
struct MMArgs {
MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale,
const float* HWY_RESTRICT add, const StridedViewD& partial)
: env(&env),
per_key(&per_key),
scale(scale),
MMArgs(MatMulEnv& env, size_t M, size_t K, size_t N, float scale_A,
const float* HWY_RESTRICT add, MMOptions options,
const MMAutoTune<MMConfig>& autotune, const MMConfig& config)
: env(env),
line_bytes(env.ctx.cache_info.LineBytes()),
range_n(0, N),
scale_A(scale_A),
add(add),
partial(partial) {}
options(options),
MatMulEnv* env;
MMPerKey* per_key;
autotune(autotune),
mr(config.MR()),
ranges_mc(config.RangesOfMC(M)),
ranges_kc(config.RangesOfKC(K)),
ranges_nc(config.RangesOfNC(N)),
order(config.Order()),
inner_tasks(config.InnerTasks()) {}
double scale;
MatMulEnv& env;
const size_t line_bytes; // from `env`, for `Stride`.
// MatMul arguments:
const IndexRange range_n; // entire N
// There can be two B, so do not yet multiply together the A and B scales.
const float scale_A;
const float* HWY_RESTRICT add;
// Same size as C, threads write at false-sharing-free granularity.
StridedViewD partial;
const MMOptions options;
const MMAutoTune<MMConfig>& autotune; // for `MaybeEnter`
// From `MMConfig`:
const size_t mr;
const IndexRangePartition ranges_mc;
const IndexRangePartition ranges_kc;
const IndexRangePartition ranges_nc;
const MMOrder order;
const size_t inner_tasks;
};
// Wrapper over hwy::Zone that is only enabled when autotuning finished.
@ -731,11 +799,12 @@ class MMZone {
}
}
// `name` must be a string literal.
template <class AutoTune>
void MaybeEnter(size_t thread, hwy::profiler::ZoneHandle zone,
const MMArgs& args) {
if (args.per_key->WantProfile()) {
new (&data_) Zone(args.env->ctx.profiler, thread, zone);
const MatMulEnv& env, const AutoTune* auto_tune) {
// Only if enabled and autotuning finished.
if (PROFILER_ENABLED && auto_tune->Best()) {
new (&data_) Zone(env.ctx.profiler, thread, zone);
HWY_DASSERT(data_ != 0);
}
}
@ -746,7 +815,8 @@ class MMZone {
};
#else
struct MMZone {
void MaybeEnter(size_t, hwy::profiler::ZoneHandle, const MMArgs&) {}
void MaybeEnter(size_t, hwy::profiler::ZoneHandle, const MatMulEnv&,
const void*) {}
};
#endif // PROFILER_ENABLED

View File

@ -28,8 +28,8 @@
#define GEMMA_MATMUL_DEFINE_ONE(TA, TB, TC) \
MMPerKey* MatMulStatic(const MatPtrT<TA>& A, const MatPtrT<TB>& B, \
const float* HWY_RESTRICT add, MatMulEnv& env, \
MatPtrT<TC>& C) { \
return MatMul(A, B, add, env, C); \
MatPtrT<TC>& C, MMOptions options) { \
return MatMul(A, B, add, env, C, options); \
}
#if defined(THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_MATMUL_STATIC_INL_H_) == \
@ -53,6 +53,14 @@ namespace HWY_NAMESPACE {
// included from matmul_static_*.cc.
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DEFINE_ONE, GEMMA_MATMUL_TB) // NOLINT
HWY_MAYBE_UNUSED void TwoMatMulStatic(const MatPtrT<BF16>& A, // NOLINT
const MatPtrT<GEMMA_MATMUL_TB>& B1,
const MatPtrT<GEMMA_MATMUL_TB>& B2,
MatMulEnv& env, MatPtrT<BF16>& C,
MMOptions options) {
TwoMatMul(A, B1, B2, env, C, options);
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();

View File

@ -35,15 +35,22 @@
#define GEMMA_MATMUL_DECL_ONE(TA, TB, TC) \
MMPerKey* MatMulStatic(const MatPtrT<TA>& A, const MatPtrT<TB>& B, \
const float* HWY_RESTRICT add, MatMulEnv& env, \
MatPtrT<TC>& C);
MatPtrT<TC>& C, MMOptions options);
#define GEMMA_MATMUL_FOR_B(TB) \
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, TB) \
void TwoMatMulStatic(const MatPtrT<BF16>& A, const MatPtrT<TB>& B1, \
const MatPtrT<TB>& B2, MatMulEnv& env, \
MatPtrT<BF16>& C, MMOptions options);
// Passed to HWY_VISIT_TARGETS; declares all overloads for all targets.
#define GEMMA_MATMUL_DECL(TARGET, NAMESPACE) \
namespace NAMESPACE { \
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, BF16) \
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, float) \
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, NuqStream) \
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, SfpStream) \
GEMMA_MATMUL_FOR_B(BF16) \
GEMMA_MATMUL_FOR_B(float) \
GEMMA_MATMUL_FOR_B(NuqStream) \
GEMMA_MATMUL_FOR_B(SfpStream) \
GEMMA_MATMUL_FOR_B(I8Stream) \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE

29
ops/matmul_static_i8.cc Normal file
View File

@ -0,0 +1,29 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "ops/matmul_static_i8.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_MATMUL_TB I8Stream
#include "ops/matmul_static-inl.h"

View File

@ -29,6 +29,8 @@
#include <stddef.h>
#include <stdio.h>
#include <atomic>
#include "ops/matmul.h"
#include "util/basics.h"
#include "util/mat.h"
@ -118,17 +120,26 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const float max_abs = MaxAbs(a_batch) * MaxAbs(b_trans_batch);
const double eps_bf16 = hwy::ConvertScalarTo<double>(hwy::Epsilon<BF16>());
const double eps_f32 = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>());
double tolerance = 12 * norm * eps_f32;
// Dot() also rounds F32,BF16 to BF16, but not with F32,F32, so increase the
// tolerance there.
if (IsF32<TA>() && IsF32<TB>()) {
tolerance += 4 * max_abs * eps_bf16;
// Dot() uses double-precision summation.
double tolerance = 20 * norm * eps_f32;
// If either is F32, Dot() promotes F32 or even F64, but MatMul demotes the
// F32 to BF16, so add extra tolerance.
if (IsF32<TA>() || IsF32<TB>()) {
tolerance += 2 * max_abs * eps_bf16;
}
if (tolerance > 500.0) {
HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs);
}
const double max_rel = 1.0 + hwy::ConvertScalarTo<double>(hwy::Epsilon<TC>());
const double rel_tolerance =
1.0 + hwy::ConvertScalarTo<double>(hwy::Epsilon<TC>());
double max_rel = 0.0;
size_t worst_r = 0;
size_t worst_c = 0;
double worst_actual = 0.0;
double worst_expected = 0.0;
size_t num_outside = 0;
for (size_t r = 0; r < A.Rows(); r++) {
const float* expected_row = c_slow_batch.Row(r);
const float* actual_row = c_batch.Row(r);
@ -143,15 +154,24 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const double min = HWY_MIN(expected_value, actual_value);
const double rel = max / HWY_MAX(min, 1E-6);
if (rel > max_rel) {
hwy::Abort(__FILE__, line,
"(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
"tolerance %f rel %E max_rel %E\n",
r, c, expected_value, actual_value, norm, max_abs,
tolerance, rel, max_rel);
worst_expected = expected_value;
worst_actual = actual_value;
worst_r = r;
worst_c = c;
max_rel = rel;
++num_outside;
}
}
}
}
if (max_rel > rel_tolerance) {
hwy::Abort(__FILE__, line,
"(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
"tolerance %f rel %E max_rel %E num_outside %zu\n",
worst_r, worst_c, worst_expected, worst_actual, norm, max_abs,
tolerance, max_rel, rel_tolerance, num_outside);
}
}
// B is already transposed.
@ -171,29 +191,22 @@ HWY_INLINE void MatMulSlow(const MatPtrT<TA> A, const MatPtrT<TB> B,
const IndexRange all_cols_c(0, C.Cols());
NestedPools& pools = env.ctx.pools;
hwy::ThreadPool& all_packages = pools.AllPackages();
const IndexRangePartition get_row_c =
StaticPartition(all_rows_c, all_packages.NumWorkers(), 1);
hwy::ThreadPool& all_clusters = pools.AllClusters();
const size_t multiple = env.ctx.allocator.QuantumBytes() / sizeof(TB);
const IndexRangePartition get_col_c =
StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple);
ParallelizeOneRange(
get_row_c, all_packages,
[&](const IndexRange& rows_c, size_t package_idx) HWY_ATTR {
hwy::ThreadPool& all_clusters = pools.AllClusters(package_idx);
const size_t multiple = env.ctx.allocator.QuantumBytes() / sizeof(TB);
const IndexRangePartition get_col_c =
StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple);
ParallelizeOneRange(
get_col_c, all_clusters,
[&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR {
for (size_t r : rows_c) {
TC* HWY_RESTRICT C_row = C.Row(r);
for (size_t c : cols_c) {
const float add = add_row ? add_row[c] : 0.0f;
C_row[c] = hwy::ConvertScalarTo<TC>(
add + scale * Dot(df, b_span, c * B.Stride(), A.Row(r),
A.Cols()));
}
}
});
get_col_c, all_clusters,
[&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR {
for (size_t r : all_rows_c) {
TC* HWY_RESTRICT C_row = C.Row(r);
for (size_t c : cols_c) {
const float add = add_row ? add_row[c] : 0.0f;
const float dot =
Dot(df, b_span, c * B.Stride(), A.Row(r), A.Cols());
C_row[c] = hwy::ConvertScalarTo<TC>(add + scale * dot);
}
}
});
}
@ -228,7 +241,9 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatStorageT<TC> C_slow("C_slow", C_extents, env.ctx.allocator,
MatPadding::kOdd);
MatStorageT<TC> C("C", C_extents, env.ctx.allocator, MatPadding::kOdd);
MatStorageT<TC> C2("C", C_extents, env.ctx.allocator, MatPadding::kOdd);
C.AllocateAndAttachRowPtrs(env.row_ptrs);
C2.AllocateAndAttachRowPtrs(env.row_ptrs);
MatStorageT<float> add_storage =
add ? GenerateMat<float>(Extents2D(1, cols_bc), env.ctx.allocator,
@ -240,10 +255,52 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatMulSlow(A, BT, add_row, env, C_slow);
// A few reps to get coverage of the various autotuned code paths.
MMOptions options;
for (size_t rep = 0; rep < 16; ++rep) {
MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C);
MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C, options);
AssertClose(A, BT, C_slow, C, env, line);
if (per_key->autotune.Best()) break;
// Check before TwoMatMulStatic(), which can invalidate per_key.
const bool autotune_done = !!per_key->autotune.Best();
// Ensure the tiled view returns the same result as C.
if constexpr (IsBF16<TA>() && IsBF16<TC>()) {
// The total view area should match the entire C matrix.
std::atomic<size_t> total_view_area = 0;
const auto fused = [&](RowPtrsBF C2_rows, IndexRange range_r,
IndexRange range_c, StridedViewBF C2_view,
size_t worker) {
total_view_area.fetch_add(range_r.Num() * range_c.Num());
HWY_ASSERT(range_c.Num() <= C2_view.Cols());
HWY_ASSERT(worker < env.ctx.pools.MaxWorkers());
for (size_t ir = 0; ir < range_r.Num(); ++ir) {
const size_t r = range_r.begin() + ir;
for (size_t ic = 0; ic < range_c.Num(); ++ic) {
const size_t c = range_c.begin() + ic;
const float expected =
hwy::ConvertScalarTo<float>(C2_rows.Row(r)[c]);
const float actual =
hwy::ConvertScalarTo<float>(C2_view.Row(ir)[ic]);
const float L1 = hwy::ScalarAbs(actual - expected);
if (L1 > 1E-6f) {
HWY_ABORT("%zu: ir %zu ic %zu L1 %f expected %f actual %f.",
worker, ir, ic, L1, expected, actual);
}
}
}
};
options.SetFunc(fused);
TwoMatMulStatic(A, BT, BT, env, C2, options);
HWY_ASSERT_EQ(C.Extents().Area(), total_view_area.load());
options.func = nullptr; // reset for next call
// TwoMatMulStatic() does not support adding a bias vector.
if (!add) {
AssertClose(A, BT, C, C2, env, line);
}
}
if (autotune_done) break;
}
}
@ -256,34 +313,28 @@ void TestTiny() {
if (first_target == 0) first_target = HWY_TARGET;
if (HWY_TARGET != first_target) return;
for (size_t max_packages : {1, 2}) {
ThreadingArgs threading_args;
threading_args.bind = Tristate::kTrue;
threading_args.max_packages = max_packages;
ThreadingContext ctx(threading_args);
MatMulEnv env(ctx);
NestedPools& pools = env.ctx.pools;
ThreadingArgs threading_args;
threading_args.bind = Tristate::kTrue;
ThreadingContext ctx(threading_args);
MatMulEnv env(ctx);
NestedPools& pools = env.ctx.pools;
if constexpr (GEMMA_DISABLE_TOPOLOGY || kMaxPackages == 1) {
if (max_packages == 2) break; // we only have one package
} else {
// If less than the limit, we have already tested all num_packages.
if (env.ctx.topology.FullTopology().packages.size() < max_packages) break;
}
fprintf(stderr, "TestTiny %zu: %s %s\n", max_packages,
env.ctx.topology.TopologyString(), pools.PinString());
fprintf(stderr, "TestTiny: %s %s\n", env.ctx.topology.TopologyString(),
pools.PinString());
pools.MaybeStartSpinning(threading_args.spin);
pools.MaybeStartSpinning(threading_args.spin);
for (size_t M = 1; M <= 12; ++M) {
for (size_t K = 1; K <= 64; K *= 2) {
for (size_t N = 4; N <= 64; N += max_packages * 4) {
TestMatMul<F32, F32, F32>(M, K, N, /*add=*/false, env, __LINE__);
}
for (size_t M = 1; M <= 12; ++M) {
for (size_t K = 1; K <= 64; K *= 2) {
for (size_t N = 4; N <= 64; N += 4) {
TestMatMul<F32, F32, F32>(M, K, N, /*add=*/false, env, __LINE__);
TestMatMul<BF16, F32, F32>(M, K, N, /*add=*/false, env, __LINE__);
TestMatMul<F32, BF16, F32>(M, K, N, /*add=*/false, env, __LINE__);
TestMatMul<BF16, BF16, F32>(M, K, N, /*add=*/false, env, __LINE__);
}
}
pools.MaybeStopSpinning(threading_args.spin);
}
pools.MaybeStopSpinning(threading_args.spin);
}
void TestAllMatMul() {
@ -297,6 +348,7 @@ void TestAllMatMul() {
ThreadingArgs threading_args;
threading_args.bind = Tristate::kTrue;
ThreadingContext ctx(threading_args);
MatMulEnv env(ctx);
NestedPools& pools = env.ctx.pools;
@ -334,6 +386,10 @@ void TestAllMatMul() {
TestMatMul<F32, SFP>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<BF16, SFP>(256, 256, 256, /*add=*/true, env, __LINE__);
// Non-vector-multiple K.
TestMatMul<F32, BF16>(128, 258, 128, /*add=*/true, env, __LINE__);
TestMatMul<BF16, BF16>(128, 258, 128, /*add=*/true, env, __LINE__);
// minimal non-square test. kColsARowsB must be at least 2 vectors.
TestMatMul<F32>(35, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16>(34, 128, 32, /*add=*/true, env, __LINE__);

View File

@ -1,303 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Include guard for non-SIMD code.
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_MATVEC_INL_H_
#define THIRD_PARTY_GEMMA_CPP_OPS_MATVEC_INL_H_
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATVEC_INL_H_
// Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_MATVEC_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_MATVEC_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_MATVEC_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_MATVEC_TOGGLE
#endif
#include "compression/compress-inl.h"
#include "ops/dot-inl.h"
#include "ops/matmul.h"
#include "util/mat.h" // MatPtrT
#include "hwy/contrib/math/math-inl.h"
#include "hwy/contrib/matvec/matvec-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
// For callers that pass `MatPtrT`, which is not necessarily packed - callers
// should use Stride() to compute `w_ofs`.
template <typename WT, typename VT>
HWY_INLINE float Dot(const MatPtrT<WT>& w, size_t w_ofs, const VT* vec_aligned,
size_t num) {
const hn::ScalableTag<VT> d;
return w.Scale() * Dot(d, w.PaddedSpan(), w_ofs, vec_aligned, num);
}
// ArrayT is MatPtrT.
// Simple version without tiling nor threading, but two offsets/outputs and
// always with addition.
template <typename ArrayT, typename VecT, typename AddT>
HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0,
const size_t mat_ofs1, const size_t outer,
const size_t inner,
const VecT* HWY_RESTRICT vec_aligned,
const AddT* HWY_RESTRICT add0,
const AddT* HWY_RESTRICT add1,
float* HWY_RESTRICT out0,
float* HWY_RESTRICT out1) {
PROFILER_ZONE("TwoOfsMatVecAddLoop");
for (size_t idx_row = 0; idx_row < outer; ++idx_row) {
const size_t row_ofs0 = mat_ofs0 + idx_row * mat.Stride();
const size_t row_ofs1 = mat_ofs1 + idx_row * mat.Stride();
out0[idx_row] = hwy::ConvertScalarTo<float>(add0[idx_row]) +
Dot(mat, row_ofs0, vec_aligned, inner);
out1[idx_row] = hwy::ConvertScalarTo<float>(add1[idx_row]) +
Dot(mat, row_ofs1, vec_aligned, inner);
}
}
HWY_INLINE constexpr size_t MaxCols() {
// Vec + mat rows should fit into 32 KiB L1.
return 2048;
}
template <size_t kOuter>
HWY_INLINE constexpr size_t RowsPerStrip() {
// Aim for 128 work items to reduce pool overhead. Must be at least one
// vector; prefer a power of two for faster division.
constexpr size_t kLanes = hn::ScalableTag<float>().MaxLanes();
constexpr size_t kRowsPerStrip =
kOuter < 128 ? kLanes
: HWY_MAX(kLanes, 1ULL << hwy::FloorLog2(kOuter / 128));
return kRowsPerStrip;
}
HWY_INLINE size_t RowsPerStrip(const size_t outer) {
// Aim for 128 work items to reduce pool overhead. Must be at least one
// vector; prefer a power of two for faster division.
constexpr size_t kLanes = hn::ScalableTag<float>().MaxLanes();
return outer < 128 ? kLanes
: HWY_MAX(kLanes, 1ULL << hwy::FloorLog2(outer / 128));
}
namespace detail {
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product
// of row i with `vec_aligned` and add into `out[i]`. The upper-left
// coordinate of the tile is r0, c0.
template <class DF, typename ArrayT, typename VecT>
HWY_INLINE void AccumulatePartialDotProducts(
DF df, const ArrayT& mat, size_t mat_ofs, size_t r0, size_t c0,
size_t num_rows, size_t num_cols, const VecT* HWY_RESTRICT vec_aligned,
float* HWY_RESTRICT out) {
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat.Stride();
out[idx_row] += Dot(mat, row_ofs + c0, vec_aligned + c0, num_cols);
}
}
// Same as AccumulatePartialDotProducts, but sets out[i] to the first partial
// dot product + init (if kInit), which avoids having to zero-initialize and
// accumulate.
template <bool kInit, class DF, typename ArrayT, typename VecT, typename InitT>
HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
size_t mat_ofs, size_t r0, size_t c0,
size_t num_rows, size_t num_cols,
const VecT* HWY_RESTRICT vec_aligned,
const InitT* HWY_RESTRICT init,
float* HWY_RESTRICT out) {
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat.Stride();
if constexpr (kInit) {
out[idx_row] = hwy::ConvertScalarTo<float>(init[idx_row + r0]) +
Dot(mat, row_ofs + c0, vec_aligned + c0, num_cols);
} else {
out[idx_row] = Dot(mat, row_ofs + c0, vec_aligned + c0, num_cols);
}
}
}
// Adds together partial dot products for all tiles with the same r0 (a
// horizontal strip of the entire matrix); the result is the full dot product
// for rows r in [r0, r0 + num_rows) + optionally the add vector, which we
// store into in out[r - r0].
template <bool kAdd, class DF, typename ArrayT, typename VecT, typename AddT>
HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
size_t mat_ofs, size_t r0,
size_t num_rows, size_t num_cols,
const VecT* HWY_RESTRICT vec_aligned,
const AddT* HWY_RESTRICT add,
float* HWY_RESTRICT out) {
HWY_DASSERT(num_cols <= mat.Cols());
// Tall and skinny: set `out` to the single dot product.
if (num_cols < MaxCols()) {
SetFirstPartialDotProducts<kAdd>(df, mat, mat_ofs, r0, 0, num_rows,
num_cols, vec_aligned, add, out);
return;
}
// We have at least MaxCols, so start by setting `out` to that:
SetFirstPartialDotProducts<kAdd>(df, mat, mat_ofs, r0, 0, num_rows, MaxCols(),
vec_aligned, add, out);
// For further multiples of MaxCols, accumulate. Remainders handled below.
size_t c0 = MaxCols();
for (; c0 <= num_cols - MaxCols(); c0 += MaxCols()) {
AccumulatePartialDotProducts(df, mat, mat_ofs, r0, c0, num_rows, MaxCols(),
vec_aligned, out);
}
if (c0 < num_cols) { // Final cols
AccumulatePartialDotProducts(df, mat, mat_ofs, r0, c0, num_rows,
num_cols - c0, vec_aligned, out);
}
}
} // namespace detail
// Stores dot products of rows with `vec_aligned` + add the values from `add`
// (if kAdd), then stores them to `out`.
template <bool kAdd, typename ArrayT, typename VecT, typename AddT>
HWY_INLINE void MatVecT(const ArrayT& mat, const size_t mat_ofs,
const size_t outer, const size_t inner,
const VecT* HWY_RESTRICT const vec_aligned,
const AddT* HWY_RESTRICT const add,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
PROFILER_ZONE("MatVecAdd");
const hn::ScalableTag<float> df;
const size_t rows_per_strip = RowsPerStrip(outer);
const size_t num_strips = outer / rows_per_strip;
// For each entire strip.
pool.Run(0, num_strips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
PROFILER_ZONE("MatVec.lambda");
const size_t r0 = strip * rows_per_strip;
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, r0, rows_per_strip,
inner, vec_aligned, add, out + r0);
});
// Remaining rows
const size_t r0 = num_strips * rows_per_strip;
if (r0 < outer) {
PROFILER_ZONE("MatVec remainder");
const size_t num_rows = outer - r0;
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, r0, num_rows, inner,
vec_aligned, add, out + r0);
}
}
// With addition
template <typename ArrayT, typename VecT, typename AddT>
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
const size_t outer, const size_t inner,
const VecT* HWY_RESTRICT const vec_aligned,
const AddT* HWY_RESTRICT const add,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
return MatVecT</*kAdd=*/true>(mat, mat_ofs, outer, inner, vec_aligned, add,
out, pool);
}
// Without addition
template <typename ArrayT, typename VecT>
HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs,
const size_t outer, const size_t inner,
const VecT* HWY_RESTRICT const vec_aligned,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
MatVecT</*kAdd=*/false>(mat, mat_ofs, outer, inner, vec_aligned,
/*add=*/static_cast<VecT*>(nullptr), out, pool);
}
// Two matrices, same vector
template <bool kAdd, typename ArrayT1, typename ArrayT2, typename VecT,
typename AddT>
HWY_NOINLINE void TwoMatVecT(const ArrayT1& mat0, const ArrayT2& mat1,
const size_t mat_ofs, size_t outer, size_t inner,
const VecT* HWY_RESTRICT vec_aligned,
const AddT* HWY_RESTRICT add0,
const AddT* HWY_RESTRICT add1,
float* HWY_RESTRICT out0, float* HWY_RESTRICT out1,
hwy::ThreadPool& pool) {
PROFILER_ZONE("TwoMatVecAdd");
const hn::ScalableTag<float> df;
const size_t rows_per_strip = RowsPerStrip(outer);
const size_t num_strips = outer / rows_per_strip;
// For each entire strip.
pool.Run(0, num_strips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
PROFILER_ZONE("TwoMatVec.lambda");
const size_t r0 = strip * rows_per_strip;
detail::FullDotProductsForStrip<kAdd>(df, mat0, mat_ofs, r0, rows_per_strip,
inner, vec_aligned, add0, out0 + r0);
detail::FullDotProductsForStrip<kAdd>(df, mat1, mat_ofs, r0, rows_per_strip,
inner, vec_aligned, add1, out1 + r0);
});
// Remaining rows
const size_t r0 = num_strips * rows_per_strip;
if (r0 < outer) {
PROFILER_ZONE("TwoMatVec remainder");
const size_t num_rows = outer - r0;
detail::FullDotProductsForStrip<kAdd>(df, mat0, mat_ofs, r0, num_rows,
inner, vec_aligned, add0, out0 + r0);
detail::FullDotProductsForStrip<kAdd>(df, mat1, mat_ofs, r0, num_rows,
inner, vec_aligned, add1, out1 + r0);
}
}
// With addition
template <typename ArrayT1, typename ArrayT2, typename VecT, typename AddT>
HWY_NOINLINE void TwoMatVecAdd(
const ArrayT1& mat0, const ArrayT2& mat1, const size_t mat_ofs,
const size_t outer, const size_t inner,
const VecT* HWY_RESTRICT vec_aligned, const AddT* HWY_RESTRICT add0,
const AddT* HWY_RESTRICT add1, float* HWY_RESTRICT out0,
float* HWY_RESTRICT out1, hwy::ThreadPool& pool) {
return TwoMatVecT</*kAdd=*/true>(mat0, mat1, mat_ofs, outer, inner,
vec_aligned, add0, add1, out0, out1, pool);
}
// Without addition
template <typename ArrayT1, typename ArrayT2, typename VecT>
HWY_NOINLINE void TwoMatVec(const ArrayT1& mat0, const ArrayT2& mat1,
const size_t mat_ofs, const size_t outer,
const size_t inner,
const VecT* HWY_RESTRICT vec_aligned,
float* HWY_RESTRICT out0, float* HWY_RESTRICT out1,
hwy::ThreadPool& pool) {
TwoMatVecT</*kAdd=*/false, ArrayT1, ArrayT2, VecT, VecT>(
mat0, mat1, mat_ofs, outer, inner, vec_aligned, /*add0=*/nullptr,
/*add1=*/nullptr, out0, out1, pool);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // NOLINT

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,7 @@
// limitations under the License.
#include "compression/types.h"
#include "util/zones.h"
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
@ -47,6 +48,7 @@
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "compression/test_util-inl.h"
#include "ops/ops-inl.h"
#include "hwy/tests/test_util-inl.h"
@ -56,6 +58,12 @@ namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
static RngStream MakeRng() {
static AesCtrEngine engine(/*deterministic=*/true);
static uint64_t stream = 0;
return RngStream(engine, ++stream);
}
template <class Test>
struct ForeachCountAndMisalign {
template <typename T, class D>
@ -83,48 +91,6 @@ T Random(hwy::RandomState& rng) {
HWY_MAX(hwy::ConvertScalarTo<double>(hwy::LowestValue<T>()), val));
}
HWY_NOINLINE void SimpleAddFrom(const float* HWY_RESTRICT other,
float* HWY_RESTRICT x, size_t size) {
for (size_t i = 0; i < size; ++i) {
x[i] += other[i];
}
}
HWY_NOINLINE void SimpleMulBy(const float* HWY_RESTRICT other,
float* HWY_RESTRICT x, size_t size) {
for (size_t i = 0; i < size; ++i) {
x[i] *= other[i];
}
}
HWY_NOINLINE void SimpleMulByConst(float c, float* HWY_RESTRICT x,
size_t size) {
for (size_t i = 0; i < size; ++i) {
x[i] *= c;
}
}
HWY_NOINLINE void SimpleMulByConstAndAdd(float c, const float* HWY_RESTRICT x,
float* HWY_RESTRICT out, size_t size) {
for (size_t i = 0; i < size; ++i) {
out[i] += x[i] * c;
}
}
HWY_NOINLINE void SimpleSoftmax(float* HWY_RESTRICT x, size_t size) {
HWY_DASSERT(size != 0);
float sum = 0.0;
const float maxval = *std::max_element(x, x + size);
for (size_t i = 0; i < size; ++i) {
x[i] = std::exp(x[i] - maxval);
sum += x[i];
}
const float scale = 1.0f / sum;
for (size_t i = 0; i < size; ++i) {
x[i] *= scale;
}
}
template <size_t k>
HWY_NOINLINE std::discrete_distribution<int> SourceCreateDistribution(
std::array<float, k>& top_k, float temperature) {
@ -141,7 +107,8 @@ HWY_NOINLINE std::discrete_distribution<int> SourceCreateDistribution(
return std::discrete_distribution<int>(std::begin(top_k), std::end(top_k));
}
struct TestAddFrom {
class TestAddFrom {
public:
template <class D>
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
hwy::RandomState& rng) {
@ -166,14 +133,30 @@ struct TestAddFrom {
}
SimpleAddFrom(o, e, count);
InitProfilerZones(hwy::Profiler::Get());
AddFrom(o, x, count, hwy::Profiler::Get(), /*worker=*/0);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__);
}
private:
template <typename T1, typename T2>
static HWY_NOINLINE void SimpleAddFrom(const T1* HWY_RESTRICT other,
T2* HWY_RESTRICT x, size_t size) {
for (size_t i = 0; i < size; ++i) {
x[i] = hwy::ConvertScalarTo<T2>(hwy::ConvertScalarTo<float>(x[i]) +
hwy::ConvertScalarTo<float>(other[i]));
}
}
};
struct TestMulByConstAndAdd {
void TestAllAddFrom() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestAddFrom>>()(float());
}
class TestMulByConstAndAdd {
public:
template <class D>
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
hwy::RandomState& rng) {
@ -199,14 +182,33 @@ struct TestMulByConstAndAdd {
T constant = Random<T>(rng);
SimpleMulByConstAndAdd(constant, o, e, count);
MulByConstAndAdd(constant, o, x, count, hwy::Profiler::Get(), /*worker=*/0);
InitProfilerZones(hwy::Profiler::Get());
MulByConstAndAdd(constant, o, x, count);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__);
}
private:
template <typename T1, typename T2>
static HWY_NOINLINE void SimpleMulByConstAndAdd(float c,
const T1* HWY_RESTRICT x,
T2* HWY_RESTRICT out,
size_t size) {
for (size_t i = 0; i < size; ++i) {
out[i] = hwy::ConvertScalarTo<T2>(hwy::ConvertScalarTo<float>(out[i]) +
hwy::ConvertScalarTo<float>(x[i]) * c);
}
}
};
struct TestMulByConst {
void TestAllMulByConstAndAdd() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulByConstAndAdd>>()(
float());
}
class TestMulByConst {
public:
template <class D>
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
hwy::RandomState& rng) {
@ -229,14 +231,68 @@ struct TestMulByConst {
T constant = Random<T>(rng);
SimpleMulByConst(constant, e, count);
MulByConst(constant, x, count, hwy::Profiler::Get(), /*worker=*/0);
InitProfilerZones(hwy::Profiler::Get());
MulByConst(constant, x, count);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__);
}
private:
template <typename T1>
static HWY_NOINLINE void SimpleMulByConst(float c, T1* HWY_RESTRICT x,
size_t size) {
for (size_t i = 0; i < size; ++i) {
x[i] = hwy::ConvertScalarTo<T1>(hwy::ConvertScalarTo<float>(x[i]) * c);
}
}
};
struct TestSoftmax {
void TestAllMulByConst() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulByConst>>()(float());
}
struct TestMulByConstTo {
template <class D>
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
hwy::RandomState& rng) {
if (misalign_b == 0) return;
using T = hn::TFromD<D>;
hwy::AlignedFreeUniquePtr<T[]> px =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
hwy::AlignedFreeUniquePtr<T[]> pe =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
hwy::AlignedFreeUniquePtr<T[]> pactual =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
HWY_ASSERT(px && pe && pactual);
T* x = px.get() + misalign_a;
T* e = pe.get() + misalign_a;
T* actual = pe.get() + misalign_a;
T constant = Random<T>(rng);
for (size_t i = 0; i < count; ++i) {
x[i] = Random<T>(rng);
e[i] = hwy::ConvertScalarTo<T>(hwy::ConvertScalarTo<float>(x[i]) *
hwy::ConvertScalarTo<float>(constant));
}
InitProfilerZones(hwy::Profiler::Get());
MulByConstTo(constant, x, actual, count, hwy::Profiler::Get(),
/*worker=*/0);
hwy::AssertArraySimilar(e, actual, count, hwy::TargetName(HWY_TARGET),
__FILE__, __LINE__);
}
};
void TestAllMulByConstTo() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulByConstTo>>()(float());
}
class TestSoftmax {
public:
template <class D>
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
hwy::RandomState& rng) {
@ -259,7 +315,8 @@ struct TestSoftmax {
}
SimpleSoftmax(e, count);
Softmax(x, count, hwy::Profiler::Get(), /*worker=*/0);
InitProfilerZones(hwy::Profiler::Get());
Softmax(Logits(x, count), hwy::Profiler::Get(), /*worker=*/0);
T sum = 0.0f;
for (size_t i = 0; i < count; ++i) {
@ -270,8 +327,27 @@ struct TestSoftmax {
}
ASSERT_NEAR(sum, 1.0, 1e-6);
}
private:
static HWY_NOINLINE void SimpleSoftmax(float* HWY_RESTRICT x, size_t size) {
HWY_DASSERT(size != 0);
float sum = 0.0;
const float maxval = *std::max_element(x, x + size);
for (size_t i = 0; i < size; ++i) {
x[i] = std::exp(x[i] - maxval);
sum += x[i];
}
const float scale = 1.0f / sum;
for (size_t i = 0; i < size; ++i) {
x[i] *= scale;
}
}
};
void TestAllSoftmax() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmax>>()(float());
}
template <size_t k>
struct TestCreateDistribution {
void operator()(hwy::RandomState& rng) {
@ -291,43 +367,60 @@ struct TestCreateDistribution {
}
};
void TestAllAddFrom() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestAddFrom>>()(float());
}
void TestAllMulByConst() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulByConst>>()(float());
}
void TestAllMulByConstAndAdd() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulByConstAndAdd>>()(
float());
}
void TestAllSoftmax() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmax>>()(float());
}
void TestAllCreateDistribution() {
TestCreateDistribution<2048>();
TestCreateDistribution<5000>();
}
void TestSigmoid() {
std::vector<float> values;
for (int i = -150; i <= 150; ++i) {
values.push_back(.1f * i);
}
std::vector<float> result = values;
Sigmoid(result.data(), result.size());
struct TestSigmoid {
template <typename T, class D>
void operator()(T, D) const {
std::vector<T> values;
for (int i = -150; i <= 150; ++i) {
values.push_back(hwy::ConvertScalarTo<T>(.1f * i));
}
std::vector<T> result = values;
Sigmoid(result.data(), result.size());
for (size_t i = 0; i < values.size(); i++) {
const float max_error = 0.00007;
float value = values[i];
float approx = result[i];
float expected = (1 / (1 + std::exp(-values[i])));
EXPECT_NEAR(approx, expected, max_error) << "Input: " << value;
for (size_t i = 0; i < values.size(); i++) {
const float max_error = IsBF16<T>() ? 0.2f : 0.00007f;
const float value = hwy::ConvertScalarTo<float>(values[i]);
const float actual = hwy::ConvertScalarTo<float>(result[i]);
const float expected = (1 / (1 + std::exp(-value)));
EXPECT_NEAR(expected, actual, max_error)
<< (IsBF16<T>() ? "bf16" : "float");
}
}
};
static HWY_NOINLINE void TestAllSigmoid() {
ForeachActivationType1<TestSigmoid>(hn::ScalableTag<float>());
}
struct TestGelu {
template <typename T, class D>
void operator()(T, D) const {
std::vector<T> values;
for (int i = -150; i <= 150; ++i) {
values.push_back(hwy::ConvertScalarTo<T>(.1f * i));
}
std::vector<T> result = values;
Gelu(result.data(), result.size());
for (size_t i = 0; i < values.size(); i++) {
const float max_error = IsBF16<T>() ? 0.2f : 0.00007f;
const float x = hwy::ConvertScalarTo<float>(values[i]);
const float actual = hwy::ConvertScalarTo<float>(result[i]);
const float expected =
x * (0.5f + 0.5f * tanh(x * (0.79788f + 0.035677f * x * x)));
EXPECT_NEAR(expected, actual, max_error)
<< (IsBF16<T>() ? "bf16" : "float");
}
}
};
static HWY_NOINLINE void TestAllGelu() {
ForeachActivationType1<TestGelu>(hn::ScalableTag<float>());
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
@ -357,10 +450,9 @@ void TestRopeAndMulBy() {
const size_t dim_qkv = config.layer_configs[0].qkv_dim;
MatStorageT<float> x("x", dim_qkv, ctx.allocator);
std::mt19937 gen;
gen.seed(0x12345678);
RngStream rng = MakeRng();
std::normal_distribution<float> r{0.0, 5.0};
auto random_float = [&r, &gen] { return r(gen); };
auto random_float = [&r, &rng] { return r(rng); };
for (size_t i = 0; i < dim_qkv; ++i) {
x.Row(0)[i] = random_float();
@ -421,7 +513,8 @@ void TestRopeAndMulBy() {
}
template <typename T>
HWY_NOINLINE float ScalarSquaredL2(const T* HWY_RESTRICT a, size_t size) {
static HWY_NOINLINE float ScalarSquaredL2(const T* HWY_RESTRICT a,
size_t size) {
double sum = 0.0;
for (size_t i = 0; i < size; ++i) {
const float f = hwy::ConvertScalarTo<float>(a[i]);
@ -431,9 +524,11 @@ HWY_NOINLINE float ScalarSquaredL2(const T* HWY_RESTRICT a, size_t size) {
}
// Supports bf16 and f32 inputs/outputs, which can be in-place.
// Shared between TestRMSNorm and TestRMSNormInplace.
template <typename XT, typename WT, typename OT>
HWY_NOINLINE void ScalarRMSNorm(const XT* x, const WT* HWY_RESTRICT weight,
OT* out, size_t size) {
static HWY_NOINLINE void ScalarRMSNorm(const XT* x,
const WT* HWY_RESTRICT weight, OT* out,
size_t size) {
constexpr float kEps = 1e-6f;
float ss = ScalarSquaredL2(x, size);
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps);
@ -445,42 +540,76 @@ HWY_NOINLINE void ScalarRMSNorm(const XT* x, const WT* HWY_RESTRICT weight,
}
}
template <typename XT, typename WT, typename OT>
void TestRMSNorm(hwy::RandomState& rng) {
constexpr size_t kSize = 128;
HWY_ALIGN XT vec[kSize];
HWY_ALIGN WT weight[kSize];
HWY_ALIGN OT expected[kSize];
HWY_ALIGN OT actual[kSize];
struct TestRMSNorm {
template <typename XT, typename WT, typename OT, class D>
void operator()(XT, WT, OT, D) const {
hwy::RandomState rng;
for (size_t i = 0; i < kSize; ++i) {
vec[i] = hwy::ConvertScalarTo<XT>(RandomGaussian(rng));
weight[i] = hwy::ConvertScalarTo<WT>(RandomGaussian(rng));
}
constexpr size_t kSize = 128;
HWY_ALIGN XT vec[kSize];
HWY_ALIGN WT weight[kSize];
HWY_ALIGN OT expected[kSize];
HWY_ALIGN OT actual[kSize];
ScalarRMSNorm(vec, weight, expected, kSize);
RMSNorm(vec, weight, 0, actual, kSize, hwy::Profiler::Get(), /*worker=*/0);
for (size_t i = 0; i < kSize; ++i) {
vec[i] = hwy::ConvertScalarTo<XT>(RandomGaussian(rng));
weight[i] = hwy::ConvertScalarTo<WT>(RandomGaussian(rng));
}
for (size_t i = 0; i < kSize; i++) {
const float e = hwy::ConvertScalarTo<float>(expected[i]);
const float a = hwy::ConvertScalarTo<float>(actual[i]);
if (!IsNear(e, a, 1e-5f)) {
HWY_ABORT("RMSNorm %s %s %s mismatch at %zu: %E %E\n", TypeName<XT>(),
TypeName<WT>(), TypeName<OT>(), i, e, a);
ScalarRMSNorm(vec, weight, expected, kSize);
InitProfilerZones(hwy::Profiler::Get());
RMSNorm(vec, weight, /*w_ofs=*/0, actual, kSize, hwy::Profiler::Get(),
/*worker=*/0);
for (size_t i = 0; i < kSize; i++) {
const float e = hwy::ConvertScalarTo<float>(expected[i]);
const float a = hwy::ConvertScalarTo<float>(actual[i]);
if (!IsNear(e, a, 1e-5f)) {
HWY_ABORT("RMSNorm %s %s %s mismatch at %zu: %E %E\n", TypeName<XT>(),
TypeName<WT>(), TypeName<OT>(), i, e, a);
}
}
}
}
};
void TestAllRMSNorm() {
hwy::RandomState rng;
TestRMSNorm<float, float, float>(rng);
TestRMSNorm<float, float, BF16>(rng);
TestRMSNorm<float, BF16, float>(rng);
TestRMSNorm<float, BF16, BF16>(rng);
TestRMSNorm<BF16, float, float>(rng);
TestRMSNorm<BF16, float, BF16>(rng);
TestRMSNorm<BF16, BF16, float>(rng);
TestRMSNorm<BF16, BF16, BF16>(rng);
ForeachActivationType3<TestRMSNorm>(hn::ScalableTag<float>());
}
struct TestRMSNormInplace {
template <typename XT, typename WT, class D>
void operator()(XT, WT, D) const {
hwy::RandomState rng;
constexpr size_t kSize = 128;
HWY_ALIGN XT expected[kSize];
HWY_ALIGN XT actual[kSize];
HWY_ALIGN WT weight[kSize];
for (size_t i = 0; i < kSize; ++i) {
expected[i] = hwy::ConvertScalarTo<XT>(RandomGaussian(rng));
actual[i] = expected[i];
weight[i] = hwy::ConvertScalarTo<WT>(RandomGaussian(rng));
}
ScalarRMSNorm(expected, weight, expected, kSize);
InitProfilerZones(hwy::Profiler::Get());
RMSNormInplace(weight, /*w_ofs=*/0, actual, kSize, hwy::Profiler::Get(),
/*worker=*/0);
for (size_t i = 0; i < kSize; i++) {
const float e = hwy::ConvertScalarTo<float>(expected[i]);
const float a = hwy::ConvertScalarTo<float>(actual[i]);
if (!IsNear(e, a, 1e-5f)) {
HWY_ABORT("RMSNormInplace %s %s mismatch at %zu: %E %E\n",
TypeName<XT>(), TypeName<WT>(), i, e, a);
}
}
}
};
void TestAllRMSNormInplace() {
ForeachActivationType2<TestRMSNormInplace>(hn::ScalableTag<float>());
}
void TestLayerNormSimple() {
@ -497,129 +626,127 @@ void TestLayerNormSimple() {
for (size_t i = 0; i < kSize; i++) {
const float max_error = 1e-6f;
float value = values[i];
float res = result[i];
// out = (x - 0.0) * 1.2 * 0.9999995 + 0.1 = 1.2999994 / -1.0999994;
float expected = (i % 2 == 0) ? 1.2999994f : -1.0999994f;
EXPECT_NEAR(res, expected, max_error) << "Input: " << value;
EXPECT_NEAR(res, expected, max_error);
}
}
// Computes mean mu and mean of squares mu2 of a vector. Used in
// ScalarLayerNorm.
template <typename T>
HWY_NOINLINE void ScalarMus(const T* HWY_RESTRICT a, size_t size, double& mu,
double& mu2) {
HWY_ASSERT(size > 0);
double sum = 0.0;
double sum2 = 0.0;
for (size_t i = 0; i < size; ++i) {
const float f = hwy::ConvertScalarTo<float>(a[i]);
sum += f;
sum2 += f * f;
}
mu = sum / size;
mu2 = sum2 / size;
}
class TestLayerNorm {
public:
template <typename XT, typename WT, typename OT, class D>
void operator()(XT, WT, OT, D) const {
hwy::RandomState rng;
constexpr size_t kSize = 128;
XT vec[kSize];
WT weight[kSize];
WT bias[kSize];
OT expected[kSize];
OT actual[kSize];
// Compare py/flax/linen/normalization.py.
// out = (x - mean) * scale * rsqrt(var + epsilon) + bias
template <typename XT, typename WT, typename OT>
HWY_NOINLINE void ScalarLayerNorm(const XT* x, const WT* HWY_RESTRICT scale,
const WT* HWY_RESTRICT bias, OT* out,
size_t size) {
constexpr double kEps = 1e-6;
double mu, mu2;
ScalarMus(x, size, mu, mu2);
double var = mu2 - mu * mu;
constexpr double kZero = 0.0;
var = HWY_MAX(var, kZero);
var = 1.0 / sqrt(var + kEps);
for (size_t j = 0; j < size; j++) {
const float v = hwy::ConvertScalarTo<float>(x[j]);
const float s = hwy::ConvertScalarTo<float>(scale[j]);
const float b = hwy::ConvertScalarTo<float>(bias[j]);
out[j] = hwy::ConvertScalarTo<OT>((v - mu) * s * var + b);
}
}
for (size_t i = 0; i < kSize; ++i) {
vec[i] = hwy::ConvertScalarTo<XT>(RandomGaussian(rng));
weight[i] = hwy::ConvertScalarTo<WT>(RandomGaussian(rng));
bias[i] = hwy::ConvertScalarTo<WT>(RandomGaussian(rng));
}
template <typename XT, typename WT, typename OT>
void TestLayerNorm(hwy::RandomState& rng) {
constexpr size_t kSize = 128;
XT vec[kSize];
WT weight[kSize];
WT bias[kSize];
OT expected[kSize];
OT actual[kSize];
double expected_mu, expected_mu2;
ScalarMus(vec, kSize, expected_mu, expected_mu2);
double actual_mu, actual_mu2;
ComputeMoments(vec, kSize, actual_mu, actual_mu2);
for (size_t i = 0; i < kSize; ++i) {
vec[i] = hwy::ConvertScalarTo<XT>(RandomGaussian(rng));
weight[i] = hwy::ConvertScalarTo<WT>(RandomGaussian(rng));
bias[i] = hwy::ConvertScalarTo<WT>(RandomGaussian(rng));
}
ScalarLayerNorm(vec, weight, bias, expected, kSize);
LayerNorm(vec, weight, bias, actual, kSize);
double expected_mu, expected_mu2;
ScalarMus(vec, kSize, expected_mu, expected_mu2);
double actual_mu, actual_mu2;
ComputeMoments(vec, kSize, actual_mu, actual_mu2);
ScalarLayerNorm(vec, weight, bias, expected, kSize);
LayerNorm(vec, weight, bias, actual, kSize);
for (size_t i = 0; i < kSize; i++) {
const float e = hwy::ConvertScalarTo<float>(expected[i]);
const float a = hwy::ConvertScalarTo<float>(actual[i]);
if (!IsNear(e, a, 1e-5f)) {
HWY_ABORT("LayerNorm %s %s %s mismatch at %zu: %E %E\n", TypeName<XT>(),
TypeName<WT>(), TypeName<OT>(), i, e, a);
for (size_t i = 0; i < kSize; i++) {
const float e = hwy::ConvertScalarTo<float>(expected[i]);
const float a = hwy::ConvertScalarTo<float>(actual[i]);
if (!IsNear(e, a, 1e-5f)) {
HWY_ABORT("LayerNorm %s %s %s mismatch at %zu: %E %E\n", TypeName<XT>(),
TypeName<WT>(), TypeName<OT>(), i, e, a);
}
}
}
}
private:
// Computes mean mu and mean of squares mu2 of a vector. Used in
// ScalarLayerNorm.
template <typename T>
static HWY_NOINLINE void ScalarMus(const T* HWY_RESTRICT a, size_t size,
double& mu, double& mu2) {
HWY_ASSERT(size > 0);
double sum = 0.0;
double sum2 = 0.0;
for (size_t i = 0; i < size; ++i) {
const float f = hwy::ConvertScalarTo<float>(a[i]);
sum += f;
sum2 += f * f;
}
mu = sum / size;
mu2 = sum2 / size;
}
// Compare py/flax/linen/normalization.py.
// out = (x - mean) * scale * rsqrt(var + epsilon) + bias
template <typename XT, typename WT, typename OT>
static HWY_NOINLINE void ScalarLayerNorm(const XT* x,
const WT* HWY_RESTRICT scale,
const WT* HWY_RESTRICT bias, OT* out,
size_t size) {
constexpr double kEps = 1e-6;
double mu, mu2;
ScalarMus(x, size, mu, mu2);
double var = mu2 - mu * mu;
constexpr double kZero = 0.0;
var = HWY_MAX(var, kZero);
var = 1.0 / sqrt(var + kEps);
for (size_t j = 0; j < size; j++) {
const float v = hwy::ConvertScalarTo<float>(x[j]);
const float s = hwy::ConvertScalarTo<float>(scale[j]);
const float b = hwy::ConvertScalarTo<float>(bias[j]);
out[j] = hwy::ConvertScalarTo<OT>((v - mu) * s * var + b);
}
}
};
void TestAllLayerNorm() {
hwy::RandomState rng;
TestLayerNorm<float, float, float>(rng);
TestLayerNorm<float, float, BF16>(rng);
TestLayerNorm<float, BF16, float>(rng);
TestLayerNorm<float, BF16, BF16>(rng);
ForeachActivationType3<TestLayerNorm>(hn::ScalableTag<float>());
}
void TestSampleTopK() {
hwy::Profiler& p = hwy::Profiler::Get();
InitProfilerZones(p);
const size_t worker = 0;
const size_t kSize = 52;
std::vector<float> logits(kSize);
std::vector<float> logits_vec(kSize);
Logits logits(logits_vec.data(), kSize);
// Create a vector going from -100 to -100+51=49 and take Softmax.
std::iota(logits.begin(), logits.end(), -100.0f);
Softmax(logits.data(), kSize, p, worker);
std::mt19937 gen;
gen.seed(0x12345678);
Softmax(logits, p, worker);
RngStream rng = MakeRng();
float temperature = 1.0f;
// SampleTopK<1> should return the argmax.
std::function<bool(int, float)> accept_token;
int sample =
SampleTopK(logits.data(), /*k=*/1, kSize, gen, temperature, accept_token);
int sample = SampleTopK(logits, /*k=*/1, rng, temperature, accept_token);
EXPECT_EQ(sample, 51); // Last is largest.
// Only accept even tokens, expect the last (largest) even index.
accept_token = [](int i, float) { return i % 2 == 0; };
sample =
SampleTopK(logits.data(), /*k=*/1, kSize, gen, temperature, accept_token);
sample = SampleTopK(logits, /*k=*/1, rng, temperature, accept_token);
EXPECT_EQ(sample, 50); // Last even index.
// Reset the logits to a positive, increasing sequence and take Softmax.
std::iota(logits.begin(), logits.end(), 1.0f);
Softmax(logits.data(), kSize, p, worker);
Softmax(logits, p, worker);
// Sample from the top 3, expect one of the top 3 even indices.
for (int i = 0; i < 100; ++i) {
sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature,
accept_token);
sample = SampleTopK(logits, /*k=*/3, rng, temperature, accept_token);
EXPECT_TRUE(sample == 50 || sample == 48 || sample == 46);
}
// Now set the temperature to 0.0f, which should always return the argmax,
// even for k=3.
temperature = 0.0f;
for (int i = 0; i < 100; ++i) {
sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature,
accept_token);
sample = SampleTopK(logits, /*k=*/3, rng, temperature, accept_token);
EXPECT_EQ(sample, 50);
}
}
@ -646,12 +773,15 @@ namespace gcpp {
HWY_BEFORE_TEST(OpsTest);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllAddFrom);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstTo);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSigmoid);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllGelu);
HWY_EXPORT_AND_TEST_P(OpsTest, TestRopeAndMulBy);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNorm);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNormInplace);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllLayerNorm);
HWY_EXPORT_AND_TEST_P(OpsTest, TestLayerNormSimple);
HWY_EXPORT_AND_TEST_P(OpsTest, TestSampleTopK);

View File

@ -27,42 +27,37 @@ void PaliGemmaHelper::InitVit(const std::string& path) {
HWY_ASSERT(image.ReadPPM(path));
const size_t image_size = config.vit_config.image_size;
image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {.gen = &env_->MutableGen(),
.verbosity = 0};
RuntimeConfig runtime_config = {.verbosity = 0};
gemma.GenerateImageTokens(runtime_config, env_->MutableKVCache().SeqLen(),
image, *image_tokens_, env_->MutableEnv());
}
std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const {
const Gemma& model = *(env_->GetGemma());
env_->MutableGen().seed(0x12345678);
std::string response;
auto stream_token = [&](int token, float) {
std::string token_text;
HWY_ASSERT(
model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
response += token_text;
return true;
};
std::string response;
auto stream_token = [&](int token, float) {
std::string token_text;
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
response += token_text;
return true;
};
std::string mutable_prompt = prompt_text;
std::vector<int> tokens = env_->WrapAndTokenize(mutable_prompt);
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
std::vector<int> tokens = env_->WrapAndTokenize(prompt_text);
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
// PrefixLM sees/attends to all tokens.
.prefill_tbatch_size = tokens.size(),
.gen = &env_->MutableGen(),
.verbosity = 0,
.stream_token = stream_token,
.image_tokens = image_tokens_.get()};
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
// PrefixLM sees/attends to all tokens.
.prefill_tbatch_size = tokens.size(),
.verbosity = 0,
.stream_token = stream_token,
.image_tokens = image_tokens_.get()};
const size_t prefix_end = tokens.size();
TimingInfo timing_info = {.verbosity = 0};
model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end,
env_->MutableKVCache(), env_->MutableEnv(), timing_info);
return response;
const size_t prefix_end = tokens.size();
TimingInfo timing_info = {.verbosity = 0};
model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end,
env_->MutableKVCache(), env_->MutableEnv(), timing_info);
return response;
}
} // namespace gcpp

View File

@ -53,12 +53,14 @@ PYBIND11_MODULE(configs, py_module) {
.value("kF32", Type::kF32)
.value("kBF16", Type::kBF16)
.value("kSFP", Type::kSFP)
.value("kNUQ", Type::kNUQ);
.value("kNUQ", Type::kNUQ)
.value("kF64", Type::kF64)
.value("kU32", Type::kU32)
.value("kU64", Type::kU64)
.value("kI8", Type::kI8);
enum_<LayerAttentionType>(py_module, "LayerAttentionType")
.value("kGemma", LayerAttentionType::kGemma)
.value("kGriffinRecurrentBlock",
LayerAttentionType::kGriffinRecurrentBlock)
.value("kVit", LayerAttentionType::kVit);
enum_<PostNormType>(py_module, "PostNormType")
@ -84,8 +86,6 @@ PYBIND11_MODULE(configs, py_module) {
.value("UNKNOWN", Model::UNKNOWN)
.value("GEMMA2_9B", Model::GEMMA2_9B)
.value("GEMMA2_27B", Model::GEMMA2_27B)
.value("GRIFFIN_2B", Model::GRIFFIN_2B)
.value("GEMMA_TINY", Model::GEMMA_TINY)
.value("GEMMA2_2B", Model::GEMMA2_2B)
.value("PALIGEMMA2_3B_224", Model::PALIGEMMA2_3B_224)
.value("PALIGEMMA2_10B_224", Model::PALIGEMMA2_10B_224)
@ -121,15 +121,11 @@ PYBIND11_MODULE(configs, py_module) {
class_<LayerConfig>(py_module, "LayerConfig")
.def(init())
.def_readwrite("model_dim", &LayerConfig::model_dim)
.def_readwrite("griffin_dim", &LayerConfig::griffin_dim)
.def_readwrite("ff_hidden_dim", &LayerConfig::ff_hidden_dim)
.def_readwrite("heads", &LayerConfig::heads)
.def_readwrite("kv_heads", &LayerConfig::kv_heads)
.def_readwrite("qkv_dim", &LayerConfig::qkv_dim)
.def_readwrite("conv1d_width", &LayerConfig::conv1d_width)
.def_readwrite("ff_biases", &LayerConfig::ff_biases)
.def_readwrite("softmax_attn_output_biases",
&LayerConfig::softmax_attn_output_biases)
.def_readwrite("optimized_gating", &LayerConfig::optimized_gating)
.def_readwrite("post_norm", &LayerConfig::post_norm)
.def_readwrite("type", &LayerConfig::type)
@ -147,8 +143,7 @@ PYBIND11_MODULE(configs, py_module) {
.def_readwrite("image_size", &VitConfig::image_size)
.def_readwrite("layer_configs", &VitConfig::layer_configs);
class_<InternalModelConfig>(py_module, "InternalModelConfig")
.def(init<>());
class_<InternalModelConfig>(py_module, "InternalModelConfig").def(init<>());
class_<ModelConfig>(py_module, "ModelConfig")
.def(init<>())
@ -167,7 +162,6 @@ PYBIND11_MODULE(configs, py_module) {
.def_readwrite("att_cap", &ModelConfig::att_cap)
.def_readwrite("final_cap", &ModelConfig::final_cap)
.def_readwrite("absolute_pe", &ModelConfig::absolute_pe)
.def_readwrite("use_local_attention", &ModelConfig::use_local_attention)
.def_readwrite("query_scale", &ModelConfig::query_scale)
.def_readwrite("layer_configs", &ModelConfig::layer_configs)
.def_readwrite("attention_window_sizes",

View File

@ -52,10 +52,9 @@ class GemmaModel {
// Generates a single example, given a prompt and a callback to stream the
// generated tokens.
void GenerateEx(std::string prompt, gcpp::StreamFunc stream,
size_t max_generated_tokens, float temperature, float seed,
gcpp::AcceptFunc accept, bool skip_prompt) {
env_.MutableGen().seed(seed);
void GenerateEx(const std::string& prompt, gcpp::StreamFunc stream,
size_t max_generated_tokens, float temperature,
float /*seed*/, gcpp::AcceptFunc accept, bool skip_prompt) {
std::vector<int> prompt_tokens = env_.WrapAndTokenize(prompt);
gcpp::RuntimeConfig& config = env_.MutableConfig();
config.max_generated_tokens = max_generated_tokens;
@ -76,8 +75,8 @@ class GemmaModel {
}
// Generates a single example, given a prompt, and returns the result.
std::string Generate(std::string prompt, size_t max_generated_tokens,
float temperature, float seed,
std::string Generate(const std::string& prompt, size_t max_generated_tokens,
float temperature, float /*seed*/,
const std::vector<std::string>& accept,
const std::vector<std::string>& end) {
std::set<int> end_token_set{};
@ -124,7 +123,6 @@ class GemmaModel {
}
};
env_.MutableGen().seed(seed);
gcpp::RuntimeConfig& config = env_.MutableConfig();
config.max_generated_tokens = max_generated_tokens;
config.temperature = temperature;
@ -144,14 +142,13 @@ class GemmaModel {
// results.
std::vector<std::string> GenerateBatch(const std::vector<std::string>& inputs,
size_t max_generated_tokens,
float temperature, float seed,
float temperature, float /*seed*/,
size_t top_k) {
gcpp::RuntimeConfig& config = env_.MutableConfig();
config.max_generated_tokens = max_generated_tokens;
config.temperature = temperature;
config.top_k = top_k;
config.verbosity = 0;
env_.MutableGen().seed(seed);
std::vector<gcpp::QueryResult> outputs = env_.BatchQueryModel(inputs);
std::vector<std::string> result;
@ -187,8 +184,7 @@ class GemmaModel {
"image_tokens",
gcpp::Extents2D(config.vit_config.seq_len, config.model_dim),
env_.MutableEnv().ctx.allocator, gcpp::MatPadding::kOdd));
gcpp::RuntimeConfig runtime_config = {.gen = &env_.MutableGen(),
.verbosity = 0};
gcpp::RuntimeConfig runtime_config = {.verbosity = 0};
gemma.GenerateImageTokens(runtime_config, env_.MutableKVCache().SeqLen(),
c_image, *image_tokens_, env_.MutableEnv());
}
@ -196,11 +192,10 @@ class GemmaModel {
// Generates a response to the given prompt, using the last set image.
// Uses the prompt_tokens if provided, otherwise tokenizes the prompt string.
std::pair<std::string, std::vector<int>> GenerateWithImage(
std::string prompt, size_t max_generated_tokens, float temperature,
float seed, gcpp::AcceptFunc accept, std::vector<int> prompt_tokens) {
const std::string& prompt, size_t max_generated_tokens, float temperature,
float /*seed*/, gcpp::AcceptFunc accept, std::vector<int> prompt_tokens) {
if (!image_tokens_) throw std::invalid_argument("No image set.");
const gcpp::Gemma& model = *env_.GetGemma();
env_.MutableGen().seed(seed);
gcpp::RuntimeConfig& config = env_.MutableConfig();
config.max_generated_tokens = max_generated_tokens;
config.temperature = temperature;
@ -273,6 +268,7 @@ PYBIND11_MODULE(gemma, mod) {
}),
py::arg("tokenizer_path"), py::arg("weights_path"),
py::arg("max_threads") = 0)
// seed arguments are ignored.
.def("generate_ex", &GemmaModel::GenerateEx, py::arg("prompt"),
py::arg("stream"), py::arg("max_generated_tokens") = 1024,
py::arg("temperature") = 0.9, py::arg("seed") = 123456789,

View File

@ -130,7 +130,7 @@ size_t DetectTotalMiB(size_t page_bytes) {
} // namespace
Allocator::Allocator(const BoundedTopology& topology, bool enable_bind) {
CacheInfo::CacheInfo(const BoundedTopology& topology) {
line_bytes_ = DetectLineBytes();
// Ensure MaxLineBytes() is an upper bound.
HWY_ASSERT(MaxLineBytes() >= LineBytes());
@ -138,10 +138,8 @@ Allocator::Allocator(const BoundedTopology& topology, bool enable_bind) {
vector_bytes_ = hwy::VectorBytes();
step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_);
base_page_bytes_ = DetectPageSize();
quantum_bytes_ = step_bytes_; // may overwrite below
const BoundedTopology::Cluster& cluster = topology.GetCluster(0, 0);
const BoundedTopology::Cluster& cluster = topology.GetCluster(0);
if (const hwy::Cache* caches = hwy::DataCaches()) {
l1_bytes_ = caches[1].size_kib << 10;
l2_bytes_ = caches[2].size_kib << 10;
@ -153,18 +151,23 @@ Allocator::Allocator(const BoundedTopology& topology, bool enable_bind) {
if (l3_bytes_ == 0) {
l3_bytes_ = (cluster.SharedKiB() ? cluster.SharedKiB() : 1024) << 10;
}
}
total_mib_ = DetectTotalMiB(base_page_bytes_);
Allocator::Allocator(const BoundedTopology& topology,
const CacheInfo& cache_info, bool enable_bind)
: line_bytes_(cache_info.LineBytes()),
base_page_bytes_(DetectPageSize()),
total_mib_(DetectTotalMiB(base_page_bytes_)) {
quantum_bytes_ = cache_info.StepBytes(); // may overwrite below
// Prerequisites for binding:
// - supported by the OS (currently Linux only),
// - the page size is known and 'reasonably small', preferably less than
// a fraction of MatMul row/col sizes, which for 27B are up to 144 KiB.
// - we successfully detected topology and there are multiple nodes;
// - there are multiple packages, because we shard by package_idx.
// - we successfully detected topology and there are multiple nodes.
if constexpr (GEMMA_BIND) {
if ((base_page_bytes_ != 0 && base_page_bytes_ <= 16 * 1024) &&
topology.NumNodes() > 1 && topology.NumPackages() > 1) {
topology.NumNodes() > 1) {
if (enable_bind) {
// Ensure pages meet the alignment requirements of `AllocBytes`.
HWY_ASSERT(base_page_bytes_ >= quantum_bytes_);

View File

@ -77,27 +77,49 @@ using AlignedPtr = std::unique_ptr<T, DeleterFunc>;
template <typename T>
using AlignedClassPtr = std::unique_ptr<T, DeleterDtor>;
// Both allocation, binding, and row accessors depend on the sizes of memory
// pages and cache lines. To avoid having to pass `Allocator&` everywhere, we
// wrap this in a singleton. A monostate requires explicit initialization,
// which we prefer to avoid because there are many main() functions.
class Allocator {
// Holds cache line size/capacity and vector size. Stored in `ThreadingContext`.
class CacheInfo {
public:
// Must be called at least once before any other function. Not thread-safe,
// hence only call this from the main thread.
Allocator(const BoundedTopology& topology, bool enable_bind);
CacheInfo(const BoundedTopology& topology);
// Bytes per cache line, or a reasonable guess if unknown. Used to choose
// ranges such that there will be no false sharing.
size_t LineBytes() const { return line_bytes_; }
// Upper bound on `LineBytes()`, for stack allocations.
static constexpr size_t MaxLineBytes() { return 256; }
// Bytes per full vector. Used to compute loop steps.
size_t VectorBytes() const { return vector_bytes_; }
// Work granularity that avoids false sharing and partial vectors.
// = HWY_MAX(LineBytes(), VectorBytes())
size_t StepBytes() const { return step_bytes_; }
// L1 and L2 are typically per core.
size_t L1Bytes() const { return l1_bytes_; }
size_t L2Bytes() const { return l2_bytes_; }
// Clusters often share an L3. We return the total size per package.
size_t L3Bytes() const { return l3_bytes_; }
private:
size_t line_bytes_;
size_t vector_bytes_;
size_t step_bytes_;
size_t l1_bytes_ = 0;
size_t l2_bytes_ = 0;
size_t l3_bytes_ = 0;
};
// NUMA-aware allocation and memory binding. Stored in `ThreadingContext`.
class Allocator {
public:
Allocator(const BoundedTopology& topology, const CacheInfo& cache_info,
bool enable_bind);
// Used by `AllocateFor`, which only takes an `Allocator` argument,
// hence copy from `CacheInfo`.
size_t LineBytes() const { return line_bytes_; }
// File size multiple required for memory mapping. Also used when binding
// memory to NUMA nodes (see `BindB/BindC`).
size_t BasePageBytes() const { return base_page_bytes_; }
@ -105,12 +127,6 @@ class Allocator {
// Desired allocator alignment: Either StepBytes, or BasePageBytes if NUMA.
size_t QuantumBytes() const { return quantum_bytes_; }
// L1 and L2 are typically per core.
size_t L1Bytes() const { return l1_bytes_; }
size_t L2Bytes() const { return l2_bytes_; }
// Clusters often share an L3. We return the total size per package.
size_t L3Bytes() const { return l3_bytes_; }
size_t TotalMiB() const { return total_mib_; }
size_t FreeMiB() const;
@ -149,28 +165,21 @@ class Allocator {
}
// Returns whether `BindMemory` can/should be called, i.e. we have page-level
// control over memory placement and multiple packages and NUMA nodes.
// control over memory placement and multiple NUMA nodes.
bool ShouldBind() const { return should_bind_; }
// Attempts to move(!) `[p, p + bytes)` to the given NUMA node, which is
// typically `BoundedTopology::GetCluster(package_idx, cluster_idx).node`.
// typically `BoundedTopology::GetCluster(cluster_idx).node`.
// Writes zeros to SOME of the memory. Only call if `ShouldBind()`.
// `p` and `bytes` must be multiples of `QuantumBytes()`.
bool BindMemory(void* p, size_t bytes, size_t node) const;
private:
size_t line_bytes_;
size_t vector_bytes_;
size_t step_bytes_;
size_t base_page_bytes_;
const size_t line_bytes_;
const size_t base_page_bytes_;
const size_t total_mib_;
size_t quantum_bytes_;
size_t l1_bytes_ = 0;
size_t l2_bytes_ = 0;
size_t l3_bytes_ = 0;
size_t total_mib_;
bool should_bind_ = false;
};

75
util/basics.cc Normal file
View File

@ -0,0 +1,75 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "util/basics.h"
#include <stddef.h>
#include <stdint.h>
#include "hwy/contrib/sort/vqsort.h"
#include "hwy/highway.h"
#include "hwy/timer.h"
namespace gcpp {
AesCtrEngine::AesCtrEngine(bool deterministic) {
// Pi-based nothing up my sleeve numbers from Randen.
key_[0] = 0x243F6A8885A308D3ull;
key_[1] = 0x13198A2E03707344ull;
if (!deterministic) { // want random seed
if (!hwy::Fill16BytesSecure(key_)) {
HWY_WARN("Failed to fill RNG key with secure random bits");
// Entropy not available. The test requires that we inject some
// differences relative to the deterministic seeds.
key_[0] ^= reinterpret_cast<uint64_t>(this);
key_[1] ^= hwy::timer::Start();
}
}
// Simple key schedule: swap and add constant (also from Randen).
for (size_t i = 0; i < kRounds; ++i) {
key_[2 + 2 * i + 0] = key_[2 * i + 1] + 0xA4093822299F31D0ull;
key_[2 + 2 * i + 1] = key_[2 * i + 0] + 0x082EFA98EC4E6C89ull;
}
}
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::Full128<uint8_t>; // 128 bits for AES
using V = hn::Vec<D>;
static V Load(const uint64_t* ptr) {
return hn::Load(D(), reinterpret_cast<const uint8_t*>(ptr));
}
uint64_t AesCtrEngine::operator()(uint64_t stream, uint64_t counter) const {
const hn::Repartition<uint64_t, D> d64;
V state = hn::BitCast(D(), hn::Dup128VecFromValues(d64, counter, stream));
state = hn::Xor(state, Load(key_)); // initial whitening
static_assert(kRounds == 5 && sizeof(key_) == 12 * sizeof(uint64_t));
state = hn::AESRound(state, Load(key_ + 2));
state = hn::AESRound(state, Load(key_ + 4));
state = hn::AESRound(state, Load(key_ + 6));
state = hn::AESRound(state, Load(key_ + 8));
// Final round: fine to use another AESRound, including MixColumns.
state = hn::AESRound(state, Load(key_ + 10));
// Return lower 64 bits of the u8 vector.
return hn::GetLane(hn::BitCast(d64, state));
}
} // namespace gcpp

View File

@ -20,7 +20,7 @@
#include <stddef.h>
#include <stdint.h>
#include "hwy/aligned_allocator.h"
#include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h"
// IWYU pragma: end_exports
@ -30,10 +30,8 @@
namespace gcpp {
// Maximum number of packages (CPU sockets) to use. `ThreadingArgs` verifies the
// runtime `max_packages` does not exceed this. MatMul's outer per-package loop
// is disabled if this is 1.
constexpr size_t kMaxPackages = 1;
// For hwy::BitSet4096. Note that KVs are extremely large for such batches.
HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096;
enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 };
@ -61,6 +59,25 @@ static inline void MaybeCheckInitialized(const void* ptr, size_t size) {
#endif
}
static inline void MaybePrintInitialized(const void* ptr, size_t size) {
#if HWY_IS_MSAN
__msan_print_shadow(ptr, size);
#else
(void)ptr;
(void)size;
#endif
}
static inline intptr_t MaybeTestInitialized(const void* ptr, size_t size) {
#if HWY_IS_MSAN
return __msan_test_shadow(ptr, size);
#else
(void)ptr;
(void)size;
return 0;
#endif
}
// Shared between gemma.h and ops-inl.h.
#pragma pack(push, 1)
struct TokenAndProb {
@ -121,6 +138,63 @@ static inline IndexRange MakeIndexRange(size_t begin, size_t end,
size_t max_size) {
return IndexRange(begin, HWY_MIN(begin + max_size, end));
}
using Logits = hwy::Span<float>; // size() is vocab_size.
// Non-cryptographic 64-bit pseudo-random number generator. Supports random or
// deterministic seeding.
//
// Based on 5-round AES-CTR. Supports 2^64 streams, each with period 2^64. This
// is useful for parallel sampling. Each thread can generate the stream for a
// particular task, without caring about prior/subsequent generations.
class alignas(16) AesCtrEngine {
// "Large-scale randomness study of security margins for 100+ cryptographic
// functions": at least four.
// "Parallel Random Numbers: As Easy as 1, 2, 3": four not Crush-resistant.
static constexpr size_t kRounds = 5;
public:
// If `deterministic` is true, uses a fixed seed; otherwise, attempts to
// grab entropy from the OS.
explicit AesCtrEngine(bool deterministic);
// Pure and thread safe; typically called via `RngStream`, which increments
// `counter`. Throughput is about 100M/s on 3 GHz Skylake. It could be
// increased 4x via unrolling by the AES latency (4-7 cycles), but because
// users generally call once at a time, this requires buffering, which is not
// worth the complexity in this application.
uint64_t operator()(uint64_t stream, uint64_t counter) const;
private:
uint64_t key_[2 * (1 + kRounds)];
};
// Flyweight per-thread adapter that maintains the counter. Conforms to C++
// `UniformRandomBitGenerator`.
class RngStream {
public:
RngStream() = default; // Allow C arrays with subsequent initialization.
// Binds to an engine, which holds the seed and must outlive this object.
// Sets the stream; any other `RngStream` with the same `counter_rng` and
// `stream` will return the same sequence. This is typically the task ID, so
// that threads can independently generate values for each task.
RngStream(const AesCtrEngine& counter_rng, uint64_t stream)
: engine_(&counter_rng), stream_(stream), counter_(0) {}
using result_type = uint64_t;
static constexpr result_type min() { return 0; }
static constexpr result_type max() { return ~result_type{0}; }
result_type operator()() { return (*engine_)(stream_, counter_++); }
private:
const AesCtrEngine* engine_ = nullptr;
uint64_t stream_ = 0; // immutable after ctor
uint64_t counter_ = 0;
// Prevent false sharing if used by multiple threads.
HWY_MAYBE_UNUSED uint8_t padding_[HWY_ALIGNMENT - 16 - sizeof(engine_)];
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_

131
util/basics_test.cc Normal file
View File

@ -0,0 +1,131 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "util/basics.h"
#include <stddef.h>
#include <stdio.h>
#include "hwy/base.h"
#include "hwy/tests/hwy_gtest.h"
#include "hwy/timer.h"
namespace gcpp {
namespace {
TEST(BasicsTest, EngineIsDeterministic) {
const AesCtrEngine engine1(/*deterministic=*/true);
const AesCtrEngine engine2(/*deterministic=*/true);
RngStream rng1(engine1, 0);
RngStream rng2(engine2, 0);
// Remember for later testing after resetting the stream.
const uint64_t r0 = rng1();
const uint64_t r1 = rng1();
// Not consecutive values. This could actually happen due to the extra XOR,
// but given the deterministic seeding here, we know it will not.
HWY_ASSERT(r0 != r1);
// Let rng2 catch up.
HWY_ASSERT(r0 == rng2());
HWY_ASSERT(r1 == rng2());
for (size_t i = 0; i < 1000; ++i) {
HWY_ASSERT(rng1() == rng2());
}
// Reset counter, ensure it matches the prior sequence.
rng1 = RngStream(engine1, 0);
HWY_ASSERT(r0 == rng1());
HWY_ASSERT(r1 == rng1());
}
TEST(BasicsTest, EngineIsSeeded) {
AesCtrEngine engine1(/*deterministic=*/true);
AesCtrEngine engine2(/*deterministic=*/false);
RngStream rng1(engine1, 0);
RngStream rng2(engine2, 0);
// It would be very unlucky to have even one 64-bit value match, and two are
// extremely unlikely.
const uint64_t a0 = rng1();
const uint64_t a1 = rng1();
const uint64_t b0 = rng2();
const uint64_t b1 = rng2();
HWY_ASSERT(a0 != b0 || a1 != b1);
}
TEST(BasicsTest, StreamsDiffer) {
AesCtrEngine engine(/*deterministic=*/true);
// Compare random streams for more coverage than just the first N streams.
RngStream rng_for_stream(engine, 0);
for (size_t i = 0; i < 1000; ++i) {
RngStream rng1(engine, rng_for_stream());
RngStream rng2(engine, rng_for_stream());
// It would be very unlucky to have even one 64-bit value match, and two are
// extremely unlikely.
const uint64_t a0 = rng1();
const uint64_t a1 = rng1();
const uint64_t b0 = rng2();
const uint64_t b1 = rng2();
HWY_ASSERT(a0 != b0 || a1 != b1);
}
}
// If not close to 50% 1-bits, the RNG is quite broken.
TEST(BasicsTest, BitDistribution) {
AesCtrEngine engine(/*deterministic=*/true);
RngStream rng(engine, 0);
constexpr size_t kU64 = 2 * 1000 * 1000;
const hwy::Timestamp t0;
uint64_t one_bits = 0;
for (size_t i = 0; i < kU64; ++i) {
one_bits += hwy::PopCount(rng());
}
const uint64_t total_bits = kU64 * 64;
const double one_ratio = static_cast<double>(one_bits) / total_bits;
const double elapsed = hwy::SecondsSince(t0);
fprintf(stderr, "1-bit ratio %.5f, %.1f M/s\n", one_ratio,
kU64 / elapsed * 1E-6);
HWY_ASSERT(0.4999 <= one_ratio && one_ratio <= 0.5001);
}
TEST(BasicsTest, ChiSquared) {
AesCtrEngine engine(/*deterministic=*/true);
RngStream rng(engine, 0);
constexpr size_t kU64 = 1 * 1000 * 1000;
// Test each byte separately.
for (size_t shift = 0; shift < 64; shift += 8) {
size_t counts[256] = {};
for (size_t i = 0; i < kU64; ++i) {
const size_t byte = (rng() >> shift) & 0xFF;
counts[byte]++;
}
double chi_squared = 0.0;
const double expected = static_cast<double>(kU64) / 256.0;
for (size_t i = 0; i < 256; ++i) {
const double diff = static_cast<double>(counts[i]) - expected;
chi_squared += diff * diff / expected;
}
// Should be within ~0.5% and 99.5% percentiles. See
// https://www.medcalc.org/manual/chi-square-table.php
if (chi_squared < 196.0 || chi_squared > 311.0) {
HWY_ABORT("Chi-squared byte %zu: %.5f \n", shift / 8, chi_squared);
}
}
}
} // namespace
} // namespace gcpp
HWY_TEST_MAIN();

View File

@ -80,11 +80,13 @@ size_t Stride(MatPadding padding, size_t cols, size_t element_bytes,
void MatOwner::AllocateFor(MatPtr& mat, const Allocator& allocator,
MatPadding padding) {
const bool is_nuq = mat.GetType() == Type::kNUQ;
if (is_nuq) padding = MatPadding::kPacked;
const bool is_compressed_and_packed =
mat.GetType() == Type::kNUQ || mat.GetType() == Type::kI8;
if (is_compressed_and_packed) padding = MatPadding::kPacked;
const size_t stride =
Stride(padding, mat.Cols(), mat.ElementBytes(), allocator.LineBytes());
const size_t num = is_nuq ? mat.PackedBytes() : mat.Rows() * stride;
const size_t num =
is_compressed_and_packed ? mat.PackedBytes() : mat.Rows() * stride;
// `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding`
// might not be enough, hence add extra. `MatT` is at least one byte, which
// is half of BF16, hence adding `VectorBytes` *elements* is enough.

View File

@ -38,16 +38,28 @@ namespace gcpp {
template <typename T>
class RowPtrs {
public:
RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs) {}
RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs), r0_(0), c0_(0) {}
T* HWY_RESTRICT operator[](size_t row_idx) const {
return HWY_RCAST_ALIGNED(T*, row_ptrs_[row_idx]);
// Extra argument is for compatibility with `StridedView`.
RowPtrs View(size_t r, size_t c, size_t /*cols*/) {
RowPtrs<T> view(row_ptrs_);
view.r0_ = static_cast<uint32_t>(r0_ + r);
view.c0_ = static_cast<uint32_t>(c0_ + c);
return view;
}
T* HWY_RESTRICT Row(size_t row_idx) const {
return HWY_RCAST_ALIGNED(T*, row_ptrs_[r0_ + row_idx]) + c0_;
}
private:
uint8_t** row_ptrs_;
uint32_t r0_;
uint32_t c0_;
};
using RowPtrsBF = RowPtrs<BF16>;
// Type-erased, non-owning pointer and metadata for rank-1 or 2 tensors (vector
// or matrix). Base class of the non-type-erased `MatPtrT`. Use this class
// to store hetereogeneous tensor references in a vector.
@ -175,7 +187,10 @@ class MatPtr : public IFields {
// will return this value. Used to set the actual number of rows for
// activations preallocated according to the batch size.
void OverrideRows(size_t rows) {
HWY_ASSERT(rows <= private_rows_);
if (HWY_UNLIKELY(rows > private_rows_)) {
HWY_ABORT("%s: rows %zu > private_rows_ %u\n", name_.c_str(), rows,
private_rows_);
}
override_rows_ = static_cast<uint32_t>(rows);
}
@ -225,6 +240,8 @@ class MatPtr : public IFields {
// `CompressedArrayElements` is a wrapper function that has the same
// effect, but that requires a template argument, not `type`.
num_elements = NuqStream::PackedEnd(num_elements);
} else if (type == Type::kI8) {
num_elements = I8Stream::PackedEnd(num_elements);
}
return num_elements;
}
@ -301,8 +318,16 @@ class MatPtrT : public MatPtr {
return HWY_RCAST_ALIGNED(const T*, RowBytes(row));
}
hwy::Span<MatT> RowSpan(size_t row) {
return hwy::Span<MatT>(Row(row), Cols());
}
hwy::Span<const MatT> RowSpan(size_t row) const {
return hwy::Span<const MatT>(Row(row), Cols());
}
PackedSpan<const MatT> PaddedSpan() const {
return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), Rows() * Stride());
const size_t num = IsPacked() ? num_elements_ : Rows() * Stride();
return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), num);
}
// For `compress-inl.h` functions, which assume contiguous streams and thus
@ -341,12 +366,12 @@ RowPtrs<T> GetOrSetTempRowPtrs(
template <class Func, typename... Args>
decltype(auto) CallUpcasted(const MatPtr* base, const Func& func,
Args&&... args) {
#if GEMMA_ENABLE_NUQ
if (base->GetType() == Type::kNUQ) {
const MatPtrT<NuqStream> mat(*base);
return func(&mat, std::forward<Args>(args)...);
if constexpr (GEMMA_ENABLE_NUQ) {
if (base->GetType() == Type::kNUQ) {
const MatPtrT<NuqStream> mat(*base);
return func(&mat, std::forward<Args>(args)...);
}
}
#endif // GEMMA_ENABLE_NUQ
if (base->GetType() == Type::kF32) {
const MatPtrT<float> mat(*base);
@ -357,6 +382,9 @@ decltype(auto) CallUpcasted(const MatPtr* base, const Func& func,
} else if (base->GetType() == Type::kSFP) {
const MatPtrT<SfpStream> mat(*base);
return func(&mat, std::forward<Args>(args)...);
} else if (base->GetType() == Type::kI8) {
const MatPtrT<I8Stream> mat(*base);
return func(&mat, std::forward<Args>(args)...);
} else {
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
}
@ -368,13 +396,13 @@ decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
const Func& func, Args&&... args) {
HWY_DASSERT(base1->GetType() == base2->GetType());
#if GEMMA_ENABLE_NUQ
if (base1->GetType() == Type::kNUQ) {
const MatPtrT<NuqStream> mat1(*base1);
const MatPtrT<NuqStream> mat2(*base2);
return func(&mat1, &mat2, std::forward<Args>(args)...);
if constexpr (GEMMA_ENABLE_NUQ) {
if (base1->GetType() == Type::kNUQ) {
const MatPtrT<NuqStream> mat1(*base1);
const MatPtrT<NuqStream> mat2(*base2);
return func(&mat1, &mat2, std::forward<Args>(args)...);
}
}
#endif // GEMMA_ENABLE_NUQ
if (base1->GetType() == Type::kF32) {
const MatPtrT<float> mat1(*base1);
@ -388,6 +416,10 @@ decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
const MatPtrT<SfpStream> mat1(*base1);
const MatPtrT<SfpStream> mat2(*base2);
return func(&mat1, &mat2, std::forward<Args>(args)...);
} else if (base1->GetType() == Type::kI8) {
const MatPtrT<I8Stream> mat1(*base1);
const MatPtrT<I8Stream> mat2(*base2);
return func(&mat1, &mat2, std::forward<Args>(args)...);
} else {
HWY_ABORT("Unhandled type %s.", TypeName(base1->GetType()));
}
@ -455,6 +487,7 @@ class MatOwner {
template <typename MatT>
class MatStorageT : public MatPtrT<MatT> {
public:
MatStorageT() = default; // for std::vector in Activations.
MatStorageT(const char* name, Extents2D extents, const Allocator& allocator,
MatPadding padding)
: MatPtrT<MatT>(name, extents) {
@ -499,5 +532,55 @@ class MatFactory {
MatPadding padding_;
};
// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows.
// Also used to decompress B, hence non-const.
#pragma pack(push, 1) // power of two size
template <typename T>
class StridedView {
public:
StridedView(T* HWY_RESTRICT row0, size_t cols, size_t stride)
: row0_(row0),
cols_(static_cast<uint32_t>(cols)),
stride_(static_cast<uint32_t>(stride)) {
if constexpr (HWY_IS_DEBUG_BUILD) {
if (stride < cols) {
HWY_ABORT("stride %zu < cols %zu", stride, cols);
}
}
}
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
StridedView(const MatPtrT<T>& mat, size_t r, size_t c, size_t cols)
: StridedView(const_cast<T*>(mat.Row(r)) + c, cols, mat.Stride()) {
HWY_DASSERT(c < mat.Cols());
HWY_DASSERT(cols <= mat.Cols() - c);
}
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
StridedView<T> View(size_t r, size_t c, size_t cols) const {
HWY_DASSERT(c < Cols());
HWY_DASSERT(cols <= Cols() - c);
return StridedView<T>(Row(r) + c, cols, stride_);
}
T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
size_t Cols() const { return static_cast<size_t>(cols_); }
size_t Stride() const { return static_cast<size_t>(stride_); }
void SetStride(size_t stride) {
HWY_DASSERT(stride >= Cols());
stride_ = stride;
}
private:
T* HWY_RESTRICT row0_;
uint32_t cols_;
uint32_t stride_;
};
#pragma pack(pop)
using StridedViewBF = StridedView<BF16>;
using StridedViewD = StridedView<double>;
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_

View File

@ -18,8 +18,6 @@
#include <stdio.h>
#include <algorithm> // std::sort
#include <atomic>
#include <memory>
#include <optional>
#include <vector>
@ -29,97 +27,62 @@
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/contrib/thread_pool/topology.h"
#include "hwy/profiler.h"
namespace gcpp {
// Sort T := packages/clusters by descending 'size' so that users who only use
// one Group get the largest.
template <class T>
static void SortByDescendingSize(std::vector<T>& groups) {
std::sort(groups.begin(), groups.end(),
[](const T& a, const T& b) { return a.Size() > b.Size(); });
static bool InContainer() {
return false; // placeholder for container detection, do not remove
}
// Singleton, holds the original process affinity and the pinning status.
class Pinning {
static bool InContainer() {
return false; }
public:
void SetPolicy(Tristate pin) {
if (pin == Tristate::kDefault) {
// Pinning is unreliable inside containers because the hypervisor might
// periodically change our affinity mask, or other processes might also
// pin themselves to the same LPs.
pin = InContainer() ? Tristate::kFalse : Tristate::kTrue;
}
want_pin_ = (pin == Tristate::kTrue);
any_error_.clear();
PinningPolicy::PinningPolicy(Tristate pin) {
if (pin == Tristate::kDefault) {
// Pinning is unreliable inside containers because the hypervisor might
// periodically change our affinity mask, or other processes might also
// pin themselves to the same LPs.
pin = InContainer() ? Tristate::kFalse : Tristate::kTrue;
}
want_pin_ = (pin == Tristate::kTrue);
}
// If want_pin_, tries to pin each worker in `pool` to an LP in `cluster`,
// and sets `any_error_` if any fails.
void MaybePin(const BoundedTopology& topology, size_t pkg_idx,
size_t cluster_idx, const BoundedTopology::Cluster& cluster,
hwy::ThreadPool& pool) {
const std::vector<size_t> lps = cluster.LPVector();
HWY_ASSERT(pool.NumWorkers() <= lps.size());
pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) {
HWY_ASSERT(task == thread); // each worker has one task
// If `pinning.Want()`, tries to pin each worker in `pool` to an LP in
// `cluster`, and calls `pinning.NotifyFailed()` if any fails.
static void MaybePin(const BoundedTopology& topology, size_t cluster_idx,
const BoundedTopology::Cluster& cluster,
PinningPolicy& pinning, hwy::ThreadPool& pool) {
const std::vector<size_t> lps = cluster.LPVector();
HWY_ASSERT(pool.NumWorkers() <= lps.size());
pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) {
HWY_ASSERT(task == thread); // each worker has one task
char buf[16]; // Linux limitation
const int bytes_written = snprintf(
buf, sizeof(buf), "P%zu X%02zu C%03d",
topology.SkippedPackages() + pkg_idx,
topology.SkippedClusters() + cluster_idx, static_cast<int>(task));
HWY_ASSERT(bytes_written < static_cast<int>(sizeof(buf)));
hwy::SetThreadName(buf, 0); // does not support varargs
char buf[16]; // Linux limitation
const int bytes_written = snprintf(
buf, sizeof(buf), "P%zu X%02zu C%03d", topology.SkippedPackages(),
topology.SkippedClusters() + cluster_idx, static_cast<int>(task));
HWY_ASSERT(bytes_written < static_cast<int>(sizeof(buf)));
hwy::SetThreadName(buf, 0); // does not support varargs
if (HWY_LIKELY(want_pin_)) {
if (HWY_UNLIKELY(!hwy::PinThreadToLogicalProcessor(lps[task]))) {
// Apple does not support pinning, hence do not warn there.
if (!HWY_OS_APPLE) {
HWY_WARN("Pinning failed for task %d of %zu to %zu (size %zu)\n",
static_cast<int>(task), pool.NumWorkers(), lps[task],
lps.size());
}
(void)any_error_.test_and_set();
if (HWY_LIKELY(pinning.Want())) {
if (HWY_UNLIKELY(!hwy::PinThreadToLogicalProcessor(lps[task]))) {
// Apple does not support pinning, hence do not warn there.
if (!HWY_OS_APPLE) {
HWY_WARN("Pinning failed for task %d of %zu to %zu (size %zu)\n",
static_cast<int>(task), pool.NumWorkers(), lps[task],
lps.size());
}
pinning.NotifyFailed();
}
});
}
// Called ONCE after all MaybePin because it invalidates the error status.
bool AllPinned(const char** pin_string) {
// If !want_pin_, MaybePin will return without setting any_error_, but in
// that case we still want to return false to avoid spinning.
// .test() was only added in C++20, so we use .test_and_set() instead.
const bool all_pinned = want_pin_ && !any_error_.test_and_set();
*pin_string = all_pinned ? "pinned"
: want_pin_ ? "pinning failed"
: "pinning skipped";
return all_pinned;
}
private:
std::atomic_flag any_error_ = ATOMIC_FLAG_INIT;
bool want_pin_; // set in SetPolicy
}; // Pinning
// Singleton saves global affinity across all BoundedTopology instances because
// pinning overwrites it.
static Pinning& GetPinning() {
static Pinning pinning;
return pinning;
}
});
}
static PoolPtr MakePool(const Allocator& allocator, size_t num_workers,
hwy::PoolWorkerMapping mapping,
std::optional<size_t> node = std::nullopt) {
// `ThreadPool` expects the number of threads to create, which is one less
// than the number of workers, but avoid underflow if zero.
const size_t num_threads = num_workers == 0 ? 0 : num_workers - 1;
PoolPtr ptr = allocator.AllocClasses<hwy::ThreadPool>(1, num_threads);
PoolPtr ptr =
allocator.AllocClasses<hwy::ThreadPool>(1, num_threads, mapping);
const size_t bytes =
hwy::RoundUpTo(sizeof(hwy::ThreadPool), allocator.QuantumBytes());
if (node.has_value() && allocator.ShouldBind()) {
@ -140,71 +103,56 @@ static size_t DivideMaxAcross(const size_t max, const size_t instances) {
return max;
}
NestedPools::NestedPools(const BoundedTopology& topology,
const Allocator& allocator, size_t max_threads,
Tristate pin) {
GetPinning().SetPolicy(pin);
packages_.resize(topology.NumPackages());
all_packages_ = MakePool(allocator, packages_.size());
const size_t max_workers_per_package =
DivideMaxAcross(max_threads, packages_.size());
// Each worker in all_packages_, including the main thread, will be the
// calling thread of an all_clusters->Run, and hence pinned to one of the
// `cluster.lps` if `pin`.
all_packages_->Run(0, packages_.size(), [&](uint64_t pkg_idx, size_t thread) {
HWY_ASSERT(pkg_idx == thread); // each thread has one task
packages_[pkg_idx] =
Package(topology, allocator, pkg_idx, max_workers_per_package);
});
all_pinned_ = GetPinning().AllPinned(&pin_string_);
// For mapping package/cluster/thread to noncontiguous TLS indices, in case
// cluster/thread counts differ.
HWY_ASSERT(!packages_.empty() && packages_.size() <= 16);
for (const Package& p : packages_) {
max_clusters_per_package_ =
HWY_MAX(max_clusters_per_package_, p.NumClusters());
max_workers_per_cluster_ =
HWY_MAX(max_workers_per_cluster_, p.MaxWorkersPerCluster());
}
HWY_ASSERT(max_clusters_per_package_ >= 1);
HWY_ASSERT(max_clusters_per_package_ <= 64);
HWY_ASSERT(max_workers_per_cluster_ >= 1);
HWY_ASSERT(max_workers_per_cluster_ <= 256);
hwy::Profiler::Get().SetMaxThreads(MaxWorkers());
}
// `max_or_zero` == 0 means no limit.
static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) {
return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero);
}
NestedPools::Package::Package(const BoundedTopology& topology,
const Allocator& allocator, size_t pkg_idx,
size_t max_workers_per_package) {
// Pre-allocate because elements are set concurrently.
clusters_.resize(topology.NumClusters(pkg_idx));
const size_t max_workers_per_cluster =
DivideMaxAcross(max_workers_per_package, clusters_.size());
NestedPools::NestedPools(const BoundedTopology& topology,
const Allocator& allocator, size_t max_threads,
Tristate pin)
: pinning_(pin) {
const size_t num_clusters = topology.NumClusters();
const size_t cluster_workers_cap = DivideMaxAcross(max_threads, num_clusters);
// Precompute cluster sizes to ensure we pass the same values to `MakePool`.
// The max is also used for `all_clusters_mapping`, see below.
size_t workers_per_cluster[hwy::kMaxClusters] = {};
size_t all_clusters_node = 0;
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
const BoundedTopology::Cluster& tcluster = topology.GetCluster(cluster_idx);
workers_per_cluster[cluster_idx] =
CapIfNonZero(tcluster.NumWorkers(), cluster_workers_cap);
// Cluster sizes can vary because individual LPs may be disabled. Use the
// max so that `GlobalIdx` is consistent within and across clusters. It is
// OK to have holes or gaps in the worker index space.
max_workers_per_cluster_ =
HWY_MAX(max_workers_per_cluster_, workers_per_cluster[cluster_idx]);
all_clusters_node = tcluster.Node(); // arbitrarily use the last node seen
}
const hwy::PoolWorkerMapping all_clusters_mapping(hwy::kAllClusters,
max_workers_per_cluster_);
all_clusters_ = MakePool(allocator, num_clusters, all_clusters_mapping,
all_clusters_node);
// Pre-allocate because elements are set concurrently.
clusters_.resize(num_clusters);
all_clusters_ = MakePool(allocator, clusters_.size(),
topology.GetCluster(pkg_idx, 0).Node());
// Parallel so we also pin the calling worker in `all_clusters` to
// `cluster.lps`.
all_clusters_->Run(
0, all_clusters_->NumWorkers(), [&](size_t cluster_idx, size_t thread) {
HWY_ASSERT(cluster_idx == thread); // each thread has one task
const BoundedTopology::Cluster& cluster =
topology.GetCluster(pkg_idx, cluster_idx);
clusters_[cluster_idx] = MakePool(
allocator, CapIfNonZero(cluster.Size(), max_workers_per_cluster),
cluster.Node());
// Pin workers AND the calling thread from `all_clusters`.
GetPinning().MaybePin(topology, pkg_idx, cluster_idx, cluster,
*clusters_[cluster_idx]);
});
all_clusters_->Run(0, num_clusters, [&](size_t cluster_idx, size_t thread) {
HWY_ASSERT(cluster_idx == thread); // each thread has one task
const BoundedTopology::Cluster& tcluster = topology.GetCluster(cluster_idx);
clusters_[cluster_idx] =
MakePool(allocator, workers_per_cluster[cluster_idx],
hwy::PoolWorkerMapping(cluster_idx, max_workers_per_cluster_),
tcluster.Node());
// Pin workers AND the calling thread from `all_clusters_`.
MaybePin(topology, cluster_idx, tcluster, pinning_,
*clusters_[cluster_idx]);
});
all_pinned_ = pinning_.AllPinned(&pin_string_);
}
} // namespace gcpp

View File

@ -19,6 +19,7 @@
#include <stddef.h>
#include <stdint.h>
#include <atomic>
#include <vector>
// IWYU pragma: begin_exports
@ -40,23 +41,49 @@ namespace gcpp {
// moving because it is a typedef to `std::unique_ptr`.
using PoolPtr = AlignedClassPtr<hwy::ThreadPool>;
class PinningPolicy {
public:
explicit PinningPolicy(Tristate pin);
bool Want() const { return want_pin_; }
void NotifyFailed() { (void)any_error_.test_and_set(); }
// Called ONCE after all MaybePin because it invalidates the error status.
bool AllPinned(const char** pin_string) {
// If !want_pin_, MaybePin will return without setting any_error_, but in
// that case we still want to return false to avoid spinning.
// .test() was only added in C++20, so we use .test_and_set() instead.
const bool all_pinned = want_pin_ && !any_error_.test_and_set();
*pin_string = all_pinned ? "pinned"
: want_pin_ ? "pinning failed"
: "pinning skipped";
return all_pinned;
}
private:
std::atomic_flag any_error_ = ATOMIC_FLAG_INIT;
bool want_pin_; // set in SetPolicy
}; // PinningPolicy
// Creates a hierarchy of thread pools according to `BoundedTopology`: one with
// a thread per enabled package; for each of those, one with a thread per
// enabled cluster (CCX/shared L3), and for each of those, the remaining
// enabled cores in that cluster.
// a thread per enabled cluster (CCX/shared L3), and for each of those, the
// remaining enabled cores in that cluster.
//
// Note that we support spin waits, thus it is important for each thread to be
// responsive, hence we do not create more than one thread per enabled core.
// For example, when there are two packages with four clusters of 8 cores,
// `AllPackages` has the main thread plus one extra thread, each `AllClusters`
// has one of the `AllPackages` threads plus three extras, each `Cluster` runs
// on one `AllClusters` thread plus seven extra workers, for a total of
// 1 + 2*3 + 2*(4*7) = 63 extras plus the main thread.
// For example, when there are four clusters of 8 cores, `AllClusters` has the
// main thread plus three extras, each `Cluster` runs on one of `AllClusters`
// plus seven extras, for a total of 3 + (4*7) = 31 extras plus the main thread.
//
// Useful when there are tasks which should be parallelized by workers sharing a
// cache, or on the same NUMA node. In both cases, individual pools have lower
// barrier synchronization latency than one large pool. However, to utilize all
// cores, call sites will have to use nested parallel-for loops.
// cores, call sites will have to use nested parallel-for loops as in
// `HierarchicalParallelFor`. To allow switching modes easily, prefer using the
// `ParallelFor` abstraction in threading_context.h).
//
// Note that this was previously intended to use all cores, but we are now
// moving toward also allowing concurrent construction with subsets of cores.
class NestedPools {
public:
// Neither move nor copy.
@ -66,14 +93,20 @@ class NestedPools {
NestedPools(NestedPools&&) = delete;
NestedPools& operator=(NestedPools&&) = delete;
// Because cross-package latency is high, this interface assumes only one
// package is used. The `skip_packages` argument to `BoundedTopology` selects
// which package that is for this `NestedPools` instance.
//
// `max_threads` is the maximum number of threads to divide among all
// clusters. This is more intuitive than a per-cluster limit for users who
// may not be aware of the CPU topology. 0 means no limit.
// may not be aware of the CPU topology. This should be zero (meaning no
// further limits) if the caller has already set limits via `skip_*` or
// `max_*` args passed to `ThreadingContext`.
//
// To ensure we do not create more threads than there are HW cores, which
// would cause huge slowdowns when spinning, the `BoundedSlice` arguments
// only impose upper bounds on the number of detected packages and clusters
// rather than defining the actual number of threads.
// only impose upper bounds on the number of detected clusters rather than
// defining the actual number of threads.
NestedPools(const BoundedTopology& topology, const Allocator& allocator,
size_t max_threads = 0, Tristate pin = Tristate::kDefault);
@ -101,107 +134,51 @@ class NestedPools {
}
}
size_t NumPackages() const { return packages_.size(); }
hwy::ThreadPool& AllPackages() { return *all_packages_; }
hwy::ThreadPool& AllClusters(size_t pkg_idx) {
HWY_DASSERT(pkg_idx < NumPackages());
return packages_[pkg_idx].AllClusters();
}
hwy::ThreadPool& Cluster(size_t pkg_idx, size_t cluster_idx) {
HWY_DASSERT(pkg_idx < NumPackages());
return packages_[pkg_idx].Cluster(cluster_idx);
size_t NumClusters() const { return clusters_.size(); }
hwy::ThreadPool& AllClusters() { return *all_clusters_; }
hwy::ThreadPool& Cluster(size_t cluster_idx) {
HWY_DASSERT(cluster_idx < clusters_.size());
return *clusters_[cluster_idx];
}
// Reasonably tight upper bounds for allocating thread-local storage (TLS).
size_t MaxWorkersPerCluster() const { return max_workers_per_cluster_; }
size_t MaxWorkersPerPackage() const {
return max_clusters_per_package_ * MaxWorkersPerCluster();
}
size_t MaxWorkers() const { return NumPackages() * MaxWorkersPerPackage(); }
// Actual number of workers.
size_t TotalWorkers() const {
size_t total_workers = 0;
for (size_t pkg_idx = 0; pkg_idx < NumPackages(); ++pkg_idx) {
total_workers += packages_[pkg_idx].TotalWorkers();
}
return total_workers;
}
size_t MaxWorkers() const { return NumClusters() * MaxWorkersPerCluster(); }
// For ShowConfig
const char* PinString() const { return pin_string_; }
// Returns a single pool on the given package: either one thread per cluster
// if there is more than one, which maximizes available memory bandwidth, or
// the first cluster, which is typically the whole package. For use by callers
// that only have a single parallel-for.
// the first cluster, which is typically the whole package. For use by
// callers that only have a single parallel-for.
// DEPRECATED: use ParallelFor instead.
hwy::ThreadPool& Pool(size_t pkg_idx = 0) {
// Only one cluster: use its pool, typically a whole socket.
if (AllClusters(pkg_idx).NumWorkers() == 1) {
return Cluster(pkg_idx, 0);
}
if (NumClusters() == 1) return Cluster(0);
// One worker per cluster to maximize bandwidth availability.
return AllClusters(pkg_idx);
return AllClusters();
}
private:
class Package {
public:
Package() = default; // for vector
Package(const BoundedTopology& topology, const Allocator& allocator,
size_t pkg_idx, size_t max_workers_per_package);
size_t NumClusters() const { return clusters_.size(); }
size_t MaxWorkersPerCluster() const {
size_t max_workers_per_cluster = 0;
for (const PoolPtr& cluster : clusters_) {
max_workers_per_cluster =
HWY_MAX(max_workers_per_cluster, cluster->NumWorkers());
}
return max_workers_per_cluster;
}
size_t TotalWorkers() const {
size_t total_workers = 0;
for (const PoolPtr& cluster : clusters_) {
total_workers += cluster->NumWorkers();
}
return total_workers;
}
hwy::ThreadPool& AllClusters() { return *all_clusters_; }
hwy::ThreadPool& Cluster(size_t cluster_idx) {
HWY_DASSERT(cluster_idx < clusters_.size());
return *clusters_[cluster_idx];
}
void SetWaitMode(hwy::PoolWaitMode wait_mode) {
all_clusters_->SetWaitMode(wait_mode);
for (PoolPtr& cluster : clusters_) {
cluster->SetWaitMode(wait_mode);
}
}
private:
std::vector<PoolPtr> clusters_;
PoolPtr all_clusters_;
}; // Package
void SetWaitMode(hwy::PoolWaitMode wait_mode) {
all_packages_->SetWaitMode(wait_mode);
for (Package& package : packages_) {
package.SetWaitMode(wait_mode);
all_clusters_->SetWaitMode(wait_mode);
for (PoolPtr& cluster : clusters_) {
cluster->SetWaitMode(wait_mode);
}
}
PinningPolicy pinning_;
bool all_pinned_;
const char* pin_string_;
std::vector<Package> packages_;
PoolPtr all_packages_;
// Must be freed after `clusters_` because it reserves threads which are
// the main threads of `clusters_`.
PoolPtr all_clusters_;
std::vector<PoolPtr> clusters_;
// For TLS indices. One might think this belongs in BoundedTopology, but it
// depends on max_threads, which is passed to the NestedPools constructor.
size_t max_clusters_per_package_ = 0;
// Used by `PoolWorkerMapping`. This depends on the `max_threads` argument,
// hence we can only compute this here, not in `BoundedTopology`.
size_t max_workers_per_cluster_ = 0;
};
@ -324,15 +301,13 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1,
// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes
// over clusters of ONE package, then within each cluster.
template <class Func>
void ParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) {
// Even if there are multiple packages, we only use the first.
const size_t pkg_idx = 0;
void HierarchicalParallelFor(size_t num_tasks, NestedPools& pools,
const Func& func) {
// If few tasks, run on a single cluster. Also avoids a bit of overhead if
// there is only one cluster.
hwy::ThreadPool& all_clusters = pools.AllClusters(pkg_idx);
hwy::ThreadPool& all_clusters = pools.AllClusters();
const size_t num_clusters = all_clusters.NumWorkers();
hwy::ThreadPool& cluster = pools.Cluster(pkg_idx, 0);
hwy::ThreadPool& cluster = pools.Cluster(0);
if (num_clusters == 1 || num_tasks <= cluster.NumWorkers()) {
return cluster.Run(0, num_tasks, [&](uint64_t task, size_t thread) {
func(task, thread);
@ -345,7 +320,7 @@ void ParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) {
ParallelizeOneRange(
ranges, all_clusters,
[&](const IndexRange& range, const size_t cluster_idx) {
hwy::ThreadPool& cluster = pools.Cluster(pkg_idx, cluster_idx);
hwy::ThreadPool& cluster = pools.Cluster(cluster_idx);
const size_t cluster_base = cluster_idx * pools.MaxWorkersPerCluster();
cluster.Run(range.begin(), range.end(),
[&](uint64_t task, size_t thread) {
@ -354,16 +329,6 @@ void ParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) {
});
}
// As above, but for lightweight tasks. Uses only one pool.
template <class Func>
void SmallParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) {
// Even if there are multiple packages, we only use the first.
const size_t pkg_idx = 0;
pools.Pool(pkg_idx).Run(
0, num_tasks, [&](uint64_t task, size_t thread) { func(task, thread); });
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_

View File

@ -20,7 +20,9 @@
#include <vector>
#include "util/zones.h"
#include "hwy/aligned_allocator.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
#include "hwy/tests/test_util.h" // RandomState
@ -28,7 +30,11 @@ namespace gcpp {
// Invokes `pool.Run` with varying task counts until auto-tuning completes, or
// an upper bound just in case.
static void TunePool(hwy::ThreadPool& pool) {
static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) {
pool.SetWaitMode(wait_mode);
// TODO(janwas): re-enable after investigating potential deadlock.
#if 0
const size_t num_workers = pool.NumWorkers();
// pool.Run would just be a serial loop without auto-tuning, so skip.
if (num_workers == 1) return;
@ -69,6 +75,19 @@ static void TunePool(hwy::ThreadPool& pool) {
HWY_ASSERT(total == prev_total + expected);
prev_total += expected;
}
#endif
}
static void TunePools(hwy::PoolWaitMode wait_mode, NestedPools& pools) {
hwy::ThreadPool& clusters = pools.AllClusters();
TunePool(wait_mode, clusters);
// Run in parallel because Turin CPUs have 16, and in real usage, we often
// run all at the same time.
clusters.Run(0, clusters.NumWorkers(),
[&](uint64_t cluster_idx, size_t /*thread*/) {
TunePool(wait_mode, pools.Cluster(cluster_idx));
});
}
ThreadingContext::ThreadingContext(const ThreadingArgs& args)
@ -76,21 +95,14 @@ ThreadingContext::ThreadingContext(const ThreadingArgs& args)
topology(BoundedSlice(args.skip_packages, args.max_packages),
BoundedSlice(args.skip_clusters, args.max_clusters),
BoundedSlice(args.skip_lps, args.max_lps)),
allocator(topology, args.bind != Tristate::kFalse),
cache_info(topology),
allocator(topology, cache_info, args.bind != Tristate::kFalse),
pools(topology, allocator, args.max_threads, args.pin) {
InitProfilerZones(profiler);
PROFILER_ZONE("Startup.ThreadingContext autotune");
TunePool(pools.AllPackages());
for (size_t pkg_idx = 0; pkg_idx < pools.NumPackages(); ++pkg_idx) {
hwy::ThreadPool& clusters = pools.AllClusters(pkg_idx);
TunePool(clusters);
// Run in parallel because Turin CPUs have 16, and in real usage, we often
// run all at the same time.
clusters.Run(0, clusters.NumWorkers(),
[&](uint64_t cluster_idx, size_t /*thread*/) {
TunePool(pools.Cluster(pkg_idx, cluster_idx));
});
}
TunePools(hwy::PoolWaitMode::kSpin, pools);
// kBlock is the default, hence set/tune it last.
TunePools(hwy::PoolWaitMode::kBlock, pools);
}
} // namespace gcpp

View File

@ -25,9 +25,10 @@
// IWYU pragma: begin_exports
#include "util/allocator.h"
#include "util/args.h"
#include "util/basics.h" // Tristate, kMaxPackages
#include "util/basics.h" // Tristate
#include "util/threading.h"
#include "util/topology.h"
#include "hwy/profiler.h"
// IWYU pragma: end_exports
namespace gcpp {
@ -40,7 +41,7 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
// For BoundedTopology:
size_t skip_packages;
size_t max_packages;
size_t max_packages = 1; // some users assign 1 to this, hence non-const.
size_t skip_clusters;
size_t max_clusters;
size_t skip_lps;
@ -55,27 +56,27 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
template <class Visitor>
void ForEach(const Visitor& visitor) {
// These can be used to partition CPU sockets/packages and their
// These can be used to partition CPU packages/sockets and their
// clusters/CCXs across several program instances. The default is to use
// all available resources.
// all available resources on the first package.
visitor(skip_packages, "skip_packages", size_t{0},
"Index of the first socket to use; default 0 = unlimited.", 2);
visitor(max_packages, "max_packages", size_t{1},
"Max sockets to use; default = 1, 0 = unlimited.", 2);
HWY_ASSERT(max_packages <= kMaxPackages);
visitor(skip_clusters, "skip_clusters", size_t{0},
"Index of the first CCX to use; default 0 = unlimited.", 2);
visitor(max_clusters, "max_clusters", size_t{0},
"Max CCXs to use; default 0 = unlimited.", 2);
// These are only used when CPU topology is unknown.
// "Logical processors" (LPs). These are used when CPU topology is unknown.
visitor(skip_lps, "skip_lps", size_t{0},
"Index of the first LP to use; default 0 = unlimited.", 2);
visitor(max_lps, "max_lps", size_t{0},
"Max LPs to use; default 0 = unlimited.", 2);
// The exact meaning is more subtle: see the comment at NestedPools ctor.
// DEPRECATED: superseded by the above fields. If nonzero, `NestedPools`
// will attempt to create this many threads distributed over the detected
// topology.
visitor(max_threads, "num_threads", size_t{0},
"Max threads to use; default 0 = unlimited.", 2);
visitor(pin, "pin", Tristate::kDefault,
"Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
visitor(spin, "spin", Tristate::kDefault,
@ -86,16 +87,126 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
}
};
// Owns threads corresponding to a subset of the system's resources. Because
// this is passed to `Gemma::Generate` (via `MatMulEnv`) rather than defined as
// a singleton, we can have multiple concurrent `Generate` calls within the
// same process, each with their own `ThreadingContext`. Because each context
// may pin its threads, it is important that they use distinct packages,
// clusters, or LPs. For example, to use two packages, the first `args` can have
// `skip_packages` = 0 and the second `skip_packages` = 1.
struct ThreadingContext {
// Expected to be called early in the program, before threading starts.
explicit ThreadingContext(const ThreadingArgs& args);
// Returns a worker index compatible with those from `ParallelFor`, assuming
// the current thread is running on one thread per cluster, which happens
// when `ParallelismStrategy` is `kAcrossClusters`.
size_t Worker(size_t cluster_idx) const {
return cluster_idx * pools.MaxWorkersPerCluster();
}
// Singleton; pass around a reference to reduce overhead.
hwy::Profiler& profiler;
// Detects topology, subject to limits imposed by user-specified `args`.
// For example, if `args.max_clusters` is 1, then `topology.NumClusters()`
// will be 1 regardless of the actual system topology.
BoundedTopology topology;
// Ctor depends on `topology` for per-cluster cache sizes.
CacheInfo cache_info;
// Ctor depends on `topology` (for NUMA) and `cache_info` (for step size).
Allocator allocator;
// Per-package/cluster/within cluster pools of threads, matching `topology`.
NestedPools pools;
};
// Describes the strategy for distributing parallel work across cores.
enum class ParallelismStrategy : uint8_t {
// Execute using a single-threaded loop on the calling thread. The `worker`
// index passed to the user's `Func` is unique across clusters.
kNone,
// One thread per cluster within the first package. The `worker` index passed
// to the user's `Func` is a `cluster_idx <= NumClusters()`. Some CPUs may
// only have a single cluster, hence `Func` should also contain a nested
// `ParallelFor` with `kWithinCluster`.
kAcrossClusters,
// All cores within the cluster identified by `cluster_idx`. The `worker`
// index passed to the user's `Func` is unique across clusters. Choose this
// strategy if already within a `ParallelFor` call with `kAcrossClusters`,
// or latency is more important than memory bandwidth.
kWithinCluster,
// Equivalent to `kAcrossClusters` if there are multiple clusters, otherwise
// `kWithinCluster`. Use for few or lightweight tasks (this only uses a
// single pool and barrier), or to maximize memory bandwidth availability.
kFlat,
// First statically partitions `kAcrossClusters`, then `kWithinCluster`. This
// utilizes all cores, but has higher fork-join overhead (two barriers); use
// if there are many or heavy tasks.
kHierarchical,
};
// Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the
// number/type of workers determined by `parallelism`. `cluster_idx` is for
// `parallelism == kWithinCluster`, and should be 0 if unknown.
template <class Func>
void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
ThreadingContext& ctx, size_t cluster_idx, const Func& func) {
HWY_DASSERT(cluster_idx < ctx.topology.NumClusters());
if (cluster_idx != 0) {
// If already running across clusters, only use within-cluster modes.
HWY_DASSERT(parallelism == ParallelismStrategy::kNone ||
parallelism == ParallelismStrategy::kWithinCluster);
}
switch (parallelism) {
case ParallelismStrategy::kNone: {
const size_t worker = ctx.Worker(cluster_idx);
for (size_t task = 0; task < num_tasks; ++task) {
func(task, worker);
}
return;
}
case ParallelismStrategy::kAcrossClusters:
return ctx.pools.AllClusters().Run(
0, num_tasks,
[&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); });
case ParallelismStrategy::kWithinCluster: {
// Ensure the worker argument is unique across clusters, because it is
// used for TLS indexing for example in profiler.h.
const size_t base = ctx.Worker(cluster_idx);
return ctx.pools.Cluster(cluster_idx)
.Run(0, num_tasks, [&](uint64_t task, size_t worker) {
func(task, base + worker);
});
}
case ParallelismStrategy::kFlat: {
// Check for single cluster; if not, we must compute `cluster_base` for
// consistent and non-overlapping worker indices.
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters();
const size_t num_clusters = all_clusters.NumWorkers();
if (num_clusters == 1) {
return ctx.pools.Cluster(cluster_idx)
.Run(0, num_tasks,
[&](uint64_t task, size_t worker) { func(task, worker); });
}
return ctx.pools.AllClusters().Run(
0, num_tasks, [&](uint64_t task, size_t cluster_idx) {
const size_t worker = ctx.Worker(cluster_idx);
func(task, worker);
});
}
case ParallelismStrategy::kHierarchical:
return HierarchicalParallelFor(num_tasks, ctx.pools, func);
}
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_CONTEXT_H_

View File

@ -99,23 +99,16 @@ TEST(ThreadingTest, TestBoundedTopology) {
const BoundedSlice all;
const BoundedSlice one(0, 1);
// All
{
BoundedTopology topology(all, all, all);
fprintf(stderr, "%s\n", topology.TopologyString());
}
// Max one package
{
BoundedTopology topology(one, all, all);
fprintf(stderr, "%s\n", topology.TopologyString());
ASSERT_EQ(1, topology.NumPackages());
}
// Max one cluster
{
BoundedTopology topology(all, one, all);
BoundedTopology topology(one, one, all);
fprintf(stderr, "%s\n", topology.TopologyString());
ASSERT_EQ(1, topology.NumClusters(0));
ASSERT_EQ(1, topology.NumClusters());
}
}
@ -380,24 +373,32 @@ TEST(ThreadingTest, BenchJoin) {
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
NestedPools& pools = ctx.pools;
// Use last package because the main thread has been pinned to it.
const size_t pkg_idx = pools.NumPackages() - 1;
measure(pools.AllPackages(), false, "block packages");
if (pools.AllClusters(pkg_idx).NumWorkers() > 1) {
measure(pools.AllClusters(pkg_idx), false, "block clusters");
if (pools.NumClusters() > 1) {
measure(pools.AllClusters(), false, "block clusters");
}
measure(pools.Cluster(pkg_idx, 0), false, "block in_cluster");
measure(pools.Cluster(0), false, "block in_cluster");
if (pools.AllPinned()) {
const bool kSpin = true;
measure(pools.AllPackages(), kSpin, "spin packages");
if (pools.AllClusters(pkg_idx).NumWorkers() > 1) {
measure(pools.AllClusters(pkg_idx), kSpin, "spin clusters");
if (pools.NumClusters() > 1) {
measure(pools.AllClusters(), kSpin, "spin clusters");
}
measure(pools.Cluster(pkg_idx, 0), kSpin, "spin in_cluster");
measure(pools.Cluster(0), kSpin, "spin in_cluster");
}
}
TEST(ThreadingTest, TestUnequalClusters) {
ThreadingArgs threading_args;
threading_args.max_lps = 13;
ThreadingContext ctx(threading_args);
const size_t last_workers =
ctx.pools.Cluster(ctx.topology.NumClusters() - 1).NumWorkers();
const size_t max_workers = ctx.pools.MaxWorkersPerCluster();
fprintf(stderr, "%zu clusters, last with %zu (max %zu)\n",
ctx.topology.NumClusters(), last_workers, max_workers);
HWY_ASSERT(last_workers <= max_workers);
}
} // namespace
} // namespace gcpp

View File

@ -18,21 +18,12 @@
#include <stdio.h>
#include <algorithm> // std::sort
#include <utility> // std::move
#include <vector>
#include "hwy/base.h"
namespace gcpp {
// Sort T := packages/clusters by descending 'size' so that users who only use
// one Group get the largest.
template <class T>
static void SortByDescendingSize(std::vector<T>& groups) {
std::sort(groups.begin(), groups.end(),
[](const T& a, const T& b) { return a.Size() > b.Size(); });
}
// Returns set of LPs available for use.
static LPS EnabledLPs(const BoundedSlice& lp_slice) {
LPS enabled_lps;
@ -88,21 +79,23 @@ BoundedTopology::BoundedTopology(BoundedSlice package_slice,
BoundedSlice cluster_slice,
BoundedSlice lp_slice)
: package_slice_(package_slice), cluster_slice_(cluster_slice) {
HWY_ASSERT(package_slice_.Max() == 1);
const LPS enabled_lps = EnabledLPs(lp_slice);
bool topology_ok = false;
#if !GEMMA_DISABLE_TOPOLOGY
if (HWY_LIKELY(!topology_.packages.empty())) {
InitFromTopology(enabled_lps);
topology_ok = InitFromTopology(enabled_lps);
}
#endif
// Topology unknown or no packages with enabled LPs: create a single
// package with one cluster, and one node.
if (HWY_UNLIKELY(NumPackages() == 0)) {
if (HWY_UNLIKELY(!topology_ok)) {
InitFromLPs(enabled_lps);
}
HWY_ASSERT(NumPackages() != 0 && NumClusters(0) != 0 && NumNodes() != 0);
HWY_ASSERT(NumClusters() != 0 && NumNodes() != 0);
}
// Topology is unknown, take the given set of LPs.
@ -161,9 +154,113 @@ constexpr bool kSplitLargeClusters = false;
constexpr size_t kMaxClusters = 8;
constexpr size_t kMaxLPsPerCluster = 6;
// Topology is unknown, use only the given LPs which derive from OS affinity
// and `lp_slice`.
BoundedTopology::Package::Package(const LPS& enabled_lps) {
#if !GEMMA_DISABLE_TOPOLOGY
static size_t CoresFromLPs(const LPS& lps, const hwy::Topology& topology) {
LPS cores;
lps.Foreach([&](size_t lp) {
if (topology.lps[lp].smt == 0) cores.Set(lp);
});
return cores.Count();
}
// tcluster is a modifiable copy of the first cluster in the package.
void BoundedTopology::SplitLargeCluster(const LPS& enabled_lps,
hwy::Topology::Cluster tcluster) {
const LPS lps = clusters_[0].LPSet(); // copy so we can clear
clusters_.clear();
// Split `lps` into several clusters.
LPS clusters_lps[kMaxClusters];
const size_t num_clusters =
HWY_MIN(kMaxClusters, hwy::DivCeil(lps.Count(), kMaxLPsPerCluster));
size_t num_lps = 0;
lps.Foreach(
[&](size_t lp) { clusters_lps[num_lps++ % num_clusters].Set(lp); });
HWY_DASSERT(num_lps == lps.Count());
// Create new clusters, just inserting the new LPS.
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
tcluster.lps = clusters_lps[cluster_idx];
// Keep same `private_kib` and `shared_kib`.
clusters_.push_back(Cluster(enabled_lps, topology_.lps, tcluster));
}
}
// Main part of ctor, called when topology is known.
bool BoundedTopology::InitFromTopology(const LPS& enabled_lps) {
const size_t tpkg_idx = package_slice_.Begin();
HWY_ASSERT(tpkg_idx < topology_.packages.size());
const hwy::Topology::Package& tpackage = topology_.packages[tpkg_idx];
const std::vector<hwy::Topology::Cluster>& tclusters = tpackage.clusters;
if (HWY_UNLIKELY(tclusters.empty())) {
HWY_WARN("Topology: no clusters found in package %zu.", tpkg_idx);
return false;
}
size_t max_tcluster_cores = 0;
size_t max_tcluster_lps = 0;
for (const hwy::Topology::Cluster& tcluster : tclusters) {
const size_t cores = CoresFromLPs(tcluster.lps, topology_);
const size_t lps = tcluster.lps.Count();
max_tcluster_cores = HWY_MAX(max_tcluster_cores, cores);
max_tcluster_lps = HWY_MAX(max_tcluster_lps, lps);
}
HWY_ASSERT(max_tcluster_cores != 0);
HWY_ASSERT(max_tcluster_lps >= max_tcluster_cores);
// Populate `clusters` with the subset of clusters in `cluster_slice` that
// have any enabled LPs.
clusters_.reserve(cluster_slice_.Num(tclusters.size()));
cluster_slice_.Foreach("cluster", tclusters.size(), [&](size_t cluster_idx) {
const hwy::Topology::Cluster& tcluster = tpackage.clusters[cluster_idx];
Cluster cluster(enabled_lps, topology_.lps, tcluster);
// Skip if empty, i.e. too few `enabled_lps`.
if (HWY_LIKELY(cluster.NumWorkers() != 0)) {
clusters_.push_back(cluster);
// Remember NUMA nodes that we are actually using (not just enabled).
nodes_.Set(cluster.Node());
}
});
if (HWY_UNLIKELY(clusters_.empty())) {
HWY_WARN("Too restrictive cluster_slice or enabled_lps, no clusters left.");
return false;
}
if (kSplitLargeClusters && clusters_.size() == 1 &&
enabled_lps.Count() >= 16) {
SplitLargeCluster(enabled_lps, tpackage.clusters[0]);
}
// Sort by descending 'size' so that users who only use one get the largest.
std::sort(clusters_.begin(), clusters_.end(),
[](const Cluster& a, const Cluster& b) {
return a.NumWorkers() > b.NumWorkers();
});
// Largest number of enabled workers in any cluster, for `topology_string_`.
// This may be less than `max_tcluster_cores` if `enabled_lps` excludes some.
size_t max_cluster_workers = 0;
for (const Cluster& c : clusters_) {
max_cluster_workers = HWY_MAX(max_cluster_workers, c.NumWorkers());
}
HWY_ASSERT(max_cluster_workers <= max_tcluster_cores);
// Do not warn about large clusters: GNR has 40.
snprintf(topology_string_, sizeof(topology_string_),
"%zuS %zuX %zuC %zuH, using %zuX %zuC (nodes=%zu)",
topology_.packages.size(), tclusters.size(), max_tcluster_cores,
max_tcluster_lps / max_tcluster_cores, NumClusters(),
max_cluster_workers, nodes_.Count());
return true;
}
#endif // !GEMMA_DISABLE_TOPOLOGY
// Called when topology is unknown or `GEMMA_DISABLE_TOPOLOGY`. Uses only the
// given LPs which derive from OS affinity and `lp_slice`.
void BoundedTopology::InitFromLPs(const LPS& enabled_lps) {
LPS clusters_lps[kMaxClusters];
const size_t num_clusters =
kSplitLargeClusters
@ -178,157 +275,11 @@ BoundedTopology::Package::Package(const LPS& enabled_lps) {
});
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
clusters.push_back(Cluster(clusters_lps[cluster_idx]));
clusters_.push_back(Cluster(clusters_lps[cluster_idx]));
}
}
// NOTE: caller is responsible for checking whether `clusters` is empty.
BoundedTopology::Package::Package(const LPS& enabled_lps,
const hwy::Topology& topology, size_t pkg_idx,
BoundedSlice cluster_slice) {
const hwy::Topology::Package& tpackage = topology.packages[pkg_idx];
// Populate `clusters` with the subset of clusters in `cluster_slice` that
// have any enabled LPs. If `clusters` remains empty, the caller will
// skip this `Package`.
clusters.reserve(cluster_slice.Num(tpackage.clusters.size()));
cluster_slice.Foreach(
"cluster", tpackage.clusters.size(), [&](size_t cluster_idx) {
const hwy::Topology::Cluster& tcluster = tpackage.clusters[cluster_idx];
Cluster cluster(enabled_lps, topology.lps, tcluster);
// Skip if empty, i.e. too few `enabled_lps`.
if (HWY_LIKELY(cluster.Size() != 0)) {
clusters.push_back(cluster);
}
});
SortByDescendingSize(clusters);
// If there is only one large cluster, split it into smaller ones.
if (kSplitLargeClusters && clusters.size() == 1 &&
enabled_lps.Count() >= 16) {
const LPS lps = clusters[0].LPSet(); // copy so we can clear
clusters.clear();
// Split `lps` into several clusters.
LPS clusters_lps[kMaxClusters];
const size_t num_clusters =
HWY_MIN(kMaxClusters, hwy::DivCeil(lps.Count(), kMaxLPsPerCluster));
size_t num_lps = 0;
lps.Foreach(
[&](size_t lp) { clusters_lps[num_lps++ % num_clusters].Set(lp); });
HWY_DASSERT(num_lps == lps.Count());
// Create new clusters, just inserting the new LPS.
hwy::Topology::Cluster tcluster = tpackage.clusters[0]; // modifiable copy
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
tcluster.lps = clusters_lps[cluster_idx];
// Keep same `private_kib` and `shared_kib`.
clusters.push_back(Cluster(enabled_lps, topology.lps, tcluster));
}
}
}
#if !GEMMA_DISABLE_TOPOLOGY
static size_t CoresFromLPs(const LPS& lps, const hwy::Topology& topology) {
LPS cores;
lps.Foreach([&](size_t lp) {
if (topology.lps[lp].smt == 0) cores.Set(lp);
});
return cores.Count();
}
// Scans hwy::Topology for clusters and their size, for use by topology_string_.
static void ScanTClusters(hwy::Topology& topology_, size_t& max_tclusters,
size_t& max_tcluster_cores,
size_t& max_tcluster_lps) {
max_tclusters = 0;
max_tcluster_cores = 0;
max_tcluster_lps = 0;
for (size_t pkg_idx = 0; pkg_idx < topology_.packages.size(); ++pkg_idx) {
const std::vector<hwy::Topology::Cluster>& tclusters =
topology_.packages[pkg_idx].clusters;
max_tclusters = HWY_MAX(max_tclusters, tclusters.size());
size_t tcluster_cores = 0;
size_t tcluster_lps = 0;
for (size_t cluster_idx = 0; cluster_idx < tclusters.size();
++cluster_idx) {
const size_t cores = CoresFromLPs(tclusters[cluster_idx].lps, topology_);
const size_t lps = tclusters[cluster_idx].lps.Count();
tcluster_cores = HWY_MAX(tcluster_cores, cores);
tcluster_lps = HWY_MAX(tcluster_lps, lps);
}
if (tclusters.size() > 1 && tcluster_cores > 8) {
HWY_WARN(
"Package %zu: multiple clusters with max size %zu, whereas CCX "
"only have 8, may indicate a bug in hwy::Topology.",
pkg_idx, tcluster_cores);
}
max_tcluster_cores = HWY_MAX(max_tcluster_cores, tcluster_cores);
max_tcluster_lps = HWY_MAX(max_tcluster_lps, tcluster_lps);
}
HWY_ASSERT(max_tclusters != 0);
HWY_ASSERT(max_tcluster_cores != 0);
HWY_ASSERT(max_tcluster_lps >= max_tcluster_cores);
}
// Main part of ctor, called when topology is known.
void BoundedTopology::InitFromTopology(const LPS& enabled_lps) {
size_t max_tclusters, max_tcluster_cores, max_tcluster_lps;
ScanTClusters(topology_, max_tclusters, max_tcluster_cores, max_tcluster_lps);
// (Possibly empty) subset of `Topology` packages that have `enabled_lps`.
package_slice_.Foreach(
"package", topology_.packages.size(), [&](size_t pkg_idx) {
Package package(enabled_lps, topology_, pkg_idx, cluster_slice_);
// Skip if empty, i.e. too few `enabled_lps`.
if (HWY_LIKELY(!package.clusters.empty())) {
packages_.push_back(std::move(package));
}
});
if (NumPackages() == 0) return;
SortByDescendingSize(packages_);
// Remember NUMA nodes that we are actually using (not just enabled).
for (const Package& p : packages_) {
for (const Cluster& c : p.clusters) {
nodes_.Set(c.Node());
}
}
// Scan for max BoundedTopology clusters and their size, for topology_string_.
size_t all_max_cluster_size = 0;
for (size_t pkg_idx = 0; pkg_idx < NumPackages(); ++pkg_idx) {
size_t max_cluster_size = 0;
for (size_t cluster_idx = 0; cluster_idx < NumClusters(pkg_idx);
++cluster_idx) {
max_cluster_size =
HWY_MAX(max_cluster_size, GetCluster(pkg_idx, cluster_idx).Size());
}
if (NumClusters(pkg_idx) > 1 && max_cluster_size > 8) {
HWY_WARN(
"Package %zu: multiple clusters with max size %zu, whereas CCX "
"only have 8, may indicate a bug in BoundedTopology.",
pkg_idx, max_cluster_size);
}
all_max_cluster_size = HWY_MAX(all_max_cluster_size, max_cluster_size);
}
snprintf(topology_string_, sizeof(topology_string_),
"%zuS %zuX %zuC %zuH, using %zuS %zuX %zuC (nodes=%zu)",
topology_.packages.size(), max_tclusters, max_tcluster_cores,
max_tcluster_lps / max_tcluster_cores, packages_.size(),
NumClusters(0), all_max_cluster_size, nodes_.Count());
}
#endif // !GEMMA_DISABLE_TOPOLOGY
void BoundedTopology::InitFromLPs(const LPS& enabled_lps) {
packages_.push_back(Package(enabled_lps));
snprintf(topology_string_, sizeof(topology_string_), "LPs=%zu",
GetCluster(0, 0).Size());
GetCluster(0).NumWorkers());
// Assume a single NUMA node.
nodes_.Set(0);

View File

@ -40,6 +40,7 @@ class BoundedSlice {
BoundedSlice(size_t skip = 0, size_t max = 0) : skip_(skip), max_(max) {}
size_t Begin() const { return skip_; }
size_t Max() const { return max_; }
// STL-style one past the end.
size_t End(size_t detected) const {
@ -82,12 +83,11 @@ using LPS = hwy::LogicalProcessorSet;
// back to a single package and cluster.
class BoundedTopology {
public:
// Defaults to "use all detected".
BoundedTopology(BoundedSlice package_slice = BoundedSlice(),
// `package_slice` must have `Max() == 1`. Others default to "use all".
BoundedTopology(BoundedSlice package_slice,
BoundedSlice cluster_slice = BoundedSlice(),
BoundedSlice lp_slice = BoundedSlice());
size_t NumPackages() const { return packages_.size(); }
size_t NumNodes() const { return nodes_.Count(); }
const char* TopologyString() const { return topology_string_; }
@ -98,8 +98,7 @@ class BoundedTopology {
const std::vector<hwy::Topology::LP>& all_lps,
const hwy::Topology::Cluster& tcluster);
// For SortByDescendingSize.
size_t Size() const { return num_workers_; }
size_t NumWorkers() const { return num_workers_; }
// Returns vector with all enabled LPs, used for pinning.
std::vector<size_t> LPVector() const {
@ -127,26 +126,11 @@ class BoundedTopology {
size_t shared_kib_ = 0;
}; // Cluster
size_t NumClusters(size_t pkg_idx) const {
HWY_ASSERT(pkg_idx < NumPackages());
return packages_[pkg_idx].clusters.size();
size_t NumClusters() const { return clusters_.size(); }
const Cluster& GetCluster(size_t cluster_idx) const {
HWY_ASSERT(cluster_idx < clusters_.size());
return clusters_[cluster_idx];
}
const Cluster& GetCluster(size_t pkg_idx, size_t cluster_idx) const {
HWY_ASSERT(pkg_idx < NumPackages());
const Package& package = packages_[pkg_idx];
HWY_ASSERT(cluster_idx < package.clusters.size());
return package.clusters[cluster_idx];
}
Cluster& GetCluster(size_t pkg_idx, size_t cluster_idx) {
HWY_ASSERT(pkg_idx < NumPackages());
Package& package = packages_[pkg_idx];
HWY_ASSERT(cluster_idx < package.clusters.size());
return package.clusters[cluster_idx];
}
#if !GEMMA_DISABLE_TOPOLOGY
const hwy::Topology& FullTopology() const { return topology_; }
#endif
// In case we are running with a subset of packages/clusters, these are added
// to the package/cluster indices for purposes of the thread name, so that
@ -155,26 +139,17 @@ class BoundedTopology {
size_t SkippedClusters() const { return cluster_slice_.Begin(); }
private:
struct Package {
explicit Package(const LPS& enabled_lps);
Package(const LPS& enabled_lps, const hwy::Topology& topology,
size_t pkg_idx, BoundedSlice cluster_slice);
// For SortByDescendingSize.
size_t Size() const { return clusters.size(); }
std::vector<Cluster> clusters;
}; // Package
void InitFromTopology(const LPS& enabled_lps);
void SplitLargeCluster(const LPS& enabled_lps,
hwy::Topology::Cluster tcluster);
bool InitFromTopology(const LPS& enabled_lps);
void InitFromLPs(const LPS& enabled_lps);
#if !GEMMA_DISABLE_TOPOLOGY
hwy::Topology topology_;
#endif
BoundedSlice package_slice_;
BoundedSlice package_slice_; // Within the entire detected topology.
BoundedSlice cluster_slice_;
std::vector<Package> packages_;
std::vector<Cluster> clusters_;
char topology_string_[96];
LPS nodes_;
};

70
util/zones.cc Normal file
View File

@ -0,0 +1,70 @@
#include "util/zones.h"
#include "hwy/profiler.h"
namespace gcpp {
#if PROFILER_ENABLED
static constexpr size_t kNumZones = static_cast<size_t>(Zones::kNumZones);
static const char* kProfilerZoneNames[kNumZones] = {
// Keep in sync with Zones enum.
"Ops.RMSNormMul",
"Ops.RMSNorm",
"Ops.RMSNormInplace",
"Ops.Rope",
"Ops.RopeAndMulBy",
"Ops.AddFrom",
"Ops.MulByConst",
"Ops.MulByConstTo",
"Ops.MulByConstAndAdd",
"Ops.MulByConstAndAddTile",
"Ops.MulByConstAndAddTile4",
"Ops.MulByConstAndAddVector",
"Ops.Softmax",
"Ops.LogitsSoftCap",
"FlashAttention.TransposeQ",
"FlashAttention.RMSNormAndPositionalEncoding",
"FlashAttention.SingleFlashAttention",
"FlashAttention.TileFlashAttention",
"FlashAttention.TileFlashAttention4",
"FlashAttention.FlashAttention",
"Gen.Activation",
"Gen.ActivationFused",
"Gen.SampleTop1",
"Gen.SampleTopK",
"Gen.Attention.QDotK",
"Gen.Attention.DotSoftmaxWeightedSum.par",
"Startup.Weights.ReadAllToBF16",
"Startup.Weights.ReadBatches",
"MM.Dispatch",
"MM.MatMul",
"MM.TwoMatMul",
"MM.DecompressA",
"MM.NT",
"MM.NT_K",
"MM.NT_MT",
"MM.NT_MT_K",
};
static hwy::profiler::ZoneHandle profiler_zone_handles[kNumZones];
#endif
void InitProfilerZones(hwy::Profiler& profiler) {
#if PROFILER_ENABLED
// Initialize the zone handles. This is done once at startup.
for (size_t i = 0; i < kNumZones; ++i) {
profiler_zone_handles[i] = profiler.AddZone(kProfilerZoneNames[i]);
}
#endif
}
hwy::profiler::ZoneHandle GetProfilerZone(Zones zone) {
#if PROFILER_ENABLED
return profiler_zone_handles[static_cast<size_t>(zone)];
#else
return hwy::profiler::ZoneHandle();
#endif
}
} // namespace gcpp

58
util/zones.h Normal file
View File

@ -0,0 +1,58 @@
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ZONES_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_ZONES_H_
#include "hwy/profiler.h"
namespace gcpp {
// Zones for the profiler.
enum class Zones {
kOpsRmsNormMul,
kOpsRmsNorm,
kOpsRmsNormInplace,
kOpsRope,
kOpsRopeAndMulBy,
kOpsAddFrom,
kOpsMulByConst,
kOpsMulByConstTo,
kOpsMulByConstAndAdd,
kOpsMulByConstAndAddTile,
kOpsMulByConstAndAddTile4,
kOpsMulByConstAndAddVector,
kOpsSoftmax,
kOpsLogitsSoftCap,
kFlashAttentionTransposeQ,
kFlashAttentionRmsNormAndPositionalEncoding,
kFlashAttentionSingleFlashAttention,
kFlashAttentionTileFlashAttention,
kFlashAttentionTileFlashAttention4,
kFlashAttentionFlashAttention,
kGenActivation,
kGenActivationFused,
kGenSampleTop1,
kGenSampleTopK,
kGenAttentionQDotK,
kGenAttentionDotSoftmaxWeightedSumPar,
kStartupWeightsReadAllToBF16,
kStartupWeightsReadBatches,
kMMDispatch,
kMMMatMul,
kMMTwoMatMul,
kMMDecompressA,
kMMNT,
kMMNT_K,
kMMNT_MT,
kMMNT_MT_K,
kNumZones
};
// Initializes the profiler zones. Must be called before any other profiler
// functions.
void InitProfilerZones(hwy::Profiler& profiler);
// Returns the zone handle for the given zone enum value.
hwy::profiler::ZoneHandle GetProfilerZone(Zones zone);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_ZONES_H_