Merge branch 'dev' into main

This commit is contained in:
KaranocaVe 2025-07-31 22:40:56 +08:00 committed by GitHub
commit 32286f0465
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
164 changed files with 14944 additions and 13962 deletions

37
.gitattributes vendored Normal file
View File

@ -0,0 +1,37 @@
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
2b-pt-sfp.sbs filter=lfs diff=lfs merge=lfs -text
tokenizer.spm filter=lfs diff=lfs merge=lfs -text

View File

@ -17,12 +17,10 @@ jobs:
fail-fast: false
matrix:
# When adding another, also add to copybara's github_check_runs.
os: ['ubuntu-latest', 'macos-latest', 'windows-latest', 'ubuntu-20.04']
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
build_type: ['Release']
preset: ['make', 'windows']
exclude:
- os: ubuntu-20.04
preset: windows
- os: ubuntu-latest
preset: windows
- os: macos-latest
@ -62,44 +60,6 @@ jobs:
${{ github.workspace }}/build/gemma
${{ github.workspace }}/build/libgemma.a
- if: matrix.os == 'ubuntu-20.04'
name: Upload build artifacts to Kaggle
uses: pculliton/push-kaggle-dataset@v1.0.0
env:
KAGGLE_USERNAME: ${{ secrets.KAGGLE_USERNAME }}
KAGGLE_KEY: ${{ secrets.KAGGLE_KEY }}
with:
id: "phillipculliton/gemma-build-artifacts"
files: |
build/gemma
build/_deps/sentencepiece-build/src/libsentencepiece.so.0
- if: matrix.os == 'ubuntu-20.04'
name: Create code for new test notebook version
run: |
cat > runner.py << EOF
import subprocess
subprocess.run(["cp", "/kaggle/input/gemma-build-artifacts/gemma", "/kaggle/working"])
subprocess.run(["chmod", "700", "/kaggle/working/gemma"])
subprocess.run(["cp", "/kaggle/input/gemma-build-artifacts/_deps/sentencepiece-build/src/libsentencepiece.so.0", "/kaggle/working"])
output = subprocess.run(["/kaggle/working/gemma", "--tokenizer", "/kaggle/input/gemma/gemmacpp/2b-it-sfp/4/tokenizer.spm", "--compressed_weights", "/kaggle/input/gemma/gemmacpp/2b-it-sfp/4/2b-it-sfp.sbs", "--model", "2b-it", "--verbosity", "0", "--max_generated_tokens", "128"], stdout=subprocess.PIPE, input='Write an email to the moon.', encoding='ascii').stdout
assert("write an email to the moon." not in output.lower());
assert("moon" in output.lower());
EOF
- if: matrix.os == 'ubuntu-20.04'
name: Run kaggle test notebook
uses: pculliton/kaggle-action@v1.0.28
with:
username: ${{ secrets.KAGGLE_USERNAME }}
key: ${{ secrets.KAGGLE_KEY }}
title: GemmaCPP-CI-2
code_file: runner.py
dataset_sources: "phillipculliton/gemma-build-artifacts"
model_sources: "google/gemma/gemmaCpp/2b-it-sfp/4"
enable_gpu: False
kernel_type: script
bazel:
runs-on: ubuntu-latest
steps:
@ -116,4 +76,4 @@ jobs:
with:
path: ~/.cache/bazel
key: bazel-${{ runner.os }}
- run: bazel build --cxxopt=-std=c++20 //:all
- run: bazel build --cxxopt=-std=c++20 //:gemma --jobs=10 --show_progress_rate_limit=1

21
.gitignore vendored
View File

@ -1,4 +1,25 @@
# Build directories
.cache/
bazel-*/
build-*/
build/
# Python cache
python/*/__pycache__
# Model files
*.sbs
*.spm
*.data
*.bin
*.weights
# IDE and editor files
.vscode/
.idea/
*.swp
*~
# Local development
.env
.env.local

15
.vscode/c_cpp_properties.json vendored Normal file
View File

@ -0,0 +1,15 @@
{
"configurations": [
{
"name": "Linux",
"includePath": [
"${workspaceFolder}/**"
],
"defines": [],
"cStandard": "c17",
"cppStandard": "c++17",
"intelliSenseMode": "linux-clang-x64"
}
],
"version": 4
}

View File

@ -19,7 +19,10 @@ license(
# Dual-licensed Apache 2 and 3-clause BSD.
licenses(["notice"])
exports_files(["LICENSE"])
exports_files([
"LICENSE",
".github/workflows/build.yml",
])
cc_library(
name = "basics",
@ -29,6 +32,16 @@ cc_library(
],
)
cc_library(
name = "args",
hdrs = ["util/args.h"],
deps = [
":basics",
"//io", # Path
"@highway//:hwy",
],
)
# Split from :threading to break a circular dependency with :allocator.
cc_library(
name = "topology",
@ -59,6 +72,7 @@ cc_library(
hdrs = ["util/threading.h"],
deps = [
":allocator",
":args",
":basics",
":topology",
# Placeholder for container detection, do not remove
@ -68,14 +82,28 @@ cc_library(
],
)
cc_library(
name = "threading_context",
srcs = ["util/threading_context.cc"],
hdrs = ["util/threading_context.h"],
deps = [
":allocator",
":args",
":basics",
":threading",
":topology",
"@highway//:hwy",
"@highway//:profiler",
],
)
cc_test(
name = "threading_test",
srcs = ["util/threading_test.cc"],
deps = [
":allocator",
":basics",
":threading",
"@googletest//:gtest_main",
":threading_context",
"@googletest//:gtest_main", # buildcleaner: keep
"@highway//:auto_tune",
"@highway//:hwy",
"@highway//:hwy_test_util",
@ -97,6 +125,124 @@ cc_library(
],
)
cc_library(
name = "configs",
srcs = ["gemma/configs.cc"],
hdrs = ["gemma/configs.h"],
deps = [
":basics",
"//compression:types",
"//io",
"//io:fields",
"@highway//:hwy", # base.h
],
)
cc_test(
name = "configs_test",
srcs = ["gemma/configs_test.cc"],
deps = [
":configs",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:types",
"//io:fields",
],
)
cc_library(
name = "tensor_info",
srcs = ["gemma/tensor_info.cc"],
hdrs = ["gemma/tensor_info.h"],
deps = [
":basics",
":configs",
"//compression:types",
],
)
cc_library(
name = "mat",
srcs = ["util/mat.cc"],
hdrs = ["util/mat.h"],
deps = [
":allocator",
":basics",
":tensor_info",
":threading_context",
"//compression:types",
"//io:fields",
"@highway//:hwy",
"@highway//:profiler",
],
)
cc_library(
name = "tokenizer",
srcs = ["gemma/tokenizer.cc"],
hdrs = ["gemma/tokenizer.h"],
deps = [
":configs",
"@highway//:hwy",
"@highway//:profiler",
"@com_google_sentencepiece//:sentencepiece_processor",
],
)
cc_library(
name = "model_store",
srcs = ["gemma/model_store.cc"],
hdrs = ["gemma/model_store.h"],
deps = [
":allocator",
":basics",
":configs",
":mat",
":tensor_info",
":threading_context",
":tokenizer",
"//compression:types",
"//io",
"//io:blob_store",
"//io:fields",
"@highway//:hwy",
"@highway//:thread_pool",
],
)
cc_library(
name = "weights",
srcs = ["gemma/weights.cc"],
hdrs = ["gemma/weights.h"],
deps = [
":configs",
":gemma_args",
":mat",
":matmul",
":model_store",
":tensor_info",
":threading_context",
"//compression:compress",
"//io:blob_store",
"@highway//:hwy",
"@highway//:profiler",
"@highway//:thread_pool",
],
)
cc_test(
name = "tensor_info_test",
srcs = ["gemma/tensor_info_test.cc"],
deps = [
":configs",
":mat",
":tensor_info",
":weights",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"@highway//:hwy", # aligned_allocator.h
],
)
# For building all tests in one command, so we can test several.
test_suite(
name = "ops_tests",
@ -104,34 +250,75 @@ test_suite(
)
cc_library(
name = "ops",
name = "matmul",
srcs = ["ops/matmul.cc"],
hdrs = ["ops/matmul.h"],
textual_hdrs = ["ops/matmul-inl.h"],
deps = [
":allocator",
":basics",
":mat",
":threading_context",
"//compression:compress",
"@highway//:bit_set",
"@highway//:hwy",
"@highway//:nanobenchmark",
"@highway//:profiler",
],
)
cc_library(
name = "matmul_static",
srcs = [
"ops/matmul.cc",
# single-file build time is ~30sec for msan, hence shard.
"ops/matmul_static_bf16.cc",
"ops/matmul_static_f32.cc",
"ops/matmul_static_nuq.cc",
"ops/matmul_static_sfp.cc",
],
hdrs = [
"ops/matmul.h",
"ops/ops.h",
"ops/matmul_static.h",
],
textual_hdrs = [
"ops/matmul_static-inl.h",
"ops/matmul-inl.h",
],
deps = [
":allocator",
":basics",
":mat",
":matmul",
":threading_context",
"//compression:compress",
"//compression:types",
"@highway//:hwy",
"@highway//:profiler",
"@highway//:timer",
],
)
cc_library(
name = "ops",
hdrs = ["ops/ops.h"],
textual_hdrs = [
"ops/dot-inl.h",
"ops/sum-inl.h",
"ops/fp_arith-inl.h",
"ops/matmul-inl.h",
"ops/matvec-inl.h",
"ops/ops-inl.h",
],
deps = [
":allocator",
":basics",
":threading",
":topology",
":mat",
":matmul",
":matmul_static",
":threading_context",
"//compression:compress",
"@highway//:algo",
"@highway//:bit_set",
"@highway//:hwy",
"@highway//:math",
"@highway//:matvec",
"@highway//:nanobenchmark",
"@highway//:profiler",
"@highway//:thread_pool",
"@highway//hwy/contrib/sort:vqsort",
@ -143,15 +330,15 @@ cc_test(
size = "small",
timeout = "long",
srcs = ["ops/dot_test.cc"],
linkstatic = True,
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["ops_tests"],
deps = [
":allocator",
":app",
":ops",
":test_util",
":threading",
":threading_context",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"//compression:test_util",
@ -167,20 +354,23 @@ cc_test(
cc_test(
name = "ops_test",
size = "small",
timeout = "eternal",
timeout = "long",
srcs = ["ops/ops_test.cc"],
linkstatic = True,
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["ops_tests"],
deps = [
":allocator",
":app",
":common",
":basics",
":configs",
":gemma_lib",
":mat",
":ops",
":test_util",
":threading",
":threading_context",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"//compression:types",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark", #buildcleaner: keep
@ -192,11 +382,14 @@ cc_test(
size = "small",
timeout = "long",
srcs = ["ops/gemma_matvec_test.cc"],
linkstatic = True,
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["ops_tests"],
deps = [
":mat",
":ops",
":threading_context",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"@highway//:hwy",
@ -210,18 +403,23 @@ cc_test(
size = "small",
timeout = "long",
srcs = ["ops/matmul_test.cc"],
linkstatic = True,
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["ops_tests"],
deps = [
":allocator",
":basics",
":mat",
":matmul",
":matmul_static",
":ops",
":threading",
":threading_context",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"//compression:test_util",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
"@highway//:thread_pool",
],
)
@ -231,6 +429,7 @@ cc_test(
size = "small",
timeout = "long",
srcs = ["ops/bench_matmul.cc"],
linkstatic = True,
local_defines = ["HWY_IS_TEST"],
tags = [
"manual",
@ -238,12 +437,12 @@ cc_test(
"ops_tests", # for test_suite.
],
deps = [
":allocator",
":basics",
":ops",
":threading",
":matmul",
":threading_context",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"//compression:test_util",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
@ -252,101 +451,47 @@ cc_test(
],
)
cc_library(
name = "common",
srcs = [
"gemma/common.cc",
"gemma/configs.cc",
"gemma/tensor_index.cc",
],
hdrs = [
"gemma/common.h",
"gemma/configs.h",
"gemma/tensor_index.h",
],
deps = [
"//compression:fields",
"//compression:sfp",
"@highway//:hwy", # base.h
"@highway//:thread_pool",
],
)
cc_test(
name = "configs_test",
srcs = ["gemma/configs_test.cc"],
deps = [
":common",
"@googletest//:gtest_main",
"@highway//:hwy",
],
)
cc_test(
name = "tensor_index_test",
srcs = ["gemma/tensor_index_test.cc"],
deps = [
":basics",
":common",
":weights",
"@googletest//:gtest_main",
"//compression:compress",
"@highway//:hwy",
],
)
cc_library(
name = "weights",
srcs = ["gemma/weights.cc"],
hdrs = ["gemma/weights.h"],
deps = [
":common",
"//compression:blob_store",
"//compression:compress",
"//compression:io",
"@highway//:hwy",
"@highway//:profiler",
"@highway//:stats",
"@highway//:thread_pool",
],
)
cc_library(
name = "tokenizer",
srcs = ["gemma/tokenizer.cc"],
hdrs = ["gemma/tokenizer.h"],
deps = [
":common",
"//compression:io",
"//compression:sfp",
"@highway//:hwy",
"@highway//:profiler",
"@com_google_sentencepiece//:sentencepiece_processor",
],
)
cc_library(
name = "kv_cache",
srcs = ["gemma/kv_cache.cc"],
hdrs = ["gemma/kv_cache.h"],
deps = [
":common",
":basics",
":configs",
":gemma_args",
":mat",
"@highway//:hwy",
],
)
cc_library(
name = "gemma_args",
hdrs = ["gemma/gemma_args.h"],
deps = [
":args",
":basics",
":mat",
":matmul",
"//io",
"@highway//:hwy",
"@highway//:profiler",
],
)
cc_library(
name = "gemma_lib",
srcs = [
"gemma/attention.cc",
"gemma/gemma.cc",
"gemma/instantiations/bf16.cc",
"gemma/instantiations/f32.cc",
"gemma/instantiations/nuq.cc",
"gemma/instantiations/sfp.cc",
"gemma/griffin.cc",
"gemma/vit.cc",
],
hdrs = [
"gemma/activations.h",
"gemma/attention.h",
"gemma/gemma.h",
"gemma/griffin.h",
"gemma/vit.h",
],
exec_properties = {
# Avoid linker OOMs when building with sanitizer instrumentation.
@ -354,23 +499,26 @@ cc_library(
},
textual_hdrs = [
"gemma/gemma-inl.h",
# Placeholder for internal file2, do not remove,
],
deps = [
":allocator",
":basics",
":common",
":ops",
":tokenizer",
":configs",
":gemma_args",
":kv_cache",
":weights",
":mat",
":matmul",
":model_store",
":ops",
":threading",
":threading_context",
":weights",
"//compression:compress",
"//compression:io",
"//compression:sfp",
"//compression:types",
"//io:blob_store",
"//io",
"//paligemma:image",
"@highway//:hwy",
"@highway//:bit_set",
"@highway//:nanobenchmark", # timer
"@highway//:profiler",
"@highway//:thread_pool",
@ -382,35 +530,9 @@ cc_library(
srcs = ["evals/cross_entropy.cc"],
hdrs = ["evals/cross_entropy.h"],
deps = [
":common",
":gemma_lib",
":ops",
"@highway//:hwy",
],
)
cc_library(
name = "args",
hdrs = ["util/args.h"],
deps = [
":basics",
"//compression:io",
"@highway//:hwy",
],
)
cc_library(
name = "app",
hdrs = ["util/app.h"],
deps = [
":args",
":basics",
":common",
":gemma_lib",
":ops",
":threading",
"//compression:io",
"//compression:sfp",
"//compression:types",
"@highway//:hwy",
],
)
@ -420,20 +542,49 @@ cc_library(
srcs = ["evals/benchmark_helper.cc"],
hdrs = ["evals/benchmark_helper.h"],
deps = [
":app",
":args",
":common",
":configs",
":cross_entropy",
":gemma_args",
":gemma_lib",
":kv_cache",
":matmul",
":ops",
":threading",
# Placeholder for internal dep, do not remove.,
":threading_context",
":tokenizer",
"@google_benchmark//:benchmark",
"//compression:compress",
"@highway//:hwy",
"@highway//:nanobenchmark",
"@highway//:topology",
"@highway//:profiler",
],
)
cc_library(
name = "gemma_shared_lib",
srcs = [
"gemma/bindings/c_api.cc",
"gemma/bindings/context.cc",
],
hdrs = [
"gemma/bindings/c_api.h",
"gemma/bindings/context.h",
],
exec_properties = {
# Avoid linker OOMs when building with sanitizer instrumentation.
"mem": "28g",
},
deps = [
":benchmark_helper",
":gemma_args",
":gemma_lib",
":kv_cache",
":matmul",
":threading",
":threading_context",
":tokenizer",
"//paligemma:image",
"@highway//:hwy",
"@highway//:profiler",
"@highway//:timer",
],
)
@ -449,9 +600,10 @@ cc_test(
],
deps = [
":benchmark_helper",
":common",
":configs",
":gemma_lib",
"@googletest//:gtest_main",
"@googletest//:gtest_main", # buildcleaner: keep
"//io",
"@highway//:hwy",
"@highway//:hwy_test_util",
],
@ -460,6 +612,7 @@ cc_test(
cc_test(
name = "gemma_batch_bench",
srcs = ["evals/gemma_batch_bench.cc"],
linkstatic = True,
# Requires model files
tags = [
"local",
@ -468,12 +621,12 @@ cc_test(
],
deps = [
":benchmark_helper",
":common",
":gemma_lib",
":tokenizer",
"@googletest//:gtest_main",
"@googletest//:gtest_main", # buildcleaner: keep
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
"@highway//:profiler",
],
)
@ -481,15 +634,13 @@ cc_binary(
name = "gemma",
srcs = ["gemma/run.cc"],
deps = [
":app",
":args",
":benchmark_helper",
":common",
":gemma_args",
":gemma_lib",
":ops",
":threading",
# Placeholder for internal dep, do not remove.,
"//compression:sfp",
":matmul",
":tokenizer",
"//compression:types",
"//paligemma:image",
"@highway//:hwy",
"@highway//:profiler",
@ -502,22 +653,15 @@ cc_binary(
deps = [
":args",
":benchmark_helper",
":common",
":cross_entropy",
":gemma_lib",
"//compression:io",
"//io",
"@highway//:hwy",
"@highway//:nanobenchmark",
"@nlohmann_json//:json",
],
)
cc_library(
name = "benchmark_prompts",
hdrs = ["evals/prompts.h"],
deps = ["@highway//:hwy"],
)
cc_binary(
name = "benchmarks",
srcs = [
@ -526,7 +670,6 @@ cc_binary(
],
deps = [
":benchmark_helper",
":benchmark_prompts",
"@google_benchmark//:benchmark",
"@highway//:hwy", # base.h
],
@ -534,14 +677,12 @@ cc_binary(
cc_binary(
name = "debug_prompt",
srcs = [
"evals/debug_prompt.cc",
],
srcs = ["evals/debug_prompt.cc"],
deps = [
":args",
":benchmark_helper",
":gemma_lib",
"//compression:io",
"//io",
"@highway//:hwy",
"@nlohmann_json//:json",
],
@ -554,158 +695,9 @@ cc_binary(
":args",
":benchmark_helper",
":gemma_lib",
"//compression:io",
"//io",
"@highway//:hwy",
"@highway//:profiler",
"@highway//:thread_pool",
"@nlohmann_json//:json",
],
)
cc_library(
name = "prompt",
hdrs = ["backprop/prompt.h"],
deps = [],
)
cc_library(
name = "sampler",
hdrs = ["backprop/sampler.h"],
deps = [
":prompt",
],
)
cc_library(
name = "backprop",
srcs = [
"backprop/backward.cc",
"backprop/forward.cc",
],
hdrs = [
"backprop/activations.h",
"backprop/backward.h",
"backprop/forward.h",
],
textual_hdrs = [
"backprop/backward-inl.h",
"backprop/forward-inl.h",
],
deps = [
":allocator",
":common",
":ops",
":prompt",
":weights",
"//compression:compress",
"@highway//:dot",
"@highway//:hwy", # base.h
"@highway//:thread_pool",
],
)
cc_library(
name = "backprop_scalar",
hdrs = [
"backprop/activations.h",
"backprop/backward_scalar.h",
"backprop/common_scalar.h",
"backprop/forward_scalar.h",
],
deps = [
":common",
":prompt",
":weights",
"//compression:compress",
"@highway//:hwy",
],
)
cc_test(
name = "backward_scalar_test",
size = "large",
srcs = [
"backprop/backward_scalar_test.cc",
"backprop/test_util.h",
],
deps = [
":backprop_scalar",
":common",
":prompt",
":sampler",
":weights",
"@googletest//:gtest_main",
"//compression:compress",
"@highway//:thread_pool",
],
)
cc_test(
name = "backward_test",
size = "large",
srcs = [
"backprop/backward_test.cc",
"backprop/test_util.h",
],
exec_properties = {
# Avoid linker OOMs when building with sanitizer instrumentation.
"mem": "28g",
},
deps = [
":allocator",
":backprop",
":backprop_scalar",
":common",
":ops",
":prompt",
":sampler",
":threading",
":weights",
"@googletest//:gtest_main",
"//compression:compress",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:thread_pool",
],
)
cc_library(
name = "optimizer",
srcs = ["backprop/optimizer.cc"],
hdrs = ["backprop/optimizer.h"],
deps = [
":allocator",
":common",
":weights",
"//compression:compress",
"@highway//:hwy",
"@highway//:thread_pool",
],
)
cc_test(
name = "optimize_test",
srcs = [
"backprop/optimize_test.cc",
],
exec_properties = {
# Avoid linker OOMs when building with sanitizer instrumentation.
"mem": "28g",
},
deps = [
":allocator",
":backprop",
":basics",
":common",
":gemma_lib",
":ops",
":optimizer",
":prompt",
":sampler",
":threading",
":weights",
"@googletest//:gtest_main",
"//compression:sfp",
"@highway//:thread_pool",
],
)

View File

@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c5bebf84ad01edec97e336f5c97ca4e0df6b4d06 EXCLUDE_FROM_ALL)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9414b48aeec251b69e6cadbfa42bebb5ddae1c34 EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(highway)
## Note: absl needs to be installed by sentencepiece. This will only happen if
@ -39,58 +39,54 @@ set(BENCHMARK_ENABLE_GTEST_TESTS OFF)
FetchContent_Declare(benchmark GIT_REPOSITORY https://github.com/google/benchmark.git GIT_TAG v1.8.2 EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(benchmark)
# Base source files
set(SOURCES
compression/blob_store.cc
compression/blob_store.h
compression/compress-inl.h
compression/compress.cc
compression/compress.h
compression/compress-inl.h
compression/fields.cc
compression/fields.h
compression/io_win.cc
compression/io.cc
compression/io.h
compression/nuq-inl.h
compression/sfp-inl.h
compression/shared.h
compression/types.h
compression/test_util-inl.h
backprop/activations.h
backprop/backward.cc
backprop/backward.h
backprop/backward-inl.h
backprop/backward_scalar.h
backprop/common_scalar.h
backprop/forward.cc
backprop/forward.h
backprop/forward-inl.h
backprop/forward_scalar.h
backprop/optimizer.cc
backprop/optimizer.h
evals/benchmark_helper.cc
evals/benchmark_helper.h
evals/cross_entropy.cc
evals/cross_entropy.h
gemma/activations.h
gemma/common.cc
gemma/common.h
gemma/attention.cc
gemma/attention.h
gemma/configs.cc
gemma/configs.h
gemma/gemma_args.h
gemma/gemma-inl.h
gemma/gemma.cc
gemma/gemma.h
gemma/instantiations/bf16.cc
gemma/instantiations/f32.cc
gemma/instantiations/nuq.cc
gemma/instantiations/sfp.cc
gemma/griffin.cc
gemma/griffin.h
gemma/kv_cache.cc
gemma/kv_cache.h
gemma/tensor_index.cc
gemma/tensor_index.h
gemma/model_store.cc
gemma/model_store.h
gemma/tensor_info.cc
gemma/tensor_info.h
gemma/tokenizer.cc
gemma/tokenizer.h
gemma/vit.cc
gemma/vit.h
gemma/weights.cc
gemma/weights.h
io/blob_store.cc
io/blob_store.h
io/fields.cc
io/fields.h
io/io_win.cc
io/io.cc
io/io.h
ops/dot-inl.h
ops/matmul_static_bf16.cc
ops/matmul_static_f32.cc
ops/matmul_static_nuq.cc
ops/matmul_static_sfp.cc
ops/matmul-inl.h
ops/matmul.cc
ops/matmul.h
@ -102,15 +98,28 @@ set(SOURCES
paligemma/image.h
util/allocator.cc
util/allocator.h
util/app.h
util/args.h
util/basics.h
util/mat.cc
util/mat.h
util/test_util.h
util/threading_context.cc
util/threading_context.h
util/threading.cc
util/threading.h
util/topology.cc
util/topology.h
)
# Add C API sources only when building DLL
if(BUILD_GEMMA_DLL)
list(APPEND SOURCES
gemma/bindings/context.h
gemma/bindings/context.cc
gemma/bindings/c_api.h
gemma/bindings/c_api.cc
)
message(STATUS "Including C API files for DLL build")
endif()
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release")
@ -131,6 +140,33 @@ target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
install(TARGETS libgemma DESTINATION lib)
# Shared library target for C# interop
if(BUILD_GEMMA_DLL)
add_library(gemma_shared SHARED ${SOURCES})
set_property(TARGET gemma_shared PROPERTY CXX_STANDARD 17)
set_target_properties(gemma_shared PROPERTIES
PREFIX ""
OUTPUT_NAME "gemma"
)
set_property(TARGET gemma_shared PROPERTY POSITION_INDEPENDENT_CODE ON)
target_include_directories(gemma_shared PUBLIC ./)
target_link_libraries(gemma_shared PRIVATE
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy>
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy_contrib>
$<LINK_LIBRARY:WHOLE_ARCHIVE,sentencepiece-static>
)
target_include_directories(gemma_shared PUBLIC ${sentencepiece_SOURCE_DIR})
target_compile_definitions(gemma_shared
PRIVATE
GEMMA_EXPORTS
$<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>
)
target_compile_options(gemma_shared PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
install(TARGETS gemma_shared DESTINATION lib)
install(FILES gemma/c_api.h DESTINATION include/gemma)
install(FILES gemma/GemmaInterop.cs DESTINATION include/gemma)
endif()
# Executable Target
add_executable(gemma gemma/run.cc)
@ -154,17 +190,14 @@ enable_testing()
include(GoogleTest)
set(GEMMA_TEST_FILES
backprop/backward_scalar_test.cc
backprop/backward_test.cc
backprop/optimize_test.cc
compression/blob_store_test.cc
compression/compress_test.cc
compression/distortion_test.cc
compression/fields_test.cc
compression/nuq_test.cc
compression/sfp_test.cc
evals/gemma_test.cc
gemma/tensor_index_test.cc
gemma/tensor_info_test.cc
io/blob_store_test.cc
io/fields_test.cc
ops/bench_matmul.cc
ops/dot_test.cc
ops/gemma_matvec_test.cc
@ -197,8 +230,5 @@ endif() # GEMMA_ENABLE_TESTS
## Tools
add_executable(compress_weights compression/compress_weights.cc)
target_link_libraries(compress_weights libgemma hwy hwy_contrib)
add_executable(migrate_weights compression/migrate_weights.cc)
add_executable(migrate_weights io/migrate_weights.cc)
target_link_libraries(migrate_weights libgemma hwy hwy_contrib)

View File

@ -31,6 +31,24 @@
"lhs": "${hostSystemName}",
"rhs": "Windows"
}
},
{
"name": "windows-dll",
"inherits": "__defaults__",
"displayName": "Windows DLL",
"description": "Visual Studio 2022 with Clang/LLVM frontend (DLL build)",
"generator": "Visual Studio 17 2022",
"toolset": "ClangCL",
"condition": {
"type": "equals",
"lhs": "${hostSystemName}",
"rhs": "Windows"
},
"cacheVariables": {
"BUILD_SHARED_LIBS": "OFF",
"CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS": "ON",
"BUILD_GEMMA_DLL": "ON"
}
}
],
"buildPresets": [
@ -54,6 +72,15 @@
"displayName": "Windows",
"configuration": "Release",
"configurePreset": "windows"
},
{
"name": "windows-dll",
"displayName": "Windows DLL",
"configuration": "Release",
"configurePreset": "windows-dll",
"targets": [
"gemma_shared"
]
}
]
}

View File

@ -96,21 +96,10 @@ https://github.com/keras-team/keras-nlp/blob/master/tools/gemma/export_gemma_to_
From Pytorch, use the following script to generate uncompressed weights:
https://github.com/google/gemma.cpp/blob/dev/compression/convert_weights.py
Then run `compression/compress_weights.cc` (Bazel target
`compression:compress_weights`), specifying the resulting file as `--weights`
and the desired .sbs name as the `--compressed_weights`.
For PaliGemma, use `python/convert_from_safetensors` to create an SBS file
directly.
## Compile-Time Flags (Advanced)
There are several compile-time flags to be aware of (note these may or may not
be exposed to the build system):
- `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV
Cache. The default is 4096 tokens but can be overridden. This is not exposed
through `CMakeLists.txt` yet.
In the medium term this will likely be deprecated in favor of handling options
at runtime - dynamically resizing the KV cache as needed.
For other models, `gemma_export_main.py` is not yet open sourced.
## Using gemma.cpp as a Library (Advanced)
@ -164,7 +153,7 @@ constrained decoding type of use cases where you want to force the generation to
fit a grammar. If you're not doing this, you can send an empty lambda or
`std::function` as a no-op which is what `run.cc` does.
### `Transformer()` implements the inference (i.e. `forward()` method in PyTorch or Jax) computation of the neural network
### `Transformer()` implements inference (i.e. `forward()` in PyTorch or Jax)
For high-level applications, you might only call `model.Generate()` and never
interact directly with the neural network, but if you're doing something a bit
@ -172,9 +161,6 @@ more custom you can call transformer which performs a single inference operation
on a single token and mutates the Activations and the KVCache through the neural
network computation.
Note that an experimental backward pass is available in backprop/, which may be
useful for fine tuning.
### For low level operations, defining new architectures, call `ops.h` functions directly
You use `ops.h` if you're writing other NN architectures or modifying the

View File

@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5")
# Require a more recent version.
git_override(
module_name = "highway",
commit = "c5bebf84ad01edec97e336f5c97ca4e0df6b4d06",
commit = "9414b48aeec251b69e6cadbfa42bebb5ddae1c34",
remote = "https://github.com/google/highway",
)
@ -71,6 +71,7 @@ pip.parse(
requirements_lock = "//compression/python:requirements.txt",
)
use_repo(pip, "compression_deps")
pip.parse(
hub_name = "python_deps",
python_version = "3.11",

216
README.md
View File

@ -6,7 +6,7 @@ foundation models from Google.
For additional information about Gemma, see
[ai.google.dev/gemma](https://ai.google.dev/gemma). Model weights, including
gemma.cpp specific artifacts, are
[available on kaggle](https://www.kaggle.com/models/google/gemma).
[available on kaggle](https://www.kaggle.com/models/google/gemma-2).
## Who is this project for?
@ -18,8 +18,8 @@ deployment-oriented C++ inference runtimes, which are not designed for
experimentation, and Python-centric ML research frameworks, which abstract away
low-level computation through compilation.
gemma.cpp provides a minimalist implementation of Gemma-1, Gemma-2, Gemma-3, and
PaliGemma models, focusing on simplicity and directness rather than full
gemma.cpp provides a minimalist implementation of Gemma-2, Gemma-3, and
PaliGemma-2 models, focusing on simplicity and directness rather than full
generality. This is inspired by vertically-integrated model implementations such
as [ggml](https://github.com/ggerganov/ggml),
[llama.c](https://github.com/karpathy/llama2.c), and
@ -45,9 +45,41 @@ this invite link](https://discord.gg/H5jCBAWxAe). This project follows
[Google's Open Source Community
Guidelines](https://opensource.google.com/conduct/).
*Active development is currently done on the `dev` branch. Please open pull
requests targeting `dev` branch instead of `main`, which is intended to be more
stable.*
> [!NOTE] Active development is currently done on the `dev` branch. Please open
> pull requests targeting `dev` branch instead of `main`, which is intended to
> be more stable.
## What's inside?
- LLM
- CPU-only inference for: Gemma 2-3, Griffin(SSM), PaliGemma 2.
- Sampling with TopK and temperature.
- Backward pass (VJP) and Adam optimizer for Gemma research.
- Optimizations
- Mixed-precision (fp8, bf16, fp32, fp64 bit) GEMM:
- Designed for BF16 instructions, can efficiently emulate them.
- Automatic runtime autotuning 7 parameters per matrix shape.
- Weight compression integrated directly into GEMM:
- Custom fp8 format with 2..3 mantissa bits; tensor scaling.
- Also bf16, f32 and non-uniform 4-bit (NUQ); easy to add new formats.
- Infrastructure
- SIMD: single implementation via Highway. Chooses ISA at runtime.
- Tensor parallelism: CCX-aware, multi-socket thread pool.
- Disk I/O: memory map or parallel read (heuristic with user override).
- Custom format with forward/backward-compatible metadata serialization.
- Model conversion from Safetensors, not yet open sourced.
- Portability: Linux, Windows/OS X supported. CMake/Bazel. 'Any' CPU.
- Frontends
- C++ APIs with streaming for single query and batched inference.
- Basic interactive command-line app.
- Basic Python bindings (pybind11).
## Quick Start
@ -74,57 +106,20 @@ winget install --id Microsoft.VisualStudio.2022.BuildTools --force --override "-
Visit the
[Kaggle page for Gemma-2](https://www.kaggle.com/models/google/gemma-2/gemmaCpp)
[or Gemma-1](https://www.kaggle.com/models/google/gemma/frameworks/gemmaCpp),
and select `Model Variations |> Gemma C++`.
On this tab, the `Variation` dropdown includes the options below. Note bfloat16
weights are higher fidelity, while 8-bit switched floating point weights enable
faster inference. In general, we recommend starting with the `-sfp` checkpoints.
If you are unsure which model to start with, we recommend starting with the
smallest Gemma-2 model, i.e. `2.0-2b-it-sfp`.
Alternatively, visit the
[gemma.cpp](https://huggingface.co/models?other=gemma.cpp) models on the Hugging
Face Hub. First go the model repository of the model of interest (see
recommendations below). Then, click the `Files and versions` tab and download
the model and tokenizer files. For programmatic downloading, if you have
`huggingface_hub` installed, you can also download by running:
```
huggingface-cli login # Just the first time
huggingface-cli download google/gemma-2b-sfp-cpp --local-dir build/
```
Gemma-1 2B instruction-tuned (`it`) and pre-trained (`pt`) models:
| Model name | Description |
| ----------- | ----------- |
| `2b-it` | 2 billion parameter instruction-tuned model, bfloat16 |
| `2b-it-sfp` | 2 billion parameter instruction-tuned model, 8-bit switched floating point |
| `2b-pt` | 2 billion parameter pre-trained model, bfloat16 |
| `2b-pt-sfp` | 2 billion parameter pre-trained model, 8-bit switched floating point |
Gemma-1 7B instruction-tuned (`it`) and pre-trained (`pt`) models:
| Model name | Description |
| ----------- | ----------- |
| `7b-it` | 7 billion parameter instruction-tuned model, bfloat16 |
| `7b-it-sfp` | 7 billion parameter instruction-tuned model, 8-bit switched floating point |
| `7b-pt` | 7 billion parameter pre-trained model, bfloat16 |
| `7b-pt-sfp` | 7 billion parameter pre-trained model, 8-bit switched floating point |
> [!NOTE]
> **Important**: We strongly recommend starting off with the `2b-it-sfp` model to
> get up and running.
> [!NOTE] **Important**: We strongly recommend starting off with the
> `gemma2-2b-it-sfp` model to get up and running.
Gemma 2 models are named `gemma2-2b-it` for 2B and `9b-it` or `27b-it`. See the
`kModelFlags` definition in `common.cc`.
`ModelPrefix` function in `configs.cc`.
### Step 2: Extract Files
If you downloaded the models from Hugging Face, skip to step 3.
After filling out the consent form, the download should proceed to retrieve a
tar archive file `archive.tar.gz`. Extract files from `archive.tar.gz` (this can
take a few minutes):
@ -162,10 +157,9 @@ cmake --build --preset make -j [number of parallel threads to use]
```
Replace `[number of parallel threads to use]` with a number - the number of
cores available on your system is a reasonable heuristic. For example,
`make -j4 gemma` will build using 4 threads. If the `nproc` command is
available, you can use `make -j$(nproc) gemma` as a reasonable default
for the number of threads.
cores available on your system is a reasonable heuristic. For example, `make -j4
gemma` will build using 4 threads. If the `nproc` command is available, you can
use `make -j$(nproc) gemma` as a reasonable default for the number of threads.
If you aren't sure of the right value for the `-j` flag, you can simply run
`make gemma` instead and it should still build the `./gemma` executable.
@ -174,7 +168,8 @@ If you aren't sure of the right value for the `-j` flag, you can simply run
> On Windows Subsystem for Linux (WSL) users should set the number of
> parallel threads to 1. Using a larger number may result in errors.
If the build is successful, you should now have a `gemma` executable in the `build/` directory.
If the build is successful, you should now have a `gemma` executable in the
`build/` directory.
#### Windows
@ -186,7 +181,8 @@ cmake --preset windows
cmake --build --preset windows -j [number of parallel threads to use]
```
If the build is successful, you should now have a `gemma.exe` executable in the `build/` directory.
If the build is successful, you should now have a `gemma.exe` executable in the
`build/` directory.
#### Bazel
@ -194,7 +190,8 @@ If the build is successful, you should now have a `gemma.exe` executable in the
bazel build -c opt --cxxopt=-std=c++20 :gemma
```
If the build is successful, you should now have a `gemma` executable in the `bazel-bin/` directory.
If the build is successful, you should now have a `gemma` executable in the
`bazel-bin/` directory.
#### Make
@ -208,33 +205,21 @@ You can now run `gemma` from inside the `build/` directory.
`gemma` has the following required arguments:
Argument | Description | Example value
--------------- | ---------------------------- | -----------------------
`--model` | The model type. | `2b-it` ... (see below)
`--weights` | The compressed weights file. | `2b-it-sfp.sbs`
`--weight_type` | The compressed weight type. | `sfp`
`--tokenizer` | The tokenizer file. | `tokenizer.spm`
`gemma` is invoked as:
```sh
./gemma \
--tokenizer [tokenizer file] \
--weights [compressed weights file] \
--weight_type [f32 or bf16 or sfp (default:sfp)] \
--model [2b-it or 2b-pt or 7b-it or 7b-pt or ...]
```
Argument | Description | Example value
------------- | ---------------------------- | ---------------
`--weights` | The compressed weights file. | `2b-it-sfp.sbs`
`--tokenizer` | The tokenizer file. | `tokenizer.spm`
Example invocation for the following configuration:
- Compressed weights file `2b-it-sfp.sbs` (2B instruction-tuned model, 8-bit
switched floating point).
- Tokenizer file `tokenizer.spm`.
- weights file `gemma2-2b-it-sfp.sbs` (Gemma2 2B instruction-tuned model,
8-bit switched floating point).
- Tokenizer file `tokenizer.spm` (can omit for single-format weights files
created after 2025-05-06, or output by migrate_weights.cc).
```sh
./gemma \
--tokenizer tokenizer.spm \
--weights 2b-it-sfp.sbs --model 2b-it
--tokenizer tokenizer.spm --weights gemma2-2b-it-sfp.sbs
```
### RecurrentGemma
@ -256,23 +241,20 @@ Step 1, and run the binary as follows:
### PaliGemma Vision-Language Model
This repository includes a version of the PaliGemma VLM
([paper](https://arxiv.org/abs/2407.07726),
[code](https://github.com/google-research/big_vision/tree/main/big_vision/configs/proj/paligemma))
and its successor PaliGemma 2 ([paper](https://arxiv.org/abs/2412.03555)). We
provide a C++ implementation of the PaliGemma model family here.
This repository includes a version of the PaliGemma 2 VLM
([paper](https://arxiv.org/abs/2412.03555)). We provide a C++ implementation of
the PaliGemma 2 model here.
To use the version of PaliGemma included in this repository, build the gemma
binary as noted above in Step 3. Download the compressed weights and tokenizer
from
[Kaggle](https://www.kaggle.com/models/google/paligemma/gemmaCpp/paligemma-3b-mix-224)
[Kaggle](https://www.kaggle.com/models/google/paligemma-2/gemmaCpp/paligemma2-3b-mix-224)
and run the binary as follows:
```sh
./gemma \
--tokenizer paligemma_tokenizer.model \
--model paligemma-224 \
--weights paligemma-3b-mix-224-sfp.sbs \
--weights paligemma2-3b-mix-224-sfp.sbs \
--image_file paligemma/testdata/image.ppm
```
@ -312,12 +294,12 @@ allows to contain the tokenizer (and the model type) directly. A tool to migrate
from the multi-file format to the single-file format is available.
```sh
compression/migrate_weights \
io/migrate_weights \
--tokenizer .../tokenizer.spm --weights .../gemma2-2b-it-sfp.sbs \
--model gemma2-2b-it --output_weights .../gemma2-2b-it-sfp-single.sbs
--output_weights .../gemma2-2b-it-sfp-single.sbs
```
After migration, you can use the new weights file with gemma.cpp like this:
After migration, you can omit the tokenizer argument like this:
```sh
./gemma --weights .../gemma2-2b-it-sfp-single.sbs
@ -325,15 +307,6 @@ After migration, you can use the new weights file with gemma.cpp like this:
### Troubleshooting and FAQs
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
The most common problem is that the `--weight_type` argument does not match that
of the model file. Revisit step #3 and check which weights you downloaded.
Note that we have already moved weight type from a compile-time decision to a
runtime argument. In a subsequent step, we plan to bake this information into
the weights.
**Problems building in Windows / Visual Studio**
Currently if you're using Windows, we recommend building in WSL (Windows
@ -344,22 +317,22 @@ configurations, see issues for active discussion.
A common issue is that you are using a pre-trained model, which is not
instruction-tuned and thus does not respond to instructions. Make sure you are
using an instruction-tuned model (`2b-it-sfp`, `2b-it`, `7b-it-sfp`, `7b-it`)
and not a pre-trained model (any model with a `-pt` suffix).
using an instruction-tuned model (`gemma2-2b-it-sfp`) and not a pre-trained
model (any model with a `-pt` suffix).
**What sequence lengths are supported?**
See `seq_len` in `configs.cc`. For the Gemma 3 models larger than 1B, this is
typically 32K but 128K would also work given enough RAM. Note that long
sequences will be slow due to the quadratic cost of attention.
See `max_seq_len` in `configs.cc` and `InferenceArgs.seq_len`. For the Gemma 3
models larger than 1B, this is typically 32K but 128K would also work given
enough RAM. Note that long sequences will be slow due to the quadratic cost of
attention.
**How do I convert my fine-tune to a `.sbs` compressed model file?**
For PaliGemma (1 and 2) checkpoints, you can use
python/convert_from_safetensors.py to convert from safetensors format (tested
with building via bazel). For an adapter model, you will likely need to call
merge_and_unload() to convert the adapter model to a single-file format before
converting it.
For PaliGemma 2 checkpoints, you can use python/convert_from_safetensors.py to
convert from safetensors format (tested with building via bazel). For an adapter
model, you will likely need to call merge_and_unload() to convert the adapter
model to a single-file format before converting it.
Here is how to use it using a bazel build of the compression library assuming
locally installed (venv) torch, numpy, safetensors, absl-py, etc.:
@ -373,22 +346,18 @@ ln -s $BAZEL_OUTPUT_DIR [...]/site-packages/compression
python3 python/convert_from_safetensors.py --load_path [...].safetensors.index.json
```
See also compression/convert_weights.py for a slightly older option to convert a
pytorch checkpoint. (The code may need updates to work with Gemma-2 models.)
**What are some easy ways to make the model run faster?**
1. Make sure you are using the 8-bit switched floating point `-sfp` models.
These are half the size of bf16 and thus use less memory bandwidth and cache
space.
2. If you're on a laptop, make sure power mode is set to maximize performance
2. Due to auto-tuning, the second and especially third query will be faster.
3. If you're on a laptop, make sure power mode is set to maximize performance
and saving mode is **off**. For most laptops, the power saving modes get
activated automatically if the computer is not plugged in.
3. Close other unused cpu-intensive applications.
4. On macs, anecdotally we observe a "warm-up" ramp-up in speed as performance
4. Close other unused cpu-intensive applications.
5. On macs, anecdotally we observe a "warm-up" ramp-up in speed as performance
cores get engaged.
5. Experiment with the `--num_threads` argument value. Depending on the device,
larger numbers don't always mean better performance.
We're also working on algorithmic and optimization approaches for faster
inference, stay tuned.
@ -411,7 +380,7 @@ newline input.
By default, verbosity is set to 1, bringing up a terminal-based interactive
interface when `gemma` is invoked:
```console
```sh
$ ./gemma [...]
__ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __
/ _` |/ _ \ '_ ` _ \| '_ ` _ \ / _` | / __| '_ \| '_ \
@ -420,11 +389,7 @@ $ ./gemma [...]
__/ | | | | |
|___/ |_| |_|
tokenizer : tokenizer.spm
compressed_weights : 2b-it-sfp.sbs
model : 2b-it
weights : [no path specified]
max_generated_tokens : 2048
...
*Usage*
Enter an instruction and press enter (%C reset conversation, %Q quits).
@ -462,7 +427,7 @@ For using the `gemma` executable as a command line tool, it may be useful to
create an alias for gemma.cpp with arguments fully specified:
```sh
alias gemma2b="~/gemma.cpp/build/gemma -- --tokenizer ~/gemma.cpp/build/tokenizer.spm --weights ~/gemma.cpp/build/gemma2-2b-it-sfp.sbs --model gemma2-2b-it --verbosity 0"
alias gemma2b="~/gemma.cpp/build/gemma -- --tokenizer ~/gemma.cpp/build/tokenizer.spm --weights ~/gemma.cpp/build/gemma2-2b-it-sfp.sbs --verbosity 0"
```
Replace the above paths with your own paths to the model and tokenizer paths
@ -481,7 +446,7 @@ cat configs.h | tail -n 35 | tr '\n' ' ' | xargs -0 echo "What does this C++ cod
The output of the above command should look like:
```console
```sh
[ Reading prompt ] [...]
This C++ code snippet defines a set of **constants** used in a large language model (LLM) implementation, likely related to the **attention mechanism**.
@ -492,8 +457,8 @@ Let's break down the code:
### Incorporating gemma.cpp as a Library in your Project
The easiest way to incorporate gemma.cpp in your own project is to pull in
gemma.cpp and dependencies using `FetchContent`. You can add the following to your
CMakeLists.txt:
gemma.cpp and dependencies using `FetchContent`. You can add the following to
your CMakeLists.txt:
```
include(FetchContent)
@ -562,9 +527,10 @@ submit a PR with a `README.md` edit.
## Acknowledgements and Contacts
gemma.cpp was started in fall 2023 by [Austin Huang](mailto:austinvhuang@google.com)
and [Jan Wassenberg](mailto:janwas@google.com), and subsequently released February 2024
thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng.
gemma.cpp was started in fall 2023 by
[Austin Huang](mailto:austinvhuang@google.com) and
[Jan Wassenberg](mailto:janwas@google.com), and subsequently released February
2024 thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng.
Griffin support was implemented in April 2024 thanks to contributions by Andrey
Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode

View File

@ -1,81 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_
#define THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_
#include <stddef.h>
#include <vector>
#include "compression/compress.h" // MatStorageT
#include "gemma/configs.h" // ModelConfig
namespace gcpp {
template <typename T>
struct ForwardLayer {
ForwardLayer(const LayerConfig& config, size_t seq_len)
: input("input", seq_len, config.model_dim),
pre_att_rms_out("pre_att_rms_out", seq_len, config.model_dim),
qkv("qkv", seq_len * (config.heads + 2), config.qkv_dim),
att("att", seq_len * config.heads, seq_len),
att_out("att_out", seq_len * config.heads, config.qkv_dim),
att_post1("att_post1", seq_len, config.model_dim),
attention_out("attention_out", seq_len, config.model_dim),
bf_pre_ffw_rms_out("bf_pre_ffw_rms_out", seq_len, config.model_dim),
ffw_hidden("ffw_hidden", seq_len, config.ff_hidden_dim * 2),
ffw_hidden_gated("ffw_hidden_gated", seq_len, config.ff_hidden_dim),
layer_config(config) {}
MatStorageT<T> input;
MatStorageT<T> pre_att_rms_out;
MatStorageT<T> qkv;
MatStorageT<T> att;
MatStorageT<T> att_out;
MatStorageT<T> att_post1;
MatStorageT<T> attention_out;
MatStorageT<T> bf_pre_ffw_rms_out;
MatStorageT<T> ffw_hidden;
MatStorageT<T> ffw_hidden_gated;
const LayerConfig& layer_config;
};
template <typename T>
struct ForwardPass {
ForwardPass(const ModelConfig& config)
: final_layer_output("final_layer_output", config.seq_len,
config.model_dim),
final_norm_output("final_norm_output", config.seq_len,
config.model_dim),
logits("logits", config.seq_len, config.vocab_size),
probs("probs", config.seq_len, config.vocab_size),
weights_config(config) {
for (const auto& layer_config : config.layer_configs) {
layers.emplace_back(layer_config, config.seq_len);
}
}
std::vector<ForwardLayer<T>> layers;
MatStorageT<T> final_layer_output;
MatStorageT<T> final_norm_output;
MatStorageT<T> logits;
MatStorageT<T> probs;
const ModelConfig& weights_config;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_

View File

@ -1,404 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Implementation of the Vector-Jacobian Products (VJP) of the individual
// operations of the forward pass.
// Include guard for non-SIMD code.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
#include <stddef.h>
#include <cmath>
#include <vector>
#include "backprop/activations.h"
#include "backprop/prompt.h"
#include "gemma/common.h"
#include "gemma/weights.h"
#include "util/allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
// Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
#endif
#include "hwy/highway.h"
// After highway.h
#include "ops/matmul-inl.h"
#include "ops/ops-inl.h"
#include "hwy/contrib/dot/dot-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
HWY_INLINE void MatMulVJP(const float* HWY_RESTRICT weights, // kRows * kCols,
const float* HWY_RESTRICT x, // num_tokens * kCols
const float* HWY_RESTRICT v, // num_tokens * kRows
size_t cols, size_t rows, size_t num_tokens,
float* HWY_RESTRICT grad_w, // kRows * kCols,
float* HWY_RESTRICT grad_x, // num_tokens * kCols
hwy::ThreadPool& pool) {
hwy::ZeroBytes(grad_x, num_tokens * cols * sizeof(grad_x[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t voffs = pos * rows;
const size_t xoffs = pos * cols;
for (size_t j = 0; j < rows; ++j) {
MulByConstAndAdd(v[voffs + j], &x[xoffs], &grad_w[j * cols], cols);
MulByConstAndAdd(v[voffs + j], &weights[j * cols], &grad_x[xoffs], cols);
}
}
}
HWY_INLINE void MultiHeadMatMulVJP(
const float* HWY_RESTRICT weights, // heads * kRows * kCols
const float* HWY_RESTRICT x, // num_tokens * heads * kCols
const float* HWY_RESTRICT v, // num_tokens * kRows
size_t heads, size_t cols, size_t rows, size_t num_tokens,
float* HWY_RESTRICT grad_w, // heads * kRows * kCols
float* HWY_RESTRICT grad_x, // num_tokens * heads * kCols
hwy::ThreadPool& pool) {
hwy::ZeroBytes(grad_x, num_tokens * heads * cols * sizeof(grad_x[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t j = 0; j < rows; ++j) {
for (size_t h = 0; h < heads; ++h) {
MulByConstAndAdd(v[pos * rows + j], &x[pos * heads * cols + h * cols],
&grad_w[h * rows * cols + j * cols], cols);
MulByConstAndAdd(v[pos * rows + j],
&weights[h * rows * cols + j * cols],
&grad_x[pos * heads * cols + h * cols], cols);
}
}
}
}
template <class D, HWY_IF_F32_D(D)>
static HWY_INLINE hn::Vec<D> DGelu(D d, hn::Vec<D> v) {
const hn::Vec<D> kMul = hn::Set(d, 0.044715f);
const hn::Vec<D> kSqrt2OverPi = hn::Set(d, 0.797884560804236f);
const hn::Vec<D> kHalf = hn::Set(d, 0.5f);
const hn::Vec<D> kOne = hn::Set(d, 1.0f);
// kSqrtOverPi*3*kMul
const hn::Vec<D> kMulv2 = hn::Set(d, 0.1070322244f);
const hn::Vec<D> v2 = hn::Mul(v, v);
const hn::Vec<D> v3 = hn::Mul(v2, v);
const hn::Vec<D> arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v));
const hn::Vec<D> tanh = hn::Tanh(d, arg);
const hn::Vec<D> cdf = hn::MulAdd(kHalf, tanh, kHalf);
const hn::Vec<D> dtanh = hn::Sub(kOne, hn::Mul(tanh, tanh));
const hn::Vec<D> darg = hn::MulAdd(kMulv2, v2, kSqrt2OverPi);
return hn::MulAdd(kHalf, hn::Mul(v, hn::Mul(dtanh, darg)), cdf);
}
static HWY_NOINLINE void SoftmaxVJP(const float* HWY_RESTRICT forward,
float* HWY_RESTRICT backward,
const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
const auto offset =
hn::Set(d, hn::Dot::Compute<0>(d, forward, backward, size));
hn::Transform1(
d, backward, size, forward,
[&offset](const auto d, const auto v, const auto y)
HWY_ATTR { return hn::Mul(y, hn::Sub(v, offset)); });
}
static HWY_NOINLINE void RMSNormVJP(
const float* HWY_RESTRICT weights, const float* HWY_RESTRICT x,
const float* HWY_RESTRICT v, size_t model_dim, size_t num_tokens,
float* HWY_RESTRICT grad_w, float* HWY_RESTRICT grad_x,
hwy::ThreadPool& pool) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t offset = pos * model_dim;
const float ss = detail::RMSNormMul(x + offset, model_dim);
for (size_t i = 0; i < model_dim; ++i) {
grad_w[i] += v[offset + i] * x[offset + i] * ss;
}
const float ss3 = ss * ss * ss / StaticCast<float>(model_dim);
float tmp = 0.0f;
for (size_t i = 0; i < model_dim; ++i) {
tmp += (1.0f + weights[i]) * v[offset + i] * x[offset + i];
}
tmp *= ss3;
for (size_t i = 0; i < model_dim; ++i) {
grad_x[offset + i] = ss * (1.0f + weights[i]) * v[offset + i] -
tmp * x[offset + i];
}
}
}
static HWY_NOINLINE void InputEmbeddingVJP(
const float* weights, const std::vector<int>& prompt,
const float scaling, const float* HWY_RESTRICT v,
float* HWY_RESTRICT grad, size_t model_dim) {
HWY_ASSERT(!prompt.empty());
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
int token = prompt[pos];
MulByConstAndAdd(scaling, v + pos * model_dim,
grad + token * model_dim, model_dim);
}
}
template <typename T>
void LayerVJP(const LayerWeightsPtrs<T>& weights,
const ForwardLayer<float>& forward,
const float* HWY_RESTRICT next_layer_grad, size_t num_tokens,
LayerWeightsPtrs<T>& grad, ForwardLayer<float>& backward,
const RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
const LayerConfig& config = weights.layer_config;
const size_t model_dim = config.model_dim;
const size_t qkv_dim = config.qkv_dim;
const size_t heads = config.heads;
const size_t seq_len = forward.input.Rows();
const size_t ff_hidden_dim = config.ff_hidden_dim;
const float query_scale =
static_cast<float>(1.0 / sqrt(static_cast<double>(qkv_dim)));
HWY_ASSERT(num_tokens <= seq_len);
MatMulVJP(weights.linear_w.data(), forward.ffw_hidden_gated.data(),
next_layer_grad, ff_hidden_dim, model_dim, num_tokens,
grad.linear_w.data(), backward.ffw_hidden_gated.data(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t hidden_offset = pos * ff_hidden_dim * 2;
const float* HWY_RESTRICT f_out = forward.ffw_hidden.data() + hidden_offset;
const float* HWY_RESTRICT f_out_mul = f_out + ff_hidden_dim;
const float* HWY_RESTRICT b_out_gated =
backward.ffw_hidden_gated.data() + pos * ff_hidden_dim;
float* HWY_RESTRICT b_out = backward.ffw_hidden.data() + hidden_offset;
float* HWY_RESTRICT b_out_mul = b_out + ff_hidden_dim;
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
DF df;
for (size_t i = 0; i < ff_hidden_dim; i += Lanes(df)) {
const auto y = Load(df, f_out + i);
const auto x = Load(df, f_out_mul + i);
const auto v = Load(df, b_out_gated + i);
hn::Store(hn::Mul(v, Gelu(df, y)), df, b_out_mul + i);
hn::Store(hn::Mul(v, hn::Mul(x, DGelu(df, y))), df, b_out + i);
}
}
MatMulVJP(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(),
backward.ffw_hidden.data(), model_dim, ff_hidden_dim * 2,
num_tokens, grad.gating_einsum_w.data(),
backward.bf_pre_ffw_rms_out.data(), pool);
RMSNormVJP(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(),
backward.bf_pre_ffw_rms_out.data(), model_dim, num_tokens,
grad.pre_ffw_norm_scale.data(), backward.attention_out.data(),
pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(next_layer_grad + pos * model_dim,
backward.attention_out.data() + pos * model_dim, model_dim);
}
backward.qkv.ZeroInit();
MultiHeadMatMulVJP(weights.attn_vec_einsum_w.data(), forward.att_out.data(),
backward.attention_out.data(), heads, qkv_dim, model_dim,
num_tokens, grad.attn_vec_einsum_w.data(),
backward.att_out.data(), pool);
for (size_t head = 0; head < heads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t aoffset = head * seq_len + pos * heads * seq_len;
const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset;
const float* HWY_RESTRICT b_att_out =
backward.att_out.data() + (pos * heads + head) * qkv_dim;
float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t v2offs = (pos2 * (heads + 2) + heads + 1) * qkv_dim;
const float* HWY_RESTRICT f_v2 = forward.qkv.data() + v2offs;
float* HWY_RESTRICT b_v2 = backward.qkv.data() + v2offs;
b_head_att[pos2] = Dot(b_att_out, f_v2, qkv_dim);
MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, qkv_dim);
}
}
}
for (size_t head = 0; head < heads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t aoffset = head * seq_len + pos * heads * seq_len;
const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset;
float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset;
SoftmaxVJP(f_head_att, b_head_att, pos + 1);
}
}
for (size_t head = 0; head < heads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t qoffs = (pos * (heads + 2) + head) * qkv_dim;
const size_t aoffs = head * seq_len + pos * heads * seq_len;
const float* HWY_RESTRICT f_q = forward.qkv.data() + qoffs;
const float* HWY_RESTRICT b_head_att = backward.att.data() + aoffs;
float* HWY_RESTRICT b_q = backward.qkv.data() + qoffs;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t k2offs = (pos2 * (heads + 2) + heads) * qkv_dim;
const float* HWY_RESTRICT f_k2 = forward.qkv.data() + k2offs;
float* HWY_RESTRICT b_k2 = backward.qkv.data() + k2offs;
MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, qkv_dim);
MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, qkv_dim);
}
}
}
for (int pos = 0; pos < static_cast<int>(num_tokens); ++pos) {
float* HWY_RESTRICT b_kv =
backward.qkv.data() + (pos * (heads + 2) + heads) * qkv_dim;
Rope(b_kv, qkv_dim, inv_timescale.Const(), -pos);
}
for (size_t head = 0; head < heads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
float* HWY_RESTRICT b_q =
backward.qkv.data() + (pos * (heads + 2) + head) * qkv_dim;
MulByConst(query_scale, b_q, qkv_dim);
Rope(b_q, qkv_dim, inv_timescale.Const(), -pos);
}
}
MatMulVJP(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(),
backward.qkv.data(), model_dim, (heads + 2) * qkv_dim, num_tokens,
grad.qkv_einsum_w.data(), backward.pre_att_rms_out.data(), pool);
RMSNormVJP(weights.pre_attention_norm_scale.data(), forward.input.data(),
backward.pre_att_rms_out.data(), model_dim, num_tokens,
grad.pre_attention_norm_scale.data(), backward.input.data(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(backward.attention_out.data() + pos * model_dim,
backward.input.data() + pos * model_dim, model_dim);
}
}
static HWY_NOINLINE void SoftcapVJP(const float cap,
const float* HWY_RESTRICT forward,
float* HWY_RESTRICT backward,
const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
const auto one = hn::Set(d, 1.0f);
const auto vcap = hn::Set(d, cap);
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
hn::Transform1(d, backward, size, forward,
[&](const auto d, const auto v, const auto y) HWY_ATTR {
const auto scaled = hn::Mul(vinv_cap, y); // = tanh
return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled)));
});
}
static HWY_NOINLINE void CrossEntropyLossGrad(
const float* HWY_RESTRICT x, float* HWY_RESTRICT grad,
const Prompt& prompt, size_t vocab_size) {
HWY_ASSERT(!prompt.tokens.empty());
const float scaling = -1.0 / std::log(2.0);
size_t num_tokens = prompt.tokens.size() - 1;
hwy::ZeroBytes(grad, num_tokens * vocab_size * sizeof(grad[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) {
if (pos + 1 < prompt.context_size) {
continue;
}
const int next_token = prompt.tokens[pos + 1];
grad[pos * vocab_size + next_token] =
scaling / x[pos * vocab_size + next_token];
}
}
template <typename T>
void CrossEntropyLossBackwardPassInl(const Prompt& prompt,
const ModelWeightsPtrs<T>& weights,
const ForwardPass<float>& forward,
ModelWeightsPtrs<T>& grad,
ForwardPass<float>& backward,
RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
const ModelConfig& config = weights.weights_config;
const size_t kVocabSize = config.vocab_size;
const size_t model_dim = config.model_dim;
const size_t kLayers = config.layer_configs.size();
const float kEmbScaling = EmbeddingScaling(model_dim);
HWY_ASSERT(!config.absolute_pe);
HWY_ASSERT(config.layer_configs[0].post_norm == PostNormType::None);
HWY_ASSERT(config.layer_configs[0].kv_heads == 1);
HWY_DASSERT(prompt.context_size > 0);
HWY_DASSERT(prompt.context_size < prompt.tokens.size());
const size_t num_tokens = prompt.tokens.size() - 1;
CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt,
kVocabSize);
for (size_t pos = 0; pos < num_tokens; ++pos) {
SoftmaxVJP(forward.probs.data() + pos * kVocabSize,
backward.logits.data() + pos * kVocabSize,
kVocabSize);
}
if (config.final_cap > 0.0f) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
SoftcapVJP(config.final_cap, forward.logits.data() + pos * kVocabSize,
backward.logits.data() + pos * kVocabSize, kVocabSize);
}
}
MatMulVJP(weights.embedder_input_embedding.data(),
forward.final_norm_output.data(), backward.logits.data(), model_dim,
kVocabSize, num_tokens, grad.embedder_input_embedding.data(),
backward.final_norm_output.data(), pool);
RMSNormVJP(weights.final_norm_scale.data(), forward.final_layer_output.data(),
backward.final_norm_output.data(), model_dim, num_tokens,
grad.final_norm_scale.data(), backward.final_layer_output.data(),
pool);
for (int layer = static_cast<int>(kLayers) - 1; layer >= 0; --layer) {
auto layer_config = config.layer_configs[layer];
// TODO(szabadka) Implement Griffin layer vjp.
HWY_ASSERT(layer_config.type == LayerAttentionType::kGemma);
float* next_layer_grad = layer + 1 < kLayers
? backward.layers[layer + 1].input.data()
: backward.final_layer_output.data();
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
num_tokens, *grad.GetLayer(layer), backward.layers[layer],
inv_timescale, pool);
}
InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens,
kEmbScaling, backward.layers[0].input.data(),
grad.embedder_input_embedding.data(), model_dim);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // NOLINT

View File

@ -1,73 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "backprop/backward.h"
#include "backprop/activations.h"
#include "backprop/prompt.h"
#include "gemma/common.h"
#include "gemma/weights.h"
#include "util/allocator.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "backprop/backward.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "backprop/backward-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
void CrossEntropyLossBackwardPassT(const Prompt& prompt,
const ModelWeightsPtrs<float>& weights,
const ForwardPass<float>& forward,
ModelWeightsPtrs<float>& grad,
ForwardPass<float>& backward,
RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
CrossEntropyLossBackwardPassInl(prompt, weights, forward, grad, backward,
inv_timescale, pool);
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(CrossEntropyLossBackwardPassT);
void CrossEntropyLossBackwardPass(const Prompt& prompt,
const ModelWeightsPtrs<float>& weights,
const ForwardPass<float>& forward,
ModelWeightsPtrs<float>& grad,
ForwardPass<float>& backward,
RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)(
prompt, weights, forward, grad, backward, inv_timescale, pool);
}
} // namespace gcpp
#endif // HWY_ONCE

View File

@ -1,37 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
#include "backprop/activations.h"
#include "backprop/prompt.h"
#include "gemma/weights.h"
#include "util/allocator.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
void CrossEntropyLossBackwardPass(const Prompt& prompt,
const ModelWeightsPtrs<float>& weights,
const ForwardPass<float>& forward,
ModelWeightsPtrs<float>& grad,
ForwardPass<float>& backward,
RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_

View File

@ -1,349 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_
#include <stddef.h>
#include <string.h>
#include <cmath>
#include <vector>
#include "backprop/activations.h"
#include "backprop/common_scalar.h"
#include "backprop/prompt.h"
#include "gemma/common.h" // EmbeddingScaling
#include "gemma/weights.h"
namespace gcpp {
template<typename T>
void MatMulVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx,
size_t N, size_t M, size_t K) {
memset(dx, 0, M * K * sizeof(dx[0]));
for (size_t i = 0; i < K; ++i) {
for (size_t j = 0; j < N; ++j) {
MulByConstAndAddT(dy[i * N + j], &x[i * M], &dw[j * M], M);
MulByConstAndAddT(dy[i * N + j], &w[j * M], &dx[i * M], M);
}
}
}
template<typename T>
void MultiHeadMatMulVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx,
size_t H, size_t N, size_t M, size_t K) {
memset(dx, 0, H * M * K * sizeof(dx[0]));
for (size_t i = 0; i < K; ++i) {
for (size_t j = 0; j < N; ++j) {
for (size_t h = 0; h < H; ++h) {
MulByConstAndAddT(dy[i * N + j], &x[i * H * M + h * M],
&dw[h * N * M + j * M], M);
MulByConstAndAddT(dy[i * N + j], &w[h * N * M + j * M],
&dx[i * H * M + h * M], M);
}
}
}
}
template<typename T>
void RMSNormVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx,
size_t N, size_t K) {
for (size_t i = 0; i < K; ++i) {
constexpr T eps(1e-6);
T ss = SquaredL2(x + i * N, N);
ss = T(1.0) / std::sqrt(ss / T(N) + eps);
for (size_t j = 0; j < N; ++j) {
dw[j] += dy[i * N + j] * x[i * N + j] * ss;
}
const T ss3 = ss * ss * ss / T(N);
T tmp = 0.0;
for (size_t j = 0; j < N; ++j) {
tmp += (T(1.0) + w[j]) * dy[i* N + j] * x[i * N + j];
}
tmp *= ss3;
for (size_t j = 0; j < N; ++j) {
dx[i * N + j] = ss * (T(1.0) + w[j]) * dy[i* N + j] - tmp * x[i * N + j];
}
}
}
template<typename T>
void SoftmaxVJPT(const T* y, T* dy, size_t N) {
T sum = {};
for (size_t i = 0; i < N; ++i) {
sum += y[i] * dy[i];
}
for (size_t i = 0; i < N; ++i) {
dy[i] = y[i] * (dy[i] - sum);
}
}
template<typename T>
void SoftmaxVJPT(const T* y, T* dy, size_t N, size_t K) {
for (size_t i = 0; i < K; ++i) {
SoftmaxVJPT(y + i * N, dy + i * N, N);
}
}
template<typename T>
T GeluDerivative(T x) {
static const T kMul = 0.044715;
static const T kSqrt2OverPi = 0.797884560804236;
static const T kMul2 = kSqrt2OverPi * T(3.0) * kMul;
const T x2 = x * x;
const T x3 = x2 * x;
const T arg = kSqrt2OverPi * (kMul * x3 + x);
const T tanh = std::tanh(arg);
const T cdf = T(0.5) * (T(1.0) + tanh);
const T dtanh = T(1.0) - tanh * tanh;
const T darg = kMul2 * x2 + kSqrt2OverPi;
return T(0.5) * x * dtanh * darg + cdf;
}
template<typename T>
void GatedGeluVJP(const T* in, const T* d_out, T* d_in, size_t N, size_t K) {
for (size_t i = 0; i < K; ++i) {
const T* x1 = in + i * 2 * N;
const T* x2 = x1 + N;
const T* v = d_out + i * N;
T* dx1 = d_in + i * 2 * N;
T* dx2 = dx1 + N;
for (size_t j = 0; j < N; ++j) {
dx1[j] = v[j] * x2[j] * GeluDerivative(x1[j]);
dx2[j] = v[j] * Gelu(x1[j]);
}
}
}
template <typename T>
void MaskedAttentionVJP(const T* qkv, const T* doutput, T* dqkv,
size_t num_tokens, size_t kHeads, size_t qkv_dim,
size_t seq_len) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t offset = pos * (kHeads + 2) * qkv_dim;
memset(dqkv + offset, 0, (kHeads + 1) * qkv_dim * sizeof(qkv[0]));
}
for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t qoffs = (pos * (kHeads + 2) + head) * qkv_dim;
const size_t aoffs = head * seq_len + pos * kHeads * seq_len;
const T* q = qkv + qoffs;
const T* dout = doutput + aoffs;
T* dq = dqkv + qoffs;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t koffs = (pos2 * (kHeads + 2) + kHeads) * qkv_dim;
const T* k = qkv + koffs;
T* dk = dqkv + koffs;
MulByConstAndAddT(dout[pos2], k, dq, qkv_dim);
MulByConstAndAddT(dout[pos2], q, dk, qkv_dim);
}
}
}
}
template <typename T>
void MaskedSoftmaxVJPT(const T* y, T* dy, size_t num_tokens, size_t kHeads,
size_t seq_len) {
for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
size_t offset = pos * kHeads * seq_len + head * seq_len;
SoftmaxVJPT(y + offset, dy + offset, pos + 1);
memset(dy + offset + pos + 1, 0, (seq_len - pos - 1) * sizeof(T));
}
}
}
template <typename T>
void MixByAttentionVJP(const T* qkv, const T* attention, const T* doutput,
T* dqkv, T* dattention, size_t num_tokens, size_t kHeads,
size_t qkv_dim, size_t seq_len) {
auto v_offset = [&](size_t pos) {
return (pos * (kHeads + 2) + kHeads + 1) * qkv_dim;
};
for (size_t pos = 0; pos < num_tokens; ++pos) {
memset(&dqkv[v_offset(pos)], 0, qkv_dim * sizeof(qkv[0]));
}
for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t offset = head * qkv_dim + pos * kHeads * qkv_dim;
const size_t aoffset = head * seq_len + pos * kHeads * seq_len;
const T* att = &attention[aoffset];
const T* dout = &doutput[offset];
T* datt = &dattention[aoffset];
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
datt[pos2] = DotT(dout, &qkv[v_offset(pos2)], qkv_dim);
MulByConstAndAddT(att[pos2], dout, &dqkv[v_offset(pos2)], qkv_dim);
}
}
}
}
template<typename T>
void InputEmbeddingVJPT(const T* w, const std::vector<int>& tokens, T scaling,
const T* dy, T* dw, size_t N) {
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
for (size_t i = 0; i < num_tokens; ++i) {
int token = tokens[i];
MulByConstAndAddT(scaling, dy + i * N, dw + token * N, N);
}
}
template <typename T>
void LayerVJP(const LayerWeightsPtrs<T>& weights,
const ForwardLayer<T>& forward, const T* dy,
LayerWeightsPtrs<T>& grad, ForwardLayer<T>& backward,
size_t num_tokens) {
const LayerConfig& layer_config = weights.layer_config;
const size_t model_dim = layer_config.model_dim;
const size_t seq_len = forward.input.Rows();
const size_t qkv_dim = layer_config.qkv_dim;
const size_t kHeads = layer_config.heads;
const size_t kFFHiddenDim = layer_config.ff_hidden_dim;
const T kQueryScale = 1.0 / std::sqrt(T(qkv_dim));
MatMulVJPT(weights.linear_w.data(), forward.ffw_hidden_gated.data(), dy,
grad.linear_w.data(), backward.ffw_hidden_gated.data(), model_dim,
kFFHiddenDim, num_tokens);
GatedGeluVJP(forward.ffw_hidden.data(), backward.ffw_hidden_gated.data(),
backward.ffw_hidden.data(), kFFHiddenDim, num_tokens);
MatMulVJPT(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(),
backward.ffw_hidden.data(), grad.gating_einsum_w.data(),
backward.bf_pre_ffw_rms_out.data(), kFFHiddenDim * 2, model_dim,
num_tokens);
RMSNormVJPT(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(),
backward.bf_pre_ffw_rms_out.data(),
grad.pre_ffw_norm_scale.data(), backward.attention_out.data(),
model_dim, num_tokens);
AddFromT(dy, backward.attention_out.data(), num_tokens * model_dim);
MultiHeadMatMulVJPT(weights.attn_vec_einsum_w.data(), forward.att_out.data(),
backward.attention_out.data(),
grad.attn_vec_einsum_w.data(), backward.att_out.data(),
kHeads, model_dim, qkv_dim, num_tokens);
MixByAttentionVJP(forward.qkv.data(), forward.att.data(),
backward.att_out.data(), backward.qkv.data(),
backward.att.data(), num_tokens, kHeads, qkv_dim, seq_len);
MaskedSoftmaxVJPT(forward.att.data(), backward.att.data(), num_tokens, kHeads,
seq_len);
MaskedAttentionVJP(forward.qkv.data(), backward.att.data(),
backward.qkv.data(), num_tokens, kHeads, qkv_dim, seq_len);
for (size_t pos = 0; pos < num_tokens; ++pos) {
T* qkv = backward.qkv.data() + pos * (kHeads + 2) * qkv_dim;
MulByConstT(kQueryScale, qkv, kHeads * qkv_dim);
}
for (int pos = 0; pos < num_tokens; ++pos) {
T* qkv = backward.qkv.data() + pos * (kHeads + 2) * qkv_dim;
for (size_t h = 0; h <= kHeads; ++h) {
Rope(qkv + h * qkv_dim, qkv_dim, -pos);
}
}
MatMulVJPT(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(),
backward.qkv.data(), grad.qkv_einsum_w.data(),
backward.pre_att_rms_out.data(), (kHeads + 2) * qkv_dim, model_dim,
num_tokens);
RMSNormVJPT(weights.pre_attention_norm_scale.data(), forward.input.data(),
backward.pre_att_rms_out.data(),
grad.pre_attention_norm_scale.data(), backward.input.data(),
model_dim, num_tokens);
AddFromT(backward.attention_out.data(), backward.input.data(),
num_tokens * model_dim);
}
template <typename T>
void SoftcapVJPT(float cap, const T* y, T* dy, size_t N) {
const T inv_cap = T{1.0} / static_cast<T>(cap);
for (size_t i = 0; i < N; ++i) {
T scaled = y[i] * inv_cap; // tanh
dy[i] *= (T{1.0} - scaled * scaled);
}
}
template<typename T>
void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) {
T scaling = -1.0 / std::log(2.0);
const std::vector<int> tokens = prompt.tokens;
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
memset(dx, 0, V * num_tokens * sizeof(x[0]));
for (size_t i = 0; i < num_tokens; ++i) {
if (i + 1 < prompt.context_size) {
continue;
}
const int next_token = tokens[i + 1];
dx[i * V + next_token] = scaling / x[i * V + next_token];
}
}
template <typename T>
void CrossEntropyLossBackwardPass(const Prompt& prompt,
const ModelWeightsPtrs<T>& weights,
const ForwardPass<T>& forward,
ModelWeightsPtrs<T>& grad,
ForwardPass<T>& backward) {
const ModelConfig& config = weights.weights_config;
const size_t model_dim = config.model_dim;
const size_t vocab_size = config.vocab_size;
const size_t layers = config.layer_configs.size();
const std::vector<int> tokens = prompt.tokens;
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt,
vocab_size);
SoftmaxVJPT(forward.probs.data(), backward.logits.data(), vocab_size,
num_tokens);
if (config.final_cap > 0.0f) {
for (size_t i = 0; i < num_tokens; ++i) {
SoftcapVJPT(config.final_cap, forward.logits.data() + i * vocab_size,
backward.logits.data() + i * vocab_size, vocab_size);
}
}
MatMulVJPT(
weights.embedder_input_embedding.data(), forward.final_norm_output.data(),
backward.logits.data(), grad.embedder_input_embedding.data(),
backward.final_norm_output.data(), vocab_size, model_dim, num_tokens);
RMSNormVJPT(weights.final_norm_scale.data(),
forward.final_layer_output.data(),
backward.final_norm_output.data(), grad.final_norm_scale.data(),
backward.final_layer_output.data(), model_dim, num_tokens);
for (int layer = static_cast<int>(layers) - 1; layer >= 0; --layer) {
T* next_layer_grad = layer + 1 < layers
? backward.layers[layer + 1].input.data()
: backward.final_layer_output.data();
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
*grad.GetLayer(layer), backward.layers[layer], num_tokens);
}
const T kEmbScaling = EmbeddingScaling(model_dim);
InputEmbeddingVJPT(weights.embedder_input_embedding.data(), tokens,
kEmbScaling, backward.layers[0].input.data(),
grad.embedder_input_embedding.data(), model_dim);
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_

View File

@ -1,635 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "backprop/backward_scalar.h"
#include <stddef.h>
#include <stdio.h>
#include <string.h> // memcpy
#include <complex>
#include <limits>
#include <random>
#include <vector>
#include "gtest/gtest.h"
#include "backprop/activations.h"
#include "backprop/common_scalar.h"
#include "backprop/forward_scalar.h"
#include "backprop/prompt.h"
#include "backprop/sampler.h"
#include "backprop/test_util.h"
#include "compression/compress.h"
#include "gemma/configs.h"
#include "gemma/weights.h"
namespace gcpp {
TEST(BackPropTest, MatMulVJP) {
static const size_t kRows = 8;
static const size_t kCols = 64;
static const size_t kTokens = 5;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
MatStorageT<T> weights("weights", kRows, kCols);
MatStorageT<T> x("x", kTokens, kCols);
MatStorageT<T> grad("grad", kRows, kCols);
MatStorageT<T> dx("dx", kTokens, kCols);
MatStorageT<TC> c_weights("c_weights", kRows, kCols);
MatStorageT<TC> c_x("c_x", kTokens, kCols);
MatStorageT<TC> c_y("c_y", kTokens, kRows);
MatStorageT<T> dy("dy", kTokens, kRows);
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0 * (1 << iter), gen);
RandInit(x, 1.0 * (1 << iter), gen);
RandInit(dy, 1.0, gen);
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens);
return DotT(dy.data(), c_y.data(), kTokens * kRows);
};
grad.ZeroInit();
MatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(),
kRows, kCols, kTokens);
TestGradient(dx, c_x, func, 1e-11, 1e-12, __LINE__);
TestGradient(grad, c_weights, func, 1e-14, 1e-12, __LINE__);
}
}
TEST(BackPropTest, MultiHeadMatMulVJP) {
static const size_t kRows = 2;
static const size_t kCols = 16;
static const size_t kHeads = 4;
static const size_t kTokens = 3;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
MatStorageT<T> weights("weights", kRows, kCols * kHeads);
MatStorageT<T> x("x", kTokens, kCols * kHeads);
MatStorageT<T> grad("grad", kRows, kCols * kHeads);
MatStorageT<T> dx("dx", kTokens, kCols * kHeads);
MatStorageT<TC> c_weights("c_weights", kRows, kCols * kHeads);
MatStorageT<TC> c_x("c_x", kTokens, kCols * kHeads);
MatStorageT<TC> c_y("c_y", kTokens, kRows);
MatStorageT<T> dy("dy", kTokens, kRows);
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0 * (1 << iter), gen);
RandInit(x, 1.0 * (1 << iter), gen);
RandInit(dy, 1.0, gen);
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows,
kCols, kTokens);
return DotT(dy.data(), c_y.data(), kTokens * kRows);
};
grad.ZeroInit();
MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(),
dx.data(), kHeads, kRows, kCols, kTokens);
TestGradient(dx, c_x, func, 1e-15, 1e-13, __LINE__);
TestGradient(grad, c_weights, func, 1e-15, 1e-13, __LINE__);
}
}
TEST(BackPropTest, RMSNormVJP) {
static const size_t K = 2;
static const size_t N = 64;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
MatStorageT<T> weights("weights", N, 1);
MatStorageT<T> grad("grad", N, 1);
MatStorageT<T> x("x", K, N);
MatStorageT<T> dx("dx", K, N);
MatStorageT<T> dy("dy", K, N);
MatStorageT<TC> c_weights("c_weights", N, 1);
MatStorageT<TC> c_x("c_x", K, N);
MatStorageT<TC> c_y("c_y", K, N);
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0 * (1 << iter), gen);
RandInit(x, 1.0 * (1 << iter), gen);
Complexify(weights, c_weights);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K);
return DotT(dy.data(), c_y.data(), K * N);
};
grad.ZeroInit();
RMSNormVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(),
N, K);
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
TestGradient(grad, c_weights, func, 1e-15, 1e-14, __LINE__);
}
}
TEST(BackPropTest, SoftmaxVJP) {
static const size_t N = 64;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
MatStorageT<T> x("x", N, 1);
MatStorageT<T> dx("dx", N, 1);
MatStorageT<T> dy("dy", N, 1);
MatStorageT<TC> c_x("c_x", N, 1);
MatStorageT<TC> c_y("c_y", N, 1);
for (int iter = 0; iter < 10; ++iter) {
RandInit(x, 1.0 * (1 << iter), gen);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
memcpy(c_y.data(), c_x.data(), c_x.SizeBytes());
Softmax(c_y.data(), N);
return DotT(dy.data(), c_y.data(), N);
};
Softmax(x.data(), N);
memcpy(dx.data(), dy.data(), dx.SizeBytes());
SoftmaxVJPT(x.data(), dx.data(), N);
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
}
}
TEST(BackPropTest, MaskedSoftmaxVJP) {
static const size_t kSeqLen = 16;
static const size_t kHeads = 2;
static const size_t kTokens = 14;
static const size_t N = kTokens * kHeads * kSeqLen;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
MatStorageT<T> x("x", N, 1);
MatStorageT<T> dy("dy", N, 1);
MatStorageT<T> dx("dx", N, 1);
MatStorageT<TC> c_x("c_x", N, 1);
MatStorageT<TC> c_y("c_y", N, 1);
dx.ZeroInit();
for (int iter = 0; iter < 10; ++iter) {
RandInit(x, 1.0 * (1 << iter), gen);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
memcpy(c_y.data(), c_x.data(),
kTokens * kHeads * kSeqLen * sizeof(c_x.At(0)));
MaskedSoftmax(c_y.data(), kTokens, kHeads, kSeqLen);
return DotT(dy.data(), c_y.data(), N);
};
MaskedSoftmax(x.data(), kTokens, kHeads, kSeqLen);
memcpy(dx.data(), dy.data(), kTokens * kHeads * kSeqLen * sizeof(dx.At(0)));
MaskedSoftmaxVJPT(x.data(), dx.data(), kTokens, kHeads, kSeqLen);
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__);
}
}
TEST(BackPropTest, SoftcapVJP) {
static const size_t N = 64;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
MatStorageT<T> x("x", N, 1);
MatStorageT<T> dx("dx", N, 1);
MatStorageT<T> dy("dy", N, 1);
MatStorageT<TC> c_x("c_x", N, 1);
MatStorageT<TC> c_y("c_y", N, 1);
constexpr float kCap = 30.0f;
for (int iter = 0; iter < 10; ++iter) {
RandInit(x, 1.0 * (1 << iter), gen);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
memcpy(c_y.data(), c_x.data(), N * sizeof(c_x.At(0)));
Softcap(kCap, c_y.data(), N);
return DotT(dy.data(), c_y.data(), N);
};
Softcap(kCap, x.data(), N);
memcpy(dx.data(), dy.data(), dx.SizeBytes());
SoftcapVJPT(kCap, x.data(), dx.data(), N);
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
}
}
TEST(BackPropTest, CrossEntropyLossGrad) {
static const size_t K = 8;
static const size_t V = 64;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
MatStorageT<T> x("x", K, V);
MatStorageT<T> dx("dx", K, V);
MatStorageT<TC> c_x("c_x", K, V);
Prompt prompt;
prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 };
const float kCap = 30.0f;
for (int iter = 0; iter < 10; ++iter) {
prompt.context_size = 1 + (iter % 6);
RandInit(x, 1.0 * (1 << iter), gen);
Softcap(kCap, x.data(), V * K);
Softmax(x.data(), V, K);
CrossEntropyLossGrad(x.data(), dx.data(), prompt, V);
Complexify(x, c_x);
auto func = [&]() {
return CrossEntropyLoss(c_x.data(), prompt, V);
};
TestGradient(dx, c_x, func, 1e-100, 1e-15, __LINE__);
}
}
TEST(BackPropTest, GatedGeluVJP) {
static const size_t K = 2;
static const size_t N = 64;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
MatStorageT<T> x("x", K, 2 * N);
MatStorageT<T> dx("dx", K, 2 * N);
MatStorageT<T> dy("dy", K, N);
MatStorageT<TC> c_x("c_x", K, 2 * N);
MatStorageT<TC> c_y("c_y", K, N);
for (int iter = 0; iter < 10; ++iter) {
RandInit(x, 1.0, gen);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
GatedGelu(c_x.data(), c_y.data(), N, K);
return DotT(dy.data(), c_y.data(), N * K);
};
GatedGeluVJP(x.data(), dy.data(), dx.data(), N, K);
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
}
}
TEST(BackPropTest, MaskedAttentionVJP) {
static const size_t kSeqLen = 16;
static const size_t kHeads = 2;
static const size_t kQKVDim = 8;
static const size_t kTokens = 14;
static const size_t kQKVSize = kSeqLen * (kHeads + 2) * kQKVDim;
static const size_t kOutSize = kTokens * kHeads * kSeqLen;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
MatStorageT<T> x("x", kQKVSize, 1);
MatStorageT<T> dx("dx", kQKVSize, 1);
MatStorageT<T> dy("dy", kOutSize, 1);
MatStorageT<TC> c_x("c_x", kQKVSize, 1);
MatStorageT<TC> c_y("c_y", kOutSize, 1);
dx.ZeroInit();
c_y.ZeroInit();
for (int iter = 0; iter < 10; ++iter) {
RandInit(x, 1.0, gen);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
MaskedAttention(c_x.data(), c_y.data(), kTokens, kHeads, kQKVDim,
kSeqLen);
return DotT(dy.data(), c_y.data(), kOutSize);
};
MaskedAttentionVJP(x.data(), dy.data(), dx.data(),
kTokens, kHeads, kQKVDim, kSeqLen);
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__);
}
}
TEST(BackPropTest, MixByAttentionVJP) {
static const size_t kSeqLen = 16;
static const size_t kHeads = 2;
static const size_t kQKVDim = 8;
static const size_t kTokens = 14;
static const size_t kQKVSize = kSeqLen * (kHeads + 2) * kQKVDim;
static const size_t kAttnSize = kSeqLen * kHeads * kSeqLen;
static const size_t kOutSize = kSeqLen * kHeads * kQKVDim;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
MatStorageT<T> qkv("qkv", kQKVSize, 1);
MatStorageT<T> dqkv("dqkv", kQKVSize, 1);
MatStorageT<T> attn("attn", kAttnSize, 1);
MatStorageT<T> dattn("dattn", kAttnSize, 1);
MatStorageT<T> dy("dy", kOutSize, 1);
MatStorageT<TC> c_qkv("c_qkv", kQKVSize, 1);
MatStorageT<TC> c_attn("c_attn", kAttnSize, 1);
MatStorageT<TC> c_y("c_y", kOutSize, 1);
dqkv.ZeroInit();
dattn.ZeroInit();
c_y.ZeroInit();
for (int iter = 0; iter < 10; ++iter) {
RandInit(qkv, 1.0, gen);
RandInit(attn, 1.0, gen);
Complexify(qkv, c_qkv);
Complexify(attn, c_attn);
RandInit(dy, 1.0, gen);
auto func = [&]() {
MixByAttention(c_qkv.data(), c_attn.data(), c_y.data(),
kTokens, kHeads, kQKVDim, kSeqLen);
return DotT(dy.data(), c_y.data(), kOutSize);
};
MixByAttentionVJP(qkv.data(), attn.data(), dy.data(), dqkv.data(),
dattn.data(), kTokens, kHeads, kQKVDim, kSeqLen);
TestGradient(dqkv, c_qkv, func, 1e-14, 1e-15, __LINE__);
TestGradient(dattn, c_attn, func, 1e-14, 1e-15, __LINE__);
}
}
TEST(BackPropTest, InputEmbeddingVJP) {
static const size_t kSeqLen = 8;
static const size_t kVocabSize = 4;
static const size_t kModelDim = 16;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
MatStorageT<T> weights("weights", kVocabSize, kModelDim);
MatStorageT<T> grad("grad", kVocabSize, kModelDim);
MatStorageT<T> dy("dy", kSeqLen, kModelDim);
MatStorageT<TC> c_weights("c_weights", kVocabSize, kModelDim);
MatStorageT<TC> c_y("c_y", kSeqLen, kModelDim);
std::vector<int> tokens = { 0, 1, 2, 3, 0, 1, 2 };
size_t num_tokens = tokens.size() - 1;
for (size_t iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0, gen);
RandInit(dy, 1.0, gen);
Complexify(weights, c_weights);
auto func = [&]() {
InputEmbedding(c_weights.data(), tokens, TC(3.0), c_y.data(), kModelDim);
return DotT(dy.data(), c_y.data(), num_tokens * kModelDim);
};
grad.ZeroInit();
InputEmbeddingVJPT(weights.data(), tokens, 3.0, dy.data(), grad.data(),
kModelDim);
TestGradient(grad, c_weights, func, 1e-16, 1e-14, __LINE__);
}
}
static ModelConfig TestConfig() {
ModelConfig config;
config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w",
"gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"};
config.model_dim = 32;
config.vocab_size = 12;
config.seq_len = 18;
LayerConfig layer_config;
layer_config.model_dim = config.model_dim;
layer_config.ff_hidden_dim = 48;
layer_config.heads = 3;
layer_config.kv_heads = 1;
layer_config.qkv_dim = 12;
config.layer_configs = {2, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize;
config.attention_window_sizes = FixedAttentionWindowSizes<2>(32);
// This is required for optimize_test to pass.
config.final_cap = 30.0f;
return config;
}
TEST(BackPropTest, LayerVJP) {
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
ModelConfig config = TestConfig();
TensorIndex tensor_index(config, /*llm_layer_idx=*/0, /*img_layer_idx=*/-1,
/*reshape_att=*/false);
const size_t kOutputSize = config.seq_len * config.model_dim;
LayerWeightsPtrs<T> weights(config.layer_configs[0], tensor_index);
LayerWeightsPtrs<T> grad(config.layer_configs[0], tensor_index);
ForwardLayer<T> forward(config.layer_configs[0], config.seq_len);
ForwardLayer<T> backward(config.layer_configs[0], config.seq_len);
LayerWeightsPtrs<TC> c_weights(config.layer_configs[0], tensor_index);
ForwardLayer<TC> c_forward(config.layer_configs[0], config.seq_len);
MatStorageT<T> y("y", kOutputSize, 1);
MatStorageT<T> dy("dy", kOutputSize, 1);
MatStorageT<TC> c_y("c_y", kOutputSize, 1);
const size_t num_tokens = 3;
std::vector<MatStorage> layer_storage;
weights.Allocate(layer_storage);
grad.Allocate(layer_storage);
c_weights.Allocate(layer_storage);
backward.input.ZeroInit();
for (size_t iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0, gen);
RandInit(forward.input, 1.0, gen);
RandInit(dy, 1.0, gen);
Complexify(weights, c_weights);
Complexify(forward.input, c_forward.input);
auto func = [&]() {
ApplyLayer(c_weights, c_forward, num_tokens, c_y.data());
return DotT(dy.data(), c_y.data(), num_tokens * config.model_dim);
};
grad.ZeroInit(/*layer_idx=*/0);
ApplyLayer(weights, forward, num_tokens, y.data());
LayerVJP(weights, forward, dy.data(), grad, backward, num_tokens);
TestGradient(backward.input, c_forward.input, func, 1e-11, 5e-11,
__LINE__);
TestGradient(grad, c_weights, func, 1e-11);
}
}
TEST(BackPropTest, EndToEnd) {
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
ModelConfig config = TestConfig();
WeightsWrapper<T> weights(config);
WeightsWrapper<T> grad(config);
ForwardPass<T> forward(config);
ForwardPass<T> backward(config);
WeightsWrapper<TC> c_weights(config);
ForwardPass<TC> c_forward(config);
ReverseSequenceSampler training_task({0, 0, 1, 1});
std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
for (const Prompt& prompt : batch) {
ReverseSequenceSampler::LogPrompt(prompt);
RandInit(weights.get(), 1.0, gen);
CrossEntropyLossForwardPass(prompt, weights.get(), forward);
grad.ZeroInit();
CrossEntropyLossBackwardPass(
prompt, weights.get(), forward, grad.get(), backward);
Complexify(weights.get(), c_weights.get());
auto func = [&]() {
return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward);
};
TestGradient(grad.get(), c_weights.get(), func, 1e-11);
}
}
template <typename T>
void MulByConstAndAddT(T c, const LayerWeightsPtrs<T>& x,
LayerWeightsPtrs<T>& out) {
MulByConstAndAddT(c, x.pre_attention_norm_scale,
out.pre_attention_norm_scale);
MulByConstAndAddT(c, x.attn_vec_einsum_w, out.attn_vec_einsum_w);
MulByConstAndAddT(c, x.qkv_einsum_w, out.qkv_einsum_w);
MulByConstAndAddT(c, x.pre_ffw_norm_scale, out.pre_ffw_norm_scale);
MulByConstAndAddT(c, x.gating_einsum_w, out.gating_einsum_w);
MulByConstAndAddT(c, x.linear_w, out.linear_w);
}
template <typename T>
void MulByConstAndAddT(T c, const ModelWeightsPtrs<T>& x,
ModelWeightsPtrs<T>& out) {
const size_t layers = x.c_layers.size();
MulByConstAndAddT(c, x.embedder_input_embedding,
out.embedder_input_embedding);
MulByConstAndAddT(c, x.final_norm_scale, out.final_norm_scale);
for (size_t i = 0; i < layers; ++i) {
MulByConstAndAddT(c, *x.GetLayer(i), *out.GetLayer(i));
}
}
// Evaluates forward pass on a batch.
template <typename T>
T CrossEntropyLossForwardPass(const std::vector<Prompt>& batch,
const WeightsWrapper<T>& weights,
ForwardPass<T>& forward) {
T loss = 0.0;
for (const Prompt& prompt : batch) {
loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward);
}
T scale = 1.0 / batch.size();
return loss * scale;
}
// Evaluates forward pass on a batch by applying gradient with the given
// learning rate. Does not update weights, but uses the given tmp weights
// instead.
template <typename T>
T CrossEntropyLossForwardPass(T learning_rate, const std::vector<Prompt>& batch,
const WeightsWrapper<T>& weights,
const WeightsWrapper<T>& grad,
WeightsWrapper<T>& tmp, ForwardPass<T>& forward) {
tmp.CopyFrom(weights);
const T scale = -learning_rate / batch.size();
MulByConstAndAddT(scale, grad.get(), tmp.get());
return CrossEntropyLossForwardPass(batch, tmp, forward);
}
// Uses line search in the negative gradient direction to update weights. We do
// this so that we can test that each step during the gradient descent can
// decrease the objective function value.
template <typename T>
T FindOptimalUpdate(const WeightsWrapper<T>& grad, WeightsWrapper<T>& weights,
WeightsWrapper<T>& tmp, ForwardPass<T>& forward,
const std::vector<Prompt>& batch, T loss,
T initial_learning_rate) {
T lr0 = initial_learning_rate;
T loss0 = CrossEntropyLossForwardPass(
lr0, batch, weights, grad, tmp, forward);
for (size_t iter = 0; iter < 30; ++iter) {
T lr1 = lr0 * 0.5;
T loss1 = CrossEntropyLossForwardPass(
lr1, batch, weights, grad, tmp, forward);
if (loss0 < loss && loss1 >= loss0) {
break;
}
loss0 = loss1;
lr0 = lr1;
}
for (size_t iter = 0; iter < 30; ++iter) {
T lr1 = lr0 * 2.0;
T loss1 = CrossEntropyLossForwardPass(
lr1, batch, weights, grad, tmp, forward);
if (loss1 >= loss0) {
break;
}
loss0 = loss1;
lr0 = lr1;
}
const T scale = -lr0 / batch.size();
MulByConstAndAddT(scale, grad.get(), weights.get());
return lr0;
}
TEST(BackProptest, Convergence) {
std::mt19937 gen(42);
using T = float;
using TC = std::complex<double>;
ModelConfig config = TestConfig();
WeightsWrapper<T> weights(config);
WeightsWrapper<T> grad(config);
WeightsWrapper<T> tmp(config);
ForwardPass<T> forward(config);
ForwardPass<T> backward(config);
WeightsWrapper<TC> c_weights(config);
ForwardPass<TC> c_forward(config);
constexpr size_t kBatchSize = 5;
ReverseSequenceSampler training_task({0, 0, 0, 1, 1});
T learning_rate = 0.01;
RandInit(weights.get(), T(1.0), gen);
printf("Sample batch:\n");
for (size_t i = 0; i < 10; ++i) {
ReverseSequenceSampler::LogPrompt(training_task.Sample(gen));
}
T prev_loss = std::numeric_limits<T>::max();
bool stop = false;
size_t step = 0;
while (!stop) {
T loss = 0.0;
grad.ZeroInit();
std::mt19937 sgen(42);
std::vector<Prompt> batch = training_task.SampleBatch(kBatchSize, sgen);
for (const Prompt& prompt : batch) {
loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward);
CrossEntropyLossBackwardPass(
prompt, weights.get(), forward, grad.get(), backward);
}
if (step % 250 == 0) {
printf("Checking gradient...\n");
Complexify(weights.get(), c_weights.get());
auto func = [&]() {
TC scale = batch.size();
return CrossEntropyLossForwardPass(batch, c_weights, c_forward) * scale;
};
TestGradient(grad.get(), c_weights.get(), func, 5e-3f);
}
loss /= batch.size();
EXPECT_LT(loss, prev_loss);
stop = step >= 10000 || loss < 1e-2;
if (step % 10 == 0 || stop) {
printf("step: %5zu loss: %.15f learning_rate: %.15f\n",
step, loss, learning_rate);
}
if (!stop) {
learning_rate = FindOptimalUpdate(
grad, weights, tmp, forward, batch, loss, learning_rate);
++step;
}
prev_loss = loss;
}
EXPECT_LT(step, 1000);
}
} // namespace gcpp

View File

@ -1,279 +0,0 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
#include <stddef.h>
#include <complex>
#include <cstdlib> // std::abs
#include <random>
#include <vector>
#include "backprop/activations.h"
#include "backprop/backward_scalar.h"
#include "backprop/common_scalar.h"
#include "backprop/forward_scalar.h"
#include "backprop/prompt.h"
#include "backprop/sampler.h"
#include "backprop/test_util.h"
#include "gemma/configs.h"
#include "ops/ops.h"
#include "util/threading.h"
#include "hwy/base.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "backprop/backward_test.cc" //NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
#include "hwy/tests/test_util-inl.h"
// After highway.h
#include "backprop/backward-inl.h"
#include "backprop/forward-inl.h"
#include "compression/compress.h"
#include "ops/ops-inl.h"
#include "util/allocator.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
void TestMatMulVJP() {
static const size_t kRows = 8;
static const size_t kCols = 64;
static const size_t kTokens = 5;
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8));
Allocator::Init(topology);
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
std::mt19937 gen(42);
MatStorageT<float> weights("weights", kRows, kCols);
MatStorageT<float> x("x", kTokens, kCols);
MatStorageT<float> dy("dy", kTokens, kRows);
MatStorageT<float> grad("grad", kRows, kCols);
MatStorageT<float> dx("dx", kTokens, kCols);
MatStorageT<float> grad_scalar("grad_scalar", kRows, kCols);
MatStorageT<float> dx_scalar("dx_scalar", kTokens, kCols);
using TC = std::complex<double>;
MatStorageT<TC> c_weights("c_weights", kRows, kCols);
MatStorageT<TC> c_x("c_x", kTokens, kCols);
MatStorageT<TC> c_y("c_y", kTokens, kRows);
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0f * (1 << iter), gen);
RandInit(x, 1.0f * (1 << iter), gen);
RandInit(dy, 1.0f, gen);
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens);
return DotT(dy.data(), c_y.data(), kTokens * kRows);
};
grad.ZeroInit();
MatMulVJP(weights.data(), x.data(), dy.data(), kCols, kRows, kTokens,
grad.data(), dx.data(), pools.Pool());
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
grad_scalar.ZeroInit();
MatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
dx_scalar.data(), kRows, kCols, kTokens);
TestNear(dx, dx_scalar, 5e-5, 1e-4, __LINE__);
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__);
}
}
void TestMultiHeadMatMulVJP() {
static const size_t kRows = 2;
static const size_t kCols = 16;
static const size_t kHeads = 4;
static const size_t kTokens = 3;
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8));
Allocator::Init(topology);
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
std::mt19937 gen(42);
MatStorageT<float> weights("weights", kRows, kCols * kHeads);
MatStorageT<float> x("x", kTokens, kCols * kHeads);
MatStorageT<float> grad("grad", kRows, kCols * kHeads);
MatStorageT<float> dx("dx", kTokens, kCols * kHeads);
MatStorageT<float> dy("dy", kTokens, kRows);
MatStorageT<float> grad_scalar("grad_scalar", kRows, kCols * kHeads);
MatStorageT<float> dx_scalar("dx_scalar", kTokens, kCols * kHeads);
using TC = std::complex<double>;
MatStorageT<TC> c_weights("c_weights", kRows, kCols * kHeads);
MatStorageT<TC> c_x("c_x", kTokens, kCols * kHeads);
MatStorageT<TC> c_y("c_y", kTokens, kRows);
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0f * (1 << iter), gen);
RandInit(x, 1.0f * (1 << iter), gen);
RandInit(dy, 1.0f, gen);
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows,
kCols, kTokens);
return DotT(dy.data(), c_y.data(), kTokens * kRows);
};
grad.ZeroInit();
MultiHeadMatMulVJP(weights.data(), x.data(), dy.data(), kHeads, kCols,
kRows, kTokens, grad.data(), dx.data(), pools.Pool());
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
grad_scalar.ZeroInit();
MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
dx_scalar.data(), kHeads, kRows, kCols, kTokens);
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__);
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__);
}
}
void TestRMSNormVJP() {
static const size_t K = 2;
static const size_t N = 64;
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8));
Allocator::Init(topology);
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
std::mt19937 gen(42);
MatStorageT<float> weights("weights", N, 1);
MatStorageT<float> x("x", K, N);
MatStorageT<float> grad("grad", N, 1);
MatStorageT<float> dx("dx", K, N);
MatStorageT<float> dy("dy", K, N);
MatStorageT<float> grad_scalar("grad_scalar", N, 1);
MatStorageT<float> dx_scalar("dx_scalar", K, N);
using TC = std::complex<double>;
MatStorageT<TC> c_weights("c_weights", N, 1);
MatStorageT<TC> c_x("c_x", K, N);
MatStorageT<TC> c_y("c_y", K, N);
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0f * (1 << iter), gen);
RandInit(x, 1.0f * (1 << iter), gen);
RandInit(dy, 1.0f, gen);
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K);
return DotT(dy.data(), c_y.data(), K * N);
};
grad.ZeroInit();
RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(),
dx.data(), pools.Pool());
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
grad_scalar.ZeroInit();
RMSNormVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
dx_scalar.data(), N, K);
TestNear(dx, dx_scalar, 0, 2e-5, __LINE__);
TestNear(grad, grad_scalar, 0, 2e-5, __LINE__);
}
}
static ModelConfig TestConfig() {
ModelConfig config;
config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w",
"gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"};
config.model_dim = 32;
config.vocab_size = 16;
config.seq_len = 24;
LayerConfig layer_config;
layer_config.model_dim = config.model_dim;
layer_config.ff_hidden_dim = 64;
layer_config.heads = 3;
layer_config.kv_heads = 1;
layer_config.qkv_dim = 16;
config.layer_configs = {2, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize;
config.attention_window_sizes = FixedAttentionWindowSizes<2>(32);
// This is required for optimize_test to pass.
config.att_cap = 50.0f;
config.final_cap = 30.0f;
return config;
}
void TestEndToEnd() {
std::mt19937 gen(42);
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 1));
Allocator::Init(topology);
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
ModelConfig config = TestConfig();
WeightsWrapper<float> weights(config);
WeightsWrapper<float> grad(config);
ForwardPass<float> forward0(config);
ForwardPass<float> forward1(config);
ForwardPass<float> backward(config);
using TC = std::complex<double>;
WeightsWrapper<TC> c_weights(config);
ForwardPass<TC> c_forward(config);
ReverseSequenceSampler training_task({0, 0, 1, 1});
std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
config.layer_configs[0].qkv_dim,
config.layer_configs[0].post_qk == PostQKType::HalfRope);
for (const Prompt& prompt : batch) {
ReverseSequenceSampler::LogPrompt(prompt);
RandInit(weights.get(), 1.0f, gen);
float loss0 = CrossEntropyLossForwardPass(prompt, weights.get(), forward0);
float loss1 = CrossEntropyLossForwardPass(
prompt.tokens, prompt.context_size, weights.get(), forward1,
inv_timescale, pools.Pool());
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5);
grad.ZeroInit();
CrossEntropyLossBackwardPassInl(prompt, weights.get(), forward1, grad.get(),
backward, inv_timescale, pools.Pool());
Complexify(weights.get(), c_weights.get());
auto func = [&]() {
return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward);
};
TestGradient(grad.get(), c_weights.get(), func, 2e-3f);
}
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_BEFORE_TEST(BackwardTest);
HWY_EXPORT_AND_TEST_P(BackwardTest, TestMatMulVJP);
HWY_EXPORT_AND_TEST_P(BackwardTest, TestMultiHeadMatMulVJP);
HWY_EXPORT_AND_TEST_P(BackwardTest, TestRMSNormVJP);
HWY_EXPORT_AND_TEST_P(BackwardTest, TestEndToEnd);
HWY_AFTER_TEST();
} // namespace gcpp
#endif

View File

@ -1,121 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
#include <stddef.h>
#include <complex>
#include "compression/compress.h" // MatStorageT
namespace gcpp {
template<typename T, typename U>
U DotT(const T* a, const U* b, size_t N) {
U sum = {};
for (size_t i = 0; i < N; ++i) {
sum += a[i] * b[i];
}
return sum;
}
template<>
inline std::complex<double> DotT(const float* a, const std::complex<double>* b,
size_t N) {
std::complex<double> sum = {};
for (size_t i = 0; i < N; ++i) {
sum += static_cast<double>(a[i]) * b[i];
}
return sum;
}
template<typename T>
void MulByConstT(T c, T* x, size_t N) {
for (size_t i = 0; i < N; ++i) {
x[i] *= c;
}
}
// out += c * x
template<typename T>
void MulByConstAndAddT(T c, const T* x, T* out, size_t N) {
for (size_t i = 0; i < N; ++i) {
out[i] += c * x[i];
}
}
template <typename T>
void MulByConstAndAddT(T c, const MatPtrT<T>& x, MatPtrT<T>& out) {
MulByConstAndAddT(c, x.data(), out.data(), x.NumElements());
}
template<typename T>
void AddFromT(const T* a, T* out, size_t N) {
for (size_t i = 0; i < N; ++i) {
out[i] += a[i];
}
}
template<typename T>
T SquaredL2(const T* x, size_t N) {
T sum = {};
for (size_t i = 0; i < N; ++i) {
sum += x[i] * x[i];
}
return sum;
}
template<typename T>
T Gelu(T x) {
static const T kMul = 0.044715;
static const T kSqrt2OverPi = 0.797884560804236;
const T x3 = x * x * x;
const T arg = kSqrt2OverPi * (kMul * x3 + x);
const T cdf = T(0.5) * (T(1.0) + std::tanh(arg));
return x * cdf;
}
template<typename T, typename U>
void Rope(T* x, U base, size_t N, int i) {
const size_t N2 = N / 2;
for (size_t dim = 0; dim < N2; ++dim) {
const T freq_exponents = T(2 * dim) / T(N);
const T timescale = std::pow(base, freq_exponents);
const T theta = T(i) / timescale;
const T cos_val = std::cos(theta);
const T sin_val = std::sin(theta);
const T x0 = x[dim];
const T x1 = x[dim + N2];
x[dim] = x0 * cos_val - x1 * sin_val;
x[dim + N2] = x0 * sin_val + x1 * cos_val;
}
}
template<typename T>
void Rope(T* x, size_t N, int i) {
Rope(x, T(10000.0), N, i);
}
template<typename T>
void Rope(std::complex<T>* x, size_t N, int i) {
Rope(x, T(10000.0), N, i);
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_

View File

@ -1,296 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Include guard for non-SIMD code.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
#include <stddef.h>
#include <stdint.h>
#include <cmath>
#include <vector>
#include "backprop/activations.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/weights.h"
#include "util/allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
// Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
#endif
#include "hwy/highway.h"
// After highway.h
#include "ops/matvec-inl.h"
#include "ops/ops-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
template <typename ArrayT>
void InputEmbedding(const ArrayT& weights, const std::vector<int>& prompt,
const float scaling, float* HWY_RESTRICT output,
size_t model_dim, size_t vocab_size) {
const hn::ScalableTag<float> df;
HWY_ASSERT(!prompt.empty());
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
int token = prompt[pos];
DecompressAndZeroPad(df, MakeSpan(weights.data(), model_dim * vocab_size),
token * model_dim, output + pos * model_dim,
model_dim);
MulByConst(scaling, output + pos * model_dim, model_dim);
}
}
template<typename WT, typename XT, typename OutT>
void ApplyRMSNorm(const WT* HWY_RESTRICT weights, const XT* HWY_RESTRICT x,
size_t model_dim, size_t num_tokens,
OutT* HWY_RESTRICT output,
hwy::ThreadPool& pool) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t offset = pos * model_dim;
RMSNorm(x + offset, weights, output + offset, model_dim);
}
}
static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs,
const std::vector<int>& prompt,
size_t context_size,
size_t vocab_size,
hwy::ThreadPool& pool) {
HWY_ASSERT(!prompt.empty());
float loss = 0.0f;
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
if (pos + 1 < context_size) {
continue; // next token is part of context, don't try to predict it
}
const int next_token = prompt[pos + 1];
loss += std::log(probs[pos * vocab_size + next_token]);
}
float scaling = -1.0 / std::log(2.0);
return loss * scaling;
}
template <typename T>
void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
ForwardLayer<float>& activations, size_t num_tokens,
float* HWY_RESTRICT output,
const RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
const LayerConfig& config = weights.layer_config;
const size_t model_dim = config.model_dim;
const size_t kSeqLen = activations.input.Rows();
const size_t kQKVDim = config.qkv_dim;
const size_t kHeads = config.heads;
static const float query_scale =
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
HWY_ASSERT(num_tokens <= kSeqLen);
ApplyRMSNorm(weights.pre_attention_norm_scale.data(),
activations.input.data(), model_dim, num_tokens,
activations.pre_att_rms_out.data(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec(weights.qkv_einsum_w, 0, (kHeads + 2) * kQKVDim, model_dim,
activations.pre_att_rms_out.data() + pos * model_dim,
activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool);
}
const size_t num_tasks = kHeads * num_tokens;
for (size_t pos = 0; pos < num_tokens; ++pos) {
float* HWY_RESTRICT k =
activations.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
Rope(k, kQKVDim, inv_timescale.Const(), pos);
}
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
float* HWY_RESTRICT q =
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
Rope(q, kQKVDim, inv_timescale.Const(), pos);
MulByConst(query_scale, q, kQKVDim);
});
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
const float* HWY_RESTRICT q =
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
float* HWY_RESTRICT head_att =
activations.att.data() + (pos * kHeads + head) * kSeqLen;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const float* HWY_RESTRICT k2 =
activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
const float score = Dot(q, k2, kQKVDim);
head_att[pos2] = score;
}
});
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
float* HWY_RESTRICT head_att =
activations.att.data() + (pos * kHeads + head) * kSeqLen;
Softmax(head_att, pos + 1);
});
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
const float* HWY_RESTRICT head_att =
activations.att.data() + (pos * kHeads + head) * kSeqLen;
float* HWY_RESTRICT att_out =
activations.att_out.data() + (pos * kHeads + head) * kQKVDim;
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
float* HWY_RESTRICT v2 =
activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim;
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
}
});
activations.attention_out.ZeroInit();
for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t head = 0; head < kHeads; ++head) {
MatVec(
weights.attn_vec_einsum_w, head * model_dim * kQKVDim, model_dim,
kQKVDim,
activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim,
activations.att_post1.data() + pos * model_dim, pool);
AddFrom(activations.att_post1.data() + pos * model_dim,
activations.attention_out.data() + pos * model_dim, model_dim);
}
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(activations.input.data() + pos * model_dim,
activations.attention_out.data() + pos * model_dim, model_dim);
}
ApplyRMSNorm(weights.pre_ffw_norm_scale.data(),
activations.attention_out.data(), model_dim, num_tokens,
activations.bf_pre_ffw_rms_out.data(), pool);
const size_t kFFHiddenDim = config.ff_hidden_dim;
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec(weights.gating_einsum_w, 0, kFFHiddenDim * 2, model_dim,
activations.bf_pre_ffw_rms_out.data() + pos * model_dim,
activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool);
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t hidden_offset = pos * kFFHiddenDim * 2;
const float* HWY_RESTRICT out =
activations.ffw_hidden.data() + hidden_offset;
const float* HWY_RESTRICT out_mul = out + kFFHiddenDim;
float* HWY_RESTRICT out_gated =
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim;
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
DF df;
for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) {
const auto y = hn::Load(df, out + i);
const auto x = hn::Load(df, out_mul + i);
hn::Store(hn::Mul(x, Gelu(df, y)), df, out_gated + i);
}
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec(weights.linear_w, 0, model_dim, kFFHiddenDim,
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim,
output + pos * model_dim, pool);
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(activations.attention_out.data() + pos * model_dim,
output + pos * model_dim, model_dim);
}
}
template <typename T>
float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
size_t context_size,
const ModelWeightsPtrs<T>& weights,
ForwardPass<float>& forward,
const RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
const ModelConfig& config = weights.weights_config;
const size_t vocab_size = config.vocab_size;
const size_t model_dim = config.model_dim;
const size_t layers = config.layer_configs.size();
const float emb_scaling = EmbeddingScaling(model_dim);
HWY_ASSERT(!config.absolute_pe);
HWY_ASSERT(config.layer_configs[0].post_norm == PostNormType::None);
HWY_ASSERT(config.layer_configs[0].kv_heads == 1);
HWY_DASSERT(context_size > 0);
HWY_DASSERT(context_size < prompt.size());
const size_t num_tokens = prompt.size() - 1;
InputEmbedding(weights.embedder_input_embedding, prompt, emb_scaling,
forward.layers[0].input.data(), model_dim, vocab_size);
for (size_t layer = 0; layer < config.layer_configs.size(); ++layer) {
auto type = config.layer_configs[layer].type;
// TODO(szabadka) Implement Griffin layer.
HWY_ASSERT(type == LayerAttentionType::kGemma);
float* HWY_RESTRICT output = layer + 1 < layers
? forward.layers[layer + 1].input.data()
: forward.final_layer_output.data();
ApplyForwardLayer(*weights.GetLayer(layer), forward.layers[layer],
num_tokens, output, inv_timescale, pool);
}
ApplyRMSNorm(weights.final_norm_scale.data(),
forward.final_layer_output.data(), model_dim, num_tokens,
forward.final_norm_output.data(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec(weights.embedder_input_embedding, 0, vocab_size, model_dim,
forward.final_norm_output.data() + pos * model_dim,
forward.logits.data() + pos * vocab_size, pool);
}
if (config.final_cap > 0.0f) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
LogitsSoftCap(config.final_cap, forward.logits.data() + pos * vocab_size,
vocab_size);
}
}
hwy::CopyBytes(forward.logits.data(), forward.probs.data(),
num_tokens * vocab_size * sizeof(forward.logits.At(0)));
for (size_t pos = 0; pos < num_tokens; ++pos) {
Softmax(forward.probs.data() + pos * vocab_size, vocab_size);
}
return CrossEntropyLoss(forward.probs.data(), prompt, context_size,
vocab_size, pool);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // NOLINT

View File

@ -1,68 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "backprop/forward.h"
#include "backprop/activations.h"
#include "backprop/prompt.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "util/allocator.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "backprop/forward.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "backprop/forward-inl.h"
#include "gemma/weights.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
float CrossEntropyLossForwardPassT(const Prompt& prompt,
const ModelWeightsPtrs<float>& weights,
ForwardPass<float>& forward,
RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
return CrossEntropyLossForwardPass(prompt.tokens, prompt.context_size,
weights, forward, inv_timescale, pool);
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(CrossEntropyLossForwardPassT);
float CrossEntropyLossForwardPass(const Prompt& prompt,
const ModelWeightsPtrs<float>& weights,
ForwardPass<float>& forward,
RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)(
prompt, weights, forward, inv_timescale, pool);
}
} // namespace gcpp
#endif // HWY_ONCE

View File

@ -1,35 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
#include "backprop/activations.h"
#include "backprop/prompt.h"
#include "gemma/weights.h"
#include "util/allocator.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
float CrossEntropyLossForwardPass(const Prompt& prompt,
const ModelWeightsPtrs<float>& weights,
ForwardPass<float>& forward,
RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_

View File

@ -1,294 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_
#include <stddef.h>
#include <string.h>
#include <cmath>
#include <complex>
#include <vector>
#include "backprop/activations.h"
#include "backprop/common_scalar.h"
#include "backprop/prompt.h"
#include "gemma/common.h" // EmbeddingScaling
#include "gemma/weights.h"
#include "hwy/base.h"
namespace gcpp {
// w is N x M matrix in row-major order, x is M x K matrix in column-major order
// y = w * x is N x K matrix in column-major order.
template<typename T>
void MatMulT(const T* w, const T* x, T* y, size_t N, size_t M, size_t K) {
for (size_t i = 0; i < K; ++i) {
for (size_t j = 0; j < N; ++j) {
y[i * N + j] = DotT(&w[j * M], &x[i * M], M);
}
}
}
// w is H concatenated N x M matrix in row-major order, x is HM x K matrix in
// column-major order and y = w' * x is N x K matrix in column-major order,
// where w' is the rearrangement of w into an N x HM matrix.
template<typename T>
void MultiHeadMatMul(const T* w, const T* x, T* y, size_t H, size_t N,
size_t M, size_t K) {
memset(y, 0, N * K * sizeof(y[0]));
for (size_t i = 0; i < K; ++i) {
for (size_t h = 0; h < H; ++h) {
for (size_t j = 0; j < N; ++j) {
y[i * N + j] += DotT(&w[h * N * M + j * M], &x[i * H * M + h * M], M);
}
}
}
}
template<typename T>
void RMSNormT(const T* w, const T* x, T* out, size_t N, size_t K) {
constexpr T eps(1e-6);
for (size_t i = 0; i < K; ++i) {
T ss = SquaredL2(x + i * N, N);
ss = T(1.0) / std::sqrt(ss / T(N) + eps);
for (size_t j = 0; j < N; j++) {
out[i * N + j] = (T(1.0) + w[j]) * (ss * x[i * N + j]);
}
}
}
template<typename T>
void Softmax(T* x, size_t N) {
T sum = {};
auto maxreal = std::real(x[0]);
for (size_t i = 1; i < N; ++i) {
if (std::real(x[i]) > maxreal) {
maxreal = std::real(x[i]);
}
}
for (size_t i = 0; i < N; ++i) {
x[i] = std::exp(x[i] - maxreal);
sum += x[i];
}
T scale = T(1.0) / sum;
for (size_t i = 0; i < N; ++i) {
x[i] *= scale;
}
}
template<typename T>
void Softmax(T* x, size_t N, size_t K) {
for (size_t i = 0; i < K; ++i) {
Softmax(x + i * N, N);
}
}
template <typename T>
void Softcap(float cap, T* x, size_t N) {
const T inv_cap = T{1.0} / static_cast<T>(cap);
for (size_t i = 0; i < N; ++i) {
x[i] = static_cast<T>(cap) * std::tanh(x[i] * inv_cap);
}
}
template<typename T>
void GatedGelu(const T* in, T* out, size_t N, size_t K) {
for (size_t i = 0; i < K; ++i) {
const T* x1 = in + i * 2 * N;
const T* x2 = x1 + N;
T* y = out + i * N;
for (size_t j = 0; j < N; ++j) {
y[j] = x2[j] * Gelu(x1[j]);
}
}
}
template<typename T>
void InputEmbedding(const T* w, const std::vector<int>& tokens, T scaling,
T* y, size_t N) {
HWY_ASSERT(w != nullptr);
HWY_ASSERT(y != nullptr);
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
for (size_t i = 0; i < num_tokens; ++i) {
int token = tokens[i];
memcpy(y + i * N, w + token * N, N * sizeof(y[0]));
MulByConstT(scaling, y + i * N, N);
}
}
template <typename T>
void MaskedAttention(const T* qkv, T* output, size_t num_tokens, size_t heads,
size_t qkv_dim, size_t seq_len) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t head = 0; head < heads; ++head) {
const size_t qoffset = pos * (heads + 2) * qkv_dim;
const size_t aoffset = pos * heads * seq_len + head * seq_len;
const T* q = qkv + qoffset + head * qkv_dim;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const T* k = qkv + (pos2 * (heads + 2) + heads) * qkv_dim;
output[aoffset + pos2] = DotT(q, k, qkv_dim);
}
}
}
}
template <typename T>
void MaskedSoftmax(T* x, size_t num_tokens, size_t heads, size_t seq_len) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t head = 0; head < heads; ++head) {
size_t offset = pos * heads * seq_len + head * seq_len;
Softmax(x + offset, pos + 1);
memset(x + offset + pos + 1, 0, (seq_len - pos - 1) * sizeof(T));
}
}
}
template <typename T>
void MixByAttention(const T* qkv, const T* attention, T* output,
size_t num_tokens, size_t heads, size_t qkv_dim,
size_t seq_len) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t head = 0; head < heads; ++head) {
const T* att = &attention[pos * heads * seq_len + head * seq_len];
T* out = &output[head * qkv_dim + pos * heads * qkv_dim];
memset(out, 0, qkv_dim * sizeof(out[0]));
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
size_t v_offset = (pos2 * (heads + 2) + heads + 1) * qkv_dim;
const T* v = &qkv[v_offset];
MulByConstAndAddT(att[pos2], v, out, qkv_dim);
}
}
}
}
template <typename T>
void ApplyLayer(const LayerWeightsPtrs<T>& weights,
ForwardLayer<T>& activations, size_t num_tokens, T* output) {
const LayerConfig& layer_config = weights.layer_config;
const size_t model_dim = layer_config.model_dim;
const size_t seq_len = activations.input.Rows();
const size_t qkv_dim = layer_config.qkv_dim;
const size_t heads = layer_config.heads;
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
static const T query_scale = T(1.0) / std::sqrt(T(qkv_dim));
RMSNormT(weights.pre_attention_norm_scale.data(), activations.input.data(),
activations.pre_att_rms_out.data(), model_dim, num_tokens);
MatMulT(weights.qkv_einsum_w.data(), activations.pre_att_rms_out.data(),
activations.qkv.data(), (heads + 2) * qkv_dim, model_dim, num_tokens);
for (size_t pos = 0; pos < num_tokens; ++pos) {
T* qkv = activations.qkv.data() + pos * (heads + 2) * qkv_dim;
for (size_t h = 0; h <= heads; ++h) {
Rope(qkv + h * qkv_dim, qkv_dim, pos);
}
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
T* qkv = activations.qkv.data() + pos * (heads + 2) * qkv_dim;
MulByConstT(query_scale, qkv, heads * qkv_dim);
}
MaskedAttention(activations.qkv.data(), activations.att.data(), num_tokens,
heads, qkv_dim, seq_len);
MaskedSoftmax(activations.att.data(), num_tokens, heads, seq_len);
MixByAttention(activations.qkv.data(), activations.att.data(),
activations.att_out.data(), num_tokens, heads, qkv_dim,
seq_len);
MultiHeadMatMul(weights.attn_vec_einsum_w.data(), activations.att_out.data(),
activations.attention_out.data(), heads, model_dim, qkv_dim,
num_tokens);
AddFromT(activations.input.data(), activations.attention_out.data(),
num_tokens * model_dim);
RMSNormT(weights.pre_ffw_norm_scale.data(), activations.attention_out.data(),
activations.bf_pre_ffw_rms_out.data(), model_dim, num_tokens);
MatMulT(weights.gating_einsum_w.data(), activations.bf_pre_ffw_rms_out.data(),
activations.ffw_hidden.data(), ff_hidden_dim * 2, model_dim,
num_tokens);
GatedGelu(activations.ffw_hidden.data(), activations.ffw_hidden_gated.data(),
ff_hidden_dim, num_tokens);
MatMulT(weights.linear_w.data(), activations.ffw_hidden_gated.data(), output,
model_dim, ff_hidden_dim, num_tokens);
AddFromT(activations.attention_out.data(), output, num_tokens * model_dim);
}
template<typename T>
T CrossEntropyLoss(const T* x, const Prompt& prompt, size_t V) {
T loss = {};
const std::vector<int> tokens = prompt.tokens;
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
for (size_t i = 0; i < num_tokens; ++i) {
if (i + 1 < prompt.context_size) {
continue; // next token is part of context, don't try to predict it
}
const int next_token = tokens[i + 1];
loss += std::log(x[i * V + next_token]);
}
T scaling = -1.0 / std::log(2.0);
return loss * scaling;
}
template <typename T>
T CrossEntropyLossForwardPass(const Prompt& prompt,
const ModelWeightsPtrs<T>& weights,
ForwardPass<T>& forward) {
const ModelConfig& config = weights.weights_config;
const size_t model_dim = config.model_dim;
const size_t vocab_size = config.vocab_size;
const size_t layers = config.layer_configs.size();
const std::vector<int> tokens = prompt.tokens;
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
const T kEmbScaling = EmbeddingScaling(model_dim);
InputEmbedding(weights.embedder_input_embedding.data(), tokens, kEmbScaling,
forward.layers[0].input.data(), model_dim);
for (size_t layer = 0; layer < layers; ++layer) {
T* output = layer + 1 < layers ? forward.layers[layer + 1].input.data()
: forward.final_layer_output.data();
ApplyLayer(*weights.GetLayer(layer), forward.layers[layer], num_tokens,
output);
}
RMSNormT(weights.final_norm_scale.data(), forward.final_layer_output.data(),
forward.final_norm_output.data(), model_dim, num_tokens);
MatMulT(weights.embedder_input_embedding.data(),
forward.final_norm_output.data(), forward.logits.data(), vocab_size,
model_dim, num_tokens);
for (size_t pos = 0; pos < num_tokens; ++pos) {
if (config.final_cap > 0.0f) {
Softcap(config.final_cap, forward.logits.data() + pos * vocab_size,
vocab_size);
}
}
memcpy(forward.probs.data(), forward.logits.data(),
num_tokens * vocab_size * sizeof(forward.logits.At(0)));
Softmax(forward.probs.data(), vocab_size, num_tokens);
return CrossEntropyLoss(forward.probs.data(), prompt, vocab_size);
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_

View File

@ -1,155 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stddef.h>
#include <algorithm>
#include <cstdio>
#include <random>
#include <vector>
#include "gtest/gtest.h"
#include "backprop/activations.h"
#include "backprop/backward.h"
#include "backprop/forward.h"
#include "backprop/optimizer.h"
#include "backprop/prompt.h"
#include "backprop/sampler.h"
#include "compression/shared.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "gemma/weights.h"
#include "ops/ops.h"
#include "util/allocator.h"
#include "util/basics.h"
#include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
TEST(OptimizeTest, GradientDescent) {
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 1));
Allocator::Init(topology);
NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
MatMulEnv env(topology, pools);
hwy::ThreadPool& pool = pools.Pool();
std::mt19937 gen(42);
const ModelInfo info = {
.model = Model::GEMMA_TINY,
.wrapping = PromptWrapping::GEMMA_IT,
.weight = Type::kF32,
};
ModelConfig config = ConfigFromModel(info.model);
ModelWeightsStorage grad, grad_m, grad_v;
grad.Allocate(info.model, info.weight, pool);
grad_m.Allocate(info.model, info.weight, pool);
grad_v.Allocate(info.model, info.weight, pool);
grad_m.ZeroInit();
grad_v.ZeroInit();
ForwardPass<float> forward(config), backward(config);
KVCache kv_cache = KVCache::Create(config, /*prefill_tbatch_size=*/16);
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
config.layer_configs[0].qkv_dim,
config.layer_configs[0].post_qk == PostQKType::HalfRope);
Gemma gemma(GemmaTokenizer(), info, env);
const auto generate = [&](const std::vector<int>& prompt) {
std::vector<int> reply;
auto stream_token = [&reply](int token, float) {
reply.push_back(token);
return token != ReverseSequenceSampler::kEndToken;
};
RuntimeConfig runtime = {
.max_generated_tokens = 16,
.temperature = 1.0f,
.gen = &gen,
.verbosity = 0,
.stream_token = stream_token,
.eos_id = ReverseSequenceSampler::kEndToken,
};
TimingInfo timing_info;
gemma.Generate(runtime, prompt, 0, kv_cache, timing_info);
return reply;
};
// Sanity check of reply tokens.
// 1) Its length should be greater than the prompt.
// 2) The prompt should be a prefix of the reply.
auto verify = [&](const Prompt& prompt) {
const std::vector<int>& context = prompt.context();
std::vector<int> reply = generate(context);
if (reply.size() <= context.size()) return false;
return std::equal(context.begin(), context.end(), reply.begin(),
reply.begin() + context.size());
};
gemma.MutableWeights().RandInit(gen);
gemma.MutableWeights().AllocAndCopyWithTranspose(pool);
printf("Initial weights:\n");
gemma.MutableWeights().LogWeightStats();
constexpr size_t kBatchSize = 8;
const float alpha = 0.001f;
const float beta1 = 0.9f;
const float beta2 = 0.999f;
const float epsilon = 1e-8f;
ReverseSequenceSampler training_task({
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1});
size_t steps = 0;
size_t num_ok;
for (; steps < 1000000; ++steps) {
std::mt19937 sgen(42);
grad.ZeroInit();
float total_loss = 0.0f;
num_ok = 0;
for (size_t i = 0; i < kBatchSize; ++i) {
Prompt prompt = training_task.Sample(sgen);
total_loss += CrossEntropyLossForwardPass(
prompt, *gemma.Weights().GetWeightsOfType<float>(), forward,
inv_timescale, pool);
CrossEntropyLossBackwardPass(
prompt, *gemma.Weights().GetWeightsOfType<float>(), forward,
*grad.GetWeightsOfType<float>(), backward, inv_timescale, pool);
gemma.MutableWeights().CopyWithTranspose(pool);
num_ok += verify(prompt) ? 1 : 0;
}
total_loss /= kBatchSize;
AdamUpdate(info.weight, grad, alpha, beta1, beta2, epsilon, steps + 1,
gemma.Weights(), grad_m, grad_v, pool);
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
steps, total_loss, num_ok, kBatchSize);
if (steps % 100 == 0) {
printf("Batch gradient:\n");
grad.LogWeightStats();
}
if (total_loss < 0.5f) {
break;
}
}
printf("Num steps: %zu\n", steps);
printf("Final weights:\n");
gemma.MutableWeights().LogWeightStats();
EXPECT_LT(steps, 300);
EXPECT_EQ(num_ok, kBatchSize);
}
} // namespace gcpp

View File

@ -1,95 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "backprop/optimizer.h"
#include <cmath>
#include "compression/compress.h"
#include "gemma/common.h"
#include "gemma/weights.h"
#include "util/allocator.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
namespace {
class AdamUpdater {
public:
explicit AdamUpdater(float alpha, float beta1, float beta2, float epsilon,
size_t t)
: alpha_(alpha), beta1_(beta1), beta2_(beta2), cbeta1_(1.0f - beta1),
cbeta2_(1.0f - beta2), norm1_(1.0 / (1.0 - std::pow(beta1, t))),
norm2_(1.0 / (1.0 - std::pow(beta2, t))), epsilon_(epsilon) {}
void operator()(const char* name, const MatPtr& grad, MatPtr& weights,
MatPtr& grad_m, MatPtr& grad_v) {
const float* HWY_RESTRICT g = grad.data<float>();
float* HWY_RESTRICT w = weights.data<float>();
float* HWY_RESTRICT m = grad_m.data<float>();
float* HWY_RESTRICT v = grad_v.data<float>();
for (size_t i = 0; i < grad.NumElements(); ++i) {
m[i] *= beta1_;
m[i] += cbeta1_ * g[i];
v[i] *= beta2_;
v[i] += cbeta2_ * g[i] * g[i];
const float mhat = m[i] * norm1_;
const float vhat = v[i] * norm2_;
w[i] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_);
}
}
private:
float alpha_;
float beta1_;
float beta2_;
float cbeta1_;
float cbeta2_;
float norm1_;
float norm2_;
float epsilon_;
};
void AdamUpdate(ModelWeightsPtrs<float>* grad, float alpha, float beta1,
float beta2, float epsilon, size_t t,
ModelWeightsPtrs<float>* weights,
ModelWeightsPtrs<float>* grad_m,
ModelWeightsPtrs<float>* grad_v, hwy::ThreadPool& pool) {
AdamUpdater updater(alpha, beta1, beta2, epsilon, t);
ModelWeightsPtrs<float>::ForEachTensor(
{grad, weights, grad_m, grad_v}, ForEachType::kLoadNoToc,
[&updater](const char* name, hwy::Span<MatPtr*> tensors) {
updater(name, *tensors[0], *tensors[1], *tensors[2], *tensors[3]);
});
}
} // namespace
void AdamUpdate(Type weight_type, const ModelWeightsStorage& grad, float alpha,
float beta1, float beta2, float epsilon, size_t t,
const ModelWeightsStorage& weights,
const ModelWeightsStorage& grad_m,
const ModelWeightsStorage& grad_v, hwy::ThreadPool& pool) {
HWY_ASSERT(weight_type == Type::kF32);
AdamUpdate(grad.GetWeightsOfType<float>(), alpha, beta1, beta2, epsilon, t,
weights.GetWeightsOfType<float>(),
grad_m.GetWeightsOfType<float>(), grad_v.GetWeightsOfType<float>(),
pool);
}
} // namespace gcpp

View File

@ -1,33 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
#include "gemma/common.h"
#include "gemma/weights.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
void AdamUpdate(Type weight_type, const ModelWeightsStorage& grad, float alpha,
float beta1, float beta2, float epsilon, size_t t,
const ModelWeightsStorage& weights,
const ModelWeightsStorage& grad_m,
const ModelWeightsStorage& grad_v, hwy::ThreadPool& pool);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_

View File

@ -1,90 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
#include <stddef.h>
#include <stdio.h>
#include <random>
#include <vector>
#include "backprop/prompt.h"
namespace gcpp {
class PromptSampler {
public:
virtual Prompt Sample(std::mt19937& gen) = 0;
virtual ~PromptSampler() = default;
std::vector<Prompt> SampleBatch(size_t batch_size, std::mt19937& gen) {
std::vector<Prompt> batch;
batch.reserve(batch_size);
for (size_t i = 0; i < batch_size; ++i) {
batch.emplace_back(Sample(gen));
}
return batch;
}
};
class ReverseSequenceSampler : public PromptSampler {
public:
explicit ReverseSequenceSampler(const std::vector<int>& length_histo)
: token_dist_(0, 9) {
for (int i = 0; i < length_histo.size(); ++i) {
const int count = length_histo[i];
for (int j = 0; j < count; ++j) {
length_lut_.push_back(i + 1);
}
}
length_dist_ = std::uniform_int_distribution<>(0, length_lut_.size() - 1);
}
virtual ~ReverseSequenceSampler() = default;
static constexpr int kReverseToken = 10;
static constexpr int kEndToken = 11;
Prompt Sample(std::mt19937& gen) override {
Prompt prompt;
int len = length_lut_[length_dist_(gen)];
prompt.tokens.resize(2 * len + 2);
prompt.tokens[len] = kReverseToken;
prompt.tokens[2 * len + 1] = kEndToken;
for (size_t i = 0; i < len; ++i) {
prompt.tokens[i] = prompt.tokens[2 * len - i] = token_dist_(gen);
}
prompt.context_size = len + 1;
return prompt;
}
static void LogPrompt(const Prompt& prompt) {
static const char* kVocab[] = {
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "-->", "|",
};
for (int token : prompt.tokens) printf("%s", kVocab[token]);
printf(" [context_size: %zu]\n", prompt.context_size);
}
private:
std::uniform_int_distribution<> token_dist_;
std::uniform_int_distribution<> length_dist_;
std::vector<int> length_lut_;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_

View File

@ -1,209 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_
#include <stddef.h>
#include <cmath>
#include <complex>
#include <random>
#include <vector>
#include "gtest/gtest.h"
#include "compression/compress.h"
#include "gemma/configs.h"
#include "gemma/weights.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
template <typename T>
void RandInit(MatPtrT<T>& x, T stddev, std::mt19937& gen) {
std::normal_distribution<T> dist(0.0, stddev);
for (size_t i = 0; i < x.NumElements(); ++i) {
x.At(i) = dist(gen);
}
}
// TODO: make a member of Layer<T>.
template <typename T>
void RandInit(LayerWeightsPtrs<T>& w, T stddev, std::mt19937& gen) {
RandInit(w.pre_attention_norm_scale, stddev, gen);
RandInit(w.attn_vec_einsum_w, stddev, gen);
RandInit(w.qkv_einsum_w, stddev, gen);
RandInit(w.pre_ffw_norm_scale, stddev, gen);
RandInit(w.gating_einsum_w, stddev, gen);
RandInit(w.linear_w, stddev, gen);
}
template <typename T>
void RandInit(ModelWeightsPtrs<T>& w, T stddev, std::mt19937& gen) {
const size_t kLayers = w.c_layers.size();
RandInit(w.embedder_input_embedding, stddev, gen);
RandInit(w.final_norm_scale, stddev, gen);
for (size_t i = 0; i < kLayers; ++i) {
RandInit(*w.GetLayer(i), stddev, gen);
}
}
template <typename T, typename U>
void Complexify(const MatPtrT<T>& x, MatPtrT<std::complex<U>>& c_x) {
for (size_t i = 0; i < x.NumElements(); ++i) {
c_x.At(i) = std::complex<U>(x.At(i), 0.0);
}
}
template <typename T, typename U>
void Complexify(const LayerWeightsPtrs<T>& w, LayerWeightsPtrs<U>& c_w) {
Complexify(w.pre_attention_norm_scale, c_w.pre_attention_norm_scale);
Complexify(w.attn_vec_einsum_w, c_w.attn_vec_einsum_w);
Complexify(w.qkv_einsum_w, c_w.qkv_einsum_w);
Complexify(w.pre_ffw_norm_scale, c_w.pre_ffw_norm_scale);
Complexify(w.gating_einsum_w, c_w.gating_einsum_w);
Complexify(w.linear_w, c_w.linear_w);
}
template <typename T, typename U>
void Complexify(const ModelWeightsPtrs<T>& w, ModelWeightsPtrs<U>& c_w) {
const size_t kLayers = w.c_layers.size();
Complexify(w.embedder_input_embedding, c_w.embedder_input_embedding);
Complexify(w.final_norm_scale, c_w.final_norm_scale);
for (size_t i = 0; i < kLayers; ++i) {
Complexify(*w.GetLayer(i), *c_w.GetLayer(i));
}
}
// Somewhat duplicates ModelWeightsStorage, but that has neither double nor
// complex types allowed and it would cause code bloat to add them there.
template <typename T>
class WeightsWrapper {
public:
explicit WeightsWrapper(const ModelConfig& config)
: pool_(0), weights_(config) {
weights_.Allocate(data_, pool_);
}
const ModelWeightsPtrs<T>& get() const { return weights_; }
ModelWeightsPtrs<T>& get() { return weights_; }
void ZeroInit() { weights_.ZeroInit(); }
void CopyFrom(const WeightsWrapper<T>& other) {
weights_.CopyFrom(other.weights_);
}
private:
hwy::ThreadPool pool_;
std::vector<MatStorage> data_;
ModelWeightsPtrs<T> weights_;
};
template <typename T, typename U>
void TestNear(const MatPtrT<T>& actual, const MatPtrT<U>& expected,
double max_abs_err, double max_rel_err, int line) {
double sum0 = 0;
double sum1 = 0;
double sum01 = 0;
for (size_t i = 0; i < actual.NumElements(); ++i) {
sum0 += actual.At(i) * actual.At(i);
sum1 += expected.At(i) * expected.At(i);
sum01 += actual.At(i) * expected.At(i);
ASSERT_NEAR(actual.At(i), expected.At(i),
std::max(max_abs_err, std::abs(expected.At(i)) * max_rel_err))
<< "line: " << line << " dim=" << expected.NumElements() << " i=" << i;
}
if (sum0 > 1e-40) {
double norm_dot = sum01 / std::sqrt(sum0) / std::sqrt(sum1);
ASSERT_NEAR(norm_dot, 1.0, 1e-7)
<< "line: " << line << " sum0: " << sum0 << " sum1: " << sum1
<< " sum01: " << sum01;
}
}
// Compute gradient with the finite difference method in the complex plane.
// If f : R->R is the tested function and F : C->C is its extension on the
// complex plane so that F is complex differentiable in x, then
//
// F(x + ih) = F(x) + ih F'(x) + O(h^2) F''(x)
//
// which means that
//
// F'(x) ~= Imag(F(x + ih)) / h
//
// This method is more numerically stable than the real-valued finite difference
// method since we don't need to subtract floating point numbers that are near
// to each other.
template <typename FUNC, typename T, typename U>
void TestGradient(const MatPtrT<T>& grad, MatPtrT<std::complex<U>>& x,
FUNC func, U step, T max_abs_err, T max_rel_err, int line) {
MatStorageT<T> exp_grad("exp_grad", x.Rows(), x.Cols());
const U inv_step = 1.0 / step;
for (size_t i = 0; i < x.NumElements(); ++i) {
const U x0 = std::real(x.At(i));
const std::complex<U> x1 = std::complex<U>(x0, step);
x.At(i) = x1;
const std::complex<U> f1 = func();
exp_grad.At(i) = std::imag(f1) * inv_step;
x.At(i) = x0;
}
TestNear(grad, exp_grad, max_abs_err, max_rel_err, line);
}
template <typename FUNC>
void TestGradient(const MatPtrT<float>& grad, MatPtrT<std::complex<float>>& x,
FUNC func, float max_abs_err, float max_rel_error, int line) {
TestGradient(grad, x, func, 1e-30f, max_abs_err, max_rel_error, line);
}
template <typename FUNC, typename T>
void TestGradient(const MatPtrT<T>& grad, MatPtrT<std::complex<double>>& x,
FUNC func, T max_abs_err, T max_rel_error, int line) {
TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line);
}
template <typename T, typename U, typename FUNC>
void TestGradient(const LayerWeightsPtrs<T>& grad,
LayerWeightsPtrs<U>& c_weights, FUNC func, T max_err) {
TestGradient(grad.pre_attention_norm_scale,
c_weights.pre_attention_norm_scale,
func, max_err, max_err, __LINE__);
TestGradient(grad.attn_vec_einsum_w, c_weights.attn_vec_einsum_w,
func, max_err, max_err, __LINE__);
TestGradient(grad.qkv_einsum_w, c_weights.qkv_einsum_w,
func, max_err, max_err, __LINE__);
TestGradient(grad.pre_ffw_norm_scale, c_weights.pre_ffw_norm_scale,
func, max_err, max_err, __LINE__);
TestGradient(grad.gating_einsum_w, c_weights.gating_einsum_w,
func, max_err, max_err, __LINE__);
TestGradient(grad.linear_w, c_weights.linear_w,
func, max_err, max_err, __LINE__);
}
template <typename T, typename U, typename FUNC>
void TestGradient(const ModelWeightsPtrs<T>& grad,
ModelWeightsPtrs<U>& c_weights, FUNC func, T max_err) {
TestGradient(grad.embedder_input_embedding,
c_weights.embedder_input_embedding,
func, 2 * max_err, max_err, __LINE__);
TestGradient(grad.final_norm_scale, c_weights.final_norm_scale,
func, max_err, max_err, __LINE__);
for (size_t i = 0; i < grad.c_layers.size(); ++i) {
TestGradient(*grad.GetLayer(i), *c_weights.GetLayer(i), func, max_err);
}
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_

3
build/.gitignore vendored
View File

@ -1,3 +0,0 @@
*
!.gitignore
!.hgignore

View File

@ -1,4 +1,4 @@
# Weight compression, I/O and analysis
# Weight compression and analysis.
package(
default_applicable_licenses = [
@ -20,78 +20,11 @@ config_setting(
visibility = ["//visibility:private"],
)
FILE_DEPS = select({
"//conditions:default": [
# Placeholder for io deps, do not remove
],
":android": [],
# Placeholder for internal build rules, do not remove
})
cc_library(
name = "io",
srcs = [
"io.cc",
# Placeholder for io backend, do not remove
],
hdrs = ["io.h"],
local_defines = select({
# Placeholder for internal build rules, do not remove
"//conditions:default": [],
}),
deps = [
"@highway//:hwy",
] + FILE_DEPS,
)
cc_library(
name = "fields",
srcs = ["fields.cc"],
hdrs = ["fields.h"],
deps = [
"@highway//:hwy",
],
)
cc_test(
name = "fields_test",
srcs = ["fields_test.cc"],
deps = [
":fields",
"@googletest//:gtest_main", # buildcleaner: keep
"@highway//:hwy_test_util",
],
)
cc_library(
name = "blob_store",
srcs = ["blob_store.cc"],
hdrs = ["blob_store.h"],
deps = [
":io",
"@highway//:hwy",
"@highway//:thread_pool",
],
)
cc_test(
name = "blob_store_test",
srcs = ["blob_store_test.cc"],
deps = [
":blob_store",
":io",
"@googletest//:gtest_main", # buildcleaner: keep
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:thread_pool",
],
)
cc_library(
name = "distortion",
hdrs = [
"distortion.h",
"shared.h",
"types.h",
],
deps = [
"//:basics",
@ -115,21 +48,29 @@ cc_test(
)
cc_library(
name = "sfp",
hdrs = ["shared.h"],
textual_hdrs = ["sfp-inl.h"],
name = "types",
hdrs = ["types.h"],
deps = [
"//:basics",
"@highway//:hwy",
],
)
cc_library(
name = "sfp",
textual_hdrs = ["sfp-inl.h"],
deps = [
":types",
"@highway//:hwy",
],
)
cc_library(
name = "nuq",
hdrs = ["shared.h"],
textual_hdrs = ["nuq-inl.h"],
deps = [
":sfp",
":types",
"//:basics",
"@highway//:hwy",
"@highway//hwy/contrib/sort:vqsort",
@ -144,8 +85,10 @@ cc_library(
deps = [
":compress",
":distortion",
"//:mat",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:thread_pool",
],
)
@ -153,7 +96,6 @@ cc_test(
name = "sfp_test",
size = "small",
srcs = ["sfp_test.cc"],
features = ["fully_static_link"],
linkstatic = True,
local_defines = ["HWY_IS_TEST"],
# for test_suite.
@ -174,7 +116,6 @@ cc_test(
size = "small",
timeout = "long",
srcs = ["nuq_test.cc"],
features = ["fully_static_link"],
linkstatic = True,
local_defines = ["HWY_IS_TEST"],
# for test_suite.
@ -182,7 +123,6 @@ cc_test(
deps = [
":distortion",
":nuq",
":sfp",
"@googletest//:gtest_main", # buildcleaner: keep
"//:test_util",
"@highway//:hwy",
@ -196,21 +136,18 @@ cc_library(
srcs = ["compress.cc"],
hdrs = [
"compress.h",
"shared.h",
"types.h",
],
textual_hdrs = ["compress-inl.h"],
deps = [
":blob_store",
":distortion",
":fields",
":io",
":nuq",
":sfp",
"//:allocator",
"//:basics",
"//:common",
"//:mat",
"@highway//:hwy",
"@highway//:nanobenchmark",
"@highway//:profiler",
"@highway//:stats",
"@highway//:thread_pool",
],
@ -221,7 +158,6 @@ cc_test(
size = "small",
timeout = "long",
srcs = ["compress_test.cc"],
features = ["fully_static_link"],
linkstatic = True,
local_defines = ["HWY_IS_TEST"],
# for test_suite.
@ -245,51 +181,10 @@ cc_library(
deps = [
":nuq",
":sfp",
":types",
"@highway//:hwy",
"@highway//:stats",
"@highway//:thread_pool",
"@highway//hwy/contrib/sort:vqsort",
],
)
cc_binary(
name = "compress_weights",
srcs = ["compress_weights.cc"],
deps = [
":compress",
":io",
"//:allocator",
"//:args",
"//:common",
"//:tokenizer",
"//:weights",
"@highway//:hwy",
"@highway//:thread_pool",
],
)
cc_binary(
name = "blob_compare",
srcs = ["blob_compare.cc"],
deps = [
":blob_store",
":io",
"//:allocator",
"//:basics",
"//:threading",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
],
)
cc_binary(
name = "migrate_weights",
srcs = ["migrate_weights.cc"],
deps = [
"//:app",
"//:args",
"//:benchmark_helper",
"//:gemma_lib",
],
)

View File

@ -26,7 +26,7 @@
#include <cstdlib> // std::abs
#include <vector>
#include "compression/shared.h"
#include "compression/types.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/stats.h"

View File

@ -1,230 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include <atomic>
#include <vector>
#include "compression/blob_store.h"
#include "compression/io.h" // Path
#include "util/allocator.h"
#include "util/basics.h" // IndexRange
#include "util/threading.h"
#include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h"
#include "hwy/timer.h"
namespace gcpp {
using KeySpan = hwy::Span<const hwy::uint128_t>;
// Returns false if any keys differ, because then blobs are not comparable.
bool CompareKeys(const BlobReader& reader1, const BlobReader& reader2) {
KeySpan keys1 = reader1.Keys();
KeySpan keys2 = reader2.Keys();
if (keys1.size() != keys2.size()) {
fprintf(stderr, "#keys mismatch: %zu vs %zu\n", keys1.size(), keys2.size());
return false;
}
for (size_t i = 0; i < keys1.size(); ++i) {
if (keys1[i] != keys2[i]) {
fprintf(stderr, "key %zu mismatch: %s vs %s\n", i,
StringFromKey(keys1[i]).c_str(), StringFromKey(keys2[i]).c_str());
return false;
}
}
return true;
}
// Total amount to allocate for all blobs.
size_t TotalBytes(BlobReader& reader) {
size_t total_bytes = 0;
for (const hwy::uint128_t key : reader.Keys()) {
total_bytes += reader.BlobSize(key);
}
return total_bytes;
}
using BytePtr = hwy::AlignedFreeUniquePtr<uint8_t[]>;
using ByteSpan = hwy::Span<uint8_t>; // Sections within BytePtr
using BlobVec = std::vector<ByteSpan>; // in order of keys
// Allocates memory within the single allocation and updates `pos`.
BlobVec ReserveMemory(BlobReader& reader, BytePtr& all_blobs, size_t& pos) {
BlobVec blobs;
for (const hwy::uint128_t key : reader.Keys()) {
const size_t bytes = reader.BlobSize(key);
blobs.push_back(ByteSpan(all_blobs.get() + pos, bytes));
pos += bytes;
}
return blobs;
}
// Reads one set of blobs in parallel (helpful if in disk cache).
void ReadBlobs(BlobReader& reader, BlobVec& blobs, hwy::ThreadPool& pool) {
HWY_ASSERT(reader.Keys().size() == blobs.size());
for (size_t i = 0; i < blobs.size(); ++i) {
reader.Enqueue(reader.Keys()[i], blobs[i].data(), blobs[i].size());
}
const BlobError err = reader.ReadAll(pool);
if (err != 0) {
HWY_ABORT("Parallel read failed: %d\n", err);
}
}
// Parallelizes ReadBlobs across (two) packages, if available.
void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2, size_t total_bytes,
BlobVec& blobs1, BlobVec& blobs2, NestedPools& pools) {
const double t0 = hwy::platform::Now();
fprintf(stderr, "Reading %zu GiB, %zux%zu cores: ", total_bytes >> 30,
pools.AllPackages().NumWorkers(), pools.Pool().NumWorkers());
pools.AllPackages().Run(0, 2, [&](size_t task, size_t pkg_idx) {
ReadBlobs(task ? reader2 : reader1, task ? blobs2 : blobs1,
pools.Pool(pkg_idx));
});
const double t1 = hwy::platform::Now();
fprintf(stderr, "%.1f GB/s\n", total_bytes / (t1 - t0) * 1E-9);
}
// Returns number of elements with a mismatch. For float and bf16 blobs, uses
// L1 and relative error, otherwise byte-wise comparison.
size_t BlobDifferences(const ByteSpan& data1, const ByteSpan& data2,
const hwy::uint128_t key) {
if (data1.size() != data2.size() || data1.size() == 0) {
HWY_ABORT("key %s size mismatch: %zu vs %zu\n", StringFromKey(key).c_str(),
data1.size(), data2.size());
}
size_t mismatches = 0;
char type;
hwy::CopyBytes(&key, &type, 1);
if (type == 'F') {
HWY_ASSERT(data1.size() % sizeof(float) == 0);
for (size_t j = 0; j < data1.size(); j += sizeof(float)) {
float f1, f2;
hwy::CopyBytes(&data1[j], &f1, sizeof(f1));
hwy::CopyBytes(&data2[j], &f2, sizeof(f2));
const float l1 = hwy::ScalarAbs(f1 - f2);
const float rel = hwy::ScalarAbs(f1) == 0.0f ? 0.0f : l1 / f1;
if (l1 > 1E-3f || rel > 1E-2f) {
fprintf(stderr, "key %s %5zu: L1 %.5f rel %.4f\n",
StringFromKey(key).c_str(), j, l1, rel);
++mismatches;
}
}
} else if (type == 'B') {
for (size_t j = 0; j < data1.size(); j += sizeof(hwy::bfloat16_t)) {
hwy::bfloat16_t b1, b2;
hwy::CopyBytes(&data1[j], &b1, sizeof(b1));
hwy::CopyBytes(&data2[j], &b2, sizeof(b2));
const float f1 = hwy::ConvertScalarTo<float>(b1);
const float f2 = hwy::ConvertScalarTo<float>(b2);
const float l1 = hwy::ScalarAbs(f1 - f2);
const float rel = hwy::ScalarAbs(f1) == 0.0f ? 0.0f : l1 / f1;
if (l1 > 1E-2f || rel > 1E-1f) {
fprintf(stderr, "key %s %5zu: L1 %.5f rel %.4f\n",
StringFromKey(key).c_str(), j, l1, rel);
++mismatches;
}
}
} else {
for (size_t j = 0; j < data1.size(); ++j) {
if (data1[j] != data2[j]) {
if (mismatches == 0) {
fprintf(stderr, "key %s mismatch at byte %5zu\n",
StringFromKey(key).c_str(), j);
}
++mismatches;
}
}
}
return mismatches;
}
void CompareBlobs(const KeySpan& keys, BlobVec& blobs1, BlobVec& blobs2,
size_t total_bytes, NestedPools& pools) {
fprintf(stderr, "Comparing %zu blobs in parallel: ", keys.size());
const double t0 = hwy::platform::Now();
std::atomic<size_t> blobs_equal{};
std::atomic<size_t> blobs_diff{};
const IndexRangePartition ranges = StaticPartition(
IndexRange(0, keys.size()), pools.AllPackages().NumWorkers(), 1);
ParallelizeOneRange(
ranges, pools.AllPackages(),
[&](const IndexRange& range, size_t pkg_idx) {
pools.Pool(pkg_idx).Run(
range.begin(), range.end(), [&](size_t i, size_t /*thread*/) {
const size_t mismatches =
BlobDifferences(blobs1[i], blobs2[i], keys[i]);
if (mismatches != 0) {
fprintf(stderr, "key %s has %zu mismatches in %zu bytes!\n",
StringFromKey(keys[i]).c_str(), mismatches,
blobs1[i].size());
blobs_diff.fetch_add(1);
} else {
blobs_equal.fetch_add(1);
}
});
});
const double t1 = hwy::platform::Now();
fprintf(stderr, "%.1f GB/s; total blob matches=%zu, mismatches=%zu\n",
total_bytes / (t1 - t0) * 1E-9, blobs_equal.load(),
blobs_diff.load());
}
// Compares two sbs files, including blob order.
void ReadAndCompareBlobs(const char* path1, const char* path2) {
// Open files.
BlobReader reader1;
BlobReader reader2;
const BlobError err1 = reader1.Open(Path(path1));
const BlobError err2 = reader2.Open(Path(path2));
if (err1 != 0 || err2 != 0) {
HWY_ABORT("Failed to open files: %s %s: %d %d\n", path1, path2, err1, err2);
}
if (!CompareKeys(reader1, reader2)) return;
// Single allocation, avoid initializing the memory.
BoundedTopology topology;
Allocator::Init(topology);
NestedPools pools(topology);
const size_t total_bytes = TotalBytes(reader1) + TotalBytes(reader2);
BytePtr all_blobs = hwy::AllocateAligned<uint8_t>(total_bytes);
size_t pos = 0;
BlobVec blobs1 = ReserveMemory(reader1, all_blobs, pos);
BlobVec blobs2 = ReserveMemory(reader2, all_blobs, pos);
ReadBothBlobs(reader1, reader2, total_bytes, blobs1, blobs2, pools);
CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, pools);
}
} // namespace gcpp
int main(int argc, char** argv) {
if (argc != 3) {
HWY_ABORT("Usage: %s <sbs_path> <sbs_path>\n", argv[0]);
}
if (strcmp(argv[1], argv[2]) == 0) {
HWY_ABORT("Filenames are the same, skipping comparison: %s\n", argv[1]);
}
gcpp::ReadAndCompareBlobs(argv[1], argv[2]);
return 0;
}

View File

@ -1,341 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "compression/blob_store.h"
#include <stddef.h>
#include <stdint.h>
#include <atomic>
#include <cstdio>
#include <memory>
#include <string>
#include <vector>
#include "compression/io.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/detect_compiler_arch.h"
namespace gcpp {
hwy::uint128_t MakeKey(const char* string) {
size_t length = 0;
for (size_t i = 0; string[i] != '\0'; ++i) {
++length;
}
if (length > 16) {
HWY_ABORT("Key %s is too long, please truncate to 16 chars.", string);
}
hwy::uint128_t ret;
hwy::ZeroBytes<sizeof(ret)>(&ret);
hwy::CopyBytes(string, &ret, length);
return ret;
}
std::string StringFromKey(hwy::uint128_t key) {
std::string name(sizeof(key) + 1, '\0');
hwy::CopyBytes(&key, name.data(), sizeof(key));
name.resize(name.find('\0'));
return name;
}
namespace {
void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data,
std::vector<BlobIO>& requests) {
// Split into chunks for load-balancing even if blob sizes vary.
constexpr size_t kChunkSize = 4 * 1024 * 1024; // bytes
// Split into whole chunks and possibly one remainder.
uint64_t pos = 0;
if (size >= kChunkSize) {
for (; pos <= size - kChunkSize; pos += kChunkSize) {
requests.emplace_back(offset + pos, kChunkSize, data + pos, 0);
}
}
if (pos != size) {
requests.emplace_back(offset + pos, size - pos, data + pos, 0);
}
}
} // namespace
static_assert(HWY_IS_LITTLE_ENDIAN, "Assumes little endian");
// On-disk representation (little-endian).
//
// Deliberately omits a version number because this file format is unchanging.
// Additional data may be added only inside new blobs. Changes to the blob
// contents or type should be handled by renaming keys.
#pragma pack(push, 1)
class BlobStore {
static constexpr uint32_t kMagic = 0x0A534253; // SBS\n
public:
// NOT including padding, so that we can also use ZeroFillPadding after
// copying the header.
static constexpr size_t HeaderSize(size_t num_blobs) {
// 16-byte fixed fields plus per-blob: 16-byte key, 16-byte offset/size.
return 16 + 32 * num_blobs;
}
// Returns how many bytes to allocate for the header without the subsequent
// blobs. Requires num_blobs_ to already be set, typically by reading
// sizeof(BlobStore) bytes from disk.
size_t PaddedHeaderSize() const {
return hwy::RoundUpTo(HeaderSize(num_blobs_), kBlobAlign);
}
// Returns aligned offset and zero-fills between that and `offset`.
uint64_t ZeroFillPadding(uint64_t offset) {
uint8_t* const bytes = reinterpret_cast<uint8_t*>(this);
const uint64_t padded = hwy::RoundUpTo(offset, kBlobAlign);
hwy::ZeroBytes(bytes + offset, padded - offset);
return padded;
}
BlobError CheckValidity(const uint64_t file_size) {
if (magic_ != kMagic) return __LINE__;
if (num_blobs_ == 0) return __LINE__;
if (file_size_ != file_size) return __LINE__;
// Ensure blobs are back to back, and zero-pad.
uint64_t offset = ZeroFillPadding(HeaderSize(num_blobs_));
for (size_t i = 0; i < num_blobs_; ++i) {
const hwy::uint128_t val = keys_[num_blobs_ + i];
if (val.lo != offset) return __LINE__;
offset = hwy::RoundUpTo(offset + val.hi, kBlobAlign);
}
if (offset != file_size_) return __LINE__;
return 0; // all OK
}
static BlobStorePtr Allocate(uint64_t total_size) {
uint8_t* bytes =
static_cast<uint8_t*>(hwy::AllocateAlignedBytes(total_size));
if (!bytes) return BlobStorePtr();
return BlobStorePtr(new (bytes) BlobStore(), hwy::AlignedFreer());
}
static std::vector<BlobIO> PrepareWriteRequests(
const hwy::uint128_t keys[], const hwy::Span<const uint8_t> blobs[],
size_t num_blobs, BlobStore* bs) {
// Sanity check and ensure the cast below is safe.
HWY_ASSERT(num_blobs < (1ULL << 20));
// Allocate var-length header.
const size_t header_size = HeaderSize(num_blobs);
const size_t padded_header_size = hwy::RoundUpTo(header_size, kBlobAlign);
const uint64_t padded_header_end = bs->ZeroFillPadding(header_size);
HWY_ASSERT(padded_header_end == padded_header_size);
// All-zero buffer used to write padding to the file without copying the
// input blobs.
static uint8_t zeros[kBlobAlign] = {0};
// Total file size will be the header plus all padded blobs.
uint64_t payload = 0;
for (size_t i = 0; i < num_blobs; ++i) {
payload += hwy::RoundUpTo(blobs[i].size(), kBlobAlign);
}
const size_t total_size = padded_header_size + payload;
// Fill header.
bs->magic_ = kMagic;
bs->num_blobs_ = static_cast<uint32_t>(num_blobs);
bs->file_size_ = total_size;
hwy::CopyBytes(keys, bs->keys_, num_blobs * sizeof(keys[0]));
// First IO request is for the header (not yet filled!).
std::vector<BlobIO> requests;
requests.reserve(1 + 2 * num_blobs);
requests.emplace_back(/*offset=*/0, padded_header_size,
reinterpret_cast<uint8_t*>(bs), 0);
// Fill second half of keys_ with offset/size and prepare IO requests.
uint64_t offset = padded_header_end;
for (size_t i = 0; i < num_blobs; ++i) {
bs->keys_[num_blobs + i].lo = offset;
bs->keys_[num_blobs + i].hi = blobs[i].size();
EnqueueChunkRequests(offset, blobs[i].size(),
const_cast<uint8_t*>(blobs[i].data()), requests);
offset += blobs[i].size();
const size_t padded_size = hwy::RoundUpTo(blobs[i].size(), kBlobAlign);
if (padded_size != blobs[i].size()) {
const size_t padding = padded_size - blobs[i].size();
HWY_ASSERT(padding <= kBlobAlign);
requests.emplace_back(offset, padding, zeros, 0);
offset += padding;
}
}
HWY_ASSERT(offset == total_size);
return requests;
}
bool FindKey(const hwy::uint128_t key, uint64_t& offset, size_t& size) const {
for (size_t i = 0; i < num_blobs_; ++i) {
if (keys_[i] == key) {
const hwy::uint128_t val = keys_[num_blobs_ + i];
offset = val.lo;
size = val.hi;
return true;
}
}
return false;
}
hwy::Span<const hwy::uint128_t> Keys() const {
return hwy::Span<const hwy::uint128_t>(keys_, num_blobs_);
}
private:
uint32_t magic_;
uint32_t num_blobs_; // never 0
uint64_t file_size_; // must match actual size of file
hwy::uint128_t keys_[1]; // length: 2 * num_blobs
// Padding, then the blob identified by keys[0], then padding etc.
};
#pragma pack(pop)
BlobError BlobReader::Open(const Path& filename) {
file_ = OpenFileOrNull(filename, "r");
if (!file_) return __LINE__;
// Read first part of header to get actual size.
BlobStore bs;
if (!file_->Read(0, sizeof(bs), &bs)) return __LINE__;
const size_t padded_size = bs.PaddedHeaderSize();
HWY_ASSERT(padded_size >= sizeof(bs));
// Allocate full header.
blob_store_ = BlobStore::Allocate(padded_size);
if (!blob_store_) return __LINE__;
// Copy what we already read (more efficient than seek + re-read).
hwy::CopySameSize(&bs, blob_store_.get());
// Read the rest of the header, but not the full file.
uint8_t* bytes = reinterpret_cast<uint8_t*>(blob_store_.get());
if (!file_->Read(sizeof(bs), padded_size - sizeof(bs), bytes + sizeof(bs))) {
return __LINE__;
}
return blob_store_->CheckValidity(file_->FileSize());
}
size_t BlobReader::BlobSize(hwy::uint128_t key) const {
uint64_t offset;
size_t size;
if (!blob_store_->FindKey(key, offset, size)) return 0;
return size;
}
BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) {
uint64_t offset;
size_t actual_size;
if (!blob_store_->FindKey(key, offset, actual_size)) return __LINE__;
if (actual_size != size) {
fprintf(stderr,
"Mismatch between expected %d and actual %d KiB size of blob %s. "
"Please see README.md on how to update the weights.\n",
static_cast<int>(size >> 10), static_cast<int>(actual_size >> 10),
StringFromKey(key).c_str());
return __LINE__;
}
EnqueueChunkRequests(offset, actual_size, reinterpret_cast<uint8_t*>(data),
requests_);
return 0;
}
// Parallel synchronous I/O. Alternatives considered:
// - readv is limited to 0x7FFFF000 bytes on Linux (even 64-bit). Note that
// pread calls preadv with a single iovec.
// - O_DIRECT seems undesirable because we do want to use the OS cache
// between consecutive runs.
// - memory-mapped I/O is less predictable and adds noise to measurements.
BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
File* pfile = file_.get(); // not owned
const auto& requests = requests_;
std::atomic_flag err = ATOMIC_FLAG_INIT;
// >5x speedup from parallel reads when cached.
pool.Run(0, requests.size(),
[pfile, &requests, &err](uint64_t i, size_t /*thread*/) {
if (!pfile->Read(requests[i].offset, requests[i].size,
requests[i].data)) {
fprintf(stderr, "Failed to read blob %zu\n",
static_cast<size_t>(i));
err.test_and_set();
}
});
if (err.test_and_set()) return __LINE__;
return 0;
}
BlobError BlobReader::ReadOne(hwy::uint128_t key, void* data,
size_t size) const {
uint64_t offset;
size_t actual_size;
if (!blob_store_->FindKey(key, offset, actual_size)) return __LINE__;
if (actual_size != size) {
fprintf(stderr,
"Mismatch between expected %d and actual %d KiB size of blob %s. "
"Please see README.md on how to update the weights.\n",
static_cast<int>(size >> 10), static_cast<int>(actual_size >> 10),
StringFromKey(key).c_str());
return __LINE__;
}
if (!file_->Read(offset, actual_size, data)) {
return __LINE__;
}
return 0;
}
hwy::Span<const hwy::uint128_t> BlobReader::Keys() const {
return blob_store_->Keys();
}
BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
HWY_ASSERT(keys_.size() == blobs_.size());
// Concatenate blobs in memory.
const size_t header_size = BlobStore::HeaderSize(keys_.size());
const size_t padded_header_size = hwy::RoundUpTo(header_size, kBlobAlign);
const BlobStorePtr bs = BlobStore::Allocate(padded_header_size);
const std::vector<BlobIO> requests = BlobStore::PrepareWriteRequests(
keys_.data(), blobs_.data(), keys_.size(), bs.get());
// Create/replace existing file.
std::unique_ptr<File> file = OpenFileOrNull(filename, "w+");
if (!file) return __LINE__;
File* pfile = file.get(); // not owned
std::atomic_flag err = ATOMIC_FLAG_INIT;
pool.Run(0, requests.size(),
[pfile, &requests, &err](uint64_t i, size_t /*thread*/) {
if (!pfile->Write(requests[i].data, requests[i].size,
requests[i].offset)) {
err.test_and_set();
}
});
if (err.test_and_set()) return __LINE__;
return 0;
}
} // namespace gcpp

View File

@ -1,117 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_BLOB_STORE_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_BLOB_STORE_H_
#include <stddef.h>
#include <stdint.h>
#include <memory>
#include <string>
#include <vector>
#include "compression/io.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::uint128_t
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
// Convenient way to construct a key from a string (<= 16 chars).
hwy::uint128_t MakeKey(const char* string);
// Returns a string from a key.
std::string StringFromKey(hwy::uint128_t key);
// Ordered list of opaque blobs (~hundreds), identified by unique opaque
// 128-bit keys.
class BlobStore;
// Incomplete type, so dtor will not be called.
using BlobStorePtr = hwy::AlignedFreeUniquePtr<BlobStore>;
// 0 if successful, otherwise the line number of the failing check.
using BlobError = int;
// Blob offsets on disk and memory addresses are a multiple of this, because
// we pad the header and each blob's size. This matches CUDA alignment and the
// maximum SVE vector size, and exceeds typical x86 cache line sizes (64 or
// 128), which can help performance.
static constexpr size_t kBlobAlign = 256;
// One I/O request, serviced by threads in a pool.
struct BlobIO {
BlobIO(uint64_t offset, size_t size, void* data, uint64_t padding)
: offset(offset), size(size), data(data), padding(padding) {}
uint64_t offset;
size_t size; // bytes
void* data;
uint64_t padding;
};
class BlobReader {
public:
BlobReader() { requests_.reserve(500); }
~BlobReader() = default;
// Opens `filename` and reads its header.
BlobError Open(const Path& filename);
// Returns the size of the blob identified by `key`, or 0 if not found.
size_t BlobSize(hwy::uint128_t key) const;
// Enqueues read requests if `key` is found and its size matches `size`, which
// is in units of bytes.
BlobError Enqueue(hwy::uint128_t key, void* data, size_t size);
// Reads all enqueued requests.
BlobError ReadAll(hwy::ThreadPool& pool);
// Reads one blob directly.
BlobError ReadOne(hwy::uint128_t key, void* data, size_t size) const;
// Returns all available blob keys.
hwy::Span<const hwy::uint128_t> Keys() const;
private:
BlobStorePtr blob_store_; // holds header, not the entire file
std::vector<BlobIO> requests_;
std::unique_ptr<File> file_;
};
class BlobWriter {
public:
// `size` is in bytes.
void Add(hwy::uint128_t key, const void* data, size_t size) {
keys_.push_back(key);
blobs_.emplace_back(static_cast<const uint8_t*>(data), size);
}
// Stores all blobs to disk in the given order with padding for alignment.
BlobError WriteAll(hwy::ThreadPool& pool, const Path& filename);
// Returns the number of blobs added.
size_t DebugNumBlobsAdded() const { return keys_.size(); }
private:
std::vector<hwy::uint128_t> keys_;
std::vector<hwy::Span<const uint8_t>> blobs_;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_BLOB_STORE_H_

View File

@ -1,86 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "compression/blob_store.h"
#include <stdio.h>
#include <algorithm>
#include <array>
#include "compression/io.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/tests/hwy_gtest.h"
#include "hwy/tests/test_util-inl.h" // HWY_ASSERT_EQ
namespace gcpp {
namespace {
#if !HWY_TEST_STANDALONE
class BlobStoreTest : public testing::Test {};
#endif
#if !HWY_OS_WIN
TEST(BlobStoreTest, TestReadWrite) {
static const std::array<float, 4> kOriginalData = {-1, 0, 3.14159, 2.71828};
// mkstemp will modify path_str so it holds a newly-created temporary file.
char path_str[] = "/tmp/blob_store_test.sbs-XXXXXX";
const int fd = mkstemp(path_str);
HWY_ASSERT(fd > 0);
hwy::ThreadPool pool(4);
const Path path(path_str);
std::array<float, 4> buffer = kOriginalData;
const hwy::uint128_t keyA = MakeKey("0123456789abcdef");
const hwy::uint128_t keyB = MakeKey("q");
BlobWriter writer;
writer.Add(keyA, "DATA", 5);
writer.Add(keyB, buffer.data(), sizeof(buffer));
HWY_ASSERT_EQ(writer.WriteAll(pool, path), 0);
HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), buffer.data(), buffer.size());
std::fill(buffer.begin(), buffer.end(), 0);
BlobReader reader;
HWY_ASSERT_EQ(reader.Open(path), 0);
HWY_ASSERT_EQ(reader.BlobSize(keyA), 5);
HWY_ASSERT_EQ(reader.BlobSize(keyB), sizeof(buffer));
HWY_ASSERT_EQ(reader.Enqueue(keyB, buffer.data(), sizeof(buffer)), 0);
HWY_ASSERT_EQ(reader.ReadAll(pool), 0);
HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), buffer.data(), buffer.size());
{
std::array<char, 5> buffer;
HWY_ASSERT(reader.ReadOne(keyA, buffer.data(), 1) != 0);
HWY_ASSERT_EQ(reader.ReadOne(keyA, buffer.data(), 5), 0);
HWY_ASSERT_STRING_EQ("DATA", buffer.data());
}
const hwy::Span<const hwy::uint128_t> keys = reader.Keys();
HWY_ASSERT_EQ(keys.size(), 2);
HWY_ASSERT_EQ(keys[0], keyA);
HWY_ASSERT_EQ(keys[1], keyB);
close(fd);
unlink(path_str);
}
#endif
} // namespace
} // namespace gcpp
HWY_TEST_MAIN();

View File

@ -21,19 +21,20 @@
#include <stdint.h>
#include <stdio.h>
#include <cmath> // lroundf, only if COMPRESS_STATS
#include <string>
#include <memory>
#include <vector>
#include "compression/blob_store.h"
#include "compression/compress.h" // IWYU pragma: export
#include "compression/distortion.h"
#include "gemma/configs.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/timer.h"
#if COMPRESS_STATS
#include <cmath> // lroundf
#endif
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_INL_H_
// Include guard for (potentially) SIMD code.
@ -64,8 +65,8 @@ static constexpr bool kIsTest = false;
template <typename T> // primary, must specialize
struct CompressTraits {};
// Used by backprop/, where weights are currently f32; also MatMul for f32
// weights or activations, if native `ReorderWidenMulAccumulate` is available.
// Used by MatMul for f32 weights or activations, if native
// `ReorderWidenMulAccumulate` is available.
template <>
struct CompressTraits<float> {
using Packed = float;
@ -379,7 +380,7 @@ struct CompressTraits<SfpStream> {
using Packed = SfpStream;
// Callers are responsible for scaling `raw` such that its magnitudes do not
// exceed `SfpStream::kMax`. See CompressedArray::scale().
// exceed `SfpStream::kMax`. See CompressedArray::Scale().
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw,
size_t num, CompressPerThread& tls,
@ -387,7 +388,7 @@ struct CompressTraits<SfpStream> {
const size_t packed_ofs) {
SfpCodec::Enc(df, raw, num, packed.ptr + packed_ofs);
if (COMPRESS_STATS) {
if constexpr (COMPRESS_STATS) {
const hn::Repartition<BF16, DF> dbf;
auto distorted =
hwy::AllocateAligned<BF16>(hwy::RoundUpTo(num, hn::Lanes(dbf)));
@ -431,9 +432,10 @@ struct CompressTraits<NuqStream> {
size_t num, CompressPerThread& tls,
const PackedSpan<Packed>& packed,
const size_t packed_ofs) {
NuqCodec::Enc(df, raw, num, tls.buf, packed, packed_ofs);
if (!tls.buf) tls.buf = std::make_unique<NuqStream::ClusterBuf>();
NuqCodec::Enc(df, raw, num, *tls.buf, packed, packed_ofs);
if (COMPRESS_STATS) {
if constexpr (COMPRESS_STATS) {
for (size_t i = 0; i < num; ++i) {
tls.stats.NotifyIn(static_cast<int>(lroundf(raw[i] * 100.0f + 500.0f)));
}
@ -477,7 +479,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
const size_t packed_ofs, hwy::ThreadPool& pool) {
packed.BoundsCheck(packed_ofs, num);
work.tls.resize(pool.NumWorkers());
if (COMPRESS_STATS) {
if constexpr (COMPRESS_STATS) {
for (auto& tls : work.tls) {
tls.stats.Reset();
}
@ -486,7 +488,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
const bool want_bench = COMPRESS_STATS || !kIsTest;
const double t0 = want_bench ? hwy::platform::Now() : 0.0;
using Traits = CompressTraits<Packed>;
using Traits = CompressTraits<hwy::RemoveConst<Packed>>;
constexpr size_t kBatch = 8192;
const size_t num_batches = hwy::DivCeil(num, kBatch);
pool.Run(0, num_batches,
@ -507,7 +509,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
fprintf(stderr, "Compress %.1f MB/s\n", mbps);
}
if (COMPRESS_STATS) {
if constexpr (COMPRESS_STATS) {
for (size_t i = 1; i < work.tls.size(); ++i) {
work.tls[0].stats.Assimilate(work.tls[i].stats);
}
@ -515,26 +517,25 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
}
}
// Adapter that compresses into `MatStorageT`. `raw` must already be scaled
// to fit the value range, if `Packed` is `SfpStream`.
// Same as above, but without parallelization nor benchmarking.
template <typename Packed>
HWY_INLINE void CompressScaled(const float* HWY_RESTRICT raw, size_t num,
CompressWorkingSet& work,
MatStorageT<Packed>& compressed,
hwy::ThreadPool& pool) {
Compress(raw, num, work,
MakeSpan(compressed.data(), compressed.NumElements()),
/*packed_ofs=*/0, pool);
HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
CompressPerThread& tls,
const PackedSpan<Packed>& packed,
const size_t packed_ofs) {
packed.BoundsCheck(packed_ofs, num);
using Traits = CompressTraits<hwy::RemoveConst<Packed>>;
const hn::ScalableTag<float> df;
Traits::Compress(df, raw, num, tls, packed, packed_ofs);
}
// Stores two f32 vectors to f32 or bf16; avoids duplicating RMSNorm and
// RMSNormInplace for the two output types.
// Stores two f32 vectors to f32 or bf16.
template <class DF, typename Packed, HWY_IF_F32_D(DF), class VF = hn::Vec<DF>>
void Compress2(DF df, VF raw0, VF raw1, const PackedSpan<Packed>& packed,
const size_t packed_ofs) {
static_assert(hwy::IsSameEither<Packed, float, BF16>());
packed.BoundsCheck(packed_ofs, 2 * hn::Lanes(df));
using Traits = CompressTraits<Packed>;
using Traits = CompressTraits<hwy::RemoveConst<Packed>>;
Traits::Store2(df, raw0, raw1, packed, packed_ofs);
}
@ -566,7 +567,7 @@ HWY_INLINE void Decompress2(DRaw d, const PackedSpan<Packed>& packed,
// Decompresses from any type of `packed`, starting at (any) `packed_ofs`, to
// (any) `num` elements in `raw`, then appends `[0, hn::Lanes(d))` zeroes as
// required to round `num` up to one vector, if it is not already. The caller is
// responsible for scaling `raw` to the original range because `EmbedToken`
// responsible for scaling `raw` to the original range because `EmbedMMToken`
// also wants to scale the decompressed elements.
// `TRaw` can be `float/BF16`, or `double` if `Packed` is `float`.
template <class DRaw, typename Packed, typename TRaw = hn::TFromD<DRaw>>
@ -708,51 +709,6 @@ HWY_INLINE float DecompressAndCall(D, const PackedSpan<const VT> v,
comp3);
}
// Functor called for each tensor, which compresses and stores them along with
// their scaling factors to BlobStore.
class Compressor {
public:
explicit Compressor(hwy::ThreadPool& pool) : writer_(pool) {}
template <typename Packed>
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name,
const float* HWY_RESTRICT weights) {
size_t num_weights = compressed->NumElements();
if (num_weights == 0 || weights == nullptr || compressed->Ptr() == nullptr)
return;
size_t num_compressed = compressed->NumElements();
PackedSpan<Packed> packed = MakeSpan(compressed->data(), num_compressed);
fprintf(stderr, "Compressing %s (%zuM), please wait\n", decorated_name,
num_weights / (1000 * 1000));
Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0,
writer_.pool());
writer_(compressed, decorated_name);
}
void AddTokenizer(const std::string& tokenizer) {
writer_.AddTokenizer(tokenizer);
}
void AddScales(const float* scales, size_t len) {
writer_.AddScales(scales, len);
}
// Writes all blobs to disk in the given order. The config is optional and
// if given, it is written to the file, along with the TOC, making it
// single-file format. Otherwise, the file is written in the multi-file format
// without a TOC.
BlobError WriteAll(const Path& blob_filename, const ModelConfig* config) {
return writer_.WriteAll(blob_filename, config);
}
// Returns the number of blobs added.
size_t DebugNumBlobsAdded() const { return writer_.DebugNumBlobsAdded(); }
private:
CompressWorkingSet work_;
WriteToBlobStore writer_;
};
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp

View File

@ -15,8 +15,34 @@
#include "compression/compress.h"
#include <stddef.h>
#include <stdint.h>
#include "util/mat.h"
#include "hwy/base.h"
#include "hwy/profiler.h"
namespace gcpp {
MatPtr::~MatPtr() {}
float ScaleWeights(float* HWY_RESTRICT raw, size_t num) {
PROFILER_FUNC;
float maxabs = 0.0;
for (size_t i = 0; i < num; ++i) {
maxabs = HWY_MAX(maxabs, hwy::ScalarAbs(raw[i]));
}
if (maxabs <= SfpStream::kMax) {
return 1.0f;
}
const float scale = maxabs / SfpStream::kMax;
const float inv_scale = static_cast<float>(1.0 / static_cast<double>(scale));
for (size_t i = 0; i < num; ++i) {
// Clamp because kMax may still be exceeded.
const float magn =
HWY_MIN(SfpStream::kMax, hwy::ScalarAbs(raw[i] * inv_scale));
raw[i] = hwy::ScalarCopySign(magn, raw[i]);
}
return scale;
}
} // namespace gcpp

View File

@ -17,31 +17,19 @@
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
#include "hwy/base.h"
#define COMPRESS_STATS 0
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <cstdio>
#include <cstring>
#include <string>
#include <unordered_map>
#include <utility>
#if COMPRESS_STATS
#include <stdio.h>
#endif
#include <memory>
#include <vector>
// IWYU pragma: begin_exports
#include "compression/blob_store.h"
#include "compression/fields.h"
#include "compression/io.h"
#include "compression/shared.h"
#include "gemma/tensor_index.h"
#include "util/basics.h"
// IWYU pragma: end_exports
#include "gemma/configs.h"
#include "util/allocator.h"
#include "hwy/per_target.h"
#include "compression/types.h" // IWYU pragma: export
#if COMPRESS_STATS
#include "compression/distortion.h"
#include "hwy/stats.h"
@ -49,388 +37,6 @@
namespace gcpp {
// Base class for rank-1 or 2 tensors (vector or matrix).
// Supports both dynamic and compile-time sizing.
// Holds metadata and a non-owning pointer to the data, owned by the derived
// MatStorageT class.
// This class also provides easy conversion from/to a table of contents for a
// BlobStore file, and a templated (compile-time) accessor for a 2-d array of
// fixed inner dimension and type.
// It is designed to be put in a vector, and has default copy and operator=, so
// it is easy to read/write a blob_store file.
class MatPtr : public IFields {
public:
// Full constructor for dynamic sizing.
MatPtr(const std::string& name, Type type, size_t element_size, size_t rows,
size_t cols)
: name_(name),
type_(type),
element_size_(element_size),
num_elements_(rows * cols),
rows_(rows),
cols_(cols),
ptr_(nullptr) {
stride_ = cols;
}
// Default is to leave all fields default-initialized.
MatPtr() = default;
virtual ~MatPtr();
// Compatibility interface for CompressedArray.
// TODO: remove.
template <typename T>
T* data() {
return HWY_RCAST_ALIGNED(T*, ptr_);
}
template <typename T>
const T* data() const {
return HWY_RCAST_ALIGNED(const T*, ptr_);
}
const void* Ptr() const { return ptr_; }
void* Ptr() { return ptr_; }
// Sets the pointer from another MatPtr.
void SetPtr(const MatPtr& other) { ptr_ = other.ptr_; }
// Copying allowed as the metadata is small.
MatPtr(const MatPtr& other) = default;
MatPtr& operator=(const MatPtr& other) = default;
// Returns the name of the blob.
const char* Name() const override { return name_.c_str(); }
void SetName(const std::string& name) { name_ = name; }
// Returns the type of the blob.
Type GetType() const { return type_; }
// Returns the size of each element in bytes.
size_t ElementSize() const { return element_size_; }
// Returns the number of elements in the array.
size_t NumElements() const { return num_elements_; }
// Returns the number of bytes in the array.
size_t SizeBytes() const {
if (this->GetType() == TypeEnum<NuqStream>()) {
return NuqStream::PackedEnd(num_elements_);
}
return num_elements_ * element_size_;
}
// Returns the number of rows in the 2-d array (outer dimension).
size_t Rows() const { return rows_; }
// Returns the number of columns in the 2-d array (inner dimension).
size_t Cols() const { return cols_; }
Extents2D Extents() const { return Extents2D(rows_, cols_); }
// Currently same as cols, but may differ in the future. This is the offset by
// which to advance pointers to the next row.
size_t Stride() const { return stride_; }
// Decoded elements should be multiplied by this to restore their original
// range. This is required because SfpStream can only encode a limited range
// of magnitudes.
float scale() const { return scale_; }
void set_scale(float scale) { scale_ = scale; }
std::string LayerName(int layer) const {
std::string name = name_ + std::to_string(layer);
HWY_ASSERT(name.size() <= sizeof(hwy::uint128_t));
return name;
}
// Sets all data to zero.
void ZeroInit() {
if (ptr_ == nullptr)
HWY_ABORT("ptr_ is null on tensor %s\n", name_.c_str());
hwy::ZeroBytes(ptr_, SizeBytes());
}
void VisitFields(IFieldsVisitor& visitor) override {
visitor(name_);
visitor(type_);
visitor(element_size_);
visitor(num_elements_);
visitor(rows_);
visitor(cols_);
visitor(scale_);
visitor(stride_);
}
// Calls func on the upcasted type. Since MatPtr by design is not templated,
// here we provide a way to get to the derived type, provided that `Type()`
// is one of the strings returned by `TypeName()`.
template <class FuncT, typename... TArgs>
decltype(auto) CallUpcasted(FuncT& func, TArgs&&... args);
protected:
// Arbitrary name for the array of preferably <= 16 characters.
std::string name_;
// Should be the result of TypeEnum<T> for CallUpcasted() to work.
Type type_;
// sizeof(T)
uint32_t element_size_ = 0;
// Number of elements in the array.
uint32_t num_elements_ = 0; // In element_size units.
// Number of rows in the 2-d array (outer dimension).
uint32_t rows_ = 0;
// Number of columns in the 2-d array (inner dimension).
uint32_t cols_ = 0;
// Scaling to apply to each element.
float scale_ = 1.0f;
// Aligned data array. This is always a borrowed pointer. It should never be
// freed. The underlying memory is owned by a subclass or some external class
// and must outlive this object.
void* ptr_ = nullptr;
uint32_t stride_;
};
// MatPtrT adds a single template argument to MatPtr for an explicit type.
// Use this class as a function argument where the type needs to be known.
// Use MatPtr where the type does not need to be known.
template <typename MatT>
class MatPtrT : public MatPtr {
public:
// Full constructor for dynamic sizing.
MatPtrT(const std::string& name, size_t rows, size_t cols)
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), rows, cols) {}
// Construction from TensorIndex entry to remove duplication of sizes.
MatPtrT(const std::string& name, const TensorIndex& tensor_index)
: MatPtrT<MatT>(name, tensor_index.FindName(name)) {}
MatPtrT(const std::string& name, const TensorInfo* tensor)
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), 0, 0) {
if (tensor == nullptr) {
cols_ = 0;
rows_ = 0;
} else {
cols_ = tensor->shape.back();
rows_ = 1;
if (tensor->cols_take_extra_dims) {
// The columns eat the extra dimensions.
rows_ = tensor->shape[0];
for (size_t i = 1; i < tensor->shape.size() - 1; ++i) {
cols_ *= tensor->shape[i];
}
} else {
// The rows eat the extra dimensions.
for (size_t i = 0; i < tensor->shape.size() - 1; ++i) {
rows_ *= tensor->shape[i];
}
}
}
stride_ = cols_;
num_elements_ = rows_ * cols_;
}
// Copying allowed as the metadata is small.
MatPtrT(const MatPtr& other) : MatPtr(other) {}
MatPtrT& operator=(const MatPtr& other) {
MatPtr::operator=(other);
return *this;
}
MatPtrT(const MatPtrT& other) = default;
MatPtrT& operator=(const MatPtrT& other) = default;
std::string CacheName(int layer = -1, char separator = ' ',
int index = -1) const {
// Already used/retired: s, S, n, 1
const char prefix = hwy::IsSame<MatT, float>() ? 'F'
: hwy::IsSame<MatT, BF16>() ? 'B'
: hwy::IsSame<MatT, SfpStream>() ? '$'
: hwy::IsSame<MatT, NuqStream>() ? '2'
: '?';
std::string name = std::string(1, prefix) + name_;
if (layer >= 0 || index >= 0) {
name += '_';
if (layer >= 0) name += std::to_string(layer);
if (index >= 0) {
name += separator + std::to_string(index);
}
}
return name;
}
// Sets the number of elements in the array. For use when the number of
// elements is != rows * cols ONLY.
void SetNumElements(size_t num_elements) {
num_elements_ = CompressedArrayElements<MatT>(num_elements);
}
// 2-d Accessor for a specific type but with a dynamic inner dimension.
template <typename T = MatT>
const T& At(size_t row, size_t col) const {
size_t index = row * cols_ + col;
HWY_DASSERT(index < num_elements_);
return HWY_RCAST_ALIGNED(const T*, ptr_)[index];
}
// 1-d Accessor for a specific type.
// TODO: replace this with a Foreach(), or at least a ForEachRow().
const MatT& At(size_t index) const {
HWY_DASSERT(index < num_elements_);
return HWY_RCAST_ALIGNED(const MatT*, ptr_)[index];
}
MatT& At(size_t index) { return HWY_RCAST_ALIGNED(MatT*, ptr_)[index]; }
// Compatibility interface for CompressedArray.
// TODO: remove
template <typename T = MatT>
T* data() {
return HWY_RCAST_ALIGNED(T*, ptr_);
}
template <typename T = MatT>
const T* data() const {
return HWY_RCAST_ALIGNED(const T*, ptr_);
}
// The const accessor data_scale1() asserts (!) that the scale is 1.0f, so
// calling it means "I am sure the scale is 1 and therefore ignore the scale".
// A scale of 0 indicates that the scale has likely never been set, so is
// "implicitly 1".
const MatT* data_scale1() const {
HWY_ASSERT(scale() == 1.f);
return HWY_RCAST_ALIGNED(const MatT*, ptr_);
}
};
template <class FuncT, typename... TArgs>
decltype(auto) MatPtr::CallUpcasted(FuncT& func, TArgs&&... args) {
if (type_ == TypeEnum<float>()) {
return func(dynamic_cast<MatPtrT<float>*>(this),
std::forward<TArgs>(args)...);
} else if (type_ == TypeEnum<BF16>()) {
return func(dynamic_cast<MatPtrT<BF16>*>(this),
std::forward<TArgs>(args)...);
} else if (type_ == TypeEnum<SfpStream>()) {
return func(dynamic_cast<MatPtrT<SfpStream>*>(this),
std::forward<TArgs>(args)...);
} else if (type_ == TypeEnum<NuqStream>()) {
return func(dynamic_cast<MatPtrT<NuqStream>*>(this),
std::forward<TArgs>(args)...);
} else {
HWY_ABORT("Type %d unknown.", type_);
}
}
// MatStorageT adds the actual data storage to MatPtrT.
// TODO: use Extents2D instead of rows and cols.
template <typename MatT>
class MatStorageT : public MatPtrT<MatT> {
public:
// Full constructor for dynamic sizing.
MatStorageT(const std::string& name, size_t rows, size_t cols)
: MatPtrT<MatT>(name, rows, cols) {
Allocate();
}
// Can copy the metadata, from a MatPtr, and allocate later.
MatStorageT(const MatPtr& other) : MatPtrT<MatT>(other) {}
~MatStorageT() = default;
// Move-only because this contains a unique_ptr.
MatStorageT(const MatStorageT& other) = delete;
MatStorageT& operator=(const MatStorageT& other) = delete;
MatStorageT(MatStorageT&& other) = default;
MatStorageT& operator=(MatStorageT&& other) = default;
// Allocate the memory and copy the pointer to the MatPtr.
// num_elements is in elements. In the default (zero) case, it is computed
// from the current num_elements_ which was set by the constructor from the
// rows and cols.
void Allocate(size_t num_elements = 0) {
if (num_elements == 0) {
num_elements = hwy::DivCeil(this->SizeBytes(), sizeof(MatT));
} else {
this->num_elements_ = num_elements;
}
// Pad to allow overrunning the last row by 2 BF16 vectors, hence at most
// `2 * VectorBytes / sizeof(BF16)` elements of MatT.
const size_t padding = hwy::VectorBytes();
data_ = Allocator::Alloc<MatT>(num_elements + padding);
hwy::ZeroBytes(&data_[num_elements], padding * sizeof(MatT));
this->ptr_ = data_.get();
}
// Zeros the content.
void ZeroInit() {
HWY_ASSERT(data_ != nullptr);
hwy::ZeroBytes(data_.get(), this->SizeBytes());
}
private:
AlignedPtr<MatT> data_;
};
// MatStorage allows heterogeneous tensors to be stored in a single vector.
using MatStorage = MatStorageT<hwy::uint128_t>;
// Table of contents for a blob store file. Full metadata, but not actual data.
class BlobToc {
public:
BlobToc() = default;
// Loads the table of contents from the given reader.
BlobError LoadToc(BlobReader& reader) {
hwy::uint128_t toc_key = MakeKey(kTocName);
size_t toc_size = reader.BlobSize(toc_key);
if (toc_size != 0) {
std::vector<uint32_t> toc(toc_size / sizeof(uint32_t));
BlobError err = reader.ReadOne(toc_key, toc.data(), toc_size);
if (err != 0) {
fprintf(stderr, "Failed to read toc (error %d)\n", err);
return err;
}
size_t consumed = 0;
size_t prev_consumed = static_cast<size_t>(-1);
while (consumed < toc.size() && prev_consumed != consumed) {
MatPtr blob;
const IFields::ReadResult result =
blob.Read(hwy::Span<const uint32_t>(toc), consumed);
prev_consumed = consumed;
consumed = result.pos;
if (blob.NumElements() > 0) {
AddToToc(blob);
}
}
}
return 0;
}
bool Empty() const { return toc_map_.empty(); }
// Returns true if the table of contents contains the given name.
bool Contains(const std::string& name) const {
return toc_map_.find(name) != toc_map_.end();
}
// Returns the blob with the given name, or nullptr if not found.
const MatPtr* Get(const std::string& name) const {
auto it = toc_map_.find(name);
if (it == toc_map_.end()) return nullptr;
return &toc_[it->second];
}
// The name of the toc in the blob store file.
static constexpr char kTocName[] = "toc";
// The name of the config in the blob store file.
static constexpr char kConfigName[] = "config";
// The name of the tokenizer in the blob store file.
static constexpr char kTokenizerName[] = "tokenizer";
private:
// Adds the blob to the table of contents.
void AddToToc(const MatPtr& blob) {
HWY_ASSERT(!Contains(blob.Name()));
toc_map_[blob.Name()] = toc_.size();
toc_.push_back(blob);
}
std::unordered_map<std::string, size_t> toc_map_;
std::vector<MatPtr> toc_;
};
#if COMPRESS_STATS
class CompressStats {
public:
@ -489,7 +95,8 @@ struct CompressStats {
#endif // COMPRESS_STATS
struct CompressPerThread {
NuqStream::ClusterBuf buf;
// Allocated the first time NUQ is used.
std::unique_ptr<NuqStream::ClusterBuf> buf;
CompressStats stats;
};
@ -497,196 +104,11 @@ struct CompressWorkingSet {
std::vector<CompressPerThread> tls;
};
// Class to collect and write a set of tensors to a blob store file.
class WriteToBlobStore {
public:
explicit WriteToBlobStore(hwy::ThreadPool& pool) : pool_(pool) {}
template <typename Packed>
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name) {
if (compressed->Ptr() == nullptr) return;
writer_.Add(MakeKey(decorated_name), compressed->Ptr(),
compressed->SizeBytes());
MatPtr renamed_tensor(*compressed);
renamed_tensor.SetName(decorated_name);
renamed_tensor.AppendTo(toc_);
}
void AddTokenizer(const std::string& tokenizer) {
writer_.Add(MakeKey(BlobToc::kTokenizerName), tokenizer.data(),
tokenizer.size() * sizeof(tokenizer[0]));
}
void AddScales(const float* scales, size_t len) {
if (len) {
MatPtrT<float> scales_ptr("scales", 0, 1);
writer_.Add(MakeKey(scales_ptr.CacheName().c_str()), scales,
len * sizeof(scales[0]));
}
}
// Writes all blobs to disk in the given order. The config is optional and
// if given, it is written to the file, along with the TOC, making it
// single-file format. Otherwise, the file is written in the multi-file format
// without a TOC.
BlobError WriteAll(const Path& blob_filename, const ModelConfig* config) {
if (config) {
writer_.Add(MakeKey(BlobToc::kTocName), toc_.data(),
toc_.size() * sizeof(toc_[0]));
config_buffer_ = config->Write();
writer_.Add(MakeKey(BlobToc::kConfigName), config_buffer_.data(),
config_buffer_.size() * sizeof(config_buffer_[0]));
}
const BlobError err = writer_.WriteAll(pool_, blob_filename);
if (err != 0) {
fprintf(stderr, "Failed to write blobs to %s (error %d)\n",
blob_filename.path.c_str(), err);
}
return err;
}
// Returns the number of blobs added.
size_t DebugNumBlobsAdded() const { return writer_.DebugNumBlobsAdded(); }
hwy::ThreadPool& pool() { return pool_; }
protected:
hwy::ThreadPool& pool_;
private:
std::vector<uint32_t> toc_;
BlobWriter writer_;
std::vector<uint32_t> config_buffer_;
};
// Functor called for each tensor, which loads them and their scaling factors
// from BlobStore.
class ReadFromBlobStore {
public:
explicit ReadFromBlobStore(const Path& blob_filename) {
err_ = reader_.Open(blob_filename);
if (HWY_UNLIKELY(err_ != 0)) {
fprintf(stderr, "Error %d opening BlobStore %s.\n", err_,
blob_filename.path.c_str());
return; // avoid overwriting err_ to ensure ReadAll will fail.
}
err_ = file_toc_.LoadToc(reader_);
if (HWY_UNLIKELY(err_ != 0)) {
fprintf(stderr, "Found a TOC, but failed to load it (code %d)\n", err_);
}
}
// Returns true if there is a TOC.
bool HaveToc() const { return !file_toc_.Empty(); }
// Reads the config from the blob store file.
BlobError LoadConfig(ModelConfig& config) {
hwy::uint128_t config_key = MakeKey(BlobToc::kConfigName);
size_t config_size = reader_.BlobSize(config_key);
if (config_size == 0) return __LINE__;
std::vector<uint32_t> config_buffer(config_size / sizeof(uint32_t));
BlobError err =
reader_.ReadOne(config_key, config_buffer.data(), config_size);
if (err != 0) {
fprintf(stderr, "Failed to read config (error %d)\n", err);
return err;
}
config.Read(hwy::Span<const uint32_t>(config_buffer), 0);
return 0;
}
// Reads the tokenizer from the blob store file.
BlobError LoadTokenizer(std::string& tokenizer) {
hwy::uint128_t key = MakeKey(BlobToc::kTokenizerName);
size_t tokenizer_size = reader_.BlobSize(key);
if (tokenizer_size == 0) return __LINE__;
tokenizer.resize(tokenizer_size);
;
BlobError err = reader_.ReadOne(key, tokenizer.data(), tokenizer_size);
if (err != 0) {
fprintf(stderr, "Failed to read tokenizer (error %d)\n", err);
return err;
}
return 0;
}
// Called for each tensor, enqueues read requests.
void operator()(const char* name, hwy::Span<MatPtr*> tensors) {
if (file_toc_.Empty() || file_toc_.Contains(name)) {
model_toc_.push_back(tensors[0]);
file_keys_.push_back(name);
}
}
BlobError LoadScales(float* scales, size_t len) {
for (size_t i = 0; i < len; ++i) {
scales[i] = 1.0f;
}
MatPtrT<float> scales_ptr("scales", 0, 1);
auto key = MakeKey(scales_ptr.CacheName().c_str());
if (reader_.BlobSize(key) == 0) return 0;
return reader_.Enqueue(key, scales, len * sizeof(scales[0]));
}
// Returns whether all tensors are successfully loaded from cache.
BlobError ReadAll(hwy::ThreadPool& pool,
std::vector<MatStorage>& model_memory) {
// reader_ invalid or any Enqueue failed
if (err_ != 0) return err_;
// Setup the model_memory.
for (size_t b = 0; b < model_toc_.size(); ++b) {
const std::string& file_key = file_keys_[b];
MatPtr* blob = model_toc_[b];
if (!file_toc_.Empty()) {
const MatPtr* toc_blob = file_toc_.Get(file_key);
if (toc_blob == nullptr) {
fprintf(stderr, "Blob %s not found in TOC\n", file_key.c_str());
return __LINE__;
}
if (toc_blob->Rows() != blob->Rows() ||
toc_blob->Cols() != blob->Cols()) {
fprintf(stderr, "Blob %s has size mismatch TOC\n", file_key.c_str());
return __LINE__;
}
std::string name = blob->Name();
*blob = *toc_blob;
blob->SetName(name);
}
model_memory.emplace_back(*blob);
model_memory.back().SetName(file_key);
}
// Allocate in parallel using the pool.
pool.Run(0, model_memory.size(),
[this, &model_memory](uint64_t task, size_t /*thread*/) {
model_memory[task].Allocate();
model_toc_[task]->SetPtr(model_memory[task]);
});
// Enqueue the read requests.
for (auto& blob : model_memory) {
err_ =
reader_.Enqueue(MakeKey(blob.Name()), blob.data(), blob.SizeBytes());
if (err_ != 0) {
fprintf(stderr,
"Failed to read blob %s (error %d) of size %zu x %zu x %zu\n",
blob.Name(), err_, blob.Rows(), blob.Cols(),
blob.ElementSize());
return err_;
}
}
return reader_.ReadAll(pool);
}
private:
BlobReader reader_;
BlobError err_ = 0;
// Table of contents from the file, if present.
BlobToc file_toc_;
// Table of contents from the model. Pointers to original MatPtrT so the
// data pointers can be updated.
std::vector<MatPtr*> model_toc_;
// Mangled names of the tensors in model_toc_ for reading from the file.
std::vector<std::string> file_keys_;
};
// Returns 1.0f if all magnitudes are <= `SfpStream::kMax`, otherwise scales
// them such that the largest magnitude is `SfpStream::kMax`, and returns the
// multiplier with which to restore the original values. This is only necessary
// before compressing to `SfpStream` and `NuqStream`.
float ScaleWeights(float* HWY_RESTRICT raw, size_t num);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_

View File

@ -13,10 +13,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// SFP uses ConcatEven/Odd which are not supported; skip SVE for faster tests.
#include "compression/types.h"
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SVE)
#endif
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
#include "compression/compress.h"
@ -80,7 +80,7 @@ struct TestDecompress2T {
stats.Notify(raw[i], hwy::ConvertScalarTo<float>(dec[i]));
}
if constexpr (false) {
if constexpr (true) { // leave enabled due to sporadic failures
fprintf(stderr,
"TypeName<Packed>() %s TypeName<T>() %s: num %zu: stats.SumL1() "
"%f stats.GeomeanValueDivL1() %f stats.WeightedAverageL1() %f "

View File

@ -1,286 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Command line tool to create compressed weights.
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"compression/compress_weights.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
#include "gemma/configs.h"
#include "gemma/tokenizer.h"
#ifndef GEMMA_COMPRESS_WEIGHTS_ONCE
#define GEMMA_COMPRESS_WEIGHTS_ONCE
#include <stddef.h>
#include <stdio.h>
#include <algorithm> // std::clamp
#include <cstdlib>
#include <iostream>
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "compression/compress.h"
#include "compression/io.h" // Path
#include "compression/shared.h" // PromptWrapping
#include "gemma/common.h" // Model
#include "gemma/weights.h"
#include "util/allocator.h"
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
namespace {
} // namespace
struct Args : public ArgsBase<Args> {
static constexpr size_t kDefaultNumThreads = ~size_t{0};
void ChooseNumThreads() {
if (num_threads == kDefaultNumThreads) {
// This is a rough heuristic, replace with something better in the future.
num_threads = static_cast<size_t>(std::clamp(
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
}
}
public:
Args(int argc, char* argv[]) {
InitAndParse(argc, argv);
ChooseNumThreads();
}
// Returns error string or nullptr if OK.
const char* Validate() {
if (const char* err = ParseModelTypeAndWrapping(model_type_str, model_type_,
prompt_wrapping_)) {
return err;
}
if (const char* err = ParseType(weight_type_str, weight_type_)) {
return err;
}
if (weights.path.empty()) {
return "Missing --weights flag, a file for the uncompressed model.";
}
if (compressed_weights.path.empty()) {
return "Missing --compressed_weights flag, a file for the compressed "
"model.";
}
if (!weights.Exists()) {
return "Can't open file specified with --weights flag.";
}
return nullptr;
}
Path weights; // uncompressed weights file location
Path compressed_weights; // compressed weights file location
std::string model_type_str;
std::string weight_type_str;
size_t num_threads;
// If non-empty, whether to include the config and TOC in the output file, as
// well as the tokenizer.
Path tokenizer;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(weights, "weights", Path(),
"Path to model weights (.bin) file.\n"
" Required argument.");
visitor(model_type_str, "model", std::string(),
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
"gr2b-pt = griffin 2B parameters, pretrained\n "
" Required argument.");
visitor(weight_type_str, "weight_type", std::string("sfp"),
"Weight type\n f32 = float, bf16 = bfloat16, SFP = 8-bit FP\n"
" Required argument.");
visitor(compressed_weights, "compressed_weights", Path(),
"Path name where compressed weights (.sbs) file will be written.\n"
" Required argument.");
visitor(num_threads, "num_threads",
kDefaultNumThreads, // see ChooseNumThreads
"Number of threads to use.\n Default = Estimate of the "
"number of supported concurrent threads.",
2);
visitor(tokenizer, "tokenizer", Path(),
"Path to tokenizer file. If given, the config and TOC are also "
"added to the output file.");
}
// Uninitialized before Validate, must call after that.
gcpp::Model ModelType() const { return model_type_; }
gcpp::PromptWrapping PromptWrappingType() const { return prompt_wrapping_; }
gcpp::Type WeightType() const { return weight_type_; }
private:
Model model_type_;
PromptWrapping prompt_wrapping_;
Type weight_type_;
};
void ShowHelp(gcpp::Args& args) {
std::cerr
<< "Usage:\n./compress_weights --weights <path to uncompressed weights> "
" --model <model type> --compressed_weights <output path>\n";
std::cerr << "\n*Arguments*\n\n";
args.Help();
std::cerr << "\n";
}
} // namespace gcpp
#endif // GEMMA_COMPRESS_WEIGHTS_ONCE
// SIMD code, compiled once per target.
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
template <typename T>
void CompressWeights(const Path& weights_path,
const Path& compressed_weights_path, Model model_type,
Type weight_type, PromptWrapping wrapping,
const Path& tokenizer_path, hwy::ThreadPool& pool) {
if (!weights_path.Exists()) {
HWY_ABORT("The model weights file '%s' does not exist.",
weights_path.path.c_str());
}
printf("Compressing weights from %s to %s\n", weights_path.path.c_str(),
compressed_weights_path.path.c_str());
ModelConfig config = ConfigFromModel(model_type);
config.weight = weight_type;
config.wrapping = wrapping;
std::vector<MatStorage> model_storage;
ModelWeightsPtrs<T> c_weights(config);
c_weights.Allocate(model_storage, pool);
ModelWeightsPtrs<float> uc_weights(config);
uc_weights.Allocate(model_storage, pool);
// Get uncompressed weights, compress, and store.
FILE* fptr = fopen(weights_path.path.c_str(), "rb");
if (fptr == nullptr) {
HWY_ABORT("Failed to open model file %s - does it exist?",
weights_path.path.c_str());
}
bool ok = true;
uint64_t total_size = 0;
ModelWeightsPtrs<float>::ForEachTensor(
{&uc_weights}, ForEachType::kLoadNoToc,
[&](const char* name, hwy::Span<MatPtr*> tensors) {
fprintf(stderr, "Loading Parameters (size %zu): %s\n",
tensors[0]->SizeBytes(), name);
ok &= 1 == fread(tensors[0]->Ptr(), tensors[0]->SizeBytes(), 1, fptr);
total_size += tensors[0]->SizeBytes();
});
if (!tokenizer_path.path.empty()) {
uc_weights.AllocAndCopyWithTranspose(pool, model_storage);
}
const bool scale_for_compression = config.num_tensor_scales > 0;
std::vector<float> scales;
if (scale_for_compression) {
uc_weights.GetOrApplyScales(scales);
}
Compressor compressor(pool);
ModelWeightsPtrs<T>::ForEachTensor(
{reinterpret_cast<ModelWeightsPtrs<T>*>(&uc_weights), &c_weights},
tokenizer_path.path.empty() ? ForEachType::kLoadNoToc
: ForEachType::kLoadWithToc,
[&compressor](const char* name, hwy::Span<MatPtr*> tensors) {
tensors[1]->CallUpcasted(
compressor, name,
reinterpret_cast<const float*>(tensors[0]->Ptr()));
});
if (!tokenizer_path.path.empty()) {
std::string tokenizer_proto = ReadFileToString(tokenizer_path);
compressor.AddTokenizer(tokenizer_proto);
} else {
compressor.AddScales(scales.data(), scales.size() * sizeof(scales[0]));
}
compressor.WriteAll(compressed_weights_path,
tokenizer_path.path.empty() ? nullptr : &config);
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
void Run(Args& args) {
hwy::ThreadPool pool(args.num_threads);
if (args.PromptWrappingType() == PromptWrapping::PALIGEMMA) {
HWY_ABORT("PaliGemma is not supported in compress_weights.");
}
const Model model_type = args.ModelType();
const Type weight_type = args.WeightType();
switch (weight_type) {
case Type::kF32:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<float>)
(args.weights, args.compressed_weights, model_type, weight_type,
args.PromptWrappingType(), args.tokenizer, pool);
break;
case Type::kBF16:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<BF16>)
(args.weights, args.compressed_weights, model_type, weight_type,
args.PromptWrappingType(), args.tokenizer, pool);
break;
case Type::kSFP:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<SfpStream>)
(args.weights, args.compressed_weights, model_type, weight_type,
args.PromptWrappingType(), args.tokenizer, pool);
break;
case Type::kNUQ:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<NuqStream>)
(args.weights, args.compressed_weights, model_type, weight_type,
args.PromptWrappingType(), args.tokenizer, pool);
break;
default:
HWY_ABORT("Weight type %d unsupported.", static_cast<int>(weight_type));
}
}
} // namespace gcpp
int main(int argc, char** argv) {
gcpp::Args args(argc, argv);
if (gcpp::HasHelp(argc, argv)) {
gcpp::ShowHelp(args);
return 0;
}
if (const char* error = args.Validate()) {
gcpp::ShowHelp(args);
HWY_ABORT("\nInvalid args: %s", error);
}
gcpp::Run(args);
return 0;
}
#endif // HWY_ONCE

View File

@ -1,209 +0,0 @@
# Copyright 2024 Google LLC
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Converts pytorch to f32 for use by compress_weights.cc."""
import argparse
import collections
import os
from gemma import config
from gemma import model as gemma_model
import numpy as np
import torch
# Requires torch 2.2 and gemma package from
# https://github.com/google/gemma_pytorch
def check_file_exists(value):
if not os.path.exists(str(value)):
raise argparse.ArgumentTypeError(
"The file %s does not appear to exist." % value
)
return value
def check_model_types(value):
if str(value).lower() not in ["2b", "7b"]:
raise argparse.ArgumentTypeError(
"Model type value %s is not in [2b, 7b]." % value
)
return value
parser = argparse.ArgumentParser()
parser.add_argument(
"--tokenizer",
dest="tokenizer",
default="models/tokenizer.spm",
help="Location of tokenizer file (.model or .spm)",
type=check_file_exists,
)
parser.add_argument(
"--weights",
dest="weights",
default="models/gemma-2b-it.ckpt",
help="Location of input checkpoint file (.ckpt)",
type=check_file_exists,
)
parser.add_argument(
"--output_file",
dest="output_file",
default="2bit-f32.sbs",
help="Location to write converted weights",
type=str,
)
parser.add_argument(
"--model_type",
dest="model_type",
default="2b",
help="Model size / type (2b, 7b)",
type=check_model_types,
)
args = parser.parse_args()
TRANSFORMATIONS = {
"2b": collections.defaultdict(
lambda: lambda x: x,
{
"embedder.weight": lambda x: x,
"self_attn.qkv_proj.weight": lambda x: x.reshape((10, 256, 2048)),
"self_attn.o_proj.weight": lambda x: x.reshape(
(2048, 8, 256)
).transpose([1, 0, 2]),
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.down_proj.weight": lambda x: x,
},
),
"7b": collections.defaultdict(
lambda: lambda x: x,
{
"embedder.weight": lambda x: x,
"self_attn.qkv_proj.weight": lambda x: x.reshape(
(3, 16, 256, 3072)
).transpose([1, 0, 2, 3]),
"self_attn.o_proj.weight": lambda x: x.reshape(
(3072, 16, 256)
).transpose([1, 0, 2]),
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.down_proj.weight": lambda x: x,
},
),
}
VALIDATIONS = {
"2b": {
"embedder.weight": lambda x: x.shape == (256000, 2048),
"model.norm.weight": lambda x: x.shape == (2048,),
"self_attn.qkv_proj.weight": lambda x: x.shape == (10, 256, 2048),
"self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256),
"mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048),
"mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048),
"mlp.down_proj.weight": lambda x: x.shape == (2048, 16384),
"input_layernorm.weight": lambda x: x.shape == (2048,),
"post_attention_layernorm.weight": lambda x: x.shape == (2048,),
},
"7b": {
"embedder.weight": lambda x: x.shape == (256000, 3072),
"model.norm.weight": lambda x: x.shape == (3072,),
"self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072),
"self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256),
"mlp.gate_proj.weight": lambda x: x.shape == (1, 24576, 3072),
"mlp.up_proj.weight": lambda x: x.shape == (1, 24576, 3072),
"mlp.down_proj.weight": lambda x: x.shape == (3072, 24576),
"input_layernorm.weight": lambda x: x.shape == (3072,),
"post_attention_layernorm.weight": lambda x: x.shape == (3072,),
},
}
def param_names(num_hidden_layers: int):
"""Return parameter names in the order they are expected for deserialization."""
# note *weight_scaler params are ignored in the forward computation unless
# quantization is being used.
#
# since we are working with the full precision weights as input, don't
# include these in the parameters being iterated over.
names = [
("embedder.weight",) * 2, # embedder_input_embedding
("model.norm.weight",) * 2, # final_norm_scale
]
layer_params = [
"self_attn.o_proj.weight", # attn_vec_einsum_w
"self_attn.qkv_proj.weight", # qkv_einsum_w
"mlp.gate_proj.weight", # gating_einsum_w
"mlp.up_proj.weight",
"mlp.down_proj.weight", # linear_w
"input_layernorm.weight", # pre_attention_norm_scale
"post_attention_layernorm.weight", # pre_ffw_norm_scale
]
for layer in range(num_hidden_layers):
for layer_param in layer_params:
names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)]
return names
def convert_weights():
"""Main function; loads weights, runs transformations, writes f32."""
model_type = args.model_type
output_file = args.output_file
model_config = config.get_model_config(model_type)
model_config.dtype = "float32"
model_config.tokenizer = args.tokenizer
device = torch.device("cpu")
torch.set_default_dtype(torch.float)
model = gemma_model.GemmaForCausalLM(model_config)
model.load_weights(args.weights)
model.to(device).eval()
model_dict = dict(model.named_parameters())
param_order = param_names(model_config.num_hidden_layers)
all_ok = True
print("Checking transformations ...")
for name, layer_name in param_order:
arr = model_dict[name].detach().numpy()
arr = TRANSFORMATIONS[model_type][layer_name](arr)
check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED"
if check == "FAILED":
all_ok = False
print(f" {name : <60}{str(arr.shape) : <20}{check}")
if all_ok:
print("Writing parameters ...")
with open(output_file, "wb") as bin_handle:
for name, layer_name in param_order:
arr = model_dict[name].detach().numpy()
arr = TRANSFORMATIONS[model_type][layer_name](arr)
check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED"
print(f" {name : <60}{str(arr.shape) : <20}{check}")
arr.flatten().astype(np.float32).tofile(bin_handle)
if __name__ == "__main__":
convert_weights()
print("Done")

View File

@ -17,7 +17,7 @@
#include <stdio.h>
#include "compression/shared.h" // SfpStream::kMax
#include "compression/types.h" // SfpStream::kMax
#include "util/test_util.h"
#include "hwy/nanobenchmark.h"
#include "hwy/tests/hwy_gtest.h"

View File

@ -1,121 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Safe to be first, does not include POSIX headers.
#include "hwy/detect_compiler_arch.h"
// Only compile this file on non-Windows; it replaces io_win.cc. It is easier to
// check this in source code because we support multiple build systems.
#if !HWY_OS_WIN
// Request POSIX 2008, including `pread()` and `posix_fadvise()`.
#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700
#undef _XOPEN_SOURCE
#define _XOPEN_SOURCE 700
#endif
#if !defined(_POSIX_C_SOURCE) || _POSIX_C_SOURCE < 200809
#define _POSIX_C_SOURCE 200809
#endif
// Make `off_t` 64-bit even on 32-bit systems. Works for Android >= r15c.
#undef _FILE_OFFSET_BITS
#define _FILE_OFFSET_BITS 64
#include <fcntl.h> // open
#include <stddef.h>
#include <stdint.h>
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
#include <sys/stat.h> // O_RDONLY
#include <unistd.h> // read, write, close
#include <memory>
#include "compression/io.h"
#include "hwy/base.h" // HWY_ASSERT
namespace gcpp {
class FilePosix : public File {
int fd_ = 0;
public:
explicit FilePosix(int fd) : fd_(fd) { HWY_ASSERT(fd > 0); }
~FilePosix() override {
if (fd_ != 0) {
HWY_ASSERT(close(fd_) != -1);
}
}
uint64_t FileSize() const override {
static_assert(sizeof(off_t) == 8, "64-bit off_t required");
const off_t size = lseek(fd_, 0, SEEK_END);
if (size < 0) {
return 0;
}
return static_cast<uint64_t>(size);
}
bool Read(uint64_t offset, uint64_t size, void* to) const override {
uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
uint64_t pos = 0;
for (;;) {
// pread seems to be faster than lseek + read when parallelized.
const auto bytes_read = pread(fd_, bytes + pos, size - pos, offset + pos);
if (bytes_read <= 0) break;
pos += bytes_read;
HWY_ASSERT(pos <= size);
if (pos == size) break;
}
return pos == size; // success if managed to read desired size
}
bool Write(const void* from, uint64_t size, uint64_t offset) override {
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(from);
uint64_t pos = 0;
for (;;) {
const auto bytes_written =
pwrite(fd_, bytes + pos, size - pos, offset + pos);
if (bytes_written <= 0) break;
pos += bytes_written;
HWY_ASSERT(pos <= size);
if (pos == size) break;
}
return pos == size; // success if managed to write desired size
}
}; // FilePosix
HWY_MAYBE_UNUSED extern std::unique_ptr<File> OpenFileGoogle(
const Path& filename, const char* mode);
std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode) {
std::unique_ptr<File> file; // OpenFileGoogle omitted
if (file) return file;
const bool is_read = mode[0] != 'w';
const int flags = is_read ? O_RDONLY : O_CREAT | O_RDWR | O_TRUNC;
const int fd = open(filename.path.c_str(), flags, 0644);
if (fd < 0) return file;
#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
if (is_read) {
// Doubles the readahead window, which seems slightly faster when cached.
(void)posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL);
}
#endif
return std::make_unique<FilePosix>(fd);
}
} // namespace gcpp
#endif // !HWY_OS_WIN

View File

@ -23,7 +23,7 @@
#include <cstdio>
#include "compression/shared.h"
#include "compression/types.h"
#include "util/basics.h"
#include "hwy/base.h"

View File

@ -13,20 +13,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// SFP uses ConcatEven/Odd which are not supported; skip SVE for faster tests.
#include "compression/types.h"
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SVE)
#endif
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <algorithm> // std::shuffle
#include <array>
#include <random>
#include "compression/distortion.h"
#include "compression/shared.h"
#include "util/test_util.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
@ -104,7 +104,7 @@ struct TestPlateaus {
HWY_ASSERT(-0.5f <= in[i] && in[i] < 0.5f);
}
std::random_device rd;
std::random_device rd; // NOLINT
std::mt19937 rng(rd());
std::shuffle(in.get(), in.get() + kGroupSize, rng);
@ -151,7 +151,7 @@ struct TestRamp {
HWY_ASSERT(-0.45f <= in[i] && in[i] < 0.55f);
}
std::random_device rd;
std::random_device rd; // NOLINT
std::mt19937 rng(rd());
std::shuffle(in.get(), in.get() + kGroupSize, rng);
@ -246,7 +246,8 @@ struct TestOffset {
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
auto dec1 = hwy::AllocateAligned<T>(total);
auto dec2 = hwy::AllocateAligned<T>(kMidLen);
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
auto nuq = hwy::AllocateAligned<NuqStream>(
hwy::RoundUpTo(NuqStream::PackedEnd(total), hwy::VectorBytes()));
HWY_ASSERT(in && dec1 && dec2 && nuq);
const auto nuq_span = MakeSpan(nuq.get(), total);
@ -296,7 +297,8 @@ struct TestUnalignedOffset {
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
auto dec1 = hwy::AllocateAligned<T>(total);
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
auto nuq = hwy::AllocateAligned<NuqStream>(
hwy::RoundUpTo(NuqStream::PackedEnd(total), hwy::VectorBytes()));
auto dec2 = hwy::AllocateAligned<T>(num_decompressed);
HWY_ASSERT(in && dec1 && dec2 && nuq);
const auto nuq_span = MakeSpan(nuq.get(), total);
@ -347,7 +349,8 @@ struct TestDec2 {
auto dec0 = hwy::AllocateAligned<T>(total);
auto dec1 = hwy::AllocateAligned<T>(total);
auto dec2 = hwy::AllocateAligned<T>(kMidLen);
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
auto nuq = hwy::AllocateAligned<NuqStream>(
hwy::RoundUpTo(NuqStream::PackedEnd(total), hwy::VectorBytes()));
HWY_ASSERT(in && dec0 && dec1 && dec2 && nuq);
const auto nuq_span = MakeSpan(nuq.get(), total);
@ -449,7 +452,8 @@ struct TestEncDec {
const size_t num = 4 * kGroupSize;
auto in = hwy::AllocateAligned<float>(num); // Enc() requires f32
auto out = hwy::AllocateAligned<T>(num); // already padded
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(num));
auto nuq = hwy::AllocateAligned<NuqStream>(
hwy::RoundUpTo(NuqStream::PackedEnd(num), hwy::VectorBytes()));
HWY_ASSERT(in && out && nuq);
const auto nuq_span = MakeSpan(nuq.get(), num);
@ -512,6 +516,7 @@ HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_BEFORE_TEST(NuqTest);
#if GEMMA_ENABLE_NUQ
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllFlat);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllPlateaus);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllRamp);
@ -525,6 +530,9 @@ HWY_EXPORT_AND_TEST_P(NuqTest, TestUnalignedOffsetF32);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNibble);
HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecBF16);
HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecF32);
#else
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(NuqTest);
#endif // GEMMA_ENABLE_NUQ
HWY_AFTER_TEST();
} // namespace gcpp
#endif // HWY_ONCE

View File

@ -14,11 +14,16 @@ cc_library(
hdrs = ["compression_clif_aux.h"],
visibility = ["//visibility:private"],
deps = [
"@abseil-cpp//absl/types:span",
"//:common",
"//:basics",
"//:configs",
"//:mat",
"//:model_store",
"//:tensor_info",
"//:threading_context",
"//:tokenizer",
"//compression:compress",
"//compression:io",
"//io",
"//io:blob_store",
"@highway//:hwy",
"@highway//:thread_pool",
],
@ -29,9 +34,9 @@ pybind_extension(
srcs = ["compression_extension.cc"],
deps = [
":compression_clif_aux",
"@abseil-cpp//absl/types:span",
"//:common",
"//compression:sfp",
"//:mat",
"//:tensor_info",
"//compression:types",
],
)

View File

@ -15,14 +15,29 @@
#include "compression/python/compression_clif_aux.h"
#include <cstddef>
#include <cstdio>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <string>
#include <vector>
#include "compression/compress.h"
#include "compression/shared.h"
#include "hwy/aligned_allocator.h"
#include "compression/compress.h" // ScaleWeights
#include "gemma/configs.h" // ModelConfig
#include "gemma/model_store.h" // ModelStore
#include "gemma/tensor_info.h" // TensorInfo
#include "gemma/tokenizer.h"
#include "io/blob_store.h" // BlobWriter
#include "io/io.h" // Path
#include "util/basics.h"
#include "util/mat.h"
#include "util/threading_context.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
@ -32,157 +47,97 @@
// After highway.h
#include "compression/compress-inl.h"
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
// compile pass, whereas we want this defined in the first.
#ifndef GEMMA_ONCE
#define GEMMA_ONCE
#include "absl/types/span.h"
#include "compression/io.h"
#include "gemma/configs.h"
#include "gemma/tensor_index.h"
#include "gemma/tokenizer.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
class WriterInterface {
public:
virtual ~WriterInterface() = default;
virtual void Insert(std::string name, absl::Span<const float> weights,
Type type, const TensorInfo& tensor_info,
float scale) = 0;
virtual void InsertSfp(std::string name, absl::Span<const float> weights) = 0;
virtual void InsertNUQ(std::string name, absl::Span<const float> weights) = 0;
virtual void InsertBfloat16(std::string name,
absl::Span<const float> weights) = 0;
virtual void InsertFloat(std::string name,
absl::Span<const float> weights) = 0;
virtual void AddScales(const std::vector<float>& scales) = 0;
virtual void AddTokenizer(const std::string& tokenizer_path) = 0;
virtual size_t DebugNumBlobsAdded() const = 0;
virtual int WriteWithConfig(std::string path, const ModelConfig* config) = 0;
};
} // namespace gcpp
#endif // GEMMA_ONCE
// SIMD code, compiled once per target.
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
class SbsWriterImpl : public WriterInterface {
// Implementation for the currently compiled SIMD target.
class SbsWriterImpl : public ISbsWriter {
template <typename Packed>
void AllocateAndCompress(const std::string& name,
absl::Span<const float> weights) {
MatPtrT<Packed> storage(name, 1, weights.size());
model_memory_.push_back(storage);
model_memory_.back().Allocate();
storage.SetPtr(model_memory_.back());
std::string decorated_name = storage.CacheName();
compressor_(&storage, decorated_name.c_str(), weights.data());
}
template <typename Packed>
void AllocateWithShape(const std::string& name,
absl::Span<const float> weights,
const TensorInfo& tensor_info, float scale) {
MatPtrT<Packed> storage(name, &tensor_info);
storage.set_scale(scale);
void InsertT(const char* name, F32Span weights,
const TensorInfo& tensor_info) {
// TODO(janwas): 1D parallel-for.
hwy::ThreadPool& pool = ctx_.pools.Pool();
// Don't reset num_elements for NUQ.
if (!hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>()) {
storage.SetNumElements(CompressedArrayElements<Packed>(weights.size()));
MatPtrT<Packed> mat(name, ExtentsFromInfo(&tensor_info));
// SFP and NUQ (which uses SFP for cluster centers) have a limited range
// and depending on the input values may require rescaling. Scaling is
// cheap for matmul and probably not an issue for other ops, but it might be
// beneficial for precision to keep the original data range for other types.
if (mat.GetType() == Type::kSFP || mat.GetType() == Type::kNUQ) {
mat.SetScale(ScaleWeights(weights.data(), weights.size()));
}
model_memory_.push_back(storage);
if (mode_ == CompressorMode::kTEST_ONLY) return;
model_memory_.back().Allocate();
storage.SetPtr(model_memory_.back());
std::string decorated_name = storage.CacheName();
compressor_(&storage, decorated_name.c_str(), weights.data());
if (weights.size() == 0) {
HWY_WARN("Ignoring zero-sized tensor %s.", name);
return;
}
mat.AppendTo(serialized_mat_ptrs_);
MatOwner mat_owner;
mat_owner.AllocateFor(mat, ctx_.allocator, MatPadding::kPacked);
// Handle gemma_export_test's MockArray. Write blobs so that the test
// succeeds, but we only have 10 floats, not the full tensor.
if (weights.size() == 10 && mat.Extents().Area() != 10) {
Compress(weights.data(), weights.size(), working_set_, mat.Span(),
/*packed_ofs=*/0, pool);
writer_.Add(name, mat.Packed(), mat.ElementBytes() * 10);
return;
}
fprintf(stderr, "Compressing %s (%zu x %zu = %zuM) to %s, please wait\n",
name, mat.Rows(), mat.Cols(), weights.size() / (1000 * 1000),
TypeName(TypeEnum<Packed>()));
HWY_ASSERT(weights.size() == mat.Extents().Area());
Compress(weights.data(), weights.size(), working_set_, mat.Span(),
/*packed_ofs=*/0, pool);
writer_.Add(name, mat.Packed(), mat.PackedBytes());
}
public:
explicit SbsWriterImpl(CompressorMode mode)
: pool_(0), compressor_(pool_), mode_(mode) {}
SbsWriterImpl(const std::string& sbs_path)
: ctx_(ThreadingArgs()),
writer_(gcpp::Path(sbs_path), ctx_.pools.Pool()) {}
void Insert(std::string name, absl::Span<const float> weights, Type type,
const TensorInfo& tensor_info, float scale) override {
void Insert(const char* name, F32Span weights, Type type,
const TensorInfo& tensor_info) override {
switch (type) {
case Type::kSFP:
AllocateWithShape<SfpStream>(name, weights, tensor_info, scale);
InsertT<SfpStream>(name, weights, tensor_info);
break;
case Type::kNUQ:
AllocateWithShape<NuqStream>(name, weights, tensor_info, scale);
InsertT<NuqStream>(name, weights, tensor_info);
break;
case Type::kBF16:
AllocateWithShape<BF16>(name, weights, tensor_info, scale);
InsertT<BF16>(name, weights, tensor_info);
break;
case Type::kF32:
AllocateWithShape<float>(name, weights, tensor_info, scale);
InsertT<float>(name, weights, tensor_info);
break;
default:
HWY_ABORT("Unsupported type");
HWY_ABORT("Unsupported destination (compressed) type %s",
TypeName(type));
}
}
void InsertSfp(std::string name, absl::Span<const float> weights) override {
AllocateAndCompress<SfpStream>(name, weights);
void Write(const ModelConfig& config,
const std::string& tokenizer_path) override {
const GemmaTokenizer tokenizer(
tokenizer_path.empty() ? kMockTokenizer
: ReadFileToString(Path(tokenizer_path)));
WriteSingleFile(config, tokenizer, serialized_mat_ptrs_, writer_);
}
void InsertNUQ(std::string name, absl::Span<const float> weights) override {
AllocateAndCompress<NuqStream>(name, weights);
}
void InsertBfloat16(std::string name,
absl::Span<const float> weights) override {
AllocateAndCompress<BF16>(name, weights);
}
void InsertFloat(std::string name, absl::Span<const float> weights) override {
AllocateAndCompress<float>(name, weights);
}
void AddScales(const std::vector<float>& scales) override {
HWY_ASSERT(scales_.empty());
scales_ = scales;
compressor_.AddScales(scales_.data(), scales_.size());
}
void AddTokenizer(const std::string& tokenizer_path) override {
Path path(tokenizer_path);
GemmaTokenizer tokenizer(path);
std::string tokenizer_proto = tokenizer.Serialize();
HWY_ASSERT(!tokenizer_proto.empty());
compressor_.AddTokenizer(tokenizer_proto);
}
// Returns the number of blobs added.
size_t DebugNumBlobsAdded() const {
if (mode_ == CompressorMode::kTEST_ONLY) return model_memory_.size();
return compressor_.DebugNumBlobsAdded();
}
int WriteWithConfig(std::string path, const ModelConfig* config) override {
return compressor_.WriteAll(gcpp::Path(path), config);
}
hwy::ThreadPool pool_;
Compressor compressor_;
ThreadingContext ctx_;
CompressWorkingSet working_set_;
std::vector<MatStorage> model_memory_;
std::vector<float> scales_;
CompressorMode mode_;
BlobWriter writer_;
std::vector<uint32_t> serialized_mat_ptrs_;
};
WriterInterface* NewSbsWriter(CompressorMode mode) {
return new SbsWriterImpl(mode);
ISbsWriter* NewSbsWriter(const std::string& sbs_path) {
return new SbsWriterImpl(sbs_path);
}
} // namespace HWY_NAMESPACE
@ -194,43 +149,11 @@ namespace gcpp {
HWY_EXPORT(NewSbsWriter);
SbsWriter::SbsWriter(CompressorMode mode)
: impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)(mode)) {}
SbsWriter::~SbsWriter() = default;
SbsWriter::SbsWriter(const std::string& path)
: impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)(path)) {}
void SbsWriter::Insert(std::string name, absl::Span<const float> weights,
Type type, const TensorInfo& tensor_info, float scale) {
impl_->Insert(name, weights, type, tensor_info, scale);
}
void SbsWriter::InsertSfp(std::string name, absl::Span<const float> weights) {
impl_->InsertSfp(name, weights);
}
void SbsWriter::InsertNUQ(std::string name, absl::Span<const float> weights) {
impl_->InsertNUQ(name, weights);
}
void SbsWriter::InsertBfloat16(std::string name,
absl::Span<const float> weights) {
impl_->InsertBfloat16(name, weights);
}
void SbsWriter::InsertFloat(std::string name, absl::Span<const float> weights) {
impl_->InsertFloat(name, weights);
}
void SbsWriter::AddScales(const std::vector<float>& scales) {
impl_->AddScales(scales);
}
void SbsWriter::AddTokenizer(const std::string& tokenizer_path) {
impl_->AddTokenizer(tokenizer_path);
}
size_t SbsWriter::DebugNumBlobsAdded() const {
return impl_->DebugNumBlobsAdded();
}
int SbsWriter::WriteWithConfig(std::string path, const ModelConfig* config) {
return impl_->WriteWithConfig(path, config);
}
SbsReader::SbsReader(const std::string& path)
: reader_(Path(path)), model_(reader_) {}
} // namespace gcpp
#endif // HWY_ONCE

View File

@ -16,52 +16,67 @@
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_
#include <cstddef>
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "absl/types/span.h"
#include "compression/shared.h"
#include "compression/types.h" // Type
#include "gemma/configs.h"
#include "gemma/tensor_index.h"
#include "gemma/model_store.h"
#include "gemma/tensor_info.h"
#include "io/blob_store.h"
#include "util/mat.h"
#include "hwy/aligned_allocator.h" // Span
namespace gcpp {
// How to process the data.
enum class CompressorMode {
// No compression, no write to file, just for testing.
kTEST_ONLY,
// Old-style compression, no table of contents.
kNO_TOC,
// New-style compression, with table of contents.
kWITH_TOC,
// Can be modified in place by ScaleWeights.
using F32Span = hwy::Span<float>;
// Interface because we compile one derived implementation per SIMD target,
// because Compress() uses SIMD.
class ISbsWriter {
public:
virtual ~ISbsWriter() = default;
virtual void Insert(const char* name, F32Span weights, Type type,
const TensorInfo& tensor_info) = 0;
virtual void Write(const ModelConfig& config,
const std::string& tokenizer_path) = 0;
};
class WriterInterface;
// Non-virtual class used by pybind that calls the interface's virtual methods.
// This avoids having to register the derived types with pybind.
class SbsWriter {
public:
explicit SbsWriter(CompressorMode mode);
~SbsWriter();
explicit SbsWriter(const std::string& sbs_path);
void Insert(std::string name, absl::Span<const float> weights, Type type,
const TensorInfo& tensor_info, float scale);
void InsertSfp(std::string name, absl::Span<const float> weights);
void InsertNUQ(std::string name, absl::Span<const float> weights);
void InsertBfloat16(std::string name, absl::Span<const float> weights);
void InsertFloat(std::string name, absl::Span<const float> weights);
void AddScales(const std::vector<float>& scales);
void AddTokenizer(const std::string& tokenizer_path);
void Insert(const char* name, F32Span weights, Type type,
const TensorInfo& tensor_info) {
impl_->Insert(name, weights, type, tensor_info);
}
size_t DebugNumBlobsAdded() const;
int Write(std::string path) { return WriteWithConfig(path, nullptr); }
int WriteWithConfig(std::string path, const ModelConfig* config);
void Write(const ModelConfig& config, const std::string& tokenizer_path) {
impl_->Write(config, tokenizer_path);
}
private:
// Isolates Highway-dispatched types and other internals from CLIF.
std::unique_ptr<WriterInterface> impl_;
std::unique_ptr<ISbsWriter> impl_;
};
// Limited metadata-only reader for tests.
class SbsReader {
public:
SbsReader(const std::string& path);
const ModelConfig& Config() const { return model_.Config(); }
const MatPtr* FindMat(const char* name) const { return model_.FindMat(name); }
private:
gcpp::BlobReader reader_;
gcpp::ModelStore model_;
};
} // namespace gcpp

View File

@ -15,58 +15,54 @@
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <stdexcept>
#include <string>
#include "absl/types/span.h"
#include "compression/python/compression_clif_aux.h"
#include "compression/shared.h"
#include "compression/types.h" // Type
#include "gemma/tensor_info.h"
#include "util/mat.h"
using gcpp::CompressorMode;
using gcpp::MatPtr;
using gcpp::SbsReader;
using gcpp::SbsWriter;
namespace py = pybind11;
namespace pybind11 {
namespace {
template <auto Func>
void wrap_span(SbsWriter& writer, std::string name, py::array_t<float> data) {
static void CallWithF32Span(SbsWriter& writer, const char* name,
array_t<float> data, gcpp::Type type,
const gcpp::TensorInfo& tensor_info) {
if (data.ndim() != 1 || data.strides(0) != sizeof(float)) {
throw std::domain_error("Input array must be 1D and densely packed.");
HWY_ABORT("Input array must be 1D (not %d) and contiguous floats.",
static_cast<int>(data.ndim()));
}
std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size()));
std::invoke(Func, writer, name,
gcpp::F32Span(data.mutable_data(0), data.size()), type,
tensor_info);
}
template <auto Func>
void wrap_span_typed(SbsWriter& writer, std::string name,
py::array_t<float> data, gcpp::Type type,
gcpp::TensorInfo tensor_info, float scale) {
if (data.ndim() != 1 || data.strides(0) != sizeof(float)) {
throw std::domain_error("Input array must be 1D and densely packed.");
}
std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size()),
type, tensor_info, scale);
}
} // namespace
PYBIND11_MODULE(compression, m) {
py::enum_<CompressorMode>(m, "CompressorMode")
.value("TEST_ONLY", CompressorMode::kTEST_ONLY)
.value("NO_TOC", CompressorMode::kNO_TOC)
.value("WITH_TOC", CompressorMode::kWITH_TOC);
class_<SbsWriter>(m, "SbsWriter")
.def(init<std::string>())
.def("insert", CallWithF32Span<&SbsWriter::Insert>)
.def("write", &SbsWriter::Write, arg("config"), arg("tokenizer_path"));
py::class_<SbsWriter>(m, "SbsWriter")
.def(py::init<CompressorMode>())
// NOTE: Individual compression backends may impose constraints on the
// array length, such as a minimum of (say) 32 elements.
.def("insert", wrap_span_typed<&SbsWriter::Insert>)
.def("insert_sfp", wrap_span<&SbsWriter::InsertSfp>)
.def("insert_nuq", wrap_span<&SbsWriter::InsertNUQ>)
.def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>)
.def("insert_float", wrap_span<&SbsWriter::InsertFloat>)
.def("add_scales", &SbsWriter::AddScales)
.def("add_tokenizer", &SbsWriter::AddTokenizer)
.def("debug_num_blobs_added", &SbsWriter::DebugNumBlobsAdded)
.def("write", &SbsWriter::Write)
.def("write_with_config", &SbsWriter::WriteWithConfig);
class_<MatPtr>(m, "MatPtr")
// No init, only created within C++.
.def_property_readonly("rows", &MatPtr::Rows, "Number of rows")
.def_property_readonly("cols", &MatPtr::Cols, "Number of cols")
.def_property_readonly("type", &MatPtr::GetType, "Element type")
.def_property_readonly("scale", &MatPtr::Scale, "Scaling factor");
class_<SbsReader>(m, "SbsReader")
.def(init<std::string>())
.def_property_readonly("config", &SbsReader::Config,
return_value_policy::reference_internal,
"ModelConfig")
.def("find_mat", &SbsReader::FindMat,
return_value_policy::reference_internal,
"Returns MatPtr for given name.");
}
} // namespace pybind11

View File

@ -25,46 +25,120 @@ from python import configs
class CompressionTest(absltest.TestCase):
def test_sbs_writer(self):
info_192 = configs.TensorInfo()
info_192.name = "ignored_192"
info_192.axes = [0]
info_192.shape = [192]
temp_file = self.create_tempfile("test.sbs")
tensor_info = configs.TensorInfo()
tensor_info.name = "foo"
tensor_info.axes = [0]
tensor_info.shape = [192]
writer = compression.SbsWriter(compression.CompressorMode.NO_TOC)
writer = compression.SbsWriter(temp_file.full_path)
writer.insert(
"foo",
np.array([0.0012] * 128 + [0.001] * 64, dtype=np.float32),
"tensor0",
# Large enough to require scaling.
np.array([3.0012] * 128 + [4.001] * 64, dtype=np.float32),
configs.Type.kSFP,
tensor_info,
1.0,
info_192,
)
tensor_info_nuq = configs.TensorInfo()
tensor_info_nuq.name = "fooNUQ"
tensor_info_nuq.axes = [0]
tensor_info_nuq.shape = [256]
# 2D tensor.
info_2d = configs.TensorInfo()
info_2d.name = "ignored_2d"
info_2d.axes = [0, 1]
info_2d.shape = [96, 192]
writer.insert(
"fooNUQ",
"tensor_2d",
np.array([i / 1e3 for i in range(96 * 192)], dtype=np.float32),
configs.Type.kBF16,
info_2d,
)
# 3D collapsed into rows.
info_3d = configs.TensorInfo()
info_3d.name = "ignored_3d"
info_3d.axes = [0, 1, 2]
info_3d.shape = [10, 12, 192]
info_3d.cols_take_extra_dims = False
writer.insert(
"tensor_3d",
# Verification of scale below depends on the shape and multiplier here.
np.array([i / 1e3 for i in range(10 * 12 * 192)], dtype=np.float32),
configs.Type.kSFP,
info_3d,
)
# Exercise all types supported by Compress.
info_256 = configs.TensorInfo()
info_256.name = "ignored_256"
info_256.axes = [0]
info_256.shape = [256]
writer.insert(
"tensor_sfp",
np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32),
configs.Type.kNUQ,
tensor_info_nuq,
1.0,
configs.Type.kSFP,
info_256,
)
writer.insert_sfp(
"bar", np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32)
writer.insert(
"tensor_bf",
np.array([0.000375] * 128 + [0.00007] * 128, dtype=np.float32),
configs.Type.kBF16,
info_256,
)
writer.insert_nuq(
"baz", np.array([0.000125] * 128 + [0.00008] * 128, dtype=np.float32)
writer.insert(
"tensor_f32",
np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32),
configs.Type.kF32,
info_256,
)
writer.insert_bf16(
"qux", np.array([0.000375] * 128 + [0.00007] * 128, dtype=np.float32)
config = configs.ModelConfig(
configs.Model.GEMMA_TINY,
configs.Type.kSFP,
configs.PromptWrapping.GEMMA_IT,
)
writer.insert_float(
"quux", np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32)
)
self.assertEqual(writer.debug_num_blobs_added(), 6)
self.assertEqual(writer.write(temp_file.full_path), 0)
tokenizer_path = "" # no tokenizer required for testing
writer.write(config, tokenizer_path)
print("Ignore next two warnings; test does not enable model deduction.")
reader = compression.SbsReader(temp_file.full_path)
self.assertEqual(reader.config.model, configs.Model.GEMMA_TINY)
self.assertEqual(reader.config.weight, configs.Type.kSFP)
mat = reader.find_mat("tensor0")
self.assertEqual(mat.cols, 192)
self.assertEqual(mat.rows, 1)
self.assertEqual(mat.type, configs.Type.kSFP)
self.assertAlmostEqual(mat.scale, 4.001 / 1.875, places=5)
mat = reader.find_mat("tensor_2d")
self.assertEqual(mat.cols, 192)
self.assertEqual(mat.rows, 96)
self.assertEqual(mat.type, configs.Type.kBF16)
self.assertAlmostEqual(mat.scale, 1.0)
mat = reader.find_mat("tensor_3d")
self.assertEqual(mat.cols, 192)
self.assertEqual(mat.rows, 10 * 12)
self.assertEqual(mat.type, configs.Type.kSFP)
self.assertAlmostEqual(mat.scale, 192 * 120 / 1e3 / 1.875, places=2)
mat = reader.find_mat("tensor_sfp")
self.assertEqual(mat.cols, 256)
self.assertEqual(mat.rows, 1)
self.assertEqual(mat.type, configs.Type.kSFP)
self.assertAlmostEqual(mat.scale, 1.0)
mat = reader.find_mat("tensor_bf")
self.assertEqual(mat.cols, 256)
self.assertEqual(mat.rows, 1)
self.assertEqual(mat.type, configs.Type.kBF16)
self.assertAlmostEqual(mat.scale, 1.0)
mat = reader.find_mat("tensor_f32")
self.assertEqual(mat.cols, 256)
self.assertEqual(mat.rows, 1)
self.assertEqual(mat.type, configs.Type.kF32)
self.assertAlmostEqual(mat.scale, 1.0)
if __name__ == "__main__":

View File

@ -20,7 +20,7 @@
#include <stddef.h>
#include <stdint.h>
#include "compression/shared.h"
#include "compression/types.h"
#include "hwy/base.h"
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_

View File

@ -13,10 +13,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// We use ConcatEven/Odd which are not supported. Use HWY_EMU128 instead.
#include "compression/types.h"
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
#include <stddef.h>
#include <stdint.h>
@ -25,7 +25,6 @@
#include <set>
#include "compression/distortion.h"
#include "compression/shared.h"
#include "util/test_util.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"

View File

@ -18,10 +18,13 @@
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
// IWYU pragma: begin_exports
#include "compression/compress.h"
#include "compression/distortion.h"
#include "util/mat.h"
// IWYU pragma: end_exports
#include "compression/compress.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
// Include guard for (potentially) SIMD code.
@ -59,7 +62,63 @@ void ForeachPackedAndRawType() {
ForeachRawType<BF16, TestT>();
ForeachRawType<float, TestT>();
ForeachRawType<SfpStream, TestT>();
ForeachRawType<NuqStream, TestT>();
if constexpr (GEMMA_ENABLE_NUQ) {
ForeachRawType<NuqStream, TestT>();
}
}
// Generates inputs: deterministic, within max SfpStream range.
template <typename MatT>
MatStorageT<MatT> GenerateMat(const Extents2D& extents,
const Allocator& allocator, MatPadding padding,
hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
ws.tls.resize(pool.NumWorkers());
MatStorageT<float> raw("raw", extents, allocator, MatPadding::kPacked);
MatStorageT<MatT> compressed("mat", extents, allocator, padding);
const float scale = SfpStream::kMax / extents.Area();
pool.Run(0, extents.rows, [&](const size_t r, size_t thread) {
float* HWY_RESTRICT row = raw.Row(r);
for (size_t c = 0; c < extents.cols; c++) {
float f = static_cast<float>(r * extents.cols + c) * scale;
if ((r + c) & 1) f = -f; // Also generate some negative values.
row[c] = f;
}
Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
MakeSpan(compressed.Row(r), compressed.Cols()),
/*packed_ofs=*/0);
});
compressed.SetScale(0.6f); // Arbitrary value, different from 1.
return compressed;
}
// Same, but `extents` describes the transposed matrix.
template <typename MatT>
MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
const Allocator& allocator,
MatPadding padding,
hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
ws.tls.resize(pool.NumWorkers());
MatStorageT<float> raw("raw", extents, allocator, MatPadding::kPacked);
MatStorageT<MatT> compressed("trans", extents, allocator, padding);
const float scale = SfpStream::kMax / extents.Area();
pool.Run(0, extents.rows, [&](const size_t r, size_t thread) {
float* HWY_RESTRICT row = raw.Row(r);
for (size_t c = 0; c < extents.cols; c++) {
float f = static_cast<float>(c * extents.rows + r) * scale;
if ((r + c) & 1) f = -f; // Also generate some negative values.
row[c] = f;
}
Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
MakeSpan(compressed.Row(r), compressed.Cols()),
/*packed_ofs=*/0);
});
// Arbitrary value, different from 1, must match `GenerateMat`.
compressed.SetScale(0.6f);
return compressed;
}
// NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -13,18 +13,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Definitions shared between the public compress-inl.h interface and the
// sfp-inl.h and nuq-inl.h implementation details.
// Types shared between tensor definitions and `compress-inl.h`.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_TYPES_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TYPES_H_
#include <stddef.h>
#include <stdint.h>
#include <complex>
#include <cstdio>
// IWYU pragma: begin_exports
#include "util/basics.h" // BF16
#include "hwy/aligned_allocator.h"
@ -33,6 +29,35 @@
namespace gcpp {
// EMU128 must not be disabled because we disable SCALAR.
#define HWY_BROKEN_EMU128 0
// Allow user override of disabled targets.
#ifndef GEMMA_DISABLED_TARGETS
// All platforms: exclude SCALAR because we use ReorderWidenMulAccumulate.
#if HWY_ARCH_ARM_V7
// No NEON because we require double-precision support.
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_ALL_NEON)
#elif HWY_ARCH_ARM_A64
// We do not yet use AES (e.g. for random generation), hence NEON is the same
// as NEON_WITHOUT_AES. Also skip SVE because SVE2_128 and SVE_256 cover most.
#define GEMMA_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON | HWY_SVE)
#elif HWY_ARCH_X86
// Skip anything older than Haswell (2013); also use Zen4 for recent CPUs,
// because we do not use anything added by SPR (e.g. FP16) nor AVX 10.2.
#define GEMMA_DISABLED_TARGETS \
(HWY_SCALAR | HWY_SSE2 | HWY_SSSE3 | HWY_SSE4 | HWY_AVX3_SPR | HWY_AVX10_2)
#endif // HWY_ARCH_*
#endif // GEMMA_DISABLED_TARGETS
// Only used in experiments, hence disable in default builds.
#ifndef GEMMA_ENABLE_NUQ
#define GEMMA_ENABLE_NUQ 0
#endif
// Switching Floating Point: a hybrid 8-bit float representation of bf16/f32
// inputs that combines the advantages of e4m3 and e5m2 into a single format.
// It supports seeking at a granularity of 1 and decoding to bf16/f32.
@ -63,30 +88,6 @@ struct SfpStream {
};
#pragma pack(pop)
// Returns 1.0f if all magnitudes are <= SfpStream::kMax, otherwise scales them
// such that the largest magnitude is SfpStream::kMax, and returns the
// multiplier with which to restore the original values. This is only necessary
// before compressing to SfpStream.
// TODO: vectorize
static inline float ScaleWeights(float* HWY_RESTRICT raw, size_t num) {
float maxabs = 0.0;
for (size_t i = 0; i < num; ++i) {
maxabs = HWY_MAX(maxabs, hwy::ScalarAbs(raw[i]));
}
if (maxabs <= SfpStream::kMax) {
return 1.0f;
}
const float scale = maxabs / SfpStream::kMax;
const float inv_scale = static_cast<float>(1.0 / static_cast<double>(scale));
for (size_t i = 0; i < num; ++i) {
// Clamp because kMax may still be exceeded.
const float magn =
HWY_MIN(SfpStream::kMax, hwy::ScalarAbs(raw[i] * inv_scale));
raw[i] = hwy::ScalarCopySign(magn, raw[i]);
}
return scale;
}
// Non-uniform quantization: a compressed representation of f32 inputs that
// supports seeking at a granularity of 1 (for `DecompressAndZeroPad`) or
// two vectors (for `Decompress2`), and decoding to bf16/f32.
@ -185,31 +186,25 @@ constexpr bool IsNuqStream() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>();
}
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
enum class PromptWrapping {
GEMMA_IT,
GEMMA_PT,
GEMMA_VLM,
PALIGEMMA,
kSentinel // must be last
// Tensor types for loading weights.
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64 };
// These are used in `ModelConfig.Specifier`, hence the strings will not
// change, though new ones may be added.
static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16",
"sfp", "nuq", "f64"};
static constexpr size_t kNumTypes =
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
static constexpr size_t kTypeBits[] = {
0,
8 * sizeof(float),
8 * sizeof(BF16),
8 * sizeof(SfpStream),
4 /* NuqStream, actually 4.5 */,
8 * sizeof(double),
};
inline bool EnumValid(PromptWrapping type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) < static_cast<int>(PromptWrapping::kSentinel);
}
// Tensor types for loading weights. Note that not all types are supported as
// weights for a model, but can be used for other purposes, such as types for
// ModelWeightsPtrs. When adding a new type that is supported, also
// update gemma.cc, weights.*, and add instantiations/new_one.cc.
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64, kU128 };
constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
"nuq", "f64", "c64", "u128"};
inline bool EnumValid(Type type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <= static_cast<int>(Type::kU128);
static inline bool EnumValid(Type type) {
return static_cast<size_t>(type) < kNumTypes;
}
// Returns a Type enum for the type of the template parameter.
@ -226,20 +221,22 @@ Type TypeEnum() {
return Type::kNUQ;
} else if constexpr (hwy::IsSame<Packed, double>()) {
return Type::kF64;
} else if constexpr (hwy::IsSame<Packed, std::complex<double>>()) {
return Type::kC64;
} else if constexpr (hwy::IsSame<Packed, hwy::uint128_t>()) {
return Type::kU128;
} else {
HWY_DASSERT(false);
return Type::kUnknown;
}
}
// Returns a string name for the type of the template parameter.
static inline size_t TypeBits(Type type) {
return kTypeBits[static_cast<int>(type)];
}
static inline const char* TypeName(Type type) {
return kTypeStrings[static_cast<int>(type)];
}
template <typename PackedT>
const char* TypeName() {
return kTypeStrings[static_cast<int>(TypeEnum<PackedT>())];
return TypeName(TypeEnum<PackedT>());
}
template <typename Packed>
@ -248,7 +245,9 @@ constexpr bool IsCompressed() {
}
// Returns the number of `MatT` elements required to store `capacity` values,
// which must not be zero.
// which must not be zero. This is only intended to support the extra tables
// required for NUQ. `capacity` includes any padding and is `rows * stride`.
// Deprecated, replaced by fixup within `MatPtr`. Only used by tests.
template <typename Packed>
constexpr size_t CompressedArrayElements(size_t capacity) {
if constexpr (hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>()) {
@ -304,4 +303,4 @@ HWY_INLINE PackedSpan<const Packed> MakeConst(PackedSpan<Packed> packed) {
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TYPES_H_

View File

@ -6,14 +6,12 @@
#include <iostream>
#include <ostream>
#include <string>
#include <utility> // std::pair
#include <vector>
#include "compression/io.h" // Path
#include "evals/benchmark_helper.h"
#include "evals/cross_entropy.h"
#include "gemma/common.h"
#include "gemma/gemma.h"
#include "io/io.h" // Path
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/timer.h"
@ -27,7 +25,6 @@ class BenchmarkArgs : public ArgsBase<BenchmarkArgs> {
public:
BenchmarkArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
Path goldens;
Path summarize_text;
Path cross_entropy;
Path trivia_qa;
@ -36,8 +33,6 @@ class BenchmarkArgs : public ArgsBase<BenchmarkArgs> {
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(goldens.path, "goldens_dir", std::string(""),
"Directory containing golden files", 2);
visitor(summarize_text.path, "summarize_text", std::string(""),
"Path to text file to summarize", 2);
visitor(cross_entropy.path, "cross_entropy", std::string(""),
@ -53,56 +48,6 @@ class BenchmarkArgs : public ArgsBase<BenchmarkArgs> {
}
};
std::vector<std::pair<std::string, std::string>> load_goldens(
const std::string& path) {
std::ifstream goldens_file(path);
if (!goldens_file) {
std::cout << "Could not load goldens file: " << path << "\n" << std::flush;
return {};
}
std::vector<std::pair<std::string, std::string>> res;
std::string query_separator;
std::string query;
std::string answer_separator;
std::string answer;
while (std::getline(goldens_file, query_separator) &&
std::getline(goldens_file, query) &&
std::getline(goldens_file, answer_separator) &&
std::getline(goldens_file, answer)) {
res.push_back({query, answer});
}
return res;
}
int BenchmarkGoldens(GemmaEnv& env, const std::string& golden_path) {
std::vector<std::pair<std::string, std::string>> queries_answers =
load_goldens(golden_path);
size_t correct_answers = 0;
size_t total_tokens = 0;
const double time_start = hwy::platform::Now();
for (auto& [question, expected_answer] : queries_answers) {
QueryResult result = env.QueryModel(question);
total_tokens += result.tokens_generated;
if (result.response.find(expected_answer) != std::string::npos) {
correct_answers++;
} else {
std::cout << "Wrong!\n";
std::cout << "Input: " << question << "\n";
std::cout << "Expected: " << expected_answer << "\n";
std::cout << "Output: " << result.response << "\n\n" << std::flush;
}
}
LogSpeedStats(time_start, total_tokens);
std::cout << "Correct: " << correct_answers << " out of "
<< queries_answers.size() << "\n"
<< std::flush;
if (correct_answers != queries_answers.size()) {
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
}
int BenchmarkSummary(GemmaEnv& env, const Path& text) {
std::string prompt("Here is some text to summarize:\n");
prompt.append(ReadFileToString(text));
@ -117,6 +62,7 @@ int BenchmarkSummary(GemmaEnv& env, const Path& text) {
int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
size_t batch_tokens) {
const Gemma& gemma = *env.GetGemma();
std::string input = ReadFileToString(text);
std::vector<int> prompt = env.Tokenize(input);
std::cout << "Number of input tokens: " << prompt.size() << "\n";
@ -128,10 +74,11 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
std::vector<int> prompt_slice(prompt.begin() + pos,
prompt.begin() + pos + num_tokens);
KVCache kv_cache = KVCache::Create(env.GetModel()->GetModelConfig(),
env.MutableConfig().prefill_tbatch_size);
float entropy = ComputeCrossEntropy(
*env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
KVCache kv_cache(gemma.Config(), gemma.Inference(),
env.MutableEnv().ctx.allocator);
float entropy =
ComputeCrossEntropy(*env.GetGemma(), num_tokens, prompt_slice, kv_cache,
env.MutableEnv(), env.Verbosity());
total_entropy += entropy;
LogSpeedStats(time_start, pos + num_tokens);
std::string text_slice = env.StringFromTokens(prompt_slice);
@ -183,14 +130,7 @@ int main(int argc, char** argv) {
gcpp::GemmaEnv env(argc, argv);
gcpp::BenchmarkArgs benchmark_args(argc, argv);
if (!benchmark_args.goldens.Empty()) {
const std::string golden_path =
benchmark_args.goldens.path + "/" +
gcpp::ModelString(env.GetModel()->Info().model,
env.GetModel()->Info().wrapping) +
".txt";
return BenchmarkGoldens(env, golden_path);
} else if (!benchmark_args.summarize_text.Empty()) {
if (!benchmark_args.summarize_text.Empty()) {
return BenchmarkSummary(env, benchmark_args.summarize_text);
} else if (!benchmark_args.cross_entropy.Empty()) {
return BenchmarkCrossEntropy(env, benchmark_args.cross_entropy,

View File

@ -18,27 +18,21 @@
#include <stdio.h>
#include <time.h>
#include <cstdio>
#include <iostream>
#include <memory>
#include <ostream>
#include <random>
#include <string>
#include <vector>
// Placeholder for internal header, do not modify.
#include "compression/compress.h" // TypeName
#include "compression/types.h" // TypeName
#include "evals/cross_entropy.h"
#include "gemma/common.h" // StringFromType
#include "gemma/gemma.h"
#include "gemma/kv_cache.h"
#include "util/app.h"
#include "util/args.h"
#include "util/threading.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/topology.h"
#include "gemma/gemma_args.h"
#include "ops/matmul.h" // MatMulEnv
#include "util/threading_context.h"
#include "hwy/highway.h"
#include "hwy/per_target.h" // VectorBytes
#include "hwy/per_target.h" // DispatchedTarget
#include "hwy/profiler.h" // PROFILER_ENABLED
#include "hwy/timer.h"
namespace gcpp {
@ -49,53 +43,37 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
gen.seed(0x12345678);
} else {
// Depending on the library implementation, this may still be deterministic.
std::random_device rd;
std::random_device rd; // NOLINT
gen.seed(rd());
}
}
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
const AppArgs& app)
: topology_(CreateTopology(app)),
pools_(CreatePools(topology_, app)),
env_(topology_, pools_) {
InferenceArgs mutable_inference = inference;
AbortIfInvalidArgs(mutable_inference);
LoaderArgs mutable_loader = loader;
if (const char* err = mutable_loader.Validate()) {
mutable_loader.Help();
fprintf(stderr, "Skipping model load because: %s\n", err);
} else {
fprintf(stderr, "Loading model...\n");
model_ = AllocateGemma(mutable_loader, env_);
// Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.resize(1);
kv_caches_[0] = KVCache::Create(model_->GetModelConfig(),
inference.prefill_tbatch_size);
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference)
: ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) {
const ModelConfig& config = gemma_.Config();
// Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.push_back(KVCache(config, inference, ctx_.allocator));
if (inference.verbosity >= 2) {
ShowConfig(loader, threading, inference, config, gemma_.WeightReadMode(),
ctx_);
}
InitGenerator(inference, gen_);
runtime_config_ = {
.max_generated_tokens = inference.max_generated_tokens,
.temperature = inference.temperature,
.gen = &gen_,
.verbosity = app.verbosity,
.verbosity = inference.verbosity,
};
}
// Internal init must run before the GemmaEnv ctor above, hence it cannot occur
// in the argv ctor below because its body runs *after* the delegating ctor.
// This helper function takes care of the init, and could be applied to any of
// the *Args classes, it does not matter which.
static AppArgs MakeAppArgs(int argc, char** argv) {
{ // So that indentation matches expectations.
// Placeholder for internal init, do not modify.
}
return AppArgs(argc, argv);
inference.CopyTo(runtime_config_);
}
GemmaEnv::GemmaEnv(int argc, char** argv)
: GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv),
MakeAppArgs(argc, argv)) {}
: GemmaEnv(LoaderArgs(argc, argv), ThreadingArgs(argc, argv),
InferenceArgs(argc, argv)) {}
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
QueryResult result;
@ -117,8 +95,8 @@ QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
}
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
runtime_config_.batch_stream_token = batch_stream_token;
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
timing_info);
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_,
timing_info);
return result;
}
@ -127,23 +105,25 @@ void GemmaEnv::QueryModel(
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
const StreamFunc previous_stream_token = runtime_config_.stream_token;
runtime_config_.stream_token = stream_token;
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
timing_info);
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_,
timing_info);
runtime_config_.stream_token = previous_stream_token;
}
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
const QueriesPromptTokens& queries_prompt) {
const QueriesPromptTokens& queries_prompt,
const hwy::Span<const size_t>& prefix_end) {
const size_t num_queries = queries_prompt.size();
HWY_ASSERT(num_queries != 0);
std::vector<QueryResult> res(num_queries);
const BatchStreamFunc batch_stream_token = [&res, &queries_prompt, this](
size_t query_index, size_t pos,
int token, float) {
const BatchStreamFunc batch_stream_token = [&, this](const size_t query_index,
const size_t pos,
const int token, float) {
HWY_ASSERT(query_index < num_queries);
std::string token_text;
HWY_ASSERT(
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
HWY_ASSERT(gemma_.Tokenizer().Decode(std::vector<int>{token}, &token_text));
res[query_index].response.append(token_text);
HWY_ASSERT(pos == res[query_index].tokens_generated);
res[query_index].tokens_generated += 1;
if (res[query_index].tokens_generated ==
queries_prompt[query_index].size()) {
@ -151,6 +131,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
}
return true;
};
runtime_config_.batch_stream_token = batch_stream_token;
if (runtime_config_.verbosity >= 2) {
fprintf(stderr, "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
runtime_config_.max_generated_tokens, runtime_config_.temperature,
@ -158,23 +139,16 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
runtime_config_.decode_qbatch_size);
}
// Ensure we have one KVCache per query.
if (kv_caches_.size() < num_queries) {
kv_caches_.resize(num_queries);
}
for (size_t i = 1; i < num_queries; ++i) {
if (kv_caches_[i].seq_len == 0) {
kv_caches_[i] = KVCache::Create(model_->GetModelConfig(),
runtime_config_.prefill_tbatch_size);
}
// Ensure we have at least one KVCache per query.
while (kv_caches_.size() < num_queries) {
kv_caches_.push_back(
KVCache(gemma_.Config(), gemma_.Inference(), ctx_.allocator));
}
const hwy::Span<KVCache> kv_caches(&kv_caches_[0], num_queries);
gcpp::AllQueries all_queries(queries_prompt, kv_caches, prefix_end);
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
runtime_config_.batch_stream_token = batch_stream_token;
std::vector<size_t> queries_pos(num_queries, 0);
model_->GenerateBatch(runtime_config_, queries_prompt,
QueriesPos(queries_pos.data(), num_queries),
KVCaches(&kv_caches_[0], num_queries), timing_info);
gemma_.GenerateBatch(runtime_config_, all_queries, env_, timing_info);
return res;
}
@ -203,8 +177,8 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
float GemmaEnv::CrossEntropy(const std::string& input) {
std::vector<int> prompt = Tokenize(input);
prompt.insert(prompt.begin(), BOS_ID);
return ComputeCrossEntropy(*GetModel(), /*max_generated_tokens=*/3072, prompt,
MutableKVCache(),
return ComputeCrossEntropy(*GetGemma(), /*max_generated_tokens=*/3072, prompt,
MutableKVCache(), env_,
/*verbosity=*/0) /
static_cast<int>(input.size());
}
@ -236,13 +210,37 @@ std::string CacheString() {
return buf;
}
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
const BoundedTopology& topology, NestedPools& pools) {
loader.Print(app.verbosity);
inference.Print(app.verbosity);
app.Print(app.verbosity);
static constexpr const char* CompiledConfig() {
if constexpr (HWY_IS_ASAN) {
return "asan";
} else if constexpr (HWY_IS_MSAN) {
return "msan";
} else if constexpr (HWY_IS_TSAN) {
return "tsan";
} else if constexpr (HWY_IS_HWASAN) {
return "hwasan";
} else if constexpr (HWY_IS_UBSAN) {
return "ubsan";
} else if constexpr (HWY_IS_DEBUG_BUILD) {
return "dbg";
} else {
return "opt";
}
}
if (app.verbosity >= 2) {
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference, const ModelConfig& config,
const WeightsPtrs::Mode weight_read_mode,
const ThreadingContext& ctx) {
threading.Print(inference.verbosity);
loader.Print(inference.verbosity);
inference.Print(inference.verbosity);
fprintf(
stderr, "Model : %s, to_bf16 %d, mmap %d => %s\n",
config.Specifier().c_str(), static_cast<int>(loader.to_bf16),
static_cast<int>(loader.map), WeightsPtrs::ToString(weight_read_mode));
if (inference.verbosity >= 2) {
time_t now = time(nullptr);
char* dt = ctime(&now); // NOLINT
char cpu100[100] = "unknown";
@ -250,38 +248,34 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
fprintf(stderr,
"Date & Time : %s" // dt includes \n
"CPU : %s\n"
"CPU : %s, bind %d\n"
"CPU topology : %s, %s, %s\n"
"Instruction set : %s (%zu bits)\n"
"Compiled config : %s\n"
"Weight Type : %s\n"
"EmbedderInput Type : %s\n",
dt, cpu100, topology.TopologyString(), pools.PinString(),
"Compiled config : %s, profiler %d\n"
"Memory MiB : %4zu\n",
dt, cpu100, static_cast<int>(threading.bind),
ctx.topology.TopologyString(), ctx.pools.PinString(),
CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()),
hwy::VectorBytes() * 8, CompiledConfig(),
StringFromType(loader.Info().weight), TypeName<EmbedderInputT>());
ctx.allocator.VectorBytes() * 8, CompiledConfig(), PROFILER_ENABLED,
ctx.allocator.TotalMiB());
}
}
void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference) {
std::cerr
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
"==========================================================\n\n"
"To run gemma.cpp, you need to "
"specify 3 required model loading arguments:\n"
" --tokenizer\n"
" --weights\n"
" --model,\n"
" or with the newer weights format, specify just:\n"
" --weights\n";
"To run with pre-2025 weights, specify --tokenizer and --weights.\n"
"With the single-file weights format, specify just --weights.\n";
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
"--weights 2b-it-sfp.sbs --model 2b-it\n";
"--weights gemma2-2b-it-sfp.sbs\n";
std::cerr << "\n*Model Loading Arguments*\n\n";
loader.Help();
std::cerr << "\n*Threading Arguments*\n\n";
threading.Help();
std::cerr << "\n*Inference Arguments*\n\n";
inference.Help();
std::cerr << "\n*Application Arguments*\n\n";
app.Help();
std::cerr << "\n";
}

View File

@ -18,15 +18,16 @@
#include <stddef.h>
#include <memory>
#include <random>
#include <string>
#include <vector>
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/tokenizer.h" // WrapAndTokenize
#include "ops/matmul.h"
#include "util/app.h"
#include "util/threading.h"
#include "util/threading_context.h"
#include "hwy/base.h"
namespace gcpp {
@ -46,19 +47,21 @@ class GemmaEnv {
public:
// Calls the other constructor with *Args arguments initialized from argv.
GemmaEnv(int argc, char** argv);
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
const AppArgs& app);
GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference);
MatMulEnv& Env() { return env_; }
size_t MaxGeneratedTokens() const {
return runtime_config_.max_generated_tokens;
}
void SetMaxGeneratedTokens(size_t max_generated_tokens) {
runtime_config_.max_generated_tokens = max_generated_tokens;
void SetMaxGeneratedTokens(int max_generated_tokens) {
runtime_config_.max_generated_tokens =
static_cast<size_t>(max_generated_tokens);
}
std::vector<int> Tokenize(const std::string& input) const {
std::vector<int> tokens;
HWY_ASSERT(model_->Tokenizer().Encode(input, &tokens));
HWY_ASSERT(gemma_.Tokenizer().Encode(input, &tokens));
return tokens;
}
@ -69,20 +72,23 @@ class GemmaEnv {
}
std::vector<int> WrapAndTokenize(std::string& input) const {
return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->Info(), 0, input);
return gcpp::WrapAndTokenize(gemma_.Tokenizer(), gemma_.ChatTemplate(),
gemma_.Config().wrapping, 0, input);
}
std::string StringFromTokens(const std::vector<int>& tokens) const {
std::string string;
HWY_ASSERT(model_->Tokenizer().Decode(tokens, &string));
HWY_ASSERT(gemma_.Tokenizer().Decode(tokens, &string));
return string;
}
// Runs inference on the given input and returns the top-1 result string and
// the number of tokens that were generated.
QueryResult QueryModel(const std::vector<int>& tokens);
// The default prefix_end means "causal attention".
std::vector<QueryResult> BatchQueryModel(
const QueriesPromptTokens& queries_prompt);
const QueriesPromptTokens& queries_prompt,
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>());
// Adds turn structure to input, tokenizes and calls the above overload.
QueryResult QueryModel(std::string& input);
std::vector<QueryResult> BatchQueryModel(
@ -97,20 +103,19 @@ class GemmaEnv {
// number of bits per token.
float CrossEntropy(const std::string& input);
// Returns nullptr if the model failed to load.
Gemma* GetModel() const { return model_.get(); }
const Gemma* GetGemma() const { return &gemma_; }
int Verbosity() const { return runtime_config_.verbosity; }
RuntimeConfig& MutableConfig() { return runtime_config_; }
std::mt19937& MutableGen() { return gen_; }
KVCache& MutableKVCache() { return kv_caches_[0]; }
MatMulEnv& MutableEnv() { return env_; }
private:
BoundedTopology topology_;
NestedPools pools_; // Thread pool.
ThreadingContext ctx_;
MatMulEnv env_;
std::mt19937 gen_; // Random number generator.
std::unique_ptr<Gemma> model_;
Gemma gemma_;
std::mt19937 gen_; // Random number generator.
std::vector<KVCache> kv_caches_; // Same number as query batch.
RuntimeConfig runtime_config_;
};
@ -118,9 +123,12 @@ class GemmaEnv {
// Logs the inference speed in tokens/sec.
void LogSpeedStats(double time_start, size_t total_tokens);
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
const BoundedTopology& topology, NestedPools& pools);
void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference, const ModelConfig& config,
WeightsPtrs::Mode weight_read_mode,
const ThreadingContext& ctx);
void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference);
} // namespace gcpp

View File

@ -13,6 +13,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
@ -38,17 +43,12 @@
#include <vector>
#include "evals/cross_entropy.h"
#include "gemma/common.h"
#include "gemma/gemma.h"
#include "hwy/base.h"
namespace gcpp {
namespace {
template <typename TConfig>
struct GetVocabSize {
int operator()() const { return TConfig::kVocabSize; }
};
static std::string TokenString(const GemmaTokenizer& tokenizer, int token) {
std::string token_str;
@ -85,7 +85,7 @@ namespace gcpp {
namespace HWY_NAMESPACE {
void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size) {
Softmax(logits, vocab_size);
Softmax(logits, vocab_size, /*worker=*/0);
}
} // namespace HWY_NAMESPACE
@ -97,12 +97,12 @@ namespace gcpp {
HWY_EXPORT(CallSoftmax);
float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
int verbosity) {
MatMulEnv& env, int verbosity) {
const StreamFunc stream_token = [](int, float) { return true; };
const int vocab_size = gemma.GetModelConfig().vocab_size;
const int vocab_size = gemma.Config().vocab_size;
float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s)
size_t pos = 1;
@ -145,7 +145,7 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
};
TimingInfo timing_info;
gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info);
gemma.Generate(runtime, prompt0, 0, kv_cache, env, timing_info);
const float scale = 1.0f / std::log(2.0f);
return cross_entropy * scale;

View File

@ -24,9 +24,9 @@
namespace gcpp {
float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
int verbosity);
MatMulEnv& env, int verbosity);
} // namespace gcpp

View File

@ -18,9 +18,9 @@
#include <string>
#include <vector>
#include "compression/io.h"
#include "evals/benchmark_helper.h"
#include "gemma/gemma.h" // LayersOutputFunc
#include "io/io.h"
#include "util/args.h"
#include "hwy/base.h"
#include "nlohmann/json.hpp"

View File

@ -13,25 +13,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/gemma.h"
#include <stdio.h>
#include <string>
#include <vector>
#include "evals/benchmark_helper.h"
#include "gemma/common.h"
#include "gemma/gemma.h"
#include "hwy/base.h"
#include "hwy/nanobenchmark.h"
#include "hwy/profiler.h"
#include "hwy/tests/hwy_gtest.h"
// This test can be run manually with the downloaded gemma weights.
// To run the test, pass the following flags:
// --model <model> --tokenizer <tokenizer_path> --weights <weights_path>
// It should pass for the following models:
// Gemma1: 2b-it (v1 and v1.1), 7b-it (v1 and v1.1), gr2b-it,
// Gemma2: gemma2-2b-it, 9b-it, 27b-it,
namespace gcpp {
namespace {
@ -40,61 +33,23 @@ namespace {
// non-local static variables with dtors.
GemmaEnv* s_env = nullptr;
class GemmaTest : public ::testing::Test {
class GemmaBatchBench : public ::testing::Test {
protected:
std::vector<std::string> BatchGemmaReply(
const std::vector<std::string>& inputs) {
s_env->SetMaxGeneratedTokens(64);
s_env->SetMaxGeneratedTokens(24);
s_env->MutableConfig().temperature = 0.0f; // deterministic
s_env->MutableConfig().verbosity = 5;
s_env->MutableConfig().verbosity = 2;
std::vector<std::string> replies;
// Using the turn structure worsens results sometimes.
// However, some models need the turn structure to work.
// It would be good to make these tests more consistent.
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
replies.push_back(result.response);
}
return replies;
}
// Otherwise, do not use turn structure.
std::vector<std::vector<int>> prompts_vector;
prompts_vector.reserve(inputs.size());
for (const auto& input_string : inputs) {
prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string));
}
std::vector<PromptTokens> prompt_spans;
for (const auto& prompt : prompts_vector) {
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
}
QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size());
for (const QueryResult& result : s_env->BatchQueryModel(prompts)) {
for (const QueryResult& result : s_env->BatchQueryModel(inputs)) {
replies.push_back(result.response);
}
return replies;
}
void GenerateTokens(std::vector<std::string> &kQA, size_t num_questions) {
ASSERT_NE(s_env->GetModel(), nullptr);
std::vector<std::string> inputs;
for (size_t i = 0; i < num_questions; ++i) {
inputs.push_back(kQA[i]);
}
std::vector<std::string> responses = BatchGemmaReply(inputs);
for (size_t i = 0; i < num_questions; ++i) {
std::string response = responses.at(i);
fprintf(stderr, "Batch answer %zu '%s'\n\n", i + 1, response.c_str());
}
}
};
TEST_F(GemmaTest, RandomQuestionsBatched) {
s_env->MutableConfig().decode_qbatch_size = 3;
s_env->MutableConfig().verbosity = 5;
static std::vector<std::string> kQA = {
TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
const std::vector<std::string> questions = {
{"Write me a poem about Australia?"},
{"What's the history of Denmark?"},
{"Write me a comedy story about the USA."},
@ -128,13 +83,27 @@ TEST_F(GemmaTest, RandomQuestionsBatched) {
{"Tell me about space travel."},
{"Explain to me how electric cars work."},
};
static const size_t kNum = kQA.size();
GenerateTokens(kQA, kNum);
// Fills prompts round robin from `questions` until the desired batch size.
std::vector<std::string> inputs;
inputs.reserve(s_env->MutableConfig().decode_qbatch_size);
size_t qpos = 0;
for (size_t i = 0; i < inputs.capacity(); ++i) {
inputs.push_back(questions[qpos++]);
if (qpos == questions.size()) qpos = 0;
}
std::vector<std::string> responses = BatchGemmaReply(inputs);
for (size_t i = 0; i < hwy::Unpredictable1() * 3; ++i) {
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
}
PROFILER_PRINT_RESULTS();
}
} // namespace
} // namespace gcpp
int main(int argc, char** argv) {
fprintf(stderr, "GemmaEnv setup..\n");
gcpp::GemmaEnv env(argc, argv);
gcpp::s_env = &env;

View File

@ -21,7 +21,8 @@
#include <vector>
#include "evals/benchmark_helper.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "io/io.h"
#include "hwy/base.h"
#include "hwy/tests/hwy_gtest.h"
@ -36,132 +37,75 @@
namespace gcpp {
namespace {
// Shared state. Requires argc/argv, so construct in main and use the same raw
// pointer approach as in benchmarks.cc. Note that the style guide forbids
// non-local static variables with dtors.
GemmaEnv* s_env = nullptr;
class GemmaTest : public ::testing::Test {
protected:
std::string GemmaReply(const std::string& prompt) {
s_env->SetMaxGeneratedTokens(2048);
s_env->MutableConfig().temperature = 0.0f; // deterministic
s_env->MutableConfig().verbosity = 0;
// Using the turn structure worsens results sometimes.
// However, some models need the turn structure to work.
// It would be good to make these tests more consistent.
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
std::string mutable_prompt = prompt;
QueryResult result = s_env->QueryModel(mutable_prompt); // Uses turns.
return result.response;
}
// Otherwise, do not use turn structure.
const std::vector<int> tokens = s_env->TokenizeAndPrependBOS(prompt);
QueryResult result = s_env->QueryModel(tokens);
return result.response;
public:
// Requires argc/argv, hence do not use `SetUpTestSuite`.
static void InitEnv(int argc, char** argv) {
HWY_ASSERT(s_env == nullptr); // Should only be called once.
s_env = new GemmaEnv(argc, argv);
const gcpp::ModelConfig& config = s_env->GetGemma()->Config();
fprintf(stderr, "Using %s\n", config.Specifier().c_str());
}
static void DeleteEnv() { delete s_env; }
protected:
std::vector<std::string> BatchGemmaReply(
const std::vector<std::string>& inputs) {
HWY_ASSERT(s_env); // must have called InitEnv()
s_env->SetMaxGeneratedTokens(64);
s_env->MutableConfig().temperature = 0.0f; // deterministic
s_env->MutableConfig().verbosity = 0;
// Always use turn structure (WrapAndTokenize).
std::vector<std::string> replies;
// Using the turn structure worsens results sometimes.
// However, some models need the turn structure to work.
// It would be good to make these tests more consistent.
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
replies.push_back(result.response);
}
return replies;
}
// Otherwise, do not use turn structure.
std::vector<std::vector<int>> prompts_vector;
prompts_vector.reserve(inputs.size());
for (const auto& input_string : inputs) {
prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string));
}
std::vector<PromptTokens> prompt_spans;
for (const auto& prompt : prompts_vector) {
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
}
QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size());
for (const QueryResult& result : s_env->BatchQueryModel(prompts)) {
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
replies.push_back(result.response);
}
return replies;
}
void TestQuestions(const char* kQA[][2], size_t num_questions, bool batch) {
ASSERT_NE(s_env->GetModel(), nullptr);
if (batch) {
std::vector<std::string> inputs;
for (size_t i = 0; i < num_questions; ++i) {
fprintf(stderr, "Batch Question %zu\n\n", i + 1);
inputs.push_back(kQA[i][0]);
}
std::vector<std::string> responses = BatchGemmaReply(inputs);
for (size_t i = 0; i < num_questions; ++i) {
std::string response = responses.at(i);
fprintf(stderr, "Batch answer %zu '%s'\n\n", i + 1, response.c_str());
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
}
} else {
for (size_t i = 0; i < num_questions; ++i) {
fprintf(stderr, "Question %zu\n\n", i + 1);
std::string response = GemmaReply(kQA[i][0]);
fprintf(stderr, "'%s'\n\n", response.c_str());
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
}
}
}
// Shared state. Requires argc/argv, so construct in main via InitEnv.
// Note that the style guide forbids non-local static variables with dtors.
static GemmaEnv* s_env;
};
TEST_F(GemmaTest, GeographyBatched) {
s_env->MutableConfig().decode_qbatch_size = 3;
// 6 are enough to test batching and the loop.
GemmaEnv* GemmaTest::s_env = nullptr;
TEST_F(GemmaTest, Batched) {
// Test remainder handling in MatMul (four rows per tile), but avoid a
// second batch in debug builds to speed up the test.
s_env->MutableConfig().decode_qbatch_size = HWY_IS_DEBUG_BUILD ? 6 : 3;
static const char* kQA[][2] = {
{"What is the capital of Australia?", "Canberra"},
{"What is the capital of Denmark?", "Copenhagen"},
{"Ljubljana is the capital of which country?", "Slovenia"},
{"Is Chicago a country?", "city"},
{"How many states does the US have?", "50"},
{"What is the Pacific?", "ocean"},
};
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
TestQuestions(kQA, HWY_MIN(kNum, 3), /*batch=*/false);
TestQuestions(kQA, 1, /*batch=*/true);
TestQuestions(kQA, kNum, /*batch=*/true);
}
TEST_F(GemmaTest, History) {
static const char* kQA[][2] = {
{"When was the battle of Hastings?", "1066"},
};
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
TestQuestions(kQA, kNum, /*batch=*/false);
}
TEST_F(GemmaTest, Arithmetic) {
static const char* kQA[][2] = {
{"what is 13 + 14?", "27"},
{"what is 7 * 8?", "56"},
};
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
TestQuestions(kQA, kNum, /*batch=*/false);
const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
std::vector<std::string> inputs;
for (size_t i = 0; i < kNum; ++i) {
inputs.push_back(kQA[i][0]);
}
std::vector<std::string> responses = BatchGemmaReply(inputs);
HWY_ASSERT(responses.size() == kNum);
for (size_t i = 0; i < kNum; ++i) {
fprintf(stderr, "#%zu: '%s'\n\n", i, responses[i].c_str());
EXPECT_TRUE(responses[i].find(kQA[i][1]) != std::string::npos); // NOLINT
}
}
TEST_F(GemmaTest, Multiturn) {
Gemma* model = s_env->GetModel();
ASSERT_NE(model, nullptr);
const Gemma* model = s_env->GetGemma();
const ModelConfig& config = model->Config();
size_t abs_pos = 0;
std::string response;
auto stream_token = [&](int token, float) {
if (token == EOS_ID) return true;
auto stream_token = [&](size_t query_idx, size_t pos, int token, float) {
HWY_ASSERT(query_idx == 0);
HWY_ASSERT(pos == abs_pos);
++abs_pos;
if (config.IsEOS(token)) return true;
std::string token_text;
EXPECT_TRUE(
model->Tokenizer().Decode(std::vector<int>{token}, &token_text));
@ -173,83 +117,44 @@ TEST_F(GemmaTest, Multiturn) {
.temperature = 0.0f,
.gen = &s_env->MutableGen(),
.verbosity = 2,
.stream_token = stream_token,
.batch_stream_token = stream_token,
};
TimingInfo timing_info{.verbosity = 0};
// First "say" something slightly unusual.
std::string mutable_prompt = "I have a car and its color is turquoise.";
std::vector<int> tokens = WrapAndTokenize(model->Tokenizer(), model->Info(),
abs_pos, mutable_prompt);
std::vector<int> tokens =
WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(),
config.wrapping, abs_pos, mutable_prompt);
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
timing_info);
s_env->MutableEnv(), timing_info);
// Note: we do not rewind any <end_of_turn> tokens here. If the model
// produced one and WrapAndTokenize() inserts another one, it will just be
// duplicated.
mutable_prompt = "Please repeat all prior statements.";
tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), abs_pos,
mutable_prompt);
tokens = WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(),
config.wrapping, abs_pos, mutable_prompt);
// Reset the `response` string here, then check that the model actually has
// access to the previous turn by asking to reproduce.
response.clear();
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
timing_info);
fprintf(stderr, "decoded: %s\n", response.c_str());
s_env->MutableEnv(), timing_info);
fprintf(stderr, "decoded: '%s'\n", response.c_str());
bool remembered_turquoise =
response.find("turquoise") != std::string::npos; // NOLINT
bool remembered_car = response.find("car") != std::string::npos; // NOLINT
EXPECT_TRUE(remembered_turquoise || remembered_car);
}
static const char kJingleBells[] = R"(
Dashing through the snow
In a one-horse open sleigh
O'er the fields we go
Laughing all the way
Bells on bobtails ring
Making spirits bright
What fun it is to ride and sing
A sleighing song tonight
)";
// The "Hay Draft" of the Gettysburg Address.
static const char kGettysburg[] = {
"Four score and seven years ago our fathers brought forth, upon this "
"continent, a new nation, conceived in Liberty, and dedicated to the "
"proposition that all men are created equal.\n\nNow we are engaged in a "
"great civil war, testing whether that nation, or any nation, so "
"conceived, and so dedicated, can long endure. We are met here on a great "
"battlefield of that war. We have come to dedicate a portion of it as a "
"final resting place for those who here gave their lives that that nation "
"might live. It is altogether fitting and proper that we should do "
"this.\n\nBut in a larger sense we can not dedicate -- we can not "
"consecrate -- we can not hallow this ground. The brave men, living and "
"dead, who struggled, here, have consecrated it far above our poor power "
"to add or detract. The world will little note, nor long remember, what we "
"say here, but can never forget what they did here. It is for us, the "
"living, rather to be dedicated here to the unfinished work which they "
"have, thus far, so nobly carried on. It is rather for us to be here "
"dedicated to the great task remaining before us -- that from these "
"honored dead we take increased devotion to that cause for which they here "
"gave the last full measure of devotion -- that we here highly resolve "
"that these dead shall not have died in vain; that this nation shall have "
"a new birth of freedom; and that this government of the people, by the "
"people, for the people, shall not perish from the earth.\n"};
TEST_F(GemmaTest, CrossEntropySmall) {
ASSERT_NE(s_env->GetModel(), nullptr);
HWY_ASSERT(s_env->GetGemma() != nullptr);
const ModelConfig& config = s_env->GetGemma()->Config();
static const char kSmall[] =
"The capital of Hungary is Budapest which is located in Europe.";
float entropy = s_env->CrossEntropy(kSmall);
fprintf(stderr, "per-token entropy: %f\n", entropy);
switch (s_env->GetModel()->Info().model) {
case gcpp::Model::GEMMA_2B:
// 2B v.1 and v.1.1 produce slightly different results.
EXPECT_NEAR(entropy, 2.6f, 0.2f);
break;
case gcpp::Model::GEMMA_7B:
// 7B v.1 and v.1.1 produce slightly different results.
EXPECT_NEAR(entropy, 2.8f, 0.2f);
break;
switch (config.model) {
case gcpp::Model::GRIFFIN_2B:
EXPECT_NEAR(entropy, 2.61f, 0.02f);
break;
@ -268,76 +173,14 @@ TEST_F(GemmaTest, CrossEntropySmall) {
}
}
TEST_F(GemmaTest, CrossEntropyJingleBells) {
ASSERT_NE(s_env->GetModel(), nullptr);
float entropy = s_env->CrossEntropy(kJingleBells);
fprintf(stderr, "per-token entropy: %f\n", entropy);
switch (s_env->GetModel()->Info().model) {
case gcpp::Model::GEMMA_2B:
// 2B v.1 and v.1.1 produce slightly different results.
EXPECT_NEAR(entropy, 1.9f, 0.2f);
break;
case gcpp::Model::GEMMA_7B:
// 7B v.1 and v.1.1 produce slightly different results.
EXPECT_NEAR(entropy, 1.07f, 0.05f);
break;
case gcpp::Model::GRIFFIN_2B:
EXPECT_NEAR(entropy, 1.62f, 0.02f);
break;
case gcpp::Model::GEMMA2_2B:
EXPECT_NEAR(entropy, 0.49f, 0.02f);
break;
case gcpp::Model::GEMMA2_9B:
EXPECT_NEAR(entropy, 0.37f, 0.02f);
break;
case gcpp::Model::GEMMA2_27B:
EXPECT_NEAR(entropy, 0.33f, 0.02f);
break;
default:
FAIL() << "no entropy expectation for this model";
break;
}
}
TEST_F(GemmaTest, CrossEntropyGettysburg) {
ASSERT_NE(s_env->GetModel(), nullptr);
float entropy = s_env->CrossEntropy(kGettysburg);
fprintf(stderr, "per-token entropy: %f\n", entropy);
switch (s_env->GetModel()->Info().model) {
case gcpp::Model::GEMMA_2B:
// 2B v.1 and v.1.1 produce slightly different results.
EXPECT_NEAR(entropy, 1.1f, 0.1f);
break;
case gcpp::Model::GEMMA_7B:
// 7B v.1 and v.1.1 produce slightly different results.
EXPECT_NEAR(entropy, 0.75f, 0.1f);
break;
case gcpp::Model::GRIFFIN_2B:
EXPECT_NEAR(entropy, 0.71f, 0.02f);
break;
case gcpp::Model::GEMMA2_2B:
EXPECT_NEAR(entropy, 0.20f, 0.02f);
break;
case gcpp::Model::GEMMA2_9B:
EXPECT_NEAR(entropy, 0.15f, 0.02f);
break;
case gcpp::Model::GEMMA2_27B:
EXPECT_NEAR(entropy, 0.14f, 0.02f);
break;
default:
FAIL() << "no entropy expectation for this model";
break;
}
}
} // namespace
} // namespace gcpp
int main(int argc, char** argv) {
gcpp::GemmaEnv env(argc, argv);
gcpp::s_env = &env;
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
gcpp::InternalInit();
gcpp::GemmaTest::InitEnv(argc, argv);
int ret = RUN_ALL_TESTS();
gcpp::GemmaTest::DeleteEnv();
return ret;
}

View File

@ -19,12 +19,11 @@
#include <string>
#include <vector>
#include "compression/io.h" // Path
#include "evals/benchmark_helper.h"
#include "gemma/gemma.h" // Gemma
#include "io/io.h" // Path
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
#include "hwy/profiler.h"
#include "nlohmann/json.hpp"
@ -89,7 +88,7 @@ void Run(GemmaEnv& env, JsonArgs& json) {
"A", "B", "C", "D", //
" A", " B", " C", " D", //
"**", "**:", ":**", "The", "Answer", "is", ":", "."};
const TokenSet accept_set(env.GetModel()->Tokenizer(), accept_strings);
const TokenSet accept_set(env.GetGemma()->Tokenizer(), accept_strings);
for (auto sample : json_data["samples"]) {
const int id = sample["i"];
@ -131,8 +130,9 @@ void Run(GemmaEnv& env, JsonArgs& json) {
.verbosity = env.Verbosity(),
.stream_token = stream_token,
};
env.GetModel()->Generate(runtime_config, prompt, /*pos=*/0,
env.MutableKVCache(), timing_info);
env.GetGemma()->Generate(runtime_config, prompt, /*pos=*/0,
env.MutableKVCache(), env.MutableEnv(),
timing_info);
std::string output_string = env.StringFromTokens(predicted_token_ids);
fprintf(stderr, "Correct %s, model '%s'\n", correct_answer.c_str(),

View File

@ -10,13 +10,11 @@ cc_binary(
name = "hello_world",
srcs = ["run.cc"],
deps = [
# Placeholder for internal dep, do not remove.,
"//:app",
"//:args",
"//:gemma_args",
"//:gemma_lib",
"//:threading",
"//:threading_context",
"//:tokenizer",
"@highway//:hwy",
"@highway//:thread_pool",
],
)

View File

@ -23,19 +23,15 @@
#include <string>
#include <vector>
// Placeholder for internal header, do not modify.
#include "gemma/gemma.h"
#include "gemma/gemma_args.h" // LoaderArgs
#include "gemma/tokenizer.h"
#include "util/app.h" // LoaderArgs
#include "util/args.h"
#include "util/threading.h"
#include "util/threading_context.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
int main(int argc, char **argv) { {
// Placeholder for internal init, do not modify.
}
int main(int argc, char **argv) {
gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs inference(argc, argv);
gcpp::AppArgs app(argc, argv);

View File

@ -10,10 +10,10 @@ cc_library(
name = "gemma",
hdrs = ["gemma.hpp"],
deps = [
"//:app",
"//:gemma_args",
"//:gemma_lib",
"//:ops",
"//:threading",
"//:matmul",
"//:threading_context",
"//:tokenizer",
"@highway//:hwy",
],
@ -24,15 +24,6 @@ cc_binary(
srcs = ["run.cc"],
deps = [
":gemma",
# Placeholder for internal dep, do not remove.,
"//:app",
"//:args",
"//:common",
"//:gemma_lib",
"//:ops",
"//:threading",
"//:tokenizer",
"@highway//:hwy",
"@highway//:thread_pool",
"//:gemma_args",
],
)

View File

@ -14,10 +14,11 @@
cmake_minimum_required(VERSION 3.11)
project(simplified_gemma)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
include(FetchContent)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c5bebf84ad01edec97e336f5c97ca4e0df6b4d06)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9414b48aeec251b69e6cadbfa42bebb5ddae1c34)
FetchContent_MakeAvailable(highway)
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
FetchContent_MakeAvailable(sentencepiece)
@ -31,7 +32,7 @@ if (NOT BUILD_MODE)
endif()
if (BUILD_MODE STREQUAL "local")
# Relative path to gemma.cpp from examples/simplified_gemma/build/
FetchContent_Declare(gemma SOURCE_DIR ../../..)
FetchContent_Declare(gemma SOURCE_DIR ../../..)
else()
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG a9aa63fd2ea6b786ed0706d619588bfe2d43370e)
endif()

View File

@ -41,7 +41,7 @@ gemma.cpp specifying the tokenizer, compressed weights file, and model type, for
example:
```sh
./simplified_gemma --tokenizer tokenizer.spm --compressed_weights 2b-it-sfp.sbs --model 2b-it
./simplified_gemma --tokenizer tokenizer.spm --weights 2b-it-sfp.sbs --model 2b-it
```
Should print a greeting to the terminal:

View File

@ -24,55 +24,39 @@
#include <vector>
#include "third_party/gemma_cpp/gemma/gemma.h"
#include "third_party/gemma_cpp/gemma/gemma_args.h" // LoaderArgs
#include "third_party/gemma_cpp/gemma/tokenizer.h"
#include "third_party/gemma_cpp/ops/matmul.h"
#include "third_party/gemma_cpp/util/app.h" // LoaderArgs
#include "third_party/gemma_cpp/util/threading.h"
#include "third_party/gemma_cpp/util/threading_context.h"
#include "third_party/highway/hwy/base.h"
class SimplifiedGemma {
public:
SimplifiedGemma(const gcpp::LoaderArgs& loader,
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs(),
const gcpp::AppArgs& app = gcpp::AppArgs())
: loader_(loader),
inference_(inference),
app_(app),
topology_(gcpp::CreateTopology(app_)),
pools_(gcpp::CreatePools(topology_, app_)),
env_(topology_, pools_),
model_(gcpp::CreateGemma(loader_, env_)) {
Init();
}
SimplifiedGemma(int argc, char** argv)
: loader_(argc, argv, /*validate=*/true),
inference_(argc, argv),
app_(argc, argv),
topology_(gcpp::CreateTopology(app_)),
pools_(gcpp::CreatePools(topology_, app_)),
env_(topology_, pools_),
model_(gcpp::CreateGemma(loader_, env_)) {
Init();
}
void Init() {
// Instantiate model and KV Cache
kv_cache_ = gcpp::KVCache::Create(model_.GetModelConfig(),
inference_.prefill_tbatch_size);
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
: ctx_(UpdateArgs(threading, inference)),
env_(ctx_),
gemma_(loader, inference, ctx_),
kv_cache_(gemma_.Config(), inference, ctx_.allocator) {
// Initialize random number generator
std::random_device rd;
gen_.seed(rd());
}
SimplifiedGemma(int argc, char** argv)
: SimplifiedGemma(gcpp::LoaderArgs(argc, argv),
gcpp::ThreadingArgs(argc, argv),
gcpp::InferenceArgs(argc, argv)) {}
void Generate(std::string& prompt, size_t max_generated_tokens = 1024,
float temperature = 0.7,
const std::set<int>& reject_tokens = {}) {
size_t generated = 0;
const std::vector<int> tokens = gcpp::WrapAndTokenize(
model_.Tokenizer(), loader_.Info(), generated, prompt);
gemma_.Tokenizer(), gemma_.ChatTemplate(),
gemma_.Config().wrapping, generated, prompt);
const size_t prompt_size = tokens.size();
// This callback function gets invoked every time a token is generated
@ -80,9 +64,9 @@ class SimplifiedGemma {
++generated;
if (generated < prompt_size) {
// print feedback
} else if (!this->model_.GetModelConfig().IsEOS(token)) {
} else if (!gemma_.Config().IsEOS(token)) {
std::string token_text;
HWY_ASSERT(this->model_.Tokenizer().Decode({token}, &token_text));
HWY_ASSERT(gemma_.Tokenizer().Decode({token}, &token_text));
std::cout << token_text << std::flush;
}
return true;
@ -100,19 +84,15 @@ class SimplifiedGemma {
return !reject_tokens.contains(token);
},
};
model_.Generate(runtime_config, tokens, 0, kv_cache_, timing_info);
gemma_.Generate(runtime_config, tokens, 0, kv_cache_, env_, timing_info);
}
~SimplifiedGemma() = default;
private:
gcpp::LoaderArgs loader_;
gcpp::InferenceArgs inference_;
gcpp::AppArgs app_;
gcpp::BoundedTopology topology_;
gcpp::NestedPools pools_;
gcpp::ThreadingContext ctx_;
gcpp::MatMulEnv env_;
gcpp::Gemma model_;
gcpp::Gemma gemma_;
gcpp::KVCache kv_cache_;
std::mt19937 gen_;
std::string validation_error_;
};
};

View File

@ -17,30 +17,25 @@
#include <string>
// Placeholder for internal header, do not modify.
#include "third_party/gemma_cpp/examples/simplified_gemma/gemma.hpp"
#include "util/app.h" // LoaderArgs
#include "gemma/gemma_args.h" // LoaderArgs
int main(int argc, char** argv) {
{
// Placeholder for internal init, do not modify.
}
// Standard usage: LoaderArgs takes argc and argv as input, then parses
// necessary flags.
gcpp::LoaderArgs loader(argc, argv, /*validate=*/true);
gcpp::LoaderArgs loader(argc, argv);
// Optional: LoaderArgs can also take tokenizer and weights paths directly.
//
// gcpp::LoaderArgs loader("/path/to/tokenizer", "/path/to/weights",
// "model_identifier");
// Optional: InferenceArgs and AppArgs can be passed in as well. If not
// Optional: ThreadingArgs and InferenceArgs can be passed in as well. If not
// specified, default values will be used.
//
// gcpp::InferenceArgs inference(argc, argv);
// gcpp::AppArgs app(argc, argv);
// SimplifiedGemma gemma(loader, inference, app);
// gcpp::ThreadingArgs threading(argc, argv);
// SimplifiedGemma gemma(loader, threading, inference);
SimplifiedGemma gemma(loader);
std::string prompt = "Write a greeting to the world.";

View File

@ -16,104 +16,199 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
#include <math.h> // sqrtf
#include <stddef.h>
#include <stdint.h>
#include "compression/shared.h" // BF16
#include "gemma/configs.h"
#include "ops/matmul.h" // MatMulEnv
#include "ops/ops.h" // CreateInvTimescale
#include "util/allocator.h" // RowVectorBatch
#include "util/threading.h"
#include "hwy/base.h" // HWY_DASSERT
#include "hwy/contrib/thread_pool/thread_pool.h"
#include <atomic>
#include <vector>
#include "gemma/configs.h" // ModelConfig
#include "ops/matmul.h" // MatMulEnv
#include "ops/ops.h" // CreateInvTimescale
#include "util/allocator.h" // Allocator
#include "util/basics.h" // BF16
#include "util/mat.h" // MatStorageT
namespace gcpp {
struct Activations {
explicit Activations(const ModelConfig& config)
: weights_config(config),
layer_config(config.layer_configs[0]),
seq_len(config.seq_len),
cache_pos_size(config.CachePosSize()) {}
struct GriffinActivations {
GriffinActivations(const ModelConfig& config, size_t batch_size,
const Allocator& allocator)
: griffin_x(
MatFactory("griffin_x", batch_size, config.model_dim, allocator)),
griffin_y(
MatFactory("griffin_y", batch_size, config.model_dim, allocator)),
griffin_gate_x(MatFactory("griffin_gate_x", batch_size,
config.model_dim, allocator)),
griffin_multiplier(MatFactory("griffin_mul", batch_size,
config.model_dim, allocator)) {}
RowVectorBatch<float> x; // input
RowVectorBatch<float> q; // query, also KV if MHA.
RowVectorBatch<float> logits;
void SetBatchSize(size_t batch_size) {
if (griffin_x.Rows() == 0) return;
griffin_x.OverrideRows(batch_size);
griffin_y.OverrideRows(batch_size);
griffin_gate_x.OverrideRows(batch_size);
griffin_multiplier.OverrideRows(batch_size);
}
// Attention
RowVectorBatch<float> pre_att_rms_out;
RowVectorBatch<float> att; // attention vector
RowVectorBatch<float> att_out; // attention output
MatStorageT<float> griffin_x;
MatStorageT<float> griffin_y;
MatStorageT<float> griffin_gate_x;
MatStorageT<float> griffin_multiplier;
};
struct AttentionActivations {
// Returns the scale value to use for the query in the attention computation.
// Also called by ops_test.
static inline float ChooseQueryScale(const ModelConfig& config) {
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
return 1.0f / sqrtf(static_cast<float>(config.model_dim /
config.layer_configs[0].heads));
// QueryScaleType::SqrtKeySize
return 1.0f / sqrtf(static_cast<float>(config.layer_configs[0].qkv_dim));
}
AttentionActivations(
const ModelConfig& config, const LayerConfig& layer_config,
size_t batch_size, size_t seq_len, const Allocator& allocator,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: config(config),
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
// and does not use an external KV cache.
q(MatFactory("q", batch_size,
config.vocab_size == 0
? layer_config.heads * 3 * layer_config.qkv_dim
: layer_config.heads * layer_config.qkv_dim,
allocator)),
pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size,
config.model_dim, allocator)),
att(MatFactory("att", batch_size, layer_config.heads * seq_len,
allocator)),
att_out(MatFactory("att_out", batch_size,
layer_config.heads * layer_config.qkv_dim,
allocator)),
att_sums(
MatFactory("att_sums", batch_size, config.model_dim, allocator)),
inv_timescale(
CreateInvTimescale(allocator, layer_config.qkv_dim,
layer_config.post_qk == PostQKType::HalfRope)),
inv_timescale_global(CreateInvTimescale(
allocator, layer_config.qkv_dim,
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)),
div_seq_len(static_cast<uint32_t>(seq_len)),
div_heads(static_cast<uint32_t>(layer_config.heads)),
query_scale(ChooseQueryScale(config)) {
// Batch size can be 0 in experimental code so do not assert.
if (batch_size == 0) {
static std::atomic_flag warned = ATOMIC_FLAG_INIT;
if (!warned.test_and_set()) {
HWY_WARN("Creating mostly empty activations with a batch_size of 0.");
}
return;
}
// For MatMul outputs, precompute their row pointers.
// If we forget any MatMul outputs here, debug builds print a warning but
// fill them in each MatMul call.
q.AllocateAndAttachRowPtrs(row_ptrs);
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
}
void SetBatchSize(size_t batch_size) {
q.OverrideRows(batch_size);
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size);
}
const ModelConfig& config;
MatStorageT<float> q; // query
MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector
MatStorageT<float> att_out; // attention output
// Accumulation of attention outputs over heads
RowVectorBatch<float> att_sums;
// Gated FFW
RowVectorBatch<BF16> bf_pre_ffw_rms_out;
RowVectorBatch<float> C1;
RowVectorBatch<float> C2;
RowVectorBatch<float> ffw_out;
// Griffin
RowVectorBatch<float> griffin_x;
RowVectorBatch<float> griffin_y;
RowVectorBatch<float> griffin_gate_x;
RowVectorBatch<float> griffin_multiplier;
MatStorageT<BF16> att_sums;
// Rope
RowVectorBatch<float> inv_timescale;
RowVectorBatch<float> inv_timescale_global;
MatStorageT<float> inv_timescale;
MatStorageT<float> inv_timescale_global;
// Dynamic because no default ctor and only initialized in `Allocate`.
MatMulEnv* env;
hwy::Divisor div_seq_len;
// Unfortunately, some models (Griffin) have non-power-of-two heads.
hwy::Divisor div_heads;
float query_scale;
};
PostQKType post_qk = PostQKType::Rope;
// And the config.
const ModelConfig& weights_config;
const LayerConfig& layer_config;
size_t seq_len;
size_t cache_pos_size = 0;
struct Activations {
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len,
const Allocator& allocator,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: layer_config(config.layer_configs[0]),
void Allocate(size_t batch_size, MatMulEnv* env) {
post_qk = layer_config.post_qk;
const size_t model_dim = weights_config.model_dim;
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
const size_t vocab_size = weights_config.vocab_size;
const size_t qkv_dim = layer_config.qkv_dim;
const size_t heads = layer_config.heads;
x(MatFactory("x", batch_size, config.model_dim, allocator)),
logits(MatFactory("logits", batch_size, config.vocab_size, allocator)),
x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
q = RowVectorBatch<float>(
Extents2D(batch_size, heads * layer_config.QStride()));
if (vocab_size > 0) {
logits = RowVectorBatch<float>(Extents2D(batch_size, vocab_size));
}
pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size,
config.model_dim, allocator)),
C1(MatFactory("C1", batch_size, layer_config.ff_hidden_dim, allocator)),
C2(MatFactory("C2", batch_size, layer_config.ff_hidden_dim, allocator)),
ffw_out(MatFactory("ffw_out", batch_size, config.model_dim, allocator)),
pre_att_rms_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
att = RowVectorBatch<float>(
Extents2D(batch_size, heads * weights_config.seq_len));
att_out = RowVectorBatch<float>(Extents2D(batch_size, heads * qkv_dim));
att_sums = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
attention(config, layer_config, batch_size, seq_len, allocator,
row_ptrs),
griffin(config, config.model == Model::GRIFFIN_2B ? batch_size : 0,
allocator) {
HWY_ASSERT(batch_size != 0);
bf_pre_ffw_rms_out = RowVectorBatch<BF16>(Extents2D(batch_size, model_dim));
C1 = RowVectorBatch<float>(Extents2D(batch_size, ff_hidden_dim));
C2 = RowVectorBatch<float>(Extents2D(batch_size, ff_hidden_dim));
ffw_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
// For MatMul outputs, precompute their row pointers.
// If we forget any MatMul outputs here, debug builds print a warning but
// fill them in each MatMul call.
x.AllocateAndAttachRowPtrs(row_ptrs);
logits.AllocateAndAttachRowPtrs(row_ptrs);
C1.AllocateAndAttachRowPtrs(row_ptrs);
C2.AllocateAndAttachRowPtrs(row_ptrs);
ffw_out.AllocateAndAttachRowPtrs(row_ptrs);
if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) {
griffin_x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
griffin_y = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
griffin_gate_x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
griffin_multiplier =
RowVectorBatch<float>(Extents2D(batch_size, model_dim));
}
inv_timescale = CreateInvTimescale(layer_config.qkv_dim,
post_qk == PostQKType::HalfRope);
inv_timescale_global =
CreateInvTimescale(qkv_dim, post_qk == PostQKType::HalfRope, 1000000.0);
this->env = env;
// Note that BindC on any MatMul output considerably slows down Prefill.
}
// Negligible CPU time.
void SetBatchSize(size_t batch_size) {
x.OverrideRows(batch_size);
logits.OverrideRows(batch_size);
pre_ffw_rms_out.OverrideRows(batch_size);
C1.OverrideRows(batch_size);
C2.OverrideRows(batch_size);
ffw_out.OverrideRows(batch_size);
attention.SetBatchSize(batch_size);
griffin.SetBatchSize(batch_size);
}
const LayerConfig& layer_config;
MatStorageT<float> x; // input
MatStorageT<float> logits;
// Gated FFW
MatStorageT<BF16> pre_ffw_rms_out;
// Norm may be large, so prefer to keep as f32.
MatStorageT<float> C1;
MatStorageT<float> C2;
MatStorageT<BF16> ffw_out;
AttentionActivations attention;
GriffinActivations griffin;
};
} // namespace gcpp

358
gemma/attention.cc Normal file
View File

@ -0,0 +1,358 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stddef.h>
#include <stdint.h>
#include <vector>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
#include "gemma/activations.h"
#include "gemma/configs.h" // kMaxQKVDim
#include "gemma/gemma.h"
#include "gemma/weights.h"
#include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/attention.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
#include "ops/ops-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
// Computes Q.K scores, which are "logits" (or scores) stored to att.
// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
const hwy::Divisor& div_seq_len,
const float* HWY_RESTRICT q,
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
const size_t worker) {
PROFILER_ZONE2(worker, "Gen.Attention.QDotK");
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const float score = Dot(q, k.Row(pos), k.Cols());
att[pos] = score;
}
} else {
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t pos_modulo = div_seq_len.Remainder(pos);
const float score = Dot(q, k.Row(pos_modulo), k.Cols());
att[pos_modulo] = score;
}
}
}
static void PositionalEncodingQK(float* qk, const size_t layer_idx,
const LayerWeightsPtrs& layer,
const AttentionActivations& activations,
const size_t worker, const size_t pos,
const float mul = 1.0f) {
const size_t qkv_dim = layer.layer_config.qkv_dim;
const PostQKType& post_qk = layer.layer_config.post_qk;
// qk is either q or k, so qkv_dim is the length we operate on.
const float* inv_timescale = activations.inv_timescale.PackedScale1();
const bool is_global_layer = activations.config.IsGlobalLayer(layer_idx);
// TODO: add a config flag instead of hardcoding the model.
if (is_global_layer && IsVLM(activations.config.model)) {
inv_timescale = activations.inv_timescale_global.PackedScale1();
}
// PostQKType::Rope
if (post_qk == PostQKType::HalfRope) {
Rope(qk, qkv_dim / 2, inv_timescale, pos, worker);
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, worker);
} else {
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, worker);
}
}
// Accumulates the sum of v (from `kv_cache`) * probability (`att`) into
// `att_out`. Equivalent in gemma/modules.py:
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
static HWY_INLINE void WeightedSumV(
const size_t start_pos, const size_t last_pos,
const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att,
const MatPtrT<KV_t>& v, float* HWY_RESTRICT att_out, const size_t worker) {
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
// we supported non-transposed B.
// TODO: 2..4x unroll
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), worker);
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), worker);
}
} else {
{
const size_t pos_mod = div_seq_len.Remainder(start_pos);
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker);
}
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = div_seq_len.Remainder(pos);
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker);
}
}
}
// Calculates the attention outputs for a single q, which may be updated
// in place for RMSNorm.
void SingleDotSoftmaxWeightedSum(
const size_t pos, const size_t start_pos, const size_t last_pos,
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
const size_t layer_idx, const LayerWeightsPtrs& layer,
const AttentionActivations& activations, float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out, const size_t worker) {
const float att_cap = activations.config.att_cap;
const float query_scale = activations.query_scale;
const size_t seq_len =
static_cast<size_t>(activations.div_seq_len.GetDivisor());
// Apply rope and scaling to Q.
if (layer.query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), 0, q,
layer.layer_config.qkv_dim, worker);
});
}
PositionalEncodingQK(q, layer_idx, layer, activations, worker, pos,
query_scale);
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, worker);
// SoftMax with optional SoftCap yields "probabilities" in att.
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
MaybeLogitsSoftCap(att_cap, att, att_len, worker);
Softmax(att, att_len, worker, /*temperature=*/1.0f);
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
worker);
}
// The attention window usually starts at 0 unless `pos` is larger than
// the attention window size, then it is `pos` - window_size + 1.
static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config,
size_t layer_idx) {
const size_t att_window_size = config.attention_window_sizes[layer_idx];
return pos - HWY_MIN(att_window_size - 1, pos);
}
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch,
NestedPools& pools) {
static const uint32_t HWY_MAYBE_UNUSED zone_id_par =
PROFILER_ADD_ZONE("Gen.Attention.DotSoftmax.par");
const hwy::Divisor div_qbatch(qbatch.Size());
const LayerConfig& layer_config = layer.layer_config;
const size_t qkv_dim = layer_config.qkv_dim;
// A "head group" in the context of GQA refers to a collection of query
// heads that share the same key and value heads.
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
const size_t cache_layer_size = layer_config.CacheLayerSize();
const size_t seq_len =
static_cast<size_t>(activations.div_seq_len.GetDivisor());
// All layers should have the same number of heads.
HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads);
// For each head/token/query, compute Q.K, softmax, and weighted V.
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
const size_t tq_idx = activations.div_heads.Divide(task);
const size_t head = activations.div_heads.Remainder(task);
#if PROFILER_ENABLED
const hwy::Zone zone(worker, zone_id_par);
#endif
const size_t qi = div_qbatch.Remainder(tq_idx);
const size_t batch_idx = div_qbatch.Divide(tq_idx);
auto& kv_cache = qbatch.KV(qi).kv_cache;
// Find the token position in the query and calculate
// the range of cache positions to attend to.
const size_t pos = qbatch.Pos(qi) + batch_idx;
const size_t start_pos = StartPos(pos, activations.config, layer_idx);
size_t last_pos = pos;
const size_t prefix_end = qbatch.PrefixEnd(qi);
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
// last_pos in QDotK and WeightedSumV is inclusive.
last_pos = prefix_end - 1;
}
float* HWY_RESTRICT q = activations.q.Row(tq_idx) + head * qkv_dim;
float* HWY_RESTRICT att = activations.att.Row(tq_idx) + head * seq_len;
float* HWY_RESTRICT att_out =
activations.att_out.Row(tq_idx) + head * qkv_dim;
// Make strided read-only views into the kv cache for
// this query and head.
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
const size_t kv_head_offset = layer_idx * cache_layer_size + head_offset;
MatPtrT<KV_t> k("k_view", Extents2D(seq_len, qkv_dim));
k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride());
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
layer, activations, att, att_out, worker);
};
{
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
const size_t pkg_idx = 0;
// Full parallelism is helpful, SmallParallelFor is insufficient.
ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads,
pools, pkg_idx, func);
}
}
// Different functions use different naming conventions for the number of
// tokens. Functions that are query-independent, such as RMSNorm*, call the
// count `num_interleaved`. Functions that are query-dependent, such as
// `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the
// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size.
// Fills activations.q and writes to KV cache.
static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer,
AttentionActivations& activations,
const QBatch& qbatch, const int flags,
MatMulEnv& env) {
PROFILER_ZONE("Gen.Attention.QKV");
const hwy::Divisor div_qbatch(qbatch.Size());
const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor();
const LayerConfig& layer_config = layer.layer_config;
const size_t qkv_dim = layer_config.qkv_dim;
const size_t kv_heads = layer_config.kv_heads;
const size_t cache_layer_size = layer_config.CacheLayerSize();
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim,
// model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows.
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w1,
/*add=*/nullptr, env, activations.q);
// Set up MatMul row pointers for writing to KV, which consists of
// `kv_heads` pairs of (k, v) vectors. This safely handles wraparound
// because rows are computed modulo seq_len.
MatPtrT<KV_t> kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(),
layer.qkv_einsum_w2.Rows()));
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t cache_pos =
activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx);
env.row_ptrs[2][interleaved_idx] = reinterpret_cast<uint8_t*>(
qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size);
}
kv_rows.AttachRowPtrs(env.row_ptrs[2].get());
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
/*add=*/nullptr, env, kv_rows);
// Apply positional encodings for K.
// Note that 2D parallelism is not worth the fork/join overhead because the
// tasks are very lightweight.
env.ctx.pools.Pool(0).Run(
0, kv_heads * num_interleaved,
[&](uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kv_heads;
const size_t interleaved_idx = task / kv_heads;
const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t pos = qbatch.Pos(qi) + batch_idx;
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
auto& kv_cache = qbatch.KV(qi).kv_cache;
KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
layer_idx * cache_layer_size +
head * qkv_dim * 2;
HWY_ALIGN float kv_f32[2 * kMaxQKVDim];
const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, MakeSpan(kv, 2 * qkv_dim), 0, kv_f32,
2 * qkv_dim);
// Apply further processing to K.
if (layer.key_norm_scale.HasPtr()) {
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim,
thread);
});
}
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, thread,
pos);
CompressPerThread tls;
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
});
}
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
// head_dim (`qkv_dim`) into output (`layer_out`).
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
AttentionActivations& activations,
MatMulEnv& env) {
PROFILER_ZONE("Gen.Attention.SumHeads");
const LayerConfig& layer_config = layer.layer_config;
// att_weights and att_out are concatenated heads, each of length
// layer_config.qkv_dim. Thus the [num_interleaved,
// layer_config.model_dim] matmul output is the sum over heads. Compare
// gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD',
// encoded)
HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 &&
layer_config.qkv_dim != 0);
const float* add = layer_config.softmax_attn_output_biases
? layer.attention_output_biases.PackedScale1()
: nullptr;
CallMatMul(activations.att_out, layer.att_weights, add, env,
activations.att_sums);
}
void GemmaAttention(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch,
MatMulEnv& env, int flags) {
const LayerConfig& layer_config = layer.layer_config;
HWY_DASSERT(!layer_config.IsMHA()); // No longer supported.
HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0,
"query heads must be a multiple of key-value heads");
(void)layer_config; // only used in HWY_DASSERT
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
env.ctx.pools);
SumHeads(layer, activations, env);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();

59
gemma/attention.h Normal file
View File

@ -0,0 +1,59 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ATTENTION_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ATTENTION_H_
// Declares GemmaAttention for all SIMD targets.
#include <stddef.h>
#include "gemma/gemma.h"
#include "hwy/highway.h"
namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void SingleDotSoftmaxWeightedSum( \
const size_t pos, const size_t start_pos, const size_t last_pos, \
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out, size_t worker); \
\
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
const LayerWeightsPtrs& layer, \
AttentionActivations& activations, \
QBatch& qbatch, NestedPools& pools); \
\
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
const LayerWeightsPtrs& layer, \
AttentionActivations& activations, QBatch& qbatch, \
MatMulEnv& env, int flags); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE
// Function declarations for each SIMD target. Allows direct call from the
// per-target namespace. We may later replace this with dynamic dispatch if
// the overhead is acceptable.
HWY_VISIT_TARGETS(GEMMA_DECL_ATTENTION)
#undef GEMMA_DECL_ATTENTION
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ATTENTION_H_

View File

@ -0,0 +1,473 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
using System;
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Text;
namespace GemmaCpp
{
public class GemmaException : Exception
{
public GemmaException(string message) : base(message) { }
}
public class Gemma : IDisposable
{
private IntPtr _context;
private bool _disposed;
// Optional: Allow setting DLL path
public static string DllPath { get; set; } = "gemma.dll";
[DllImport("kernel32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
private static extern IntPtr LoadLibrary(string lpFileName);
static Gemma()
{
// Load DLL from specified path
if (LoadLibrary(DllPath) == IntPtr.Zero)
{
throw new DllNotFoundException($"Failed to load {DllPath}. Error: {Marshal.GetLastWin32Error()}");
}
}
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern IntPtr GemmaCreate(
[MarshalAs(UnmanagedType.LPUTF8Str)] string tokenizerPath,
[MarshalAs(UnmanagedType.LPUTF8Str)] string modelType,
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightsPath,
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightType,
int maxGeneratedTokens);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern void GemmaDestroy(IntPtr context);
// Delegate type for token callbacks
public delegate bool TokenCallback(string token);
// Keep delegate alive for duration of calls
private GCHandle _callbackHandle;
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
private delegate bool GemmaTokenCallback(
[MarshalAs(UnmanagedType.LPUTF8Str)] string text,
IntPtr userData);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern int GemmaGenerate(
IntPtr context,
[MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
[Out] byte[] output,
int maxOutputChars,
GemmaTokenCallback callback,
IntPtr userData);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern int GemmaGenerateMultimodal(
IntPtr context,
[MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
IntPtr image_data, // Renamed param to match C API
int image_width, // Added dimension
int image_height, // Added dimension
[MarshalAs(UnmanagedType.LPUTF8Str)] StringBuilder output, // Output should be StringBuilder for multimodal
int maxOutputChars,
GemmaTokenCallback callback,
IntPtr userData);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern int GemmaCountTokens(
IntPtr context,
[MarshalAs(UnmanagedType.LPUTF8Str)] string text);
// Configuration function imports
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern void GemmaSetMaxGeneratedTokens(IntPtr context, int value);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern void GemmaSetMultiturn(IntPtr context, int value);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern void GemmaSetTemperature(IntPtr context, float value);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern void GemmaSetTopK(IntPtr context, int value);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern void GemmaSetDeterministic(IntPtr context, int value);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern void GemmaSetPrefillTbatchSize(IntPtr context, int value);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaResetConversation")]
private static extern void GemmaResetConversation(IntPtr context);
// Conversation management function imports
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaCreateConversation")]
private static extern int GemmaCreateConversation(
IntPtr context,
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaSwitchConversation")]
private static extern int GemmaSwitchConversation(
IntPtr context,
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaDeleteConversation")]
private static extern int GemmaDeleteConversation(
IntPtr context,
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaHasConversation")]
private static extern int GemmaHasConversation(
IntPtr context,
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaGetCurrentConversation")]
[return: MarshalAs(UnmanagedType.LPUTF8Str)] // Marshal the const char* return value as a string
private static extern string GemmaGetCurrentConversation(IntPtr context);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaSaveConversation")]
private static extern void GemmaSaveConversation(IntPtr context);
// Native callback delegate type
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
private delegate void GemmaLogCallback(
[MarshalAs(UnmanagedType.LPUTF8Str)] string message,
IntPtr userData);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern void GemmaSetLogCallback(
IntPtr context,
GemmaLogCallback callback,
IntPtr userData);
private GCHandle _logCallbackHandle;
private bool _loggingEnabled = false;
public Gemma(string tokenizerPath, string weightsPath, int maxGeneratedTokens = 8192)
{
_context = GemmaCreate(tokenizerPath, weightsPath, maxGeneratedTokens);
if (_context == IntPtr.Zero)
{
throw new GemmaException("Failed to create Gemma context");
}
}
// Enable debug logging
public void EnableLogging(bool enable = true)
{
if (enable && !_loggingEnabled)
{
GemmaLogCallback logCallback = (message, _) =>
{
Debug.WriteLine($"Gemma: {message}");
};
_logCallbackHandle = GCHandle.Alloc(logCallback);
GemmaSetLogCallback(_context, logCallback, IntPtr.Zero);
_loggingEnabled = true;
}
else if (!enable && _loggingEnabled)
{
if (_logCallbackHandle.IsAllocated)
_logCallbackHandle.Free();
GemmaSetLogCallback(_context, null, IntPtr.Zero);
_loggingEnabled = false;
}
}
// Configuration methods
public void SetMultiturn(bool enable)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
GemmaSetMultiturn(_context, enable ? 1 : 0);
Debug.WriteLine($"Gemma: Set multiturn to {(enable ? "enabled" : "disabled")}");
}
public void SetTemperature(float temperature)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
GemmaSetTemperature(_context, temperature);
Debug.WriteLine($"Gemma: Set temperature to {temperature}");
}
public void SetTopK(int topK)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
GemmaSetTopK(_context, topK);
Debug.WriteLine($"Gemma: Set topK to {topK}");
}
public void SetDeterministic(bool deterministic)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
GemmaSetDeterministic(_context, deterministic ? 1 : 0);
Debug.WriteLine($"Gemma: Set deterministic to {(deterministic ? "true" : "false")}");
}
// Renamed public method
public void ResetConversation()
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
GemmaResetConversation(_context); // Call P/Invoke method
Debug.WriteLine("Gemma: Reset active conversation");
}
// Conversation management methods
public bool CreateConversation(string conversationName)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
bool result = GemmaCreateConversation(_context, conversationName) != 0; // Call P/Invoke method
Debug.WriteLine($"Gemma: Create conversation '{conversationName}' - {(result ? "succeeded" : "failed")}");
return result;
}
public bool SwitchConversation(string conversationName)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
bool result = GemmaSwitchConversation(_context, conversationName) != 0; // Call P/Invoke method
Debug.WriteLine($"Gemma: Switch to conversation '{conversationName}' - {(result ? "succeeded" : "failed")}");
return result;
}
public bool DeleteConversation(string conversationName)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
bool result = GemmaDeleteConversation(_context, conversationName) != 0; // Call P/Invoke method
Debug.WriteLine($"Gemma: Delete conversation '{conversationName}' - {(result ? "succeeded" : "failed")}");
return result;
}
public bool HasConversation(string conversationName)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
bool result = GemmaHasConversation(_context, conversationName) != 0; // Call P/Invoke method
Debug.WriteLine($"Gemma: Has conversation '{conversationName}' - {result}");
return result;
}
public string GetCurrentConversation()
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
string currentConversation = GemmaGetCurrentConversation(_context); // Call P/Invoke method
Debug.WriteLine($"Gemma: Current conversation is '{currentConversation}'");
return currentConversation;
}
public void SaveConversation()
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
GemmaSaveConversation(_context);
Debug.WriteLine($"Gemma: Saved current conversation ('{GetCurrentConversation()}') to prewarmed cache.");
}
public int CountTokens(string prompt)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
int count = GemmaCountTokens(_context, prompt);
return count;
}
public string Generate(string prompt, int maxOutputChars = 4096)
{
return Generate(prompt, null, maxOutputChars);
}
public string Generate(string prompt, TokenCallback callback, int maxOutputChars = 4096)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
var outputBuffer = new byte[maxOutputChars * 4]; // Allow for worst case UTF-8 size
GemmaTokenCallback nativeCallback = null;
// Track token count for debugging
int tokenCount = 0;
if (callback != null)
{
nativeCallback = (text, _) =>
{
tokenCount++;
// Log token for debugging
Debug.WriteLine($"Token {tokenCount}: '{text}'");
// Pass token to user callback
return callback(text);
};
_callbackHandle = GCHandle.Alloc(nativeCallback);
}
try
{
int length = GemmaGenerate(_context, prompt, outputBuffer, maxOutputChars,
nativeCallback, IntPtr.Zero);
if (length < 0)
throw new GemmaException("Generation failed");
Debug.WriteLine($"Generation complete: {tokenCount} tokens processed, result length: {length}");
// Convert the byte buffer to a string using UTF-8 encoding
string result = Encoding.UTF8.GetString(outputBuffer, 0, length);
return result;
}
finally
{
if (_callbackHandle.IsAllocated)
_callbackHandle.Free();
}
}
public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, int maxOutputChars = 4096)
{
// Pass width and height to the overloaded method
return GenerateMultimodal(prompt, imageData, imageWidth, imageHeight, null, maxOutputChars);
}
public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, TokenCallback callback, int maxOutputChars = 4096)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
if (imageData == null || imageData.Length == 0)
throw new ArgumentException("Image data cannot be null or empty", nameof(imageData));
if (imageWidth <= 0 || imageHeight <= 0)
throw new ArgumentException("Image dimensions must be positive");
if (imageData.Length < imageWidth * imageHeight * 3)
throw new ArgumentException("Image data array is too small for the specified dimensions");
var output = new StringBuilder(maxOutputChars);
GemmaTokenCallback nativeCallback = null;
if (callback != null)
{
nativeCallback = (text, _) => callback(text);
_callbackHandle = GCHandle.Alloc(nativeCallback);
}
// Pin the image data so it doesn't move during the native call
GCHandle imageHandle = GCHandle.Alloc(imageData, GCHandleType.Pinned);
try
{
IntPtr imagePtr = imageHandle.AddrOfPinnedObject();
// Pass image dimensions to the native call
int length = GemmaGenerateMultimodal(_context, prompt, imagePtr, imageWidth, imageHeight, output, maxOutputChars,
nativeCallback, IntPtr.Zero);
if (length < 0)
throw new GemmaException("Multimodal generation failed");
return output.ToString();
}
finally
{
imageHandle.Free();
if (_callbackHandle.IsAllocated)
_callbackHandle.Free();
}
}
public void Dispose()
{
if (!_disposed)
{
if (_context != IntPtr.Zero)
{
GemmaDestroy(_context);
_context = IntPtr.Zero;
}
if (_logCallbackHandle.IsAllocated)
_logCallbackHandle.Free();
_disposed = true;
}
}
~Gemma()
{
Dispose();
}
}
}

139
gemma/bindings/c_api.cc Normal file
View File

@ -0,0 +1,139 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef GEMMA_EXPORTS
#define GEMMA_EXPORTS
#endif
#include "gemma/bindings/c_api.h"
extern "C" {
GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
const char* weights_path,
int max_generated_tokens) {
try {
GemmaContext* ctx = GemmaContext::Create(tokenizer_path, weights_path,
max_generated_tokens);
return ctx;
} catch (...) {
return nullptr;
}
}
GEMMA_API void GemmaDestroy(GemmaContext* ctx) {
delete ctx;
}
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
int max_output_chars, GemmaTokenCallback callback,
void* user_data) {
if (!ctx) return -1;
return ctx->Generate(prompt, output, max_output_chars, callback, user_data);
}
GEMMA_API int GemmaGenerateMultimodal(GemmaContext* ctx, const char* prompt,
const void* image_data, int image_width,
int image_height, char* output,
int max_output_chars,
GemmaTokenCallback callback,
void* user_data) {
if (!ctx) return -1;
return ctx->GenerateMultimodal(prompt, image_data, image_width, image_height,
output, max_output_chars, callback, user_data);
}
GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text) {
if (!ctx || !text) return -1;
return ctx->CountTokens(text);
}
GEMMA_API void GemmaSetLogCallback(GemmaContext* ctx, GemmaLogCallback callback,
void* user_data) {
if (!ctx) return;
ctx->SetLogCallback(callback, user_data);
}
// Configuration functions implementation
GEMMA_API void GemmaSetMaxGeneratedTokens(GemmaContext* ctx, int value) {
if (!ctx) return;
ctx->SetMaxGeneratedTokens(value);
}
GEMMA_API void GemmaSetMultiturn(GemmaContext* ctx, int value) {
if (!ctx) return;
ctx->SetMultiturn(value);
}
GEMMA_API void GemmaSetTemperature(GemmaContext* ctx, float value) {
if (!ctx) return;
ctx->SetTemperature(value);
}
GEMMA_API void GemmaSetTopK(GemmaContext* ctx, int value) {
if (!ctx) return;
ctx->SetTopK(value);
}
GEMMA_API void GemmaSetDeterministic(GemmaContext* ctx, int value) {
if (!ctx) return;
ctx->SetDeterministic(value != 0);
}
GEMMA_API void GemmaSetPrefillTbatchSize(GemmaContext* ctx, int value) {
if (!ctx) return;
ctx->SetPrefillTbatchSize(value);
}
GEMMA_API void GemmaResetConversation(GemmaContext* ctx) { // Renamed function
if (!ctx) return;
ctx->ResetConversation();
}
GEMMA_API int GemmaCreateConversation(GemmaContext* ctx,
const char* conversation_name) {
if (!ctx || !conversation_name) return 0;
return ctx->CreateConversation(conversation_name) ? 1 : 0;
}
GEMMA_API int GemmaSwitchConversation(GemmaContext* ctx,
const char* conversation_name) {
if (!ctx || !conversation_name) return 0;
return ctx->SwitchConversation(conversation_name) ? 1 : 0;
}
GEMMA_API int GemmaDeleteConversation(GemmaContext* ctx,
const char* conversation_name) {
if (!ctx || !conversation_name) return 0;
return ctx->DeleteConversation(conversation_name) ? 1 : 0;
}
GEMMA_API int GemmaHasConversation(GemmaContext* ctx,
const char* conversation_name) {
if (!ctx || !conversation_name) return 0;
return ctx->HasConversation(conversation_name) ? 1 : 0;
}
GEMMA_API const char* GemmaGetCurrentConversation(GemmaContext* ctx) {
if (!ctx) return nullptr;
return ctx->GetCurrentConversation();
}
GEMMA_API void GemmaSaveConversation(GemmaContext* ctx) {
if (!ctx) return;
ctx->SaveConversation();
}
}

86
gemma/bindings/c_api.h Normal file
View File

@ -0,0 +1,86 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_C_API_H_
#define THIRD_PARTY_GEMMA_C_API_H_
#include "gemma/bindings/context.h"
#ifdef _WIN32
#ifdef GEMMA_EXPORTS
#define GEMMA_API __declspec(dllexport)
#else
#define GEMMA_API __declspec(dllimport)
#endif
#else
#define GEMMA_API __attribute__((visibility("default")))
#endif
#ifdef __cplusplus
extern "C" {
#endif
#ifdef __cplusplus
typedef gcpp::GemmaContext GemmaContext;
#else
typedef struct GemmaContext GemmaContext;
#endif
typedef bool (*GemmaTokenCallback)(const char* text, void* user_data);
typedef void (*GemmaLogCallback)(const char* message, void* user_data);
GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
const char* weights_path,
int max_generated_tokens);
GEMMA_API void GemmaDestroy(GemmaContext* ctx);
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
int max_output_chars, GemmaTokenCallback callback,
void* user_data);
GEMMA_API int GemmaGenerateMultimodal(GemmaContext* ctx, const char* prompt,
const void* image_data, int image_width,
int image_height, char* output,
int max_output_chars,
GemmaTokenCallback callback,
void* user_data);
GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text);
GEMMA_API void GemmaSetLogCallback(GemmaContext* ctx, GemmaLogCallback callback,
void* user_data);
// Configuration functions
GEMMA_API void GemmaSetMultiturn(GemmaContext* ctx, int value);
GEMMA_API void GemmaSetTemperature(GemmaContext* ctx, float value);
GEMMA_API void GemmaSetTopK(GemmaContext* ctx, int value);
GEMMA_API void GemmaSetDeterministic(GemmaContext* ctx, int value);
GEMMA_API void GemmaResetConversation(GemmaContext* ctx);
// Conversation management functions (renamed)
GEMMA_API int GemmaCreateConversation(GemmaContext* ctx,
const char* conversation_name);
GEMMA_API int GemmaSwitchConversation(GemmaContext* ctx,
const char* conversation_name);
GEMMA_API int GemmaDeleteConversation(GemmaContext* ctx,
const char* conversation_name);
GEMMA_API int GemmaHasConversation(GemmaContext* ctx,
const char* conversation_name);
GEMMA_API const char* GemmaGetCurrentConversation(GemmaContext* ctx);
GEMMA_API void GemmaSaveConversation(GemmaContext* ctx);
#ifdef __cplusplus
}
#endif
#endif // THIRD_PARTY_GEMMA_C_API_H_

350
gemma/bindings/context.cc Normal file
View File

@ -0,0 +1,350 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/bindings/context.h"
#include <stddef.h>
#include <string.h> // strncpy
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "evals/benchmark_helper.h" // InitGenerator
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/tokenizer.h" // WrapAndTokenize
#include "util/threading.h"
#include "util/threading_context.h"
#include "hwy/profiler.h"
#include "hwy/timer.h"
#ifdef _WIN32
#include <Windows.h>
#endif
#include "gemma/kv_cache.h"
#include "paligemma/image.h"
namespace gcpp {
// ConversationData constructor implementation
ConversationData::ConversationData(const ModelConfig& model_config,
const InferenceArgs& inference_args,
const Allocator& allocator)
: kv_cache(
std::make_unique<KVCache>(model_config, inference_args, allocator)),
abs_pos(0) {}
// ConversationData copy constructor implementation
ConversationData::ConversationData(const ConversationData& other)
: kv_cache(nullptr), abs_pos(other.abs_pos) {
if (other.kv_cache) {
kv_cache = std::make_unique<KVCache>(other.kv_cache->Copy());
}
}
// Initialize static members
GemmaLogCallback GemmaContext::s_log_callback = nullptr;
void* GemmaContext::s_log_user_data = nullptr;
GemmaContext* GemmaContext::Create(const char* tokenizer_path,
const char* weights_path,
int max_generated_tokens) {
std::stringstream ss;
ss << "Creating GemmaContext with tokenizer_path: "
<< (tokenizer_path ? tokenizer_path : "null")
<< ", weights_path: " << (weights_path ? weights_path : "null")
<< ", max_generated_tokens: " << max_generated_tokens;
LogDebug(ss.str().c_str());
ThreadingArgs threading_args;
threading_args.spin = gcpp::Tristate::kFalse;
LoaderArgs loader(tokenizer_path, weights_path);
LogDebug("LoaderArgs created");
// Initialize cached args
LogDebug("Initializing inference args");
InferenceArgs inference_args;
inference_args.Init();
inference_args.max_generated_tokens = max_generated_tokens;
inference_args.temperature = 0.7f;
inference_args.top_k = 1;
inference_args.deterministic = false;
ss.str("");
ss << "Inference args initialized with max_tokens: " << max_generated_tokens
<< ", temperature: " << inference_args.temperature
<< ", top_k: " << inference_args.top_k << ", deterministic: "
<< (inference_args.deterministic ? "true" : "false");
LogDebug(ss.str().c_str());
return new GemmaContext(loader, inference_args, threading_args,
max_generated_tokens);
}
GemmaContext::GemmaContext(const LoaderArgs& loader,
const InferenceArgs& inference_args,
const ThreadingArgs& threading_args,
int max_generated_tokens)
: inference_args(inference_args),
threading_args(threading_args),
ctx(UpdateArgs(threading_args, inference_args)),
matmul_env(ctx),
active_conversation_name("default"),
model(loader, inference_args, matmul_env.ctx) {
std::stringstream ss;
LogDebug("Creating initial ConversationData");
// Create the initial ConversationData object using make_shared
active_conversation = std::make_shared<ConversationData>(
model.Config(), inference_args, ctx.allocator);
LogDebug(
"Storing initial ConversationData in conversation_cache[\"default\"]");
// Store the shared_ptr in the map under the "default" key
conversation_cache["default"] = active_conversation;
LogDebug("GemmaContext constructor completed");
}
// Internal implementation shared by Generate and GenerateMultimodal
int GemmaContext::GenerateInternal(const char* prompt_string,
const void* image_data, int image_width,
int image_height, char* output,
int max_output_chars,
GemmaTokenCallback callback,
void* user_data) {
PROFILER_ZONE("Gen.Internal");
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
size_t prompt_size = 0;
std::stringstream ss;
result_buffer.clear();
InitGenerator(inference_args, gen);
// Ensure we have an active conversation
if (!active_conversation || !active_conversation->kv_cache) {
LogDebug("Generate called with null active_conversation or kv_cache");
return -1;
}
// callback function invoked for each generated token.
auto stream_token = [&, callback, user_data](int token, float) {
// Use abs_pos from the active conversation
++(active_conversation->abs_pos);
const bool in_prompt = tokens_generated_this_turn < prompt_size;
const bool first_response_token = tokens_generated_this_turn == prompt_size;
++tokens_generated_this_turn;
if (in_prompt || model.Config().IsEOS(token)) {
return true;
}
std::string token_text;
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
if (first_response_token) {
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
}
// if we have a managed callback, pass it the token text
if (callback) {
if (!callback(token_text.c_str(), user_data)) {
LogDebug("Callback returned false, stopping generation");
return false;
}
}
result_buffer.append(token_text);
return true;
};
// set up runtime config
TimingInfo timing_info = {};
RuntimeConfig runtime_config = {.gen = &gen,
.stream_token = stream_token,
.use_spinning = threading_args.spin};
inference_args.CopyTo(runtime_config);
size_t prefix_end = 0;
const ModelConfig& model_config = model.Config();
// generate
std::vector<int> prompt;
const size_t pool_dim = model_config.vit_config.pool_dim;
ImageTokens image_tokens(
"image_tokens",
image_data
? Extents2D(model_config.vit_config.seq_len / (pool_dim * pool_dim),
model_config.model_dim)
: Extents2D(0, 0),
ctx.allocator, MatPadding::kOdd);
if (image_data != nullptr) {
HWY_ASSERT(model_config.wrapping == PromptWrapping::PALIGEMMA ||
model_config.wrapping == PromptWrapping::GEMMA_VLM);
Image image;
image.Set(image_width, image_height, static_cast<const float*>(image_data));
// We may need to resize the supplied image depending on whether we're using
// PaliGemma or Gemma 3.
const size_t image_size = model_config.vit_config.image_size;
image.Resize(image_size, image_size);
// Use the existing runtime_config defined earlier in the function.
// RuntimeConfig runtime_config = { ... }; // This was already defined
double image_tokens_start = hwy::platform::Now();
// Pass the populated image object to GenerateImageTokens
model.GenerateImageTokens(runtime_config,
active_conversation->kv_cache->SeqLen(), image,
image_tokens, matmul_env);
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
ss.str("");
ss << "\n\n[ Timing info ] Image token generation took: ";
ss << static_cast<int>(image_tokens_duration * 1000) << " ms\n",
LogDebug(ss.str().c_str());
prompt = WrapAndTokenize(
model.Tokenizer(), model.ChatTemplate(), model_config.wrapping,
active_conversation->abs_pos, prompt_string, image_tokens.Rows());
runtime_config.image_tokens = &image_tokens;
prompt_size = prompt.size();
// The end of the prefix for prefix-LM style attention in Paligemma.
// See Figure 2 of https://arxiv.org/abs/2407.07726.
prefix_end = prompt_size;
} else {
// Text-only case (original logic)
// Use abs_pos from the active conversation
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
model_config.wrapping,
active_conversation->abs_pos, prompt_string);
prompt_size = prompt.size();
}
// Check if prompt generation failed (e.g., multimodal not implemented yet)
if (prompt.empty() && image_data != nullptr) {
// Already logged the error, just ensure we don't proceed.
return -1;
}
// Create a span from the prompt vector - Generate() expects a hwy::Span,
// which has a different memory footprint to that of a std::vector.
hwy::Span<const int> prompt_span(prompt.data(), prompt.size());
// Pass the KVCache object by reference from the active conversation
model.Generate(runtime_config, prompt_span, active_conversation->abs_pos,
prefix_end, *active_conversation->kv_cache, matmul_env,
timing_info);
// prepare for next turn
if (!inference_args.multiturn ||
model_config.wrapping == PromptWrapping::PALIGEMMA) {
// If not multiturn, or Paligemma (which handles turns differently),
// reset the *active* conversation's position.
active_conversation->abs_pos = 0;
InitGenerator(inference_args, gen);
} else {
// Multi-turn Gemma: Rewind position in the active conversation
// The last token was either EOS, then it should be ignored because it is
// never part of the dialog, see Table 5 in the Gemma-2 paper:
// https://arxiv.org/pdf/2408.00118
// Or we have hit max_generated_tokens, then the last token will be lost.
// (We could store it in stream_token, and then prepend to the next turn,
// but it's not worth the complexity, as multi-turn with max_generated is
// not a common use case.)
// In either case, we need to rewind the active conversation's abs_pos by
// one.
HWY_ASSERT(active_conversation->abs_pos > 0);
active_conversation->abs_pos--;
}
// Copy result buffer to output C-string (ensure null termination)
strncpy(output, result_buffer.c_str(), max_output_chars - 1);
output[max_output_chars - 1] = '\0';
return static_cast<int>(strlen(output));
}
// Public Generate method (wrapper for text-only)
int GemmaContext::Generate(const char* prompt_string, char* output,
int max_output_chars, GemmaTokenCallback callback,
void* user_data) {
// Call the internal implementation with null image_data and 0 dimensions
return GenerateInternal(prompt_string, nullptr, 0, 0, output,
max_output_chars, callback, user_data);
}
// Public GenerateMultimodal method (wrapper)
int GemmaContext::GenerateMultimodal(const char* prompt_string,
const void* image_data, int image_width,
int image_height, char* output,
int max_output_chars,
GemmaTokenCallback callback,
void* user_data) {
if (image_data == nullptr) {
LogDebug(
"GenerateMultimodal called with null image_data. Use Generate for "
"text-only.");
// Or potentially call GenerateInternal with null image_data anyway?
// Returning error seems safer.
return -1;
}
return GenerateInternal(prompt_string, image_data, image_width, image_height,
output, max_output_chars, callback, user_data);
}
int GemmaContext::CountTokens(const char* text) {
LogDebug("CountTokens method started");
std::stringstream ss;
ss << "CountTokens called with text: '" << (text ? text : "null") << "'";
LogDebug(ss.str().c_str());
if (!text) {
LogDebug("CountTokens failed: Invalid parameters");
if (!text) LogDebug(" text is null");
return -1;
}
try {
LogDebug("Creating text string");
std::string text_str(text);
LogDebug("Creating tokens vector");
std::vector<int> tokens;
LogDebug("Encoding text to tokens");
HWY_ASSERT(model.Tokenizer().Encode(text_str, &tokens));
ss.str("");
ss << "Text tokenized into " << tokens.size() << " tokens";
LogDebug(ss.str().c_str());
LogDebug("CountTokens completed successfully");
return static_cast<int>(tokens.size());
} catch (...) {
LogDebug("Unknown exception in CountTokens");
return -1;
}
}
// Get the name of the currently active conversation
const char* GemmaContext::GetCurrentConversation() {
return active_conversation_name.c_str();
}
} // namespace gcpp

316
gemma/bindings/context.h Normal file
View File

@ -0,0 +1,316 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_
#include <memory> // For std::shared_ptr, std::make_shared
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
// Logging
#ifdef _WIN32
#include <windows.h>
#else
#include <stdio.h>
#endif
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/kv_cache.h"
#include "ops/matmul.h" // MatMulEnv
#include "hwy/base.h"
#include "hwy/highway.h"
namespace gcpp {
// Struct to hold data for a single conversation thread
struct ConversationData {
ConversationData(const ModelConfig& model_config,
const InferenceArgs& inference_args,
const Allocator& allocator);
ConversationData(const ConversationData& other);
std::unique_ptr<KVCache> kv_cache;
size_t abs_pos = 0;
};
typedef bool (*GemmaTokenCallback)(const char* text, void* user_data);
typedef void (*GemmaLogCallback)(const char* message, void* user_data);
class GemmaContext {
private:
GemmaContext(const LoaderArgs& loader, const InferenceArgs& inference_args,
const ThreadingArgs& threading_args, int max_generated_tokens);
public:
static GemmaContext* Create(const char* tokenizer_path,
const char* weights_path,
int max_generated_tokens);
// Returns length of generated text, or -1 on error
int Generate(const char* prompt_string, char* output, int max_output_chars,
GemmaTokenCallback callback, void* user_data);
// Returns length of generated text, or -1 on error
int GenerateMultimodal(const char* prompt_string, const void* image_data,
int image_width, int image_height, char* output,
int max_output_chars, GemmaTokenCallback callback,
void* user_data);
// Returns number of tokens in text, or -1 on error
int CountTokens(const char* text);
// Add new method to set logger
static void SetLogCallback(GemmaLogCallback callback, void* user_data) {
s_log_callback = callback;
s_log_user_data = user_data;
}
// Set max generated tokens
void SetMaxGeneratedTokens(size_t value) {
inference_args.max_generated_tokens = value;
LogDebug("Setting max_generated_tokens to configured value");
}
// Set multiturn flag (0 = disabled, 1 = enabled)
void SetMultiturn(int value) {
inference_args.multiturn = value;
LogDebug("Setting multiturn to configured value");
}
// Set temperature for token generation
void SetTemperature(float value) {
inference_args.temperature = value;
LogDebug("Setting temperature to configured value");
}
// Set top_k parameter for sampling
void SetTopK(int value) {
inference_args.top_k = value;
LogDebug("Setting top_k to configured value");
}
// Set deterministic flag
void SetDeterministic(bool value) {
inference_args.deterministic = value;
// Reset the random number generator for deterministic generation
if (value) {
gen.seed(0x87654321);
}
LogDebug("Setting deterministic flag to configured value");
}
// Set prefill_tbatch_size
void SetPrefillTbatchSize(size_t value) {
inference_args.prefill_tbatch_size = value;
LogDebug("Setting prefill_tbatch_size to configured value");
}
void SaveConversation() {
if (!active_conversation || active_conversation_name.empty()) {
if (!active_conversation) {
LogDebug("SaveConversation: No active conversation to save.");
} else { // active_conversation_name must be empty
LogDebug(
"SaveConversation: Active conversation name is empty. Cannot "
"save.");
}
return;
}
std::string log_msg = "SaveConversation: Attempting to save '";
log_msg += active_conversation_name;
log_msg += "' to prewarmed_cache.";
LogDebug(log_msg.c_str());
// Create a deep copy of the active_conversation via copy ctor.
auto conversation_copy =
std::make_shared<ConversationData>(*active_conversation);
// Store the deep copy in prewarmed_cache.
// If a conversation with the same name already exists, it will be
// overwritten. std::shared_ptr will handle the destruction of the old
// object if it's being replaced.
prewarmed_cache[active_conversation_name] = conversation_copy;
log_msg = "SaveConversation: Successfully saved '";
log_msg += active_conversation_name;
log_msg += "' to prewarmed_cache.";
LogDebug(log_msg.c_str());
}
// Reset the currently active conversation
void ResetConversation() {
if (active_conversation) {
std::string log_prefix = "ResetConversation ('";
log_prefix += active_conversation_name.empty() ? "[unnamed]"
: active_conversation_name;
log_prefix += "'): ";
LogDebug((log_prefix + "Attempting to reset.").c_str());
// Attempt to restore from prewarmed_cache first, regardless of name.
auto it = prewarmed_cache.find(active_conversation_name);
if (it != prewarmed_cache.end() && it->second && it->second->kv_cache) {
// Found in prewarmed_cache and the cached entry is valid.
LogDebug((log_prefix + "Found in prewarmed_cache. Restoring state.")
.c_str());
active_conversation->abs_pos = it->second->abs_pos;
// Perform a deep copy of the KVCache from the prewarmed version.
active_conversation->kv_cache =
std::make_unique<KVCache>(it->second->kv_cache->Copy());
LogDebug((log_prefix + "Successfully restored from prewarmed_cache.")
.c_str());
return;
}
// If not found in prewarmed_cache or prewarmed_cache entry is invalid,
// rewind to initial state.
active_conversation->abs_pos = 0;
// Replace the cache within the current ConversationData object
active_conversation->kv_cache = std::make_unique<KVCache>(
model.Config(), inference_args, ctx.allocator);
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
} else {
LogDebug("Cannot reset conversation: active_conversation is null");
}
}
// Create a new named conversation
bool CreateConversation(const char* conversation_name) {
std::string name(conversation_name);
if (conversation_cache.count(name)) {
LogDebug("Conversation already exists");
return false;
}
LogDebug("Creating new conversation");
// Create a new ConversationData object using make_shared
conversation_cache[name] = std::make_shared<ConversationData>(
model.Config(), inference_args, ctx.allocator);
return true;
}
// Switch to a named conversation
bool SwitchConversation(const char* conversation_name) {
std::string name(conversation_name);
auto it = conversation_cache.find(name);
if (it == conversation_cache.end()) {
LogDebug("Conversation not found");
return false;
}
LogDebug("Switching active conversation");
active_conversation = it->second;
active_conversation_name = conversation_name;
return true;
}
// Delete a named conversation
bool DeleteConversation(const char* conversation_name) {
std::string name(conversation_name);
auto it = conversation_cache.find(name);
if (it == conversation_cache.end()) {
LogDebug("Conversation not found for deletion");
return false;
}
if (name == "default") {
LogDebug("Cannot delete the default conversation");
return false;
}
if (it->second == active_conversation) {
LogDebug("Cannot delete the currently active conversation");
return false;
}
LogDebug("Deleting conversation");
conversation_cache.erase(it);
auto it2 = prewarmed_cache.find(name);
if (it2 != prewarmed_cache.end()) {
prewarmed_cache.erase(it2);
}
return true;
}
// Check if a named conversation exists
bool HasConversation(const char* conversation_name) {
std::string name(conversation_name);
return conversation_cache.count(name);
}
// Get the name of the currently active conversation
const char* GetCurrentConversation();
private:
// Internal implementation shared by Generate and GenerateMultimodal
int GenerateInternal(const char* prompt_string,
const void* image_data, // Null for text-only generation
int image_width,
int image_height,
char* output, int max_output_chars,
GemmaTokenCallback callback, void* user_data);
// Pointer to the currently active conversation's data
std::shared_ptr<ConversationData> active_conversation;
// Cache of all named conversations
std::unordered_map<std::string, std::shared_ptr<ConversationData>>
conversation_cache;
std::unordered_map<std::string, std::shared_ptr<ConversationData>>
prewarmed_cache;
// Buffers (potentially could be moved into ConversationData if needed
// per-conversation)
std::string prompt_buffer;
std::string result_buffer;
std::vector<int> token_buffer;
// Cached args (remain global for the context)
InferenceArgs inference_args;
ThreadingArgs threading_args;
ThreadingContext ctx;
MatMulEnv matmul_env;
std::string active_conversation_name;
// Model itself (don't move this, needs to be below the args above)
Gemma model;
// Random generator (remains global for the context)
std::mt19937 gen;
// Static members for logging
static GemmaLogCallback s_log_callback;
static void* s_log_user_data;
// Use logging helper method to print messages into a managed callback if
// necessary
static void LogDebug(const char* message) {
if (s_log_callback != nullptr) {
s_log_callback(message, s_log_user_data);
} else {
#ifdef _WIN32
OutputDebugStringA(message);
#else
printf("%s", message);
#endif
}
}
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_

View File

@ -1,177 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/common.h"
#include <math.h> // sqrtf
#include <stddef.h>
#include <string.h>
#include <algorithm> // std::transform
#include <cctype>
#include <string>
#include <vector>
#include "compression/shared.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
constexpr const char* kModelFlags[] = {
"2b-pt", "2b-it", // Gemma 2B
"7b-pt", "7b-it", // Gemma 7B
"gr2b-pt", "gr2b-it", // RecurrentGemma
"tiny", // Gemma Tiny (mostly for debugging)
"gemma2-2b-pt", "gemma2-2b-it", // Gemma2 2B
"9b-pt", "9b-it", // Gemma2 9B
"27b-pt", "27b-it", // Gemma2 27B
"paligemma-224", // PaliGemma 224
"paligemma-448", // PaliGemma 448
"paligemma2-3b-224", // PaliGemma2 3B 224
"paligemma2-3b-448", // PaliGemma2 3B 448
"paligemma2-10b-224", // PaliGemma2 10B 224
"paligemma2-10b-448", // PaliGemma2 10B 448
"gemma3-4b", // Gemma3 4B
"gemma3-1b", // Gemma3 1B
"gemma3-12b", // Gemma3 12B
"gemma3-27b", // Gemma3 27B
};
constexpr Model kModelTypes[] = {
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
Model::GEMMA_7B, Model::GEMMA_7B, // Gemma 7B
Model::GRIFFIN_2B, Model::GRIFFIN_2B, // RecurrentGemma
Model::GEMMA_TINY, // Gemma Tiny
Model::GEMMA2_2B, Model::GEMMA2_2B, // Gemma2 2B
Model::GEMMA2_9B, Model::GEMMA2_9B, // Gemma2 9B
Model::GEMMA2_27B, Model::GEMMA2_27B, // Gemma2 27B
Model::PALIGEMMA_224, // PaliGemma 224
Model::PALIGEMMA_448, // PaliGemma 448
Model::PALIGEMMA2_3B_224, // PaliGemma2 3B 224
Model::PALIGEMMA2_3B_448, // PaliGemma2 3B 448
Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224
Model::PALIGEMMA2_10B_448, // PaliGemma2 10B 448
Model::GEMMA3_4B, // Gemma3 4B
Model::GEMMA3_1B, // Gemma3 1B
Model::GEMMA3_12B, // Gemma3 12B
Model::GEMMA3_27B, // Gemma3 27B
};
constexpr PromptWrapping kPromptWrapping[] = {
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 2B
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 7B
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // RecurrentGemma
PromptWrapping::GEMMA_IT, // Gemma Tiny
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 2B
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 9B
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 27B
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PaliGemma 224/448
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 3B 224/448
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 10B 224/448
PromptWrapping::GEMMA_VLM, // Gemma3 4B
PromptWrapping::GEMMA_PT, // Gemma3 1B
PromptWrapping::GEMMA_VLM, // Gemma3 12B
PromptWrapping::GEMMA_VLM, // Gemma3 27B
};
constexpr size_t kNumModelFlags = std::size(kModelFlags);
static_assert(kNumModelFlags == std::size(kModelTypes));
static_assert(kNumModelFlags == std::size(kPromptWrapping));
const char* ParseModelTypeAndWrapping(const std::string& model_flag,
Model& model, PromptWrapping& wrapping) {
static std::string kErrorMessageBuffer =
"Invalid or missing model flag, need to specify one of ";
for (size_t i = 0; i + 1 < kNumModelFlags; ++i) {
kErrorMessageBuffer.append(kModelFlags[i]);
kErrorMessageBuffer.append(", ");
}
kErrorMessageBuffer.append(kModelFlags[kNumModelFlags - 1]);
kErrorMessageBuffer.append(".");
std::string model_type_lc = model_flag;
std::transform(model_type_lc.begin(), model_type_lc.end(),
model_type_lc.begin(), ::tolower);
for (size_t i = 0; i < kNumModelFlags; ++i) {
if (kModelFlags[i] == model_type_lc) {
model = kModelTypes[i];
wrapping = kPromptWrapping[i];
HWY_ASSERT(std::string(ModelString(model, wrapping)) == model_type_lc);
return nullptr;
}
}
return kErrorMessageBuffer.c_str();
}
const char* ModelString(Model model, PromptWrapping wrapping) {
for (size_t i = 0; i < kNumModelFlags; i++) {
if (kModelTypes[i] == model && kPromptWrapping[i] == wrapping)
return kModelFlags[i];
}
HWY_ABORT("Unknown model %d wrapping %d\n", static_cast<int>(model),
static_cast<int>(wrapping));
}
const char* StringFromType(Type type) {
return kTypeStrings[static_cast<size_t>(type)];
}
const char* ParseType(const std::string& type_string, Type& type) {
constexpr size_t kNum = std::size(kTypeStrings);
static std::string kErrorMessageBuffer =
"Invalid or missing type, need to specify one of ";
for (size_t i = 0; i + 1 < kNum; ++i) {
kErrorMessageBuffer.append(kTypeStrings[i]);
kErrorMessageBuffer.append(", ");
}
kErrorMessageBuffer.append(kTypeStrings[kNum - 1]);
kErrorMessageBuffer.append(".");
std::string type_lc = type_string;
std::transform(type_lc.begin(), type_lc.end(), type_lc.begin(), ::tolower);
for (size_t i = 0; i < kNum; ++i) {
if (kTypeStrings[i] == type_lc) {
type = static_cast<Type>(i);
HWY_ASSERT(std::string(StringFromType(type)) == type_lc);
return nullptr;
}
}
return kErrorMessageBuffer.c_str();
}
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) {
// Instruction-tuned models are trained to expect control tokens.
if (info.wrapping == PromptWrapping::GEMMA_IT) {
// Prepend "<end_of_turn>" if this is a multi-turn dialogue continuation.
const std::string start = (pos == 0)
? "<start_of_turn>user\n"
: "<end_of_turn>\n<start_of_turn>user\n";
prompt = start + prompt + "<end_of_turn>\n<start_of_turn>model\n";
}
}
float EmbeddingScaling(size_t model_dim) {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
sqrtf(static_cast<float>(model_dim))));
}
float ChooseQueryScale(const ModelConfig& config) {
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
return 1.0f / sqrtf(static_cast<float>(config.model_dim /
config.layer_configs[0].heads));
// QueryScaleType::SqrtKeySize
return 1.0f / sqrtf(static_cast<float>(config.layer_configs[0].qkv_dim));
}
} // namespace gcpp

View File

@ -1,57 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
#include <stddef.h>
#include <string>
#include "compression/shared.h" // PromptWrapping
#include "gemma/configs.h" // IWYU pragma: export
#include "hwy/base.h" // ConvertScalarTo
namespace gcpp {
// Struct to bundle model information.
struct ModelInfo {
Model model;
PromptWrapping wrapping;
Type weight;
};
// Returns error string or nullptr if OK.
// Thread-hostile.
const char* ParseModelTypeAndWrapping(const std::string& model_flag,
Model& model, PromptWrapping& wrapping);
const char* ParseType(const std::string& type_string, Type& type);
// Inverse of ParseModelTypeAndWrapping.
const char* ModelString(Model model, PromptWrapping wrapping);
const char* StringFromType(Type type);
// Wraps the given prompt using the expected control tokens for IT models.
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt);
// Returns the scale value to use for the embedding (basically sqrt model_dim).
float EmbeddingScaling(size_t model_dim);
// Returns the scale value to use for the query in the attention computation.
float ChooseQueryScale(const ModelConfig& config);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_

View File

@ -15,22 +15,31 @@
#include "gemma/configs.h"
#include <cstddef>
#include <iostream>
#include <stddef.h>
#include <stdio.h>
#include <string>
#include <vector>
#include "compression/types.h" // Type
#include "io/fields.h" // IFields
#include "io/io.h" // Path
#include "hwy/base.h"
namespace gcpp {
static constexpr size_t kVocabSize = 256000;
static constexpr size_t kGemmaV3VocabSize = 262144;
static ModelConfig ConfigNoSSM() {
ModelConfig config;
config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w",
"gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"};
config.scale_base_names = {"att_ein", "qkv_ein", "gr_lin_x_w",
"gr_lin_y_w", "gr_lin_out_w", "gr_gate_w",
"gating_ein", "linear_w"};
return config;
}
static ModelConfig ConfigBaseGemmaV1() { return ConfigNoSSM(); }
static ModelConfig ConfigBaseGemmaV2() {
ModelConfig config = ConfigNoSSM();
config.att_cap = 50.0f;
@ -54,17 +63,17 @@ static LayerConfig LayerConfigGemma2_27B(size_t model_dim) {
static ModelConfig ConfigGemma2_27B() {
ModelConfig config = ConfigBaseGemmaV2();
config.model_name = "Gemma2_27B";
config.display_name = "Gemma2_27B";
config.model = Model::GEMMA2_27B;
config.model_dim = 4608;
config.vocab_size = kVocabSize;
config.seq_len = 8192;
config.max_seq_len = 8192;
LayerConfig layer_config = LayerConfigGemma2_27B(config.model_dim);
config.layer_configs = {46, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.num_layers = 46;
config.layer_configs = {config.num_layers, layer_config};
config.query_scale = QueryScaleType::SqrtModelDimDivNumHeads;
config.attention_window_sizes =
RepeatedAttentionWindowSizes<46, 2>({4096, 8192});
RepeatedAttentionWindowSizes<46, 2>({4096, config.max_seq_len});
return config;
}
@ -82,17 +91,17 @@ static LayerConfig LayerConfigGemma2_9B(size_t model_dim) {
static ModelConfig ConfigGemma2_9B() {
ModelConfig config = ConfigBaseGemmaV2();
config.model_name = "Gemma2_9B";
config.display_name = "Gemma2_9B";
config.model = Model::GEMMA2_9B;
config.model_dim = 3584;
config.vocab_size = kVocabSize;
config.seq_len = 8192;
config.max_seq_len = 8192;
LayerConfig layer_config = LayerConfigGemma2_9B(config.model_dim);
config.layer_configs = {42, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.num_layers = 42;
config.layer_configs = {config.num_layers, layer_config};
config.query_scale = QueryScaleType::SqrtKeySize;
config.attention_window_sizes =
RepeatedAttentionWindowSizes<42, 2>({4096, 8192});
RepeatedAttentionWindowSizes<42, 2>({4096, config.max_seq_len});
return config;
}
@ -110,66 +119,17 @@ static LayerConfig LayerConfigGemma2_2B(size_t model_dim) {
static ModelConfig ConfigGemma2_2B() {
ModelConfig config = ConfigBaseGemmaV2();
config.model_name = "Gemma2_2B";
config.display_name = "Gemma2_2B";
config.model = Model::GEMMA2_2B;
config.model_dim = 2304;
config.vocab_size = kVocabSize;
config.seq_len = 8192;
config.max_seq_len = 8192;
LayerConfig layer_config = LayerConfigGemma2_2B(config.model_dim);
config.layer_configs = {26, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.num_layers = 26;
config.layer_configs = {config.num_layers, layer_config};
config.query_scale = QueryScaleType::SqrtKeySize;
config.attention_window_sizes =
RepeatedAttentionWindowSizes<26, 2>({4096, 8192});
return config;
}
static LayerConfig LayerConfigGemma7B(size_t model_dim) {
LayerConfig config;
config.model_dim = model_dim;
config.ff_hidden_dim = 16 * 3072 / 2; // = 24576
config.heads = 16;
config.kv_heads = 16;
config.qkv_dim = 256;
return config;
}
static ModelConfig ConfigGemma7B() {
ModelConfig config = ConfigBaseGemmaV1();
config.model_name = "Gemma7B";
config.model = Model::GEMMA_7B;
config.model_dim = 3072;
config.vocab_size = kVocabSize;
config.seq_len = kSeqLen;
LayerConfig layer_config = LayerConfigGemma7B(config.model_dim);
config.layer_configs = {28, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize;
config.attention_window_sizes = FixedAttentionWindowSizes<28>(kSeqLen);
return config;
}
static LayerConfig LayerConfigGemma2B(size_t model_dim) {
LayerConfig config;
config.model_dim = model_dim;
config.ff_hidden_dim = 16 * 2048 / 2; // = 16384
config.heads = 8;
config.kv_heads = 1;
config.qkv_dim = 256;
return config;
}
static ModelConfig ConfigGemma2B() {
ModelConfig config = ConfigBaseGemmaV1();
config.model_name = "Gemma2B";
config.model = Model::GEMMA_2B;
config.model_dim = 2048;
config.vocab_size = kVocabSize;
config.seq_len = kSeqLen;
LayerConfig layer_config = LayerConfigGemma2B(config.model_dim);
config.layer_configs = {18, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.attention_window_sizes = FixedAttentionWindowSizes<18>(kSeqLen);
RepeatedAttentionWindowSizes<26, 2>({4096, config.max_seq_len});
return config;
}
@ -185,17 +145,18 @@ static LayerConfig LayerConfigGemmaTiny(size_t model_dim) {
static ModelConfig ConfigGemmaTiny() {
ModelConfig config = ConfigNoSSM();
config.model_name = "GemmaTiny";
config.display_name = "GemmaTiny";
config.model = Model::GEMMA_TINY;
config.model_dim = 128;
config.vocab_size = 64;
config.seq_len = 32;
config.wrapping = PromptWrapping::GEMMA_IT;
config.model_dim = 32;
config.vocab_size = 32; // at least two f32 vectors
config.max_seq_len = 32;
LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim);
config.layer_configs = {3, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.num_layers = 2;
config.layer_configs = {config.num_layers, layer_config};
config.query_scale = QueryScaleType::SqrtKeySize;
config.attention_window_sizes = FixedAttentionWindowSizes<3>(32);
// This is required for optimize_test to pass.
config.attention_window_sizes = FixedAttentionWindowSizes<2>(32);
config.att_cap = 50.0f;
config.final_cap = 30.0f;
config.eos_id = 11;
config.secondary_eos_id = 11;
@ -223,23 +184,23 @@ static LayerConfig LayerConfigGriffin2B(size_t model_dim) {
static ModelConfig ConfigGriffin2B() {
ModelConfig config = ConfigNoSSM();
config.model_name = "Griffin2B";
config.display_name = "Griffin2B";
config.model = Model::GRIFFIN_2B;
// Griffin uses local attention, so kSeqLen is actually the local attention
// window.
// Griffin uses local attention, so max_seq_len is actually the local
// attention window.
config.model_dim = 2560;
config.vocab_size = kVocabSize;
config.seq_len = 2048;
config.max_seq_len = 2048;
LayerConfig layer_config = LayerConfigGriffin2B(config.model_dim);
config.layer_configs = {26, layer_config};
for (size_t i = 2; i < config.layer_configs.size(); i += 3) {
config.num_layers = 26;
config.layer_configs = {config.num_layers, layer_config};
for (size_t i = 2; i < config.num_layers; i += 3) {
config.layer_configs[i].type = LayerAttentionType::kGemma;
config.layer_configs[i].griffin_dim = 0;
}
config.num_tensor_scales = 140;
config.attention_window_sizes = FixedAttentionWindowSizes<26>(config.seq_len);
config.attention_window_sizes =
FixedAttentionWindowSizes<26>(config.max_seq_len);
config.use_local_attention = true;
// This is required for optimize_test to pass.
config.final_cap = 0.0f;
return config;
}
@ -273,26 +234,10 @@ static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
config.vit_config.num_scales = 4 * config.vit_config.layer_configs.size();
}
static ModelConfig ConfigPaliGemma_224() {
ModelConfig config = ConfigGemma2B();
config.model_name = "PaliGemma_224";
config.model = Model::PALIGEMMA_224;
AddVitConfig(config);
return config;
}
static ModelConfig ConfigPaliGemma_448() {
ModelConfig config = ConfigGemma2B();
config.model_name = "PaliGemma_448";
config.model = Model::PALIGEMMA_448;
AddVitConfig(config, /*image_size=*/448);
return config;
}
ModelConfig GetVitConfig(const ModelConfig& config) {
ModelConfig vit_config = ConfigNoSSM();
vit_config.model_dim = config.vit_config.model_dim;
vit_config.seq_len = config.vit_config.seq_len;
vit_config.max_seq_len = config.vit_config.seq_len;
vit_config.layer_configs = config.vit_config.layer_configs;
vit_config.pool_dim = config.vit_config.pool_dim;
vit_config.wrapping = config.wrapping;
@ -303,32 +248,36 @@ ModelConfig GetVitConfig(const ModelConfig& config) {
static ModelConfig ConfigPaliGemma2_3B_224() {
ModelConfig config = ConfigGemma2_2B();
config.model_name = "PaliGemma2_3B_224";
config.display_name = "PaliGemma2_3B_224";
config.model = Model::PALIGEMMA2_3B_224;
config.wrapping = PromptWrapping::PALIGEMMA;
AddVitConfig(config);
return config;
}
static ModelConfig ConfigPaliGemma2_3B_448() {
ModelConfig config = ConfigGemma2_2B();
config.model_name = "PaliGemma2_3B_448";
config.display_name = "PaliGemma2_3B_448";
config.model = Model::PALIGEMMA2_3B_448;
config.wrapping = PromptWrapping::PALIGEMMA;
AddVitConfig(config, /*image_size=*/448);
return config;
}
static ModelConfig ConfigPaliGemma2_10B_224() {
ModelConfig config = ConfigGemma2_9B();
config.model_name = "PaliGemma2_10B_224";
config.display_name = "PaliGemma2_10B_224";
config.model = Model::PALIGEMMA2_10B_224;
config.wrapping = PromptWrapping::PALIGEMMA;
AddVitConfig(config);
return config;
}
static ModelConfig ConfigPaliGemma2_10B_448() {
ModelConfig config = ConfigGemma2_9B();
config.model_name = "PaliGemma2_10B_448";
config.display_name = "PaliGemma2_10B_448";
config.model = Model::PALIGEMMA2_10B_448;
config.wrapping = PromptWrapping::PALIGEMMA;
AddVitConfig(config, /*image_size=*/448);
return config;
}
@ -358,18 +307,19 @@ static LayerConfig LayerConfigGemma3_1B_LM(size_t model_dim) {
static ModelConfig ConfigGemma3_1B() {
ModelConfig config = ConfigBaseGemmaV3();
config.model_name = "Gemma3_1B";
config.display_name = "Gemma3_1B";
config.model = Model::GEMMA3_1B;
config.wrapping = PromptWrapping::GEMMA_VLM;
config.model_dim = 1152;
config.vocab_size = 262144; // new vocab size / tokenizer
config.seq_len = 32 * 1024;
config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer
config.max_seq_len = 32 * 1024;
LayerConfig layer_config = LayerConfigGemma3_1B_LM(config.model_dim);
config.layer_configs = {26, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.num_layers = 26;
config.layer_configs = {config.num_layers, layer_config};
config.query_scale = QueryScaleType::SqrtKeySize;
// interleaved local / global attention
config.attention_window_sizes = RepeatedAttentionWindowSizes<26, 6>(
{512, 512, 512, 512, 512, config.seq_len});
{512, 512, 512, 512, 512, config.max_seq_len});
return config;
}
@ -389,27 +339,29 @@ static LayerConfig LayerConfigGemma3_4B_LM(size_t model_dim) {
// Until we have the SigLIP checkpoints included, we use the LM config directly.
static ModelConfig ConfigGemma3_4B_LM() {
ModelConfig config = ConfigBaseGemmaV3();
config.model_name = "Gemma3_4B";
config.display_name = "Gemma3_4B";
config.model = Model::GEMMA3_4B;
config.wrapping = PromptWrapping::GEMMA_VLM;
config.model_dim = 2560;
config.vocab_size = 262144; // new vocab size / tokenizer
config.seq_len = 32 * 1024;
config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer
config.max_seq_len = 32 * 1024;
LayerConfig layer_config = LayerConfigGemma3_4B_LM(config.model_dim);
config.layer_configs = {34, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.num_layers = 34;
config.layer_configs = {config.num_layers, layer_config};
config.query_scale = QueryScaleType::SqrtKeySize;
// interleaved local / global attention
config.attention_window_sizes = RepeatedAttentionWindowSizes<34, 6>(
{1024, 1024, 1024, 1024, 1024, config.seq_len});
{1024, 1024, 1024, 1024, 1024, config.max_seq_len});
return config;
}
static ModelConfig ConfigGemma3_4B() {
ModelConfig config = ConfigGemma3_4B_LM();
config.model_name = "Gemma3_4B";
config.display_name = "Gemma3_4B";
config.model = Model::GEMMA3_4B;
config.wrapping = PromptWrapping::GEMMA_VLM;
AddVitConfig(config, /*image_size=*/896);
config.vocab_size = 262144;
config.vocab_size = kGemmaV3VocabSize;
config.vit_config.pool_dim = 4;
const size_t num_patches =
config.vit_config.image_size / config.vit_config.patch_width;
@ -436,27 +388,29 @@ static LayerConfig LayerConfigGemma3_12B_LM(size_t model_dim) {
static ModelConfig ConfigGemma3_12B_LM() {
ModelConfig config = ConfigBaseGemmaV3();
config.model_name = "Gemma3_12B";
config.display_name = "Gemma3_12B";
config.model = Model::GEMMA3_12B;
config.wrapping = PromptWrapping::GEMMA_VLM;
config.model_dim = 3840;
config.vocab_size = 262144; // new vocab size / tokenizer
config.seq_len = 32 * 1024;
config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer
config.max_seq_len = 32 * 1024;
LayerConfig layer_config = LayerConfigGemma3_12B_LM(config.model_dim);
config.layer_configs = {48, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.num_layers = 48;
config.layer_configs = {config.num_layers, layer_config};
config.query_scale = QueryScaleType::SqrtKeySize;
// interleaved local / global attention
config.attention_window_sizes = RepeatedAttentionWindowSizes<48, 6>(
{1024, 1024, 1024, 1024, 1024, config.seq_len});
{1024, 1024, 1024, 1024, 1024, config.max_seq_len});
return config;
}
static ModelConfig ConfigGemma3_12B() {
ModelConfig config = ConfigGemma3_12B_LM();
config.model_name = "Gemma3_12B";
config.display_name = "Gemma3_12B";
config.model = Model::GEMMA3_12B;
config.wrapping = PromptWrapping::GEMMA_VLM;
AddVitConfig(config, /*image_size=*/896);
config.vocab_size = 262144;
config.vocab_size = kGemmaV3VocabSize;
config.vit_config.pool_dim = 4;
const size_t num_patches =
config.vit_config.image_size / config.vit_config.patch_width;
@ -483,27 +437,29 @@ static LayerConfig LayerConfigGemma3_27B_LM(size_t model_dim) {
static ModelConfig ConfigGemma3_27B_LM() {
ModelConfig config = ConfigBaseGemmaV3();
config.model_name = "Gemma3_27B";
config.display_name = "Gemma3_27B";
config.model = Model::GEMMA3_27B;
config.wrapping = PromptWrapping::GEMMA_VLM;
config.model_dim = 5376;
config.vocab_size = 262144; // new vocab size / tokenizer
config.seq_len = 32 * 1024;
config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer
config.max_seq_len = 32 * 1024;
LayerConfig layer_config = LayerConfigGemma3_27B_LM(config.model_dim);
config.layer_configs = {62, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.num_layers = 62;
config.layer_configs = {config.num_layers, layer_config};
config.query_scale = QueryScaleType::SqrtKeySize;
// interleaved local / global attention
config.attention_window_sizes = RepeatedAttentionWindowSizes<62, 6>(
{1024, 1024, 1024, 1024, 1024, config.seq_len});
{1024, 1024, 1024, 1024, 1024, config.max_seq_len});
return config;
}
static ModelConfig ConfigGemma3_27B() {
ModelConfig config = ConfigGemma3_27B_LM();
config.model_name = "Gemma3_27B";
config.display_name = "Gemma3_27B";
config.model = Model::GEMMA3_27B;
config.wrapping = PromptWrapping::GEMMA_VLM;
AddVitConfig(config, /*image_size=*/896);
config.vocab_size = 262144;
config.vocab_size = kGemmaV3VocabSize;
config.vit_config.pool_dim = 4;
const size_t num_patches =
config.vit_config.image_size / config.vit_config.patch_width;
@ -515,12 +471,8 @@ static ModelConfig ConfigGemma3_27B() {
return config;
}
ModelConfig ConfigFromModel(Model model) {
static ModelConfig ConfigFromModel(Model model) {
switch (model) {
case Model::GEMMA_2B:
return ConfigGemma2B();
case Model::GEMMA_7B:
return ConfigGemma7B();
case Model::GEMMA2_2B:
return ConfigGemma2_2B();
case Model::GEMMA2_9B:
@ -531,10 +483,6 @@ ModelConfig ConfigFromModel(Model model) {
return ConfigGriffin2B();
case Model::GEMMA_TINY:
return ConfigGemmaTiny();
case Model::PALIGEMMA_224:
return ConfigPaliGemma_224();
case Model::PALIGEMMA_448:
return ConfigPaliGemma_448();
case Model::PALIGEMMA2_3B_224:
return ConfigPaliGemma2_3B_224();
case Model::PALIGEMMA2_3B_448:
@ -556,124 +504,249 @@ ModelConfig ConfigFromModel(Model model) {
}
}
#define TEST_EQUAL(a, b) \
if (a != b) { \
if (debug) \
std::cerr << #a << "=" << a << " != " << #b << "=" << b << "\n"; \
result = false; \
const char* ModelPrefix(Model model) {
switch (model) {
case Model::UNKNOWN:
return "unknown";
case Model::GEMMA2_2B:
return "gemma2-2b";
case Model::GEMMA2_9B:
return "9b";
case Model::GEMMA2_27B:
return "27b";
case Model::GRIFFIN_2B:
return "gr2b";
case Model::GEMMA_TINY:
return "tiny";
case Model::PALIGEMMA2_3B_224:
return "paligemma2-3b-224";
case Model::PALIGEMMA2_3B_448:
return "paligemma2-3b-448";
case Model::PALIGEMMA2_10B_224:
return "paligemma2-10b-224";
case Model::PALIGEMMA2_10B_448:
return "paligemma2-10b-448";
case Model::GEMMA3_4B:
return "gemma3-4b";
case Model::GEMMA3_1B:
return "gemma3-1b";
case Model::GEMMA3_12B:
return "gemma3-12b";
case Model::GEMMA3_27B:
return "gemma3-27b";
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
#define RETURN_IF_NOT_EQUAL(a, b) \
if (a != b) { \
if (debug) \
std::cerr << #a << "=" << a << " != " << #b << "=" << b << "\n"; \
return false; \
}
#define WARN_IF_NOT_EQUAL(a, b) \
if (a != b) { \
std::cerr << #a << "=" << a << " != " << #b << "=" << b << "\n"; \
}
bool LayerConfig::TestEqual(const LayerConfig& other, bool partial,
bool debug) const {
bool result = true;
// Optimized gating may not be set correctly in the c++ configs.
if (debug) {
WARN_IF_NOT_EQUAL(optimized_gating, other.optimized_gating)
}
TEST_EQUAL(model_dim, other.model_dim);
TEST_EQUAL(griffin_dim, other.griffin_dim);
TEST_EQUAL(ff_hidden_dim, other.ff_hidden_dim);
TEST_EQUAL(heads, other.heads);
TEST_EQUAL(kv_heads, other.kv_heads);
TEST_EQUAL(qkv_dim, other.qkv_dim);
TEST_EQUAL(conv1d_width, other.conv1d_width);
if (!partial) {
TEST_EQUAL(ff_biases, other.ff_biases);
TEST_EQUAL(softmax_attn_output_biases, other.softmax_attn_output_biases);
}
TEST_EQUAL(static_cast<int>(post_norm), static_cast<int>(other.post_norm));
TEST_EQUAL(static_cast<int>(type), static_cast<int>(other.type));
TEST_EQUAL(static_cast<int>(activation), static_cast<int>(other.activation));
TEST_EQUAL(static_cast<int>(post_qk), static_cast<int>(other.post_qk));
return result;
}
bool VitConfig::TestEqual(const VitConfig& other, bool partial,
bool debug) const {
bool result = true;
TEST_EQUAL(model_dim, other.model_dim);
TEST_EQUAL(seq_len, other.seq_len);
if (!partial) {
TEST_EQUAL(num_scales, other.num_scales);
PromptWrapping ChooseWrapping(const Model model, Tristate wrapping) {
if (IsPaliGemma(model)) {
if (wrapping != Tristate::kDefault) {
HWY_WARN("Ignoring unnecessary --wrapping for PaliGemma models.");
}
return PromptWrapping::PALIGEMMA;
}
TEST_EQUAL(patch_width, other.patch_width);
TEST_EQUAL(image_size, other.image_size);
RETURN_IF_NOT_EQUAL(layer_configs.size(), other.layer_configs.size());
for (size_t i = 0; i < layer_configs.size(); ++i) {
result &=
layer_configs[i].TestEqual(other.layer_configs[i], partial, debug);
if (IsVLM(model)) {
if (wrapping != Tristate::kDefault) {
HWY_WARN("Ignoring unnecessary --wrapping for VLM models.");
}
return PromptWrapping::GEMMA_VLM;
}
return result;
// Default to IT unless --wrapping=0.
return wrapping == Tristate::kFalse ? PromptWrapping::GEMMA_PT
: PromptWrapping::GEMMA_IT;
}
bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
bool debug) const {
bool result = true;
TEST_EQUAL(model_family_version, other.model_family_version);
// We don't care about model_name, model, wrapping, or weight being different,
// but will output in debug mode if they are.
if (debug) {
WARN_IF_NOT_EQUAL(model_name, other.model_name);
WARN_IF_NOT_EQUAL(static_cast<int>(model), static_cast<int>(other.model));
WARN_IF_NOT_EQUAL(static_cast<int>(wrapping),
static_cast<int>(other.wrapping));
WARN_IF_NOT_EQUAL(static_cast<int>(weight), static_cast<int>(other.weight));
ModelConfig::ModelConfig(const Model model, Type weight,
PromptWrapping wrapping) {
HWY_ASSERT(weight != Type::kUnknown);
HWY_ASSERT(wrapping != PromptWrapping::kSentinel);
this->model = model;
if (model != Model::UNKNOWN) *this = ConfigFromModel(model);
HWY_ASSERT(this->model == model);
this->weight = weight;
this->wrapping = wrapping;
}
static Model FindModel(const std::string& specifier) {
Model found_model = Model::UNKNOWN;
ForEachModel([&](Model model) {
// Some model names are prefixes of other model names
const std::string prefix = std::string(ModelPrefix(model)) + "-";
if (specifier.rfind(prefix, 0) == 0) { // Starts with prefix.
// We only expect one match.
HWY_ASSERT_M(found_model == Model::UNKNOWN, specifier.c_str());
found_model = model;
}
});
HWY_ASSERT_M(found_model != Model::UNKNOWN, specifier.c_str());
return found_model;
}
static Type FindType(const std::string& specifier) {
Type found_type = Type::kUnknown;
for (size_t i = 1; i < kNumTypes; ++i) {
const Type type = static_cast<Type>(i);
if (specifier.find(TypeName(type)) != std::string::npos) { // NOLINT
// We only expect one match.
HWY_ASSERT_M(found_type == Type::kUnknown, specifier.c_str());
found_type = type;
}
}
TEST_EQUAL(model_dim, other.model_dim);
TEST_EQUAL(vocab_size, other.vocab_size);
TEST_EQUAL(seq_len, other.seq_len);
if (!partial) {
TEST_EQUAL(num_tensor_scales, other.num_tensor_scales);
HWY_ASSERT_M(found_type != Type::kUnknown, specifier.c_str());
return found_type;
}
static PromptWrapping FindWrapping(const std::string& specifier) {
PromptWrapping found_wrapping = PromptWrapping::kSentinel;
for (size_t i = 0; i < static_cast<size_t>(PromptWrapping::kSentinel); ++i) {
const PromptWrapping w = static_cast<PromptWrapping>(i);
if (specifier.find(WrappingSuffix(w)) != std::string::npos) { // NOLINT
// We expect zero or one match.
HWY_ASSERT_M(found_wrapping == PromptWrapping::kSentinel,
specifier.c_str());
found_wrapping = w;
}
}
TEST_EQUAL(att_cap, other.att_cap);
TEST_EQUAL(final_cap, other.final_cap);
TEST_EQUAL(absolute_pe, other.absolute_pe);
TEST_EQUAL(use_local_attention, other.use_local_attention);
TEST_EQUAL(static_cast<int>(query_scale),
static_cast<int>(other.query_scale));
RETURN_IF_NOT_EQUAL(layer_configs.size(), other.layer_configs.size());
for (size_t i = 0; i < layer_configs.size(); ++i) {
result &=
layer_configs[i].TestEqual(other.layer_configs[i], partial, debug);
if (found_wrapping == PromptWrapping::kSentinel) {
return ChooseWrapping(FindModel(specifier));
}
RETURN_IF_NOT_EQUAL(attention_window_sizes.size(),
other.attention_window_sizes.size());
for (size_t i = 0; i < attention_window_sizes.size(); ++i) {
TEST_EQUAL(attention_window_sizes[i], other.attention_window_sizes[i]);
return found_wrapping;
}
// Obtains model/weight/wrapping by finding prefix and suffix strings.
ModelConfig::ModelConfig(const std::string& specifier)
: ModelConfig(FindModel(specifier), FindType(specifier),
FindWrapping(specifier)) {}
std::string ModelConfig::Specifier() const {
HWY_ASSERT(model != Model::UNKNOWN);
HWY_ASSERT(weight != Type::kUnknown);
HWY_ASSERT(wrapping != PromptWrapping::kSentinel);
std::string base_name = ModelPrefix(model);
base_name += '-';
base_name += TypeName(weight);
if (wrapping != PromptWrapping::GEMMA_VLM &&
wrapping != PromptWrapping::PALIGEMMA) {
base_name += WrappingSuffix(wrapping);
}
if (!partial) {
if (scale_names != other.scale_names) {
result = false;
if (debug) {
std::cerr << "scale_names mismatch\n";
return base_name;
}
// Returns whether all fields match.
static bool AllEqual(const IFields& a, const IFields& b, bool print) {
const std::vector<uint32_t> serialized_a = a.Write();
const std::vector<uint32_t> serialized_b = b.Write();
if (serialized_a != serialized_b) {
if (print) {
fprintf(stderr, "%s differs. Recommend generating a diff:\n", a.Name());
a.Print();
b.Print();
}
return false;
}
return true;
}
bool LayerConfig::TestEqual(const LayerConfig& other, bool print) const {
return AllEqual(*this, other, print);
}
bool VitConfig::TestEqual(const VitConfig& other, bool print) const {
return AllEqual(*this, other, print);
}
bool ModelConfig::TestEqual(const ModelConfig& other, bool print) const {
// Early out to guard the loop below; a differing number of layers will anyway
// cause a mismatch.
if (layer_configs.size() != other.layer_configs.size()) {
if (print) {
HWY_WARN("Layer configs size mismatch %zu vs %zu", layer_configs.size(),
other.layer_configs.size());
}
return false;
}
// Copy so we can 'ignore' fields by setting them to the same value.
ModelConfig a = *this;
ModelConfig b = other;
// Called by `OverwriteWithCanonical`, so ignore the fields it will set.
a.display_name = b.display_name;
a.model = b.model;
// The following are not yet set by config_converter.py, so we here ignore
// them for purposes of comparison, and there overwrite the converter's config
// with the canonical ModelConfig constructed via (deduced) enum, so that
// these fields will be set.
// `vit_config` is also not yet set, but we must not ignore it because
// otherwise PaliGemma models will be indistinguishable for `configs_test`.
a.pool_dim = b.pool_dim; // ViT
a.eos_id = b.eos_id;
a.secondary_eos_id = b.secondary_eos_id;
a.scale_base_names = b.scale_base_names;
for (size_t i = 0; i < a.layer_configs.size(); ++i) {
a.layer_configs[i].optimized_gating = b.layer_configs[i].optimized_gating;
}
return AllEqual(a, b, print);
}
// Constructs the canonical ModelConfig for each model. If there is one for
// which TestEqual returns true, overwrites `*this` with that and returns true.
bool ModelConfig::OverwriteWithCanonical() {
bool found = false;
const bool print = false;
ForEachModel([&](Model model) {
const ModelConfig config(model, weight, wrapping);
if (config.TestEqual(*this, print)) {
HWY_ASSERT(!found); // Should only find one.
found = true;
*this = config;
}
});
return found;
}
Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) {
switch (layers) {
case 2:
return Model::GEMMA_TINY;
case 26:
if (layer_types & kDeducedGriffin) return Model::GRIFFIN_2B;
if (layer_types & kDeducedViT) return Model::GEMMA3_1B;
return Model::GEMMA2_2B;
case 27:
return (layer_types & kDeduced448) ? Model::PALIGEMMA2_3B_448
: Model::PALIGEMMA2_3B_224;
case 34:
return Model::GEMMA3_4B;
case 42:
if (layer_types & kDeducedViT) {
return (layer_types & kDeduced448) ? Model::PALIGEMMA2_10B_448
: Model::PALIGEMMA2_10B_224;
}
}
}
TEST_EQUAL(norm_num_groups, other.norm_num_groups);
result &= vit_config.TestEqual(other.vit_config, partial, debug);
return result;
}
return Model::GEMMA2_9B;
case 46:
return Model::GEMMA2_27B;
case 48:
return Model::GEMMA3_12B;
case 62:
return Model::GEMMA3_27B;
Model ModelFromConfig(const ModelConfig& config) {
for (Model model : kAllModels) {
ModelConfig model_config = ConfigFromModel(model);
if (config.TestEqual(model_config, /*partial=*/true, /*debug=*/false)) {
return model;
}
// TODO: detect these.
/*
return Model::GEMMA2_772M;
return Model::PALIGEMMA2_772M_224;
*/
default:
HWY_WARN("Failed to deduce model type from %s, layer count %zu types %x.",
blob_path.path.c_str(), layers, layer_types);
return Model::UNKNOWN;
}
return Model::UNKNOWN;
}
} // namespace gcpp

View File

@ -19,35 +19,52 @@
// Model configurations
#include <stddef.h>
#include <stdint.h>
#include <algorithm>
#include <array>
#include <cstdint>
#include <string>
#include <unordered_set>
#include <vector>
#include "compression/fields.h" // IFieldsVisitor
#include "compression/shared.h" // BF16
#include "compression/types.h" // Type
#include "io/fields.h" // IFieldsVisitor
#include "io/io.h" // Path
#include "util/basics.h"
namespace gcpp {
// Allow changing pre-allocated kv cache size as a compiler flag
#ifndef GEMMA_MAX_SEQLEN
#define GEMMA_MAX_SEQLEN 4096
#endif // !GEMMA_MAX_SEQLEN
// Allow changing k parameter of `SampleTopK` as a compiler flag
#ifndef GEMMA_TOPK
#define GEMMA_TOPK 1
#endif // !GEMMA_TOPK
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
static constexpr size_t kTopK = GEMMA_TOPK;
static constexpr size_t kVocabSize = 256000;
static constexpr size_t kMaxConv1DWidth = 4;
static constexpr size_t kMaxQKVDim = 1024;
using EmbedderInputT = BF16;
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
enum class PromptWrapping {
GEMMA_IT,
GEMMA_PT,
GEMMA_VLM, // for >1B Gemma3
PALIGEMMA,
kSentinel // must be last
};
// This is used in `ModelConfig.Specifier`, so the strings will not change,
// though new ones may be added.
static inline const char* WrappingSuffix(PromptWrapping wrapping) {
switch (wrapping) {
case PromptWrapping::GEMMA_IT:
return "-it";
case PromptWrapping::GEMMA_PT:
return "-pt";
case PromptWrapping::GEMMA_VLM:
return "-vlm";
case PromptWrapping::PALIGEMMA:
return "-pg";
default:
return "-?";
}
}
static inline bool EnumValid(PromptWrapping wrapping) {
return static_cast<size_t>(wrapping) <
static_cast<size_t>(PromptWrapping::kSentinel);
}
enum class LayerAttentionType {
kGemma,
@ -55,63 +72,68 @@ enum class LayerAttentionType {
kVit,
};
inline bool EnumValid(LayerAttentionType type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <= static_cast<int>(LayerAttentionType::kVit);
static inline bool EnumValid(LayerAttentionType type) {
return type == LayerAttentionType::kGemma ||
type == LayerAttentionType::kGriffinRecurrentBlock ||
type == LayerAttentionType::kVit;
}
// Post attention and ffw normalization type.
enum class PostNormType {
None,
Scale,
kSentinel // must be last
};
inline bool EnumValid(PostNormType type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <= static_cast<int>(PostNormType::Scale);
static inline bool EnumValid(PostNormType type) {
return static_cast<size_t>(type) <
static_cast<size_t>(PostNormType::kSentinel);
}
// Post qk projection operation type.
enum class PostQKType {
Rope,
HalfRope,
kSentinel // must be last
};
inline bool EnumValid(PostQKType type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <= static_cast<int>(PostQKType::HalfRope);
static inline bool EnumValid(PostQKType type) {
return static_cast<size_t>(type) <
static_cast<size_t>(PostNormType::kSentinel);
}
// FFW activation function.
enum class ActivationType {
Gelu,
kSentinel // must be last
};
inline bool EnumValid(ActivationType type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <= static_cast<int>(ActivationType::Gelu);
static inline bool EnumValid(ActivationType type) {
return static_cast<size_t>(type) <
static_cast<size_t>(ActivationType::kSentinel);
}
// Attention query scale.
enum class QueryScaleType {
SqrtKeySize,
SqrtModelDimDivNumHeads,
kSentinel // must be last
};
inline bool EnumValid(QueryScaleType type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <=
static_cast<int>(QueryScaleType::SqrtModelDimDivNumHeads);
static inline bool EnumValid(QueryScaleType type) {
return static_cast<size_t>(type) <
static_cast<size_t>(QueryScaleType::kSentinel);
}
// Residual connection type.
enum class ResidualType {
Add,
kSentinel // must be last
};
inline bool EnumValid(ResidualType type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <= static_cast<int>(ResidualType::Add);
static inline bool EnumValid(ResidualType type) {
return static_cast<size_t>(type) <
static_cast<size_t>(ResidualType::kSentinel);
}
template <size_t kNum>
@ -137,17 +159,15 @@ std::vector<uint32_t> RepeatedAttentionWindowSizes(
// Model variants: see configs.cc for details.
enum class Model {
UNKNOWN,
GEMMA_2B,
GEMMA_7B,
GEMMA2_9B,
UNKNOWN = 0,
// 1 and 2 are obsolete.
GEMMA2_9B = 3,
GEMMA2_27B,
GRIFFIN_2B,
GEMMA_TINY,
GEMMA_TINY, // for testing only
GEMMA2_2B,
PALIGEMMA_224,
PALIGEMMA_448,
PALIGEMMA2_3B_224,
// 8 and 9 are obsolete.
PALIGEMMA2_3B_224 = 10,
PALIGEMMA2_3B_448,
PALIGEMMA2_10B_224,
PALIGEMMA2_10B_448,
@ -155,43 +175,64 @@ enum class Model {
GEMMA3_1B,
GEMMA3_12B,
GEMMA3_27B,
kSentinel,
};
// Allows the Model enum to be iterated over.
static constexpr Model kAllModels[] = {
Model::GEMMA_2B, Model::GEMMA_7B, Model::GEMMA2_9B, Model::GEMMA2_27B,
Model::GRIFFIN_2B, Model::GEMMA_TINY, Model::GEMMA2_2B,
Model::PALIGEMMA_224, Model::PALIGEMMA_448, Model::PALIGEMMA2_3B_224,
Model::PALIGEMMA2_3B_448, Model::PALIGEMMA2_10B_224,
Model::PALIGEMMA2_10B_448, Model::GEMMA3_4B, Model::GEMMA3_1B,
Model::GEMMA3_12B, Model::GEMMA3_27B,
};
// Returns canonical model name without the PromptWrapping suffix. This is used
// in Specifier and thus does not change.
const char* ModelPrefix(Model model);
inline bool EnumValid(Model model) {
for (Model m : kAllModels) {
if (m == model) return true;
// Gemma3 is multimodal and has a different prompt wrapping than PaliGemma.
// This is used for deducing the PromptWrapping for pre-2025 BlobStore.
static inline bool IsVLM(Model model) {
return model == Model::GEMMA3_4B || model == Model::GEMMA3_1B ||
model == Model::GEMMA3_12B || model == Model::GEMMA3_27B;
}
static inline bool IsPaliGemma(Model model) {
if (model == Model::PALIGEMMA2_3B_224 || model == Model::PALIGEMMA2_3B_448 ||
model == Model::PALIGEMMA2_10B_224 ||
model == Model::PALIGEMMA2_10B_448) {
return true;
}
return false;
}
// Visits every valid model enum, skipping `UNKNOWN` and `kSentinel`.
template <class Func>
void ForEachModel(const Func& func) {
for (size_t i = static_cast<size_t>(Model::GEMMA2_9B);
i < static_cast<size_t>(Model::kSentinel); ++i) {
if (i == 8 || i == 9) continue;
func(static_cast<Model>(i));
}
}
static inline bool EnumValid(Model model) {
// Valid for purposes of serialization, even if unknown.
if (model == Model::UNKNOWN) return true;
const size_t i = static_cast<size_t>(model);
if (i >= static_cast<size_t>(Model::GEMMA2_9B) &&
i < static_cast<size_t>(Model::kSentinel) && i != 8 && i != 9) {
return true;
}
return false;
}
struct InternalLayerConfig : public IFields {
const char* Name() const override { return "InternalLayerConfig"; }
// Source of truth for field ordering.
void VisitFields(IFieldsVisitor& visitor) override {
// Append new fields here, then update `python/configs.cc`.
}
};
// Per-layer configuration.
struct LayerConfig : public IFields {
// Returns true if *this and other are equal.
// If partial is true, then we don't check for items that are only set after
// the tensors are loaded from the checkpoint.
// If debug is true, then we output the mismatched fields to stderr.
bool TestEqual(const LayerConfig& other, bool partial, bool debug) const;
size_t CacheLayerSize() const { return kv_heads * qkv_dim * 2; }
// Multi-Head Attention?
bool IsMHA() const { return heads == kv_heads; }
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
size_t QStride() const { return qkv_dim * (IsMHA() ? 3 : 1); }
const char* Name() const override { return "LayerConfig"; }
// Source of truth for field ordering.
void VisitFields(IFieldsVisitor& visitor) override {
visitor(model_dim);
visitor(griffin_dim);
@ -208,35 +249,41 @@ struct LayerConfig : public IFields {
visitor(activation);
visitor(post_qk);
visitor(use_qk_norm);
internal.VisitFields(visitor);
// Append new fields here, then update `python/configs.cc`.
}
// Returns whether all fields match.
bool TestEqual(const LayerConfig& other, bool print) const;
size_t CacheLayerSize() const { return kv_heads * qkv_dim * 2; }
// Multi-Head Attention?
bool IsMHA() const { return heads == kv_heads; }
uint32_t model_dim = 0;
uint32_t griffin_dim = 0;
uint32_t ff_hidden_dim = 0;
uint32_t heads = 0;
uint32_t kv_heads = 0;
uint32_t qkv_dim = 0;
uint32_t conv1d_width = 0; // griffin only
uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous).
uint32_t conv1d_width = 0; // Griffin only
bool ff_biases = false;
bool softmax_attn_output_biases = false;
bool optimized_gating = true;
bool softmax_attn_output_biases = false; // for Griffin
bool optimized_gating = true; // for Gemma3
PostNormType post_norm = PostNormType::None;
LayerAttentionType type = LayerAttentionType::kGemma;
ActivationType activation = ActivationType::Gelu;
PostQKType post_qk = PostQKType::Rope;
bool use_qk_norm = false;
InternalLayerConfig internal;
};
// Dimensions related to image processing.
struct VitConfig : public IFields {
// Returns true if *this and other are equal.
// If partial is true, then we don't check for items that are only set after
// the tensors are loaded from the checkpoint.
// If debug is true, then we output the mismatched fields to stderr.
bool TestEqual(const VitConfig& other, bool partial, bool debug) const;
const char* Name() const override { return "VitConfig"; }
// Source of truth for field ordering.
void VisitFields(IFieldsVisitor& visitor) override {
visitor(model_dim);
visitor(seq_len);
@ -245,8 +292,12 @@ struct VitConfig : public IFields {
visitor(image_size);
visitor(layer_configs);
visitor(pool_dim);
// Append new fields here, then update `python/configs.cc`.
}
// Returns whether all fields match.
bool TestEqual(const VitConfig& other, bool print) const;
uint32_t model_dim = 0;
uint32_t seq_len = 0;
uint32_t num_scales = 0;
@ -256,20 +307,96 @@ struct VitConfig : public IFields {
std::vector<LayerConfig> layer_configs;
};
// Returns a valid `PromptWrapping` for the given `model`, for passing to the
// `ModelConfig` ctor when the caller does not care about the wrapping. The
// wrapping mode is either determined by the model (for PaliGemma and Gemma3),
// or defaults to IT, subject to user override for PT.
PromptWrapping ChooseWrapping(Model model,
Tristate wrapping = Tristate::kDefault);
struct InternalModelConfig : public IFields {
const char* Name() const override { return "InternalModelConfig"; }
// Source of truth for field ordering.
void VisitFields(IFieldsVisitor& visitor) override {
// Append new fields here, then update `python/configs.cc`.
}
};
struct ModelConfig : public IFields {
// Returns true if *this and other are equal.
// If partial is true, then we don't check for items that are only set after
// the tensors are loaded from the checkpoint.
// If debug is true, then we output the mismatched fields to stderr.
bool TestEqual(const ModelConfig& other, bool partial, bool debug) const;
// Preferred usage (single-file format): default-construct, then deserialize
// from a blob. Also used by `config_converter.py`, which sets sufficient
// fields for `TestEqual` and then calls `OverwriteWithCanonical()`.
ModelConfig() = default;
// For use by `model_store.cc` for pre-2025 format after deducing the model
// from tensors plus a user-specified `wrapping` override (`ChooseWrapping`).
ModelConfig(Model model, Type weight, PromptWrapping wrapping);
// Parses a string returned by `Specifier()`. Used by the exporter to select
// the model from command line arguments. Do not use this elsewhere - the
// second ctor is preferred because it is type-checked.
ModelConfig(const std::string& specifier);
const char* Name() const override { return "ModelConfig"; }
// Source of truth for field ordering.
void VisitFields(IFieldsVisitor& visitor) override {
visitor(model_family_version);
visitor(display_name);
visitor(model);
visitor(wrapping);
visitor(weight);
visitor(num_layers);
visitor(model_dim);
visitor(vocab_size);
visitor(max_seq_len);
visitor(unused_num_tensor_scales);
visitor(att_cap);
visitor(final_cap);
visitor(absolute_pe);
visitor(use_local_attention);
visitor(query_scale);
visitor(layer_configs);
visitor(attention_window_sizes);
visitor(norm_num_groups);
visitor(vit_config);
visitor(pool_dim);
visitor(eos_id);
visitor(secondary_eos_id);
visitor(scale_base_names);
internal.VisitFields(visitor);
// Append new fields here, then update `python/configs.cc`.
}
// Returns whether all fields match except `model` and `display_name`, and
// some others that are not yet set by config_converter.py. This is for
// internal use by `OverwriteWithCanonical`, but potentially useful elsewhere.
bool TestEqual(const ModelConfig& other, bool print) const;
// For each model, constructs its canonical `ModelConfig` and if `TestEqual`
// returns true, overwrites `*this` with that. Otherwise, returns false to
// indicate this is not a known model. Called by `config_converter.py`.
bool OverwriteWithCanonical();
// Returns a string encoding of the model family, size, weight, and
// `PromptWrapping`. Stable/unchanging; can be used as the model file name.
// The third ctor also expects a string returned by this.
std::string Specifier() const;
void AddLayerConfig(const LayerConfig& layer_config) {
layer_configs.push_back(layer_config);
HWY_ASSERT(layer_configs.size() <= num_layers);
}
size_t CachePosSize() const {
size_t num_layers = layer_configs.size();
return num_layers * layer_configs[0].CacheLayerSize();
bool IsGlobalLayer(size_t layer_idx) const {
return attention_window_sizes[layer_idx] == max_seq_len;
}
size_t NumLayersOfTypeBefore(LayerAttentionType type, size_t num) const {
@ -287,77 +414,77 @@ struct ModelConfig : public IFields {
size_t NumHeads() const {
uint32_t num_heads = 0;
for (const auto& layer_config : layer_configs) {
num_heads = std::max(num_heads, layer_config.heads);
num_heads = HWY_MAX(num_heads, layer_config.heads);
}
return num_heads;
}
const char* Name() const override { return "ModelConfig"; }
size_t KVCacheCols() const {
size_t num_layers = layer_configs.size();
return num_layers * layer_configs[0].CacheLayerSize();
}
bool IsEOS(int id) const { return (id == eos_id || id == secondary_eos_id); }
void VisitFields(IFieldsVisitor& visitor) override {
visitor(model_family_version);
visitor(model_name);
visitor(model);
visitor(wrapping);
visitor(weight);
visitor(num_layers);
visitor(model_dim);
visitor(vocab_size);
visitor(seq_len);
visitor(num_tensor_scales);
visitor(att_cap);
visitor(final_cap);
visitor(absolute_pe);
visitor(use_local_attention);
visitor(query_scale);
visitor(layer_configs);
visitor(attention_window_sizes);
visitor(norm_num_groups);
visitor(vit_config);
visitor(pool_dim);
visitor(eos_id);
visitor(secondary_eos_id);
}
// Major version of the model family. It is used as a fallback to distinguish
// between model types when there is no explicit information in the config.
// Major version of the model family, reflecting architecture changes. This is
// more convenient to compare than `Model` because that also includes the
// model size.
uint32_t model_family_version = 1;
std::string model_name;
Model model = Model::UNKNOWN;
// For display only, may change. Use `Specifier()` for setting the
// file name. Not checked by `TestEqual` because `config_converter.py` does
// not set this.
std::string display_name;
Model model = Model::UNKNOWN; // Not checked by `TestEqual`, see above.
PromptWrapping wrapping = PromptWrapping::GEMMA_PT;
Type weight = Type::kUnknown;
uint32_t num_layers = 0;
uint32_t model_dim = 0;
uint32_t vocab_size = 0;
uint32_t seq_len = 0;
uint32_t num_tensor_scales = 0;
uint32_t max_seq_len = 0;
// We no longer set nor use this: config_converter is not able to set this,
// and only pre-2025 format stores scales, and we do not require advance
// knowledge of how many there will be. Any scales present will just be
// assigned in order to the tensors matching `scale_base_names`.
uint32_t unused_num_tensor_scales = 0;
float att_cap = 0.0f;
float final_cap = 0.0f;
bool absolute_pe = false;
bool use_local_attention = false; // griffin only
bool use_local_attention = false; // Griffin only
QueryScaleType query_scale = QueryScaleType::SqrtKeySize;
std::vector<LayerConfig> layer_configs;
std::vector<uint32_t> attention_window_sizes;
std::unordered_set<std::string> scale_names;
uint32_t norm_num_groups = 1;
// Dimensions related to image processing.
VitConfig vit_config;
uint32_t pool_dim = 1; // used only for VitConfig copy
int eos_id = 1;
int secondary_eos_id = 1;
// Tensor base names without a layer suffix, used by `ModelStore` only for
// pre-2025 format.
std::vector<std::string> scale_base_names;
InternalModelConfig internal;
};
// Returns the config for the given model.
ModelConfig ConfigFromModel(Model model);
// Returns the model for the given config, if it matches any standard model.
Model ModelFromConfig(const ModelConfig& config);
// Returns the sub-config for the ViT model of the PaliGemma model.
ModelConfig GetVitConfig(const ModelConfig& config);
enum DeducedLayerTypes {
kDeducedGriffin = 1,
kDeducedViT = 2,
kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224.
};
// layer_types is one or more of `DeducedLayerTypes`.
Model DeduceModel(const Path& blob_path, size_t layers, int layer_types);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_

View File

@ -1,461 +1,44 @@
#include "gemma/configs.h"
#include <array>
#include <cstddef>
#include <cstdint>
#include <type_traits>
#include <stdio.h>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "hwy/aligned_allocator.h"
#include "compression/types.h" // Type
#include "io/fields.h" // Type
namespace gcpp {
template <size_t kNum>
constexpr std::array<LayerAttentionType, kNum> OldFixedLayerConfig(
LayerAttentionType type) {
std::array<LayerAttentionType, kNum> config = {};
for (LayerAttentionType& l : config) {
l = type;
}
return config;
}
TEST(ConfigsTest, TestAll) {
ForEachModel([&](Model model) {
ModelConfig config(model, Type::kSFP, ChooseWrapping(model));
fprintf(stderr, "Testing %s (%s)\n", config.display_name.c_str(),
config.Specifier().c_str());
HWY_ASSERT(config.model == model);
template <size_t kNum>
constexpr std::array<size_t, kNum> OldFixedAttentionWindowSizes(
size_t window_size) {
std::array<size_t, kNum> window_size_configs = {};
for (size_t& l : window_size_configs) {
l = window_size;
}
return window_size_configs;
}
// We can deduce the model/display_name from all other fields.
config.model = Model::UNKNOWN;
const std::string saved_display_name = config.display_name;
config.display_name.clear();
HWY_ASSERT(config.OverwriteWithCanonical());
HWY_ASSERT(config.model == model);
HWY_ASSERT(config.display_name == saved_display_name);
// Repeat window_size_pattern for kNum / kPatternSize times.
template <size_t kNum, size_t kPatternSize>
constexpr std::array<size_t, kNum> OldRepeatedAttentionWindowSizes(
const std::array<size_t, kPatternSize>& window_size_pattern) {
static_assert(kNum % kPatternSize == 0,
"kNum must be a multiple of kPatternSize");
std::array<size_t, kNum> window_size_configs = {};
for (size_t i = 0; i < kNum; ++i) {
window_size_configs[i] = window_size_pattern[i % kPatternSize];
}
return window_size_configs;
}
template <size_t kNumLayers>
constexpr size_t OldNumLayersOfTypeBefore(
const std::array<LayerAttentionType, kNumLayers>& layers,
LayerAttentionType type, size_t num) {
size_t count = 0;
for (size_t i = 0; i < num; i++) {
if (layers[i] == type) count++;
}
return count;
}
template <class TConfig, typename = void>
struct CacheLayerSize {
constexpr size_t operator()() const {
return TConfig::kKVHeads * TConfig::kQKVDim * 2;
}
};
template <class TConfig, typename = void>
struct CachePosSize {
constexpr size_t operator()() const {
return TConfig::kGemmaLayers * CacheLayerSize<TConfig>()();
}
};
struct OldConfigNoVit {
struct VitConfig {
// Some of these are needed to make the compiler happy when trying to
// generate code that will actually never be used.
using Weight = float;
static constexpr int kLayers = 0;
static constexpr std::array<LayerAttentionType, 0> kLayerConfig =
OldFixedLayerConfig<0>(LayerAttentionType::kVit);
static constexpr int kModelDim = 0;
static constexpr int kFFHiddenDim = 0;
static constexpr int kHeads = 1; // Avoid division by 0 in griffin gate_w.
static constexpr int kKVHeads = 0;
static constexpr int kQKVDim = 0;
static constexpr int kSeqLen = 0;
static constexpr ResidualType kResidual = ResidualType::Add;
static constexpr int kGriffinLayers = 0;
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr PostNormType kPostNorm = PostNormType::None;
};
};
struct OldConfigNoSSM : OldConfigNoVit {
static constexpr int kGriffinLayers = 0;
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr bool kUseHalfRope = false;
static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true;
static constexpr PostQKType kPostQK = PostQKType::Rope;
static constexpr ActivationType kActivation = ActivationType::Gelu;
static constexpr ResidualType kResidual = ResidualType::Add;
};
struct OldConfigBaseGemmaV1 : OldConfigNoSSM {
static constexpr float kAttCap = 0.0f;
static constexpr float kFinalCap = 0.0f;
static constexpr PostNormType kPostNorm = PostNormType::None;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
};
struct OldConfigBaseGemmaV2 : OldConfigNoSSM {
static constexpr float kAttCap = 50.0f;
static constexpr float kFinalCap = 30.0f;
static constexpr PostNormType kPostNorm = PostNormType::Scale;
};
template <typename TWeight>
struct OldConfigGemma2_27B : public OldConfigBaseGemmaV2 {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = 8192;
static constexpr int kVocabSize = gcpp::kVocabSize;
static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
OldFixedLayerConfig<46>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 46> kAttentionWindowSizes =
OldRepeatedAttentionWindowSizes<46, 2>({4096, kSeqLen});
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kNumTensorScales = 4 * kLayers;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 4608;
static constexpr int kFFHiddenDim = 16 * 4608 / 2; // = 36864
static constexpr int kHeads = 32;
static constexpr int kKVHeads = 16;
static constexpr int kQKVDim = 128; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr QueryScaleType kQueryScale =
QueryScaleType::SqrtModelDimDivNumHeads;
};
template <typename TWeight>
struct OldConfigGemma2_9B : public OldConfigBaseGemmaV2 {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = 8192;
static constexpr int kVocabSize = gcpp::kVocabSize;
static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
OldFixedLayerConfig<42>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 42> kAttentionWindowSizes =
OldRepeatedAttentionWindowSizes<42, 2>({4096, kSeqLen});
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kNumTensorScales = 4 * kLayers;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 3584;
static constexpr int kFFHiddenDim = 8 * 3584 / 2; // = 14336
static constexpr int kHeads = 16;
static constexpr int kKVHeads = 8;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
};
template <typename TWeight>
struct OldConfigGemma7B : public OldConfigBaseGemmaV1 {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = gcpp::kVocabSize;
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
OldFixedLayerConfig<28>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 28> kAttentionWindowSizes =
OldFixedAttentionWindowSizes<28>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kNumTensorScales = 4 * kLayers;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 3072;
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
static constexpr int kHeads = 16;
static constexpr int kKVHeads = 16; // standard MHA
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
};
template <typename TWeight>
struct OldConfigGemma2B : public OldConfigBaseGemmaV1 {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = gcpp::kVocabSize;
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
OldFixedLayerConfig<18>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 18> kAttentionWindowSizes =
OldFixedAttentionWindowSizes<18>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kNumTensorScales = 4 * kLayers;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 2048;
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
static constexpr int kHeads = 8;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
};
template <typename TWeight>
struct OldConfigPaliGemma_224 : public OldConfigGemma2B<TWeight> {
// On the LM side, the vocab size is one difference to Gemma1-2B in the
// architecture. PaliGemma adds 1024 <locNNNN> and 128 <segNNN> tokens.
static constexpr int kVocabSize = 256000 + 1024 + 128; // = 257152
// Sub-config for the Vision-Transformer part.
struct VitConfig : public OldConfigNoSSM {
using Weight = TWeight;
// The ViT parts. https://arxiv.org/abs/2305.13035
// "SoViT-400m/14 [...] has a width of 1152, depth 27, and MLP dim 4304."
static constexpr std::array<LayerAttentionType, 27> kLayerConfig =
OldFixedLayerConfig<27>(LayerAttentionType::kVit);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kNumTensorScales = 4 * kLayers;
static constexpr int kModelDim = 1152;
static constexpr int kFFHiddenDim = 4304;
static constexpr int kHeads = 16;
static constexpr int kKVHeads = 16; // standard MHA
static constexpr int kQKVDim = 72;
static constexpr int kSeqLen = 16 * 16; // 256
static constexpr bool kFFBiases = true;
// The Vit part does not have a vocabulary, the image patches are embedded.
static constexpr int kVocabSize = 0;
// Dimensions related to image processing.
static constexpr int kPatchWidth = 14;
static constexpr int kImageSize = 224;
// Necessary constant for the layer configuration.
static constexpr PostNormType kPostNorm = PostNormType::None;
};
};
template <typename TWeight>
struct OldConfigGemma2_2B : public OldConfigBaseGemmaV2 {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = 8192;
static constexpr int kVocabSize = gcpp::kVocabSize;
static constexpr std::array<LayerAttentionType, 26> kLayerConfig =
OldFixedLayerConfig<26>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
OldRepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen});
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kNumTensorScales = 4 * kLayers;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 2304;
static constexpr int kFFHiddenDim = 8 * 2304 / 2; // = 9216
static constexpr int kHeads = 8;
static constexpr int kKVHeads = 4;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
};
template <typename TWeight>
struct OldConfigGemmaTiny : public OldConfigNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = 32;
static constexpr int kVocabSize = 64;
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
OldFixedLayerConfig<3>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 3> kAttentionWindowSizes =
OldFixedAttentionWindowSizes<3>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kNumTensorScales = 4 * kLayers;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 128;
static constexpr int kFFHiddenDim = 256;
static constexpr int kHeads = 4;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 16; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr PostNormType kPostNorm = PostNormType::None;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
static constexpr float kAttCap = 0.0f;
// This is required for optimize_test to pass.
static constexpr float kFinalCap = 30.0f;
};
template <typename TWeight>
struct OldConfigGriffin2B : OldConfigNoVit {
using Weight = TWeight; // make accessible where we only have a TConfig
// Griffin uses local attention, so kSeqLen is actually the local attention
// window.
static constexpr int kSeqLen = 2048;
static constexpr int kVocabSize = gcpp::kVocabSize;
static constexpr std::array<LayerAttentionType, 26> kLayerConfig = {
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
};
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
OldFixedAttentionWindowSizes<26>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = OldNumLayersOfTypeBefore(
kLayerConfig, LayerAttentionType::kGemma, kLayers);
static constexpr int kGriffinLayers = OldNumLayersOfTypeBefore(
kLayerConfig, LayerAttentionType::kGriffinRecurrentBlock, kLayers);
static constexpr int kModelDim = 2560;
static constexpr int kFFHiddenDim = 7680;
static constexpr int kHeads = 10;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr PostNormType kPostNorm = PostNormType::None;
// No SoftCap.
static constexpr float kAttCap = 0.0f;
static constexpr float kFinalCap = 0.0f;
// SSM config.
static constexpr int kConv1dWidth = 4;
static constexpr bool kFFBiases = true;
static constexpr bool kSoftmaxAttnOutputBiases = true;
static constexpr bool kUseHalfRope = true;
static constexpr bool kUseLocalAttention = true;
static constexpr bool kInterleaveQKV = false;
static constexpr int kNumTensorScales = 140;
static constexpr PostQKType kPostQK = PostQKType::Rope;
static constexpr ActivationType kActivation = ActivationType::Gelu;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
static constexpr ResidualType kResidual = ResidualType::Add;
};
template <class TConfig>
void AssertMatch(const ModelConfig& config) {
ASSERT_EQ(TConfig::kModelDim, config.model_dim);
if constexpr (TConfig::VitConfig::kModelDim != 0) {
ASSERT_EQ(TConfig::VitConfig::kModelDim, config.vit_config.model_dim);
ASSERT_EQ(TConfig::VitConfig::kSeqLen, config.vit_config.seq_len);
ASSERT_EQ(TConfig::VitConfig::kNumTensorScales,
config.vit_config.num_scales);
for (size_t i = 0; i < config.vit_config.layer_configs.size(); ++i) {
ASSERT_EQ(TConfig::VitConfig::kLayerConfig[i],
config.vit_config.layer_configs[i].type);
}
}
ASSERT_EQ(TConfig::kVocabSize, config.vocab_size);
ASSERT_EQ(TConfig::kSeqLen, config.seq_len);
ASSERT_EQ(TConfig::kAttCap, config.att_cap);
ASSERT_EQ(TConfig::kFinalCap, config.final_cap);
ASSERT_EQ(TConfig::kAbsolutePE, config.absolute_pe);
ASSERT_EQ(TConfig::kUseLocalAttention, config.use_local_attention);
ASSERT_EQ(TConfig::kQueryScale, config.query_scale);
ASSERT_EQ(TConfig::kGemmaLayers,
config.NumLayersOfType(LayerAttentionType::kGemma));
ASSERT_EQ(TConfig::kGriffinLayers,
config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock));
for (size_t i = 0; i < config.layer_configs.size(); ++i) {
ASSERT_EQ(TConfig::kModelDim, config.layer_configs[i].model_dim);
ASSERT_EQ(TConfig::kFFHiddenDim, config.layer_configs[i].ff_hidden_dim);
ASSERT_EQ(TConfig::kHeads, config.layer_configs[i].heads);
ASSERT_EQ(TConfig::kKVHeads, config.layer_configs[i].kv_heads);
ASSERT_EQ(TConfig::kQKVDim, config.layer_configs[i].qkv_dim);
ASSERT_EQ(TConfig::kConv1dWidth, config.layer_configs[i].conv1d_width);
ASSERT_EQ(TConfig::kFFBiases, config.layer_configs[i].ff_biases);
ASSERT_EQ(TConfig::kSoftmaxAttnOutputBiases,
config.layer_configs[i].softmax_attn_output_biases);
ASSERT_EQ(TConfig::kPostNorm, config.layer_configs[i].post_norm);
ASSERT_EQ(TConfig::kLayerConfig[i], config.layer_configs[i].type);
ASSERT_EQ(TConfig::kActivation, config.layer_configs[i].activation);
PostQKType post_qk = TConfig::kPostQK;
if (TConfig::kUseHalfRope) {
post_qk = PostQKType::HalfRope;
}
ASSERT_EQ(post_qk, config.layer_configs[i].post_qk);
}
ASSERT_EQ(TConfig::kAttentionWindowSizes.size(),
config.attention_window_sizes.size());
for (size_t i = 0; i < config.attention_window_sizes.size(); ++i) {
ASSERT_EQ(TConfig::kAttentionWindowSizes[i],
config.attention_window_sizes[i]);
}
ASSERT_EQ(TConfig::kNumTensorScales, config.num_tensor_scales);
}
ModelConfig RoundTripSerialize(const ModelConfig& config) {
std::vector<uint32_t> config_buffer = config.Write();
ModelConfig deserialized;
deserialized.Read(hwy::Span<const uint32_t>(config_buffer), 0);
return deserialized;
}
TEST(ConfigsTest, OldConfigGemma2B) {
AssertMatch<OldConfigGemma2B<float>>(ConfigFromModel(Model::GEMMA_2B));
ModelConfig config = RoundTripSerialize(ConfigFromModel(Model::GEMMA_2B));
AssertMatch<OldConfigGemma2B<float>>(config);
}
TEST(ConfigsTest, OldConfigGemma7B) {
AssertMatch<OldConfigGemma7B<float>>(ConfigFromModel(Model::GEMMA_7B));
}
TEST(ConfigsTest, OldConfigGemma2_2B) {
AssertMatch<OldConfigGemma2_2B<float>>(ConfigFromModel(Model::GEMMA2_2B));
}
TEST(ConfigsTest, OldConfigGemma2_9B) {
AssertMatch<OldConfigGemma2_9B<float>>(ConfigFromModel(Model::GEMMA2_9B));
}
TEST(ConfigsTest, OldConfigGemma2_27B) {
AssertMatch<OldConfigGemma2_27B<float>>(ConfigFromModel(Model::GEMMA2_27B));
}
TEST(ConfigsTest, OldConfigGriffin2B) {
AssertMatch<OldConfigGriffin2B<float>>(ConfigFromModel(Model::GRIFFIN_2B));
}
TEST(ConfigsTest, OldConfigGemmaTiny) {
AssertMatch<OldConfigGemmaTiny<float>>(ConfigFromModel(Model::GEMMA_TINY));
}
TEST(ConfigsTest, OldConfigPaliGemma_224) {
AssertMatch<OldConfigPaliGemma_224<float>>(
ConfigFromModel(Model::PALIGEMMA_224));
const std::vector<uint32_t> serialized = config.Write();
ModelConfig deserialized;
const IFields::ReadResult result =
deserialized.Read(hwy::Span<const uint32_t>(serialized), /*pos=*/0);
HWY_ASSERT(result.pos == serialized.size());
// We wrote it, so all fields should be known, and no extra.
HWY_ASSERT(result.extra_u32 == 0);
HWY_ASSERT(result.missing_fields == 0);
// All fields should match.
HWY_ASSERT(deserialized.TestEqual(config, /*print=*/true));
HWY_ASSERT(deserialized.model == model);
HWY_ASSERT(deserialized.display_name == saved_display_name);
});
}
} // namespace gcpp

File diff suppressed because it is too large Load Diff

View File

@ -13,173 +13,654 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Defines Gemma member functions; the actual implementations are in
// gemma-inl.h, included from instantiations/*.cc.
// Defines Gemma member functions which dynamic-dispatch into the SIMD
// implementations in gemma-inl.h.
#include "gemma/gemma.h"
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "gemma/attention.h" // includes highway.h
#include "gemma/gemma-inl.h"
#include "gemma/griffin.h" // includes highway.h
#include "gemma/vit.h" // includes highway.h
#ifndef GEMMA_CC_ONCE
#define GEMMA_CC_ONCE
#include <math.h> // sqrtf
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <string>
#include <utility> // std::move
#include <vector>
#include "compression/io.h" // Path
#include "compression/shared.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/model_store.h"
#include "gemma/weights.h"
#include "ops/ops-inl.h"
#include "io/blob_store.h"
#include "io/io.h" // Path
#include "ops/matmul.h"
#include "paligemma/image.h"
#include "util/threading.h"
#include "hwy/highway.h"
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h"
#include "hwy/timer.h"
#endif // GEMMA_CC_ONCE
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
const ModelInfo& info, MatMulEnv& env)
: env_(env), tokenizer_(tokenizer_path) {
model_.Load(weights, info.model, info.weight, info.wrapping,
env_.parallel.Pools().Pool(0),
/*tokenizer_proto=*/nullptr);
}
Gemma::Gemma(const Path& weights, MatMulEnv& env) : env_(env) {
std::string tokenizer_proto;
model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT,
env_.parallel.Pools().Pool(0), &tokenizer_proto);
tokenizer_.Deserialize(tokenizer_proto);
}
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env)
: env_(env), tokenizer_(std::move(tokenizer)) {
HWY_ASSERT(info.weight == Type::kF32);
model_.Allocate(info.model, info.weight, env_.parallel.Pools().Pool(0));
}
Gemma::~Gemma() {
}
// There are >=3 types of the inference code. To reduce compile time,
// we shard them across multiple translation units in instantiations/*.cc.
// This declares the functions defined there. We use overloading because
// explicit instantiations are still too slow to compile.
#define GEMMA_DECLARE(TWEIGHT) \
extern void GenerateSingle(TWEIGHT, const ModelWeightsStorage& model, \
const RuntimeConfig& runtime_config, \
const PromptTokens& prompt, size_t pos, \
size_t prefix_end, KVCache& kv_cache, \
MatMulEnv* env, TimingInfo& timing_info); \
extern void GenerateBatch( \
TWEIGHT, const ModelWeightsStorage& model, \
const RuntimeConfig& runtime_config, const QueriesPromptTokens& prompts, \
const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, \
const KVCaches& kv_caches, MatMulEnv* env, TimingInfo& timing_info); \
extern void GenerateImageTokens(TWEIGHT, const ModelWeightsStorage& model, \
const RuntimeConfig& runtime_config, \
const Image& image, \
ImageTokens& image_tokens, MatMulEnv* env);
GEMMA_DECLARE(float)
GEMMA_DECLARE(BF16)
GEMMA_DECLARE(NuqStream)
GEMMA_DECLARE(SfpStream)
// Adapters to select from the above overloads via CallForModelWeight.
template <class TConfig>
struct GenerateSingleT {
void operator()(const ModelWeightsStorage& model,
const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, size_t prefix_end,
KVCache& kv_cache, MatMulEnv* env,
TimingInfo& timing_info) const {
GenerateSingle(TConfig(), model, runtime_config, prompt, pos, prefix_end,
kv_cache, env, timing_info);
void Attention(LayerAttentionType type, const size_t num_tokens,
const size_t layer_idx, const LayerWeightsPtrs& layer,
Activations& activations, QBatch& qbatch, MatMulEnv& env) {
if (type == LayerAttentionType::kGemma) {
GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch,
env,
/*flags=*/0);
} else {
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
// KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer,
// so map `layer` to the Griffin layer index.
const size_t griffin_layer =
activations.attention.config.NumLayersOfTypeBefore(type, layer_idx);
GriffinRecurrent(num_tokens, griffin_layer, &layer, activations, qbatch,
env);
}
};
}
template <class TConfig>
struct GenerateBatchT {
void operator()(const ModelWeightsStorage& model,
const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
const KVCaches& kv_caches, MatMulEnv* env,
TimingInfo& timing_info) const {
GenerateBatch(TConfig(), model, runtime_config, queries_prompt, queries_pos,
queries_prefix_end, kv_caches, env, timing_info);
}
};
static HWY_NOINLINE void TransformerLayer(const size_t num_tokens,
const size_t layer_idx,
const LayerWeightsPtrs& layer,
Activations& activations,
QBatch& qbatch, MatMulEnv& env) {
const LayerConfig& layer_config = layer.layer_config;
template <class TConfig>
struct GenerateImageTokensT {
void operator()(const ModelWeightsStorage& model,
const RuntimeConfig& runtime_config, const Image& image,
ImageTokens& image_tokens, MatMulEnv* env) const {
GenerateImageTokens(TConfig(), model, runtime_config, image, image_tokens,
env);
RMSNormBatched(activations.x, layer.pre_attention_norm_scale,
activations.attention.pre_att_rms_out, env.ctx);
Attention(layer_config.type, num_tokens, layer_idx, layer, activations,
qbatch, env);
PostNorm(layer_config.post_norm, layer.post_attention_norm_scale,
activations.attention.att_sums, env.ctx);
ResidualConnection(activations.attention.att_sums, activations.x, layer,
/*is_attention=*/true, env.ctx);
RMSNormBatched(activations.x, layer.pre_ffw_norm_scale,
activations.pre_ffw_rms_out, env.ctx);
if (layer_config.type == LayerAttentionType::kVit) {
FFWVit(layer, activations, env);
} else {
FFWNoVit(layer, activations, env);
}
};
PostNorm(layer_config.post_norm, layer.post_ffw_norm_scale,
activations.ffw_out, env.ctx);
ResidualConnection(activations.ffw_out, activations.x, layer,
/*is_attention=*/false, env.ctx);
}
// Returns the scale value to use for the embedding (basically sqrt model_dim).
static float EmbeddingScaling(size_t model_dim) {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(sqrtf(static_cast<float>(model_dim))));
}
// `batch_idx` indicates which row of `x` to write to.
// `pos` is the *token*'s position, not the start of the batch, because this is
// called for batches of tokens in prefill, but batches of queries in decode.
//
// For GEMMA_VLM, image tokens are copied into -2 locations (per the Gemma 3
// spec) until we run out of image tokens. This allows for a multi-image prompt
// if -2 locations with appropriate begin/end image tokens are created by the
// calling application.
// Returns new image_token_position.
static HWY_NOINLINE size_t
EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
const ModelConfig& model_config, const WeightsPtrs& weights,
MatStorageT<float>& x, const ImageTokens* image_tokens = nullptr,
size_t image_token_position = 0) {
// Image tokens just need to be copied.
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
image_tokens != nullptr && token == -2 &&
image_token_position < image_tokens->Rows()) {
hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(qi),
x.Cols() * x.ElementBytes());
return image_token_position + 1;
}
if (model_config.wrapping == PromptWrapping::PALIGEMMA &&
image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) {
hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(qi),
x.Cols() * x.ElementBytes());
return image_token_position;
}
const size_t model_dim = model_config.model_dim;
const float emb_scaling = EmbeddingScaling(model_dim);
const size_t worker = 0; // Not yet parallelized.
HWY_DASSERT(token >= 0);
HWY_DASSERT(token < static_cast<int>(model_config.vocab_size));
CallUpcasted(&weights.embedder_input_embedding, [&](const auto* weights_t) {
// Using `Stride` to compute the offset works for both NUQ (because we use
// an offset and NUQ is never padded) and padded, because non-NUQ types are
// seekable, hence the offset can also skip any padding.
const size_t embedding_ofs = token * weights_t->Stride();
HWY_ASSERT(weights_t->Cols() == model_dim);
const auto embedding_span =
MakeSpan(weights_t->Row(0), embedding_ofs + model_dim);
const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi),
model_dim);
MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim, worker);
});
if (model_config.absolute_pe) {
AddAbsolutePositionalEmbeddings(x.Row(qi), model_dim, pos);
}
return image_token_position;
}
// Populates KV cache for batches of tokens from one query at a time. This is
// called if prompts are longer than the query batch size, and also in
// prefix-LM mode (end > 0), which must see all tokens in one batch.
static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const WeightsPtrs& weights,
Activations& activations, QBatch& qbatch,
MatMulEnv& env,
hwy::BitSet4096<>& non_eos) {
PROFILER_ZONE("Gen.PrefillT");
// Batches are important for amortizing loading weights over multiple tokens.
// This is possible in prefill because we know all tokens beforehand, whereas
// decode depends on the previous output token. However, each prefill batch of
// a query requires that preceding batches already wrote to the KV cache,
// hence we sequentially loop over token batches. We can reduce the number of
// iterations by increasing the batch size, but this also increases arithmetic
// intensity, and so we are eventually compute-limited. TransformerLayer uses
// all available threads, so we do not also parallelize over queries, but note
// that PrefillQBatch uses queries as the batch dimension.
const size_t max_tbatch_size = runtime_config.prefill_tbatch_size;
// For each query. `qi` is within the batch, not the global query index.
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
non_eos.Set(qi);
// One query at a time, batching will be the query's prompt tokens.
QBatch qbatch_1 = qbatch.Single(qi);
const size_t prompt_size = qbatch_1.Prompt(0).size();
// In autoregressive mode, we don't need to prefill the last token, so - 1.
size_t prefill_this_query = prompt_size - 1;
const size_t prefix_end_this_query = qbatch_1.PrefixEnd(0);
// We can't attend beyond the prompt_size.
HWY_ASSERT(prefix_end_this_query <= prompt_size);
// Special case: if the prefix includes the last token, we need to prefill
// the last token, too. However, we need to rewind this for the generation
// of the first token. So we need to keep track of this.
// TODO: consider implementing masking instead of this logic?
const bool attend_to_last_token =
(prefill_this_query < prefix_end_this_query);
if (attend_to_last_token) {
// The difference can be at most 1.
prefill_this_query += 1;
HWY_ASSERT(prefill_this_query == prefix_end_this_query);
}
// In prefix-LM mode, we need to look at all the tokens for the prefix in
// one iteration through the layers, so we need a large enough batch size.
HWY_ASSERT(prefix_end_this_query == 0 ||
max_tbatch_size >= prefill_this_query);
// For each batch of tokens in the query:
for (size_t tbatch_start = 0; tbatch_start < prefill_this_query;
tbatch_start += max_tbatch_size) {
const size_t tbatch_size =
HWY_MIN(max_tbatch_size, prefill_this_query - tbatch_start);
activations.SetBatchSize(tbatch_size);
// Fill activations.x (much faster than TransformerLayer).
size_t image_token_position = 0;
for (size_t ti = 0; ti < tbatch_size; ++ti) {
const size_t pos = qbatch_1.Pos(0) + ti;
const size_t pos_in_prompt = tbatch_start + ti;
const int token = qbatch_1.Prompt(0)[pos_in_prompt];
image_token_position = EmbedMMToken(
token, ti, pos, pos_in_prompt, config, weights, activations.x,
runtime_config.image_tokens, image_token_position);
}
// Transformer with one batch of tokens from a single query.
for (size_t layer_idx = 0; layer_idx < config.layer_configs.size();
++layer_idx) {
TransformerLayer(tbatch_size, layer_idx, *weights.GetLayer(layer_idx),
activations, qbatch_1, env);
}
// NOTE: we unconditionally call StreamToken, even if EOS.
for (size_t ti = 0; ti < tbatch_size; ++ti) {
const size_t pos = qbatch_1.Pos(0) + ti;
const size_t pos_in_prompt = tbatch_start + ti;
const int token = qbatch_1.Prompt(0)[pos_in_prompt];
if (pos_in_prompt < prompt_size - 1) {
runtime_config.StreamToken(qbatch_1.QueryIdx(0), pos, token, 0.0f);
} else {
// The last token will be streamed later and we should only get here
// if we need to attend to the last token because it is in the prefix.
HWY_ASSERT(attend_to_last_token);
}
}
qbatch_1.MutablePos(0) += tbatch_size;
} // for tbatch_start
if (attend_to_last_token) {
// We need to rewind the position for the last token that we only
// attended to to make sure the prefix LM sees everything.
// This means we duplicate work on the last prompt token in autoregressive
// decoding. Alternatives: (1) real masking; (2) always prefill the last
// token and only generate the next one from the already prefilled
// activations.
qbatch_1.MutablePos(0) -= 1;
}
}
}
// Embeds PrevToken (one from each query) and calls each TransformerLayer.
// Called by query-batched `PrefillQBatch` and `DecodeStepT`, but not the
// token-batched `PrefillTBatch`.
static HWY_NOINLINE void Transformer(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const WeightsPtrs& weights,
Activations& activations, QBatch& qbatch,
MatMulEnv& env) {
if (HWY_UNLIKELY(runtime_config.layers_output)) {
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
const float token_f = qbatch.PrevToken(qi);
runtime_config.layers_output(qbatch.QueryIdx(qi), qbatch.Pos(qi),
"tokens", -1, &token_f, 1);
}
}
// TODO: parallelize?
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
EmbedMMToken(qbatch.PrevToken(qi), qi, qbatch.Pos(qi),
/*pos_in_prompt=*/0, config, weights, activations.x);
}
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx),
activations, qbatch, env);
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
runtime_config.activations_observer(
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx,
activations);
}
}
}
// Populates KV cache for the batch queries, one token at a time. Only called
// for autoregressive (non-prefix-LM) prefill, so `queries_prefix_end` == 0.
static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
const ModelConfig& config,
const RuntimeConfig& runtime_config,
const WeightsPtrs& weights,
Activations& activations, QBatch& qbatch,
MatMulEnv& env,
hwy::BitSet4096<>& non_eos) {
PROFILER_ZONE("Gen.PrefillQ");
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
non_eos.Set(qi);
HWY_DASSERT(qbatch.PrefixEnd(qi) == 0);
}
// In autoregressive mode, we don't prefill the last token, hence - 1.
for (size_t pos_in_prompt = 0; pos_in_prompt < max_prompt_size - 1;
++pos_in_prompt) {
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
int token = config.eos_id;
if (pos_in_prompt < qbatch.Prompt(qi).size() - 1) {
token = qbatch.Prompt(qi)[pos_in_prompt];
// Ignore StreamToken return value because requesting to stop does not
// make sense during prefill.
(void)runtime_config.StreamToken(qbatch.QueryIdx(qi), pos_in_prompt,
token, 0.0f);
qbatch.MutablePos(qi) = pos_in_prompt;
}
qbatch.PrevToken(qi) = token;
}
// The input (PrevToken) is one token from each query in the batch.
// Do not call DecodeStepT because it computes logits for token
// probabilities, which are not required for the prompt tokens.
Transformer(config, runtime_config, weights, activations, qbatch, env);
}
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
qbatch.MutablePos(qi) = qbatch.Prompt(qi).size() - 1;
}
}
// Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent
// `DecodeStepT`, and increments `MutablePos`. Also updates `non_eos` if the
// query is at the end of its sequence.
static void StreamAndUpdateEOS(const size_t qi, int token, const float prob,
const ModelConfig& config,
const RuntimeConfig& runtime_config,
QBatch& qbatch, hwy::BitSet4096<>& non_eos) {
HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called.
if (HWY_UNLIKELY(!runtime_config.StreamToken(qbatch.QueryIdx(qi),
qbatch.Pos(qi), token, prob))) {
// User decided to stop: set token to primary EOS to trigger IsEOS below.
token = config.eos_id;
HWY_DASSERT(config.IsEOS(token));
}
qbatch.PrevToken(qi) = token;
qbatch.MutablePos(qi) += 1;
// Primary or secondary EOS: mark query as EOS, but still increment (for
// multi-turn, we should still keep the prior EOS).
if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi);
}
// For a batch of queries, runs Transformer, computes logits, samples and
// streams the token.
static void DecodeStepT(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const WeightsPtrs& weights,
const SampleFunc& sample_token,
Activations& activations, QBatch& qbatch,
MatMulEnv& env, hwy::BitSet4096<>& non_eos,
TimingInfo& timing_info) {
HWY_DASSERT(qbatch.Size() == activations.x.Rows());
Transformer(config, runtime_config, weights, activations, qbatch, env);
RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx);
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
runtime_config.activations_observer(
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations);
}
{
PROFILER_ZONE("Gen.EmbeddingMatmul");
// Compute logits from last layer activations.
CallMatMul(activations.x, weights.embedder_input_embedding,
/*add=*/nullptr, env, activations.logits);
}
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
const size_t worker = 0; // TODO: parallelize
non_eos.Foreach([&](size_t qi) {
float* HWY_RESTRICT logits = activations.logits.Row(qi);
MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size, worker);
const TokenAndProb tp = sample_token(logits, config.vocab_size);
timing_info.NotifyGenerated();
StreamAndUpdateEOS(qi, tp.token, tp.prob, config, runtime_config, qbatch,
non_eos);
});
}
static HWY_INLINE SampleFunc
ChooseSampleFunc(const RuntimeConfig& runtime_config) {
// If user provided a sample_func, use it.
if (runtime_config.sample_func) return runtime_config.sample_func;
const size_t worker = 0; // TODO: parallelize
// Fast path for top-1 with no accept_token.
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE2(worker, "Gen.Sample Top1");
return Top1OfSoftmax(logits, vocab_size);
};
}
// General case: Softmax with top-k sampling.
return [&runtime_config](float* logits,
size_t vocab_size) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE("Gen.Sample general");
return FusedSoftmaxAndSampleTopK(
logits, runtime_config.top_k, vocab_size, *runtime_config.gen,
runtime_config.temperature, runtime_config.accept_token, worker);
};
}
// Decode: generates one continuation token for each query in `qbatch`.
static void GenerateT(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const WeightsPtrs& weights, Activations& activations,
QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) {
// Griffin assumes that the recurrent block cache is zero-initialized.
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
if (qbatch.MutablePos(qi) == 0) {
qbatch.KV(qi).ZeroGriffinCache(); // No-op for non-Griffin models.
}
}
size_t max_prompt_size = 0;
bool all_prefix_end_are_zero = true;
size_t total_prefill_tokens = 0; // only for throughput stats.
const size_t seq_len = qbatch.KV(0).SeqLen();
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
const PromptTokens& prompt = qbatch.Prompt(qi);
max_prompt_size = HWY_MAX(max_prompt_size, prompt.size());
// Prefill stops before size - 1 because the last prompt token is the
// first input token for generation.
total_prefill_tokens += prompt.size() - 1;
// Sanity check: prompts should not be empty, nor start with EOS.
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
all_prefix_end_are_zero &= qbatch.PrefixEnd(qi) == 0;
// We use a single divisor, so all sequence lengths must be the same.
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
}
if (max_prompt_size >= seq_len) {
HWY_ABORT("max_prompt_size = %zu, increase --seq_len to at least that.",
max_prompt_size);
}
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
// qi loops anyway.
hwy::BitSet4096<> non_eos; // indexed by qi
timing_info.prefill_start = hwy::platform::Now();
// Batch over the larger of prompt length, or queries.
if ((qbatch.Size() > max_prompt_size) && all_prefix_end_are_zero) {
activations.SetBatchSize(qbatch.Size()); // required before PrefillQBatch
PrefillQBatch(max_prompt_size, config, runtime_config, weights, activations,
qbatch, env, non_eos);
} else {
PrefillTBatch(config, runtime_config, weights, activations, qbatch, env,
non_eos);
activations.SetBatchSize(qbatch.Size()); // Restore after PrefillTBatch.
}
HWY_DASSERT(non_eos.Count() == qbatch.Size());
timing_info.NotifyPrefill(total_prefill_tokens);
// queries_pos have been incremented by Prefill.
// Stream the last prompt token from each query, fill activations.gen_tokens.
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config,
runtime_config, qbatch, non_eos);
}
size_t max_gen_steps = runtime_config.max_generated_tokens;
if (max_prompt_size + max_gen_steps > seq_len) {
HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.",
max_prompt_size, max_gen_steps, seq_len);
max_gen_steps = seq_len - max_prompt_size;
}
const SampleFunc sample_token = ChooseSampleFunc(runtime_config);
{
timing_info.generate_start = hwy::platform::Now();
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
DecodeStepT(config, runtime_config, weights, sample_token, activations,
qbatch, env, non_eos, timing_info);
}
timing_info.NotifyGenerateDone();
}
}
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const ModelConfig& config,
const RuntimeConfig& runtime_config,
const WeightsPtrs& weights, KVCache& kv_cache,
MatMulEnv& env, TimingInfo& timing_info) {
Activations activations(config, runtime_config.prefill_tbatch_size,
kv_cache.SeqLen(), env.ctx.allocator, env.row_ptrs);
AllQueries all_queries(prompt, pos, prefix_end,
hwy::Span<KVCache>(&kv_cache, 1));
QBatch qbatch(/*start=*/0, /*max_size=*/1, all_queries);
GenerateT(config, runtime_config, weights, activations, qbatch, env,
timing_info);
}
// Splits the input into batches of at most `runtime_config.decode_qbatch_size`
// queries, and calls `GenerateT` on each batch.
void GenerateBatchT(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const WeightsPtrs& weights, AllQueries& all_queries,
MatMulEnv& env, TimingInfo& timing_info) {
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
runtime_config.prefill_tbatch_size);
Activations activations(config, max_batch_size,
all_queries[0].kv_cache.SeqLen(), env.ctx.allocator,
env.row_ptrs);
for (size_t start = 0; start < all_queries.NumQueries();
start += runtime_config.decode_qbatch_size) {
QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries);
// Generate a batch of one token for each of `qbatch.Size()` queries.
GenerateT(config, runtime_config, weights, activations, qbatch, env,
timing_info);
}
}
void GenerateImageTokensT(const ModelConfig& config,
const RuntimeConfig& runtime_config, size_t seq_len,
const WeightsPtrs& weights, const Image& image,
ImageTokens& image_tokens, MatMulEnv& env) {
if (config.vit_config.layer_configs.empty()) {
HWY_ABORT("Model does not support generating image tokens.");
}
RuntimeConfig prefill_runtime_config = runtime_config;
const ModelConfig vit_config = GetVitConfig(config);
const size_t num_tokens = vit_config.max_seq_len;
prefill_runtime_config.prefill_tbatch_size =
num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
Activations prefill_activations(vit_config, num_tokens, num_tokens,
env.ctx.allocator, env.row_ptrs);
// Weights are for the full PaliGemma model, not just the ViT part.
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
prefill_activations, env);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(GenerateSingleT);
HWY_EXPORT(GenerateBatchT);
HWY_EXPORT(GenerateImageTokensT);
Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
ThreadingContext& ctx)
: reader_(loader.weights),
model_(reader_, loader.tokenizer, loader.wrapping),
weights_(model_.Config()),
chat_template_(model_.Tokenizer(), model_.Config().model),
inference_(inference) {
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference,
mat_owners_, ctx);
// Read everything into memory, or `weights_.mapped_` keeps the mapping alive.
reader_.CloseFile();
}
Gemma::~Gemma() = default;
void Gemma::Save(const Path& weights_path, NestedPools& pools) const {
BlobWriter writer(weights_path, pools.Pool());
const std::vector<uint32_t> serialized_mat_ptrs =
weights_.AddTensorDataToWriter(writer);
WriteSingleFile(model_.Config(), model_.Tokenizer(), serialized_mat_ptrs,
writer);
}
void Gemma::Generate(const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, size_t prefix_end,
KVCache& kv_cache, TimingInfo& timing_info) {
env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning);
KVCache& kv_cache, MatMulEnv& env,
TimingInfo& timing_info) const {
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
model_.CallForModelWeight<GenerateSingleT>(
runtime_config, prompt, pos, prefix_end, kv_cache, &env_, timing_info);
HWY_DYNAMIC_DISPATCH(GenerateSingleT)(prompt, pos, prefix_end,
model_.Config(), runtime_config,
weights_, kv_cache, env, timing_info);
env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning);
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
const KVCaches& kv_caches, TimingInfo& timing_info) {
// If we did not get passed prefix ends (size 0), assume 0 and pass that on.
QueriesPos mutable_queries_prefix_end = queries_prefix_end;
std::vector<size_t> prefix_end_vec;
if (queries_prefix_end.size() == 0) {
prefix_end_vec.resize(queries_prompt.size(), 0);
mutable_queries_prefix_end =
QueriesPos(prefix_end_vec.data(), prefix_end_vec.size());
}
AllQueries& all_queries, MatMulEnv& env,
TimingInfo& timing_info) const {
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning);
HWY_DYNAMIC_DISPATCH(GenerateBatchT)(model_.Config(), runtime_config,
weights_, all_queries, env, timing_info);
model_.CallForModelWeight<GenerateBatchT>(
runtime_config, queries_prompt, queries_pos, mutable_queries_prefix_end,
kv_caches, &env_, timing_info);
env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning);
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}
void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens) {
env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning);
size_t seq_len, const Image& image,
ImageTokens& image_tokens,
MatMulEnv& env) const {
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
model_.CallForModelWeight<GenerateImageTokensT>(runtime_config, image,
image_tokens, &env_);
HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(model_.Config(), runtime_config,
seq_len, weights_, image,
image_tokens, env);
env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning);
}
// Non-template functions moved from gemma-inl.h to avoid ODR violations.
void RangeChecks(const ModelConfig& weights_config,
size_t& max_generated_tokens, const size_t prompt_size) {
if (!weights_config.use_local_attention) {
if (max_generated_tokens > weights_config.seq_len) {
fprintf(stderr,
"WARNING: max_generated_tokens %zu > kSeqLen %u, truncating.\n",
max_generated_tokens, weights_config.seq_len);
max_generated_tokens = weights_config.seq_len;
}
}
HWY_ASSERT(prompt_size > 0);
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}
} // namespace gcpp
#endif // HWY_ONCE

View File

@ -16,122 +16,154 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
#include <functional>
#include <random>
#include <string>
#include <stdio.h>
#include <vector>
// IWYU pragma: begin_exports
#include "compression/io.h" // Path
#include "gemma/activations.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/gemma_args.h"
#include "gemma/kv_cache.h"
#include "gemma/tokenizer.h"
#include "gemma/model_store.h"
#include "gemma/weights.h"
#include "io/blob_store.h"
#include "io/io.h" // Path
#include "ops/matmul.h" // MatMulEnv
#include "paligemma/image.h"
#include "util/allocator.h" // RowVectorBatch
#include "util/basics.h" // TokenAndProb
#include "util/basics.h" // TokenAndProb
#include "util/threading_context.h"
#include "hwy/timer.h"
// IWYU pragma: end_exports
#include "hwy/aligned_allocator.h" // Span
namespace gcpp {
using PromptTokens = hwy::Span<const int>;
// Batches of independent queries have their own prompt, previous token,
// position in the sequence, and KVCache.
using QueriesPromptTokens = hwy::Span<const PromptTokens>;
using QueriesToken = hwy::Span<const int>;
using QueriesPos = hwy::Span<const size_t>;
using KVCaches = hwy::Span<KVCache>;
struct PerQuery {
PromptTokens prompt;
// StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f. StreamFunc should return false to stop generation and
// true to continue generation.
using StreamFunc = std::function<bool(int, float)>;
// BatchStreamFunc is called with (query_idx, pos, token, probability).
// For prompt tokens, probability is 0.0f.
// StreamFunc should return false to stop generation and true to continue.
using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
// If not empty, AcceptFunc is called with token. It should return false for
// tokens you don't want to generate and true for tokens you want to generate.
using AcceptFunc = std::function<bool(int, float)>;
// If not empty, SampleFunc is called with the logits for the next token, which
// it may modify/overwrite, and its return value is the next generated token
// together with its probability.
using SampleFunc = std::function<TokenAndProb(float*, size_t)>;
// If not empty, LayersOutputFunc is called for layer outputs, specified with:
// - index of query within containing batch (if any); zero otherwise.
// - position in the tokens sequence
// - name of the data, e.g. "tokens" for token IDs
// - layer index (or -1 for global outputs)
// - pointer to the data array
// - size of the data array
using LayersOutputFunc = std::function<void(size_t, size_t, const std::string&,
int, const float*, size_t)>;
// If not empty, ActivationsObserverFunc is invoked after each layer with:
// - per-query position within the tokens sequence
// - layer index (or -1 for post-norm output)
// - activations
using ActivationsObserverFunc =
std::function<void(const QueriesPos& queries_pos, int, const Activations&)>;
// Position in the KV cache: initially zero for the first turn, or when
// multi-turn is NOT desired. Incremented by prefill and `StreamAndUpdateEOS`.
size_t mutable_pos;
// Allows computing the last prefill token as `mutable_pos - initial_pos`,
// which might differ from `prompt.size() - 1` for prefix-LM.
size_t initial_pos;
// Zero for causal attention, or the end of the prefix for prefix-LM style
// attention in Paligemma.
size_t prefix_end;
// ImageTokens are represented as a RowVectorBatch, where each "batch" index
// corresponds to a token for an image patch as computed by the image encoder.
using ImageTokens = RowVectorBatch<float>;
KVCache& kv_cache;
// RuntimeConfig holds configuration for a single generation run.
struct RuntimeConfig {
// If not empty, batch_stream_token is called for each token in the batch,
// instead of stream_token.
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
if (batch_stream_token) {
return batch_stream_token(query_idx, pos, token, prob);
// Previous token generated for this query, or the last prompt token. Will be
// fed into the next Transformer() call.
int prev_token = 0;
};
// Array of `PerQuery`. Referenced by `QBatch` and passed to `GenerateBatch`.
struct AllQueries {
AllQueries() = default;
// For `GenerateSingleT`: same prompt/pos, replicated for each KV cache.
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const hwy::Span<KVCache>& kv_caches) {
per_query_.reserve(kv_caches.size());
for (size_t i = 0; i < kv_caches.size(); ++i) {
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
per_query_.push_back(PerQuery{
.prompt = prompt,
.mutable_pos = pos,
.initial_pos = pos,
.prefix_end = prefix_end,
.kv_cache = kv_caches[i],
});
}
return stream_token(token, prob);
}
// Limit on the number of tokens generated.
size_t max_generated_tokens;
// Batch of queries with initial position set to zero. Causal attention
// is requested via empty or all-zero `prefix_end`.
AllQueries(
const hwy::Span<const PromptTokens>& prompts,
const hwy::Span<KVCache>& kv_caches,
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) {
HWY_ASSERT(prompts.size() == kv_caches.size());
HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0);
per_query_.reserve(kv_caches.size());
for (size_t i = 0; i < kv_caches.size(); ++i) {
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
per_query_.push_back(PerQuery{
.prompt = prompts[i],
.mutable_pos = 0,
.initial_pos = 0,
.prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i],
.kv_cache = kv_caches[i],
});
}
}
// These defaults are overridden by InferenceArgs::CopyTo(*this):
// Max tokens per batch during prefill.
size_t prefill_tbatch_size = 256;
// Max queries per batch (one token from each) during decode.
size_t decode_qbatch_size = 16;
void Reserve(size_t size) { per_query_.reserve(size); }
void Append(const PerQuery& query) { per_query_.push_back(query); }
// Sampling-related parameters.
float temperature; // Temperature for sampling.
size_t top_k = kTopK; // Top-k for sampling.
std::mt19937* gen; // Random number generator used for sampling.
size_t NumQueries() const { return per_query_.size(); }
int verbosity; // Controls verbosity of printed messages.
PerQuery& operator[](size_t query_idx) {
HWY_DASSERT(query_idx < NumQueries());
return per_query_[query_idx];
}
const PerQuery& operator[](size_t query_idx) const {
HWY_DASSERT(query_idx < NumQueries());
return per_query_[query_idx];
}
// Functions operating on the generated tokens.
StreamFunc stream_token;
BatchStreamFunc batch_stream_token;
AcceptFunc accept_token; // if empty, accepts all tokens.
SampleFunc sample_func; // if empty, uses SampleTopK.
private:
std::vector<PerQuery> per_query_;
};
// Observer callbacks for intermediate data.
LayersOutputFunc layers_output; // if not empty, called after each layer.
ActivationsObserverFunc activations_observer; // if set, called per-layer.
// View into AllQueries: either a batch of queries, or a single query for use
// in PrefillTBatch or GenerateSingleT. Cheap to create because it holds a
// reference to AllQueries.
class QBatch {
public:
QBatch(size_t start, size_t max_size, AllQueries& queries)
: start_(start),
max_size_(max_size),
queries_(queries),
size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) {
HWY_ASSERT(max_size_ <= 4096); // non_eos uses `BitSet4096`.
HWY_DASSERT(size_ != 0);
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
}
// If not empty, these point to the image tokens and are used in the
// PaliGemma prefix-LM style attention.
const ImageTokens *image_tokens = nullptr;
// Returns a single-query view starting at `qi` relative to this batch.
QBatch Single(size_t qi) const { return QBatch(start_ + qi, 1, queries_); }
// Whether to use thread spinning to reduce barrier synchronization latency.
// Mutable so we can change kDefault to kTrue/kFalse during Generate, because
// RuntimeConfig is const there and is not passed to the Gemma ctor. This
// default decision is likely sufficient because it is based on whether
// threads are successfully pinned.
mutable Tristate use_spinning = Tristate::kDefault;
// How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`.
size_t Size() const { return size_; }
// End-of-sequence token.
int eos_id = EOS_ID;
// Returns index for use with `AllQueries` and `BatchStreamToken`.
size_t QueryIdx(size_t qi) const {
HWY_DASSERT(qi < size_);
return start_ + qi;
}
// Accessor functions to bridge the previous SoA and current AoS layout.
const PromptTokens& Prompt(size_t qi) const {
return queries_[QueryIdx(qi)].prompt;
}
size_t Pos(size_t qi) const { return queries_[QueryIdx(qi)].mutable_pos; }
size_t& MutablePos(size_t qi) { return queries_[QueryIdx(qi)].mutable_pos; }
size_t InitialPos(size_t qi) const {
return queries_[QueryIdx(qi)].initial_pos;
}
size_t PrefixEnd(size_t qi) const {
return queries_[QueryIdx(qi)].prefix_end;
}
KVCache& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; }
private:
size_t start_;
size_t max_size_;
AllQueries& queries_;
size_t size_;
};
struct TimingInfo {
@ -193,82 +225,58 @@ struct TimingInfo {
size_t tokens_generated = 0;
};
// After construction, all methods are const and thread-compatible if using
// separate ThreadingContext for each thread.
class Gemma {
public:
// Reads old format weights file and tokenizer file.
// `env` must remain valid for the lifetime of this Gemma.
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
MatMulEnv& env);
// Reads new format weights file that contains everything in a single file.
// `env` must remain valid for the lifetime of this Gemma.
Gemma(const Path& weights, MatMulEnv& env);
// Allocates weights, caller is responsible for filling them.
Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env);
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
// `ctx` is only used to read tensors, but it is typically also referenced
// by the `MatMulEnv` passed to the Generate* methods.
Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
ThreadingContext& ctx);
~Gemma();
const ModelConfig& GetModelConfig() const { return model_.Config(); }
ModelInfo Info() const {
return ModelInfo({.model = model_.Config().model,
.wrapping = model_.Config().wrapping,
.weight = model_.Config().weight});
}
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
const ModelWeightsStorage& Weights() const { return model_; }
ModelWeightsStorage& MutableWeights() { return model_; }
void Save(const Path& weights, hwy::ThreadPool& pool) {
std::string tokenizer_proto = tokenizer_.Serialize();
model_.Save(tokenizer_proto, weights, pool);
}
const ModelConfig& Config() const { return model_.Config(); }
const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); }
const WeightsPtrs& Weights() const { return weights_; }
WeightsPtrs::Mode WeightReadMode() const { return weight_read_mode_; }
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
const InferenceArgs& Inference() const { return inference_; }
void Save(const Path& weights_path, NestedPools& pools) const;
// `pos` is the position in the KV cache. Users are responsible for
// incrementing it in the `*StreamFunc`, or setting to zero for single-turn.
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
size_t pos, KVCache& kv_cache, TimingInfo& timing_info) {
Generate(runtime_config, prompt, pos, /*prefix_end=*/0, kv_cache,
size_t pos, KVCache& kv_cache, MatMulEnv& env,
TimingInfo& timing_info) const {
Generate(runtime_config, prompt, pos, /*prefix_end=*/0, kv_cache, env,
timing_info);
}
// For prefix-LM style attention, we can pass the end of the prefix.
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
size_t pos, size_t prefix_end, KVCache& kv_cache,
TimingInfo& timing_info);
MatMulEnv& env, TimingInfo& timing_info) const;
// `queries_pos` are the positions in the KV cache. Users are responsible for
// incrementing them in `BatchStreamFunc`, or setting to zero for single-turn.
void GenerateBatch(const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos, const KVCaches& kv_caches,
TimingInfo& timing_info) {
GenerateBatch(runtime_config, queries_prompt, queries_pos,
/*queries_prefix_end=*/{}, kv_caches, timing_info);
}
// For prefix-LM style attention, we can pass the ends of the prefixes.
void GenerateBatch(const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
const KVCaches& kv_caches, TimingInfo& timing_info);
AllQueries& all_queries, MatMulEnv& env,
TimingInfo& timing_info) const;
// Generates the image tokens by running the image encoder ViT.
void GenerateImageTokens(const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens);
void GenerateImageTokens(const RuntimeConfig& runtime_config, size_t seq_len,
const Image& image, ImageTokens& image_tokens,
MatMulEnv& env) const;
private:
MatMulEnv& env_;
GemmaTokenizer tokenizer_;
// Type-erased so that this can be defined in the header.
ModelWeightsStorage model_;
BlobReader reader_;
ModelStore model_;
std::vector<MatOwner> mat_owners_;
WeightsPtrs weights_;
WeightsPtrs::Mode weight_read_mode_;
GemmaChatTemplate chat_template_;
InferenceArgs inference_;
};
// Adds BOS token and possibly 'turn' annotations, which depend on `info`
// and `pos`, the number of tokens decoded so far; returns the corresponding
// tokens. Asserts that tokenization is successful.
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const ModelInfo& info, size_t pos,
std::string& prompt);
void RangeChecks(const ModelConfig& weights_config,
size_t& max_generated_tokens, size_t prompt_size);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_

273
gemma/gemma_args.h Normal file
View File

@ -0,0 +1,273 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Shared between various frontends.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
#include <stddef.h>
#include <stdio.h>
#include <functional>
#include <random>
#include <string>
#include "io/io.h" // Path
#include "ops/matmul.h" // MMStorage::kMax*
#include "util/args.h"
#include "util/basics.h" // Tristate
#include "util/mat.h"
#include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h" // HWY_ABORT
#include "hwy/profiler.h"
namespace gcpp {
struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
LoaderArgs(const std::string& tokenizer_path,
const std::string& weights_path) {
Init(); // Init sets to defaults, so assignments must come after Init().
tokenizer.path = tokenizer_path;
weights.path = weights_path;
};
Path tokenizer;
Path weights; // weights file location
Tristate map;
Tristate to_bf16;
Tristate wrapping;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(tokenizer, "tokenizer", Path(),
"Path name of tokenizer model; only required for pre-2025 format.");
visitor(weights, "weights", Path(),
"Path name of model weights (.sbs) file.\n Required argument.\n");
visitor(map, "map", Tristate::kDefault,
"Enable memory-mapping? -1 = auto, 0 = no, 1 = yes.");
visitor(to_bf16, "to_bf16", Tristate::kDefault,
"Convert weights to bf16? -1 = auto, 0 = no, 1 = yes.");
visitor(wrapping, "wrapping", Tristate::kDefault,
"Enable prompt wrapping? Specify 0 for pre-2025 format PT models.");
}
};
using PromptTokens = hwy::Span<const int>;
// Batches of independent queries have their own prompt, previous token,
// position in the sequence, and KVCache.
using QueriesPromptTokens = hwy::Span<const PromptTokens>;
using QueriesToken = hwy::Span<const int>;
using QueriesPos = hwy::Span<const size_t>;
// ImageTokens are represented as a matrix, where each row corresponds
// to a token for an image patch as computed by the image encoder.
using ImageTokens = MatStorageT<float>;
// StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f. StreamFunc should return false to stop generation and
// true to continue generation.
using StreamFunc = std::function<bool(int, float)>;
// BatchStreamFunc is called with (query_idx, pos, token, probability).
// For prompt tokens, probability is 0.0f. Generation continues if this returns
// true and stops if it returns false. Note that query_idx is absolute, not
// relative to the batch.
using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
// If not empty, AcceptFunc is called with token. It should return false for
// tokens you don't want to generate and true for tokens you want to generate.
using AcceptFunc = std::function<bool(int, float)>;
// If not empty, SampleFunc is called with the logits for the next token, which
// it may modify/overwrite, and its return value is the next generated token
// together with its probability.
using SampleFunc = std::function<TokenAndProb(float*, size_t)>;
// If not empty, LayersOutputFunc is called for layer outputs, specified with:
// - index of query within containing batch (if any); zero otherwise.
// - position in the tokens sequence
// - name of the data, e.g. "tokens" for token IDs
// - layer index (or -1 for global outputs)
// - pointer to the data array
// - size of the data array
using LayersOutputFunc = std::function<void(size_t, size_t, const std::string&,
int, const float*, size_t)>;
// If not empty, ActivationsObserverFunc is invoked after each layer with:
// - per-query position within the tokens sequence
// - layer index (or -1 for post-norm output)
// - activations
struct Activations;
using ActivationsObserverFunc =
std::function<void(const QueriesPos& queries_pos, int, const Activations&)>;
// RuntimeConfig holds configuration for a single generation run.
// TODO: move into InferenceArgs, use that directly.
struct RuntimeConfig {
// If non-null, `batch_stream_token` is called for each token in the batch,
// otherwise `stream_token`. `query_idx` is absolute, not batch-relative.
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
PROFILER_ZONE("Gen.StreamToken");
if (batch_stream_token) {
return batch_stream_token(query_idx, pos, token, prob);
}
return stream_token(token, prob);
}
// Limit on the number of tokens generated.
size_t max_generated_tokens;
// These defaults are overridden by InferenceArgs::CopyTo(*this):
// Max tokens per batch during prefill.
size_t prefill_tbatch_size = 256;
// Max queries per batch (one token from each) during decode.
size_t decode_qbatch_size = 16;
// Sampling-related parameters.
float temperature; // Temperature for sampling.
size_t top_k = 1; // Top-k for sampling.
std::mt19937* gen; // Random number generator used for sampling.
int verbosity; // Controls verbosity of printed messages.
// Functions operating on the generated tokens.
StreamFunc stream_token;
BatchStreamFunc batch_stream_token;
AcceptFunc accept_token; // if empty, accepts all tokens.
SampleFunc sample_func; // if empty, uses SampleTopK.
// Observer callbacks for intermediate data.
LayersOutputFunc layers_output; // if not empty, called after each layer.
ActivationsObserverFunc activations_observer; // if set, called per-layer.
// If not empty, these point to the image tokens and are used in the
// PaliGemma prefix-LM style attention.
const ImageTokens* image_tokens = nullptr;
// Whether to use thread spinning to reduce barrier synchronization latency.
// Mutable so we can change kDefault to kTrue/kFalse during Generate, because
// RuntimeConfig is const there and is not passed to the Gemma ctor. This
// default decision is likely sufficient because it is based on whether
// threads are successfully pinned.
mutable Tristate use_spinning = Tristate::kDefault;
};
struct InferenceArgs : public ArgsBase<InferenceArgs> {
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
InferenceArgs() { Init(); };
bool IsInteractive() const { return prompt.empty() && prompt_file.Empty(); }
int verbosity;
size_t seq_len;
size_t max_generated_tokens;
size_t prefill_tbatch_size;
size_t decode_qbatch_size;
float temperature;
size_t top_k;
bool deterministic;
bool multiturn;
Path image_file;
std::string prompt; // Bypasses std::getline
// For prompts longer than the Linux terminal's 4K line edit buffer.
Path prompt_file;
std::string eot_line;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(verbosity, "verbosity", 1,
"Show verbose developer information\n 0 = only print generation "
"output\n 1 = standard user-facing terminal ui\n 2 = show "
"developer/debug info).\n Default = 1.",
1);
visitor(seq_len, "seq_len", size_t{8192},
"Sequence length, capped by ModelConfig.max_seq_len.");
visitor(max_generated_tokens, "max_generated_tokens", size_t{4096},
"Maximum number of tokens to generate.");
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256},
"Prefill: max tokens per batch.");
visitor(decode_qbatch_size, "decode_qbatch", size_t{16},
"Decode: max queries per batch.");
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
visitor(top_k, "top_k", size_t{1}, "Number of top-K tokens to sample from",
2);
visitor(deterministic, "deterministic", false,
"Make top-k sampling deterministic", 2);
visitor(multiturn, "multiturn", false,
"Multiturn mode\n 0 = clear KV cache after every "
"interaction\n 1 = continue KV cache after every interaction\n "
" Default : 0 (conversation "
"resets every turn)");
visitor(image_file, "image_file", Path(), "Image file to load.");
visitor(prompt, "prompt", std::string(""),
"Initial prompt for non-interactive mode. When specified, "
"generates a response and exits.",
1);
visitor(prompt_file, "prompt_file", Path(),
"Path to file containing the prompt for non-interactive mode. When "
" specified, generates a response and exits.",
1);
visitor(
eot_line, "eot_line", std::string(""),
"End of turn line. "
"When you specify this, the prompt will be all lines "
"before the line where only the given string appears.\n Default = "
"When a newline is encountered, that signals the end of the turn.",
2);
}
void CopyTo(RuntimeConfig& runtime_config) const {
runtime_config.max_generated_tokens = max_generated_tokens;
runtime_config.prefill_tbatch_size = prefill_tbatch_size;
runtime_config.decode_qbatch_size = decode_qbatch_size;
if (prefill_tbatch_size > MMStorage::kMaxM) {
HWY_ABORT(
"prefill_tbatch_size %zu > kMaxM %zu: specify a smaller value, "
"or increase the constant in MMStorage.\n",
prefill_tbatch_size, MMStorage::kMaxM);
}
if (decode_qbatch_size > MMStorage::kMaxM) {
HWY_ABORT(
"decode_qbatch_size %zu > kMaxM %zu: specify a smaller value, "
"or increase the constant in MMStorage.\n",
decode_qbatch_size, MMStorage::kMaxM);
}
runtime_config.temperature = temperature;
runtime_config.top_k = top_k;
}
};
static inline ThreadingArgs UpdateArgs(const ThreadingArgs& threading_args,
const InferenceArgs& inference_args) {
if (inference_args.decode_qbatch_size >= 256) {
ThreadingArgs copy = threading_args;
copy.max_packages = 1;
return copy;
}
return threading_args;
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_

192
gemma/griffin.cc Normal file
View File

@ -0,0 +1,192 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stddef.h>
#include <stdint.h>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
#include "gemma/activations.h"
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/weights.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/griffin.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "ops/matvec-inl.h"
#include "ops/ops-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
void GriffinRecurrent(size_t num_tokens, size_t griffin_layer,
const LayerWeightsPtrs* layer_weights,
Activations& activations, QBatch& qbatch,
MatMulEnv& env) {
PROFILER_ZONE("Gen.Griffin");
hwy::ThreadPool& pool = env.ctx.pools.Pool(0);
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D df;
const size_t model_dim = layer_weights->layer_config.model_dim;
HWY_DASSERT(model_dim % hn::Lanes(df) == 0);
const size_t heads = layer_weights->layer_config.heads;
const size_t conv_1d_width = layer_weights->layer_config.conv1d_width;
HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even");
const size_t kHeadDim = model_dim / heads;
const size_t kMatrixSize = kHeadDim * kHeadDim;
const size_t num_interleaved = num_tokens * qbatch.Size();
const hwy::Divisor div_qbatch(static_cast<uint32_t>(qbatch.Size()));
GriffinActivations& griffin = activations.griffin;
// X / Y linear layers.
// TODO: MatMul
HWY_DASSERT(griffin.griffin_y.Rows() == griffin.griffin_x.Rows());
HWY_DASSERT(num_interleaved == griffin.griffin_y.Rows());
CallUpcastedSame(
&layer_weights->griffin.linear_x_w, &layer_weights->griffin.linear_y_w,
[&](const auto* wx, const auto* wy) {
for (size_t r = 0; r < num_interleaved; ++r) {
float* HWY_RESTRICT y = griffin.griffin_y.Row(r);
float* HWY_RESTRICT x = griffin.griffin_x.Row(r);
TwoMatVecAdd(
*wx, *wy, 0, model_dim, model_dim,
activations.attention.pre_att_rms_out.Row(r),
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
/*out0=*/x, /*out1=*/y, pool);
Gelu(y, model_dim);
}
});
// Conv1D.
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t pos = qbatch.Pos(qi) + batch_idx;
float* HWY_RESTRICT x = griffin.griffin_x.Row(qi);
// cache[i] = input at time t-i.
float* HWY_RESTRICT cache[kMaxConv1DWidth];
cache[0] = x;
for (size_t i = 1; i < conv_1d_width; i++) {
cache[i] =
qbatch.KV(qi).conv1d_cache.Row(griffin_layer) +
((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim;
}
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
auto xv = hn::Load(df, x + i);
auto accum0 =
hn::Load(df, layer_weights->griffin.conv_biases.PackedScale1() + i);
auto accum1 = hn::Zero(df);
for (size_t l = 0; 2 * l < conv_1d_width; l++) {
auto wv0 =
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
(conv_1d_width - 1 - 2 * l) * model_dim + i);
auto wv1 =
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
(conv_1d_width - 2 - 2 * l) * model_dim + i);
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
}
hn::Store(hn::Add(accum0, accum1), df, x + i);
hn::Store(xv, df, cache[HWY_MAX(conv_1d_width, 1) - 1] + i);
}
}
// RGLRU
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t pos = qbatch.Pos(qi) + batch_idx;
float* HWY_RESTRICT x = griffin.griffin_x.Row(qi);
float* HWY_RESTRICT y = griffin.griffin_y.Row(qi);
float* HWY_RESTRICT gate_x = griffin.griffin_gate_x.Row(qi);
float* HWY_RESTRICT a = griffin.griffin_multiplier.Row(qi);
float* HWY_RESTRICT rnn_state =
qbatch.KV(qi).rglru_cache.Row(griffin_layer);
pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
size_t head_offset = head * kHeadDim;
CallUpcasted(&layer_weights->griffin.gate_w, [&](const auto* gate_w) {
TwoOfsMatVecAddLoop(
*gate_w, kMatrixSize * head, kMatrixSize * (heads + head), kHeadDim,
kHeadDim, x + head_offset,
/*add0=*/layer_weights->griffin.gate_biases.PackedScale1() +
head_offset,
/*add1=*/layer_weights->griffin.gate_biases.PackedScale1() +
model_dim + head_offset,
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
});
Sigmoid(gate_x + head_offset, kHeadDim);
Sigmoid(a + head_offset, kHeadDim);
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
HWY_ATTR { return hn::Mul(x, gate_x); };
hn::Transform1(D(), a + head_offset, kHeadDim,
layer_weights->griffin.a.PackedScale1() + head_offset,
fn_mul);
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
fn_mul);
// RNN scan
HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0);
for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) {
auto log_a = hn::Load(df, a + head_offset + i);
auto gated_x = hn::Load(df, x + head_offset + i);
auto rnn = hn::Load(df, rnn_state + head_offset + i);
auto a = hn::Exp(df, log_a);
auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0f)));
if (pos == 0) {
x_multiplier = hn::Set(df, 1.0f);
}
auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn));
hn::Store(new_x, df, rnn_state + head_offset + i);
// Join branches.
auto yv = hn::Load(df, y + head_offset + i);
auto pre_out = hn::Mul(yv, new_x);
hn::Store(pre_out, df, x + head_offset + i);
}
});
} // interleaved_idx
// Final linear layer.
CallMatMul(griffin.griffin_x, layer_weights->griffin.linear_out_w,
layer_weights->griffin.linear_out_biases.PackedScale1(), env,
activations.attention.att_sums);
} // GriffinRecurrent
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();

47
gemma/griffin.h Normal file
View File

@ -0,0 +1,47 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_
// Declares GriffinRecurrent for all SIMD targets.
#include <stddef.h>
#include "gemma/gemma.h"
#include "hwy/highway.h"
namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_GRIFFIN(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void GriffinRecurrent(size_t num_tokens, size_t griffin_layer, \
const LayerWeightsPtrs* layer_weights, \
Activations& activations, QBatch& qbatch, \
MatMulEnv& env); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE
// Function declarations for each SIMD target. Allows direct call from the
// per-target namespace. We may later replace this with dynamic dispatch if
// the overhead is acceptable.
HWY_VISIT_TARGETS(GEMMA_DECL_GRIFFIN)
#undef GEMMA_DECL_GRIFFIN
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_

View File

@ -15,91 +15,75 @@
#include "gemma/kv_cache.h"
#include <algorithm>
#include <stddef.h>
#include "gemma/common.h" // CallForModel
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // ZeroBytes
#include "gemma/configs.h"
#include "gemma/gemma_args.h"
#include "util/mat.h" // ZeroInit
#include "hwy/base.h" // HWY_MAX
namespace gcpp {
void KVCache::ZeroGriffinCache() {
if (conv1d_cache_size != 0) {
hwy::ZeroBytes(conv1d_cache.get(),
conv1d_cache_size * sizeof(conv1d_cache[0]));
}
if (rglru_cache_size != 0) {
hwy::ZeroBytes(rglru_cache.get(),
rglru_cache_size * sizeof(rglru_cache[0]));
}
if (conv1d_cache.Rows() == 0) return;
ZeroInit(conv1d_cache);
ZeroInit(rglru_cache);
}
// prefill_tbatch_size is the maximum number of tokens from one query to
// prefill at a time.
KVCache KVCache::Create(const ModelConfig& weights_config,
size_t prefill_tbatch_size) {
KVCache kv_cache = {};
const size_t size_cache_pos = weights_config.CachePosSize();
if (size_cache_pos != 0) {
// Allocate more so that prefill can always access one batch, even if
// near the end of the sequence.
kv_cache.seq_len = weights_config.seq_len + prefill_tbatch_size;
kv_cache.kv_cache =
hwy::AllocateAligned<float>(kv_cache.seq_len * size_cache_pos);
}
const size_t num_griffin_layers = weights_config.NumLayersOfType(
LayerAttentionType::kGriffinRecurrentBlock);
// TODO(patrickms): Add query batching support for Griffin.
if (num_griffin_layers > 0) {
uint32_t conv1d_width = 0;
for (const auto& layer_config : weights_config.layer_configs) {
conv1d_width = std::max(conv1d_width, layer_config.conv1d_width);
}
const size_t conv1d_cache_size =
num_griffin_layers * (conv1d_width == 0 ? 0 : conv1d_width - 1) *
weights_config.model_dim;
kv_cache.conv1d_cache_size = conv1d_cache_size;
if (conv1d_cache_size != 0) {
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size);
}
const size_t rglru_cache_size =
num_griffin_layers * weights_config.model_dim;
kv_cache.rglru_cache_size = rglru_cache_size;
if (rglru_cache_size != 0) {
kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size);
}
} // num_griffin_layers
return kv_cache;
static size_t GriffinLayers(const ModelConfig& config) {
return config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock);
}
KVCache KVCache::Copy(const ModelConfig& weights_config,
size_t prefill_tbatch_size) {
KVCache kv_cache_copy = Create(weights_config, prefill_tbatch_size);
static size_t GriffinConv1dCols(const ModelConfig& config) {
size_t conv1d_width = 0;
for (const auto& layer_config : config.layer_configs) {
conv1d_width = HWY_MAX(conv1d_width, layer_config.conv1d_width);
}
// The row offset, in blocks of model_dim is computed mod (conv1d_width - 1),
// hence allocate conv1d_width * model_dim total columns.
return conv1d_width * config.model_dim;
}
const size_t size_cache_pos = weights_config.CachePosSize();
if (size_cache_pos != 0) {
std::copy(kv_cache.get(), kv_cache.get() + size_cache_pos * seq_len,
kv_cache_copy.kv_cache.get());
// Number of rows for KV cache. Note that both rows and cols are u32, and
// the total number of elements can exceed 2^32.
static size_t CappedSeqLen(const ModelConfig& config,
const InferenceArgs& inference_args) {
if (inference_args.seq_len > config.max_seq_len) {
HWY_WARN("Capping seq_len %zu to config.max_seq_len %u.",
inference_args.seq_len, config.max_seq_len);
return config.max_seq_len;
}
return inference_args.seq_len;
}
KVCache::KVCache(const Extents2D& conv1d_extents,
const Extents2D& rglru_extents, const Extents2D& kv_extents,
const Allocator& allocator)
: conv1d_cache("conv1d_cache", conv1d_extents, allocator, MatPadding::kOdd),
rglru_cache("rglru_cache", rglru_extents, allocator, MatPadding::kOdd),
kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
allocator_(allocator) {}
KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
const Allocator& allocator)
: KVCache(
Extents2D(GriffinLayers(config), GriffinConv1dCols(config)),
Extents2D(GriffinLayers(config), config.model_dim),
Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()),
allocator) {}
KVCache KVCache::Copy() {
KVCache copy(conv1d_cache.Extents(), rglru_cache.Extents(),
kv_cache.Extents(), allocator_);
if (conv1d_cache.Rows() != 0) {
CopyMat(conv1d_cache, copy.conv1d_cache);
CopyMat(rglru_cache, copy.rglru_cache);
}
const size_t num_griffin_layers = weights_config.NumLayersOfType(
LayerAttentionType::kGriffinRecurrentBlock);
if (num_griffin_layers > 0) {
if (conv1d_cache_size != 0) {
std::copy(conv1d_cache.get(), conv1d_cache.get() + conv1d_cache_size,
kv_cache_copy.conv1d_cache.get());
}
if (rglru_cache_size != 0) {
std::copy(rglru_cache.get(),
rglru_cache.get() + rglru_cache_size * sizeof(rglru_cache[0]),
kv_cache_copy.rglru_cache.get());
}
}
return kv_cache_copy;
CopyMat(kv_cache, copy.kv_cache);
return copy;
}
} // namespace gcpp

View File

@ -18,34 +18,41 @@
#include <stddef.h>
#include "gemma/common.h" // Model
#include "hwy/aligned_allocator.h"
#include "gemma/configs.h" // ModelConfig
#include "gemma/gemma_args.h" // InferenceArgs
#include "util/basics.h" // BF16
#include "util/mat.h"
namespace gcpp {
using KV_t = float;
struct KVCache {
size_t seq_len = 0; // = kSeqLen + prefill_tbatch_size
KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
const Allocator& allocator);
// seq_len * kGemmaLayers * kKVHeads * kQKVDim * 2
hwy::AlignedFreeUniquePtr<float[]> kv_cache;
// (kConv1dWidth - 1) * kModelDim * kGriffinLayers
hwy::AlignedFreeUniquePtr<float[]> conv1d_cache;
size_t conv1d_cache_size = 0;
// kModelDim * kGriffinLayers
hwy::AlignedFreeUniquePtr<float[]> rglru_cache;
size_t rglru_cache_size = 0;
// Returns a deep copy of the KVCache. Use explicit function instead of
// copy ctor to make the cost explicit.
KVCache Copy();
// Zero-initialize the Griffin recurrent block cache, i.e. the conv1d_cache
// and rglru_cache.
void ZeroGriffinCache();
static KVCache Create(const ModelConfig& weights_config,
size_t prefill_tbatch_size);
size_t SeqLen() const { return kv_cache.Rows(); }
// Returns a deep copy of the KVCache.
KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size);
// [griffin_layers, griffin_conv1d_cols * model_dim]
MatStorageT<float> conv1d_cache;
MatStorageT<float> rglru_cache; // [griffin_layers, model_dim]
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
private:
const Allocator& allocator_;
// For use by other ctor and Copy()
KVCache(const Extents2D& conv1d_extents, const Extents2D& rglru_extents,
const Extents2D& kv_extents, const Allocator& allocator);
};
} // namespace gcpp

464
gemma/model_store.cc Normal file
View File

@ -0,0 +1,464 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/model_store.h"
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <array>
#include <charconv>
#include <cstdlib>
#include <cstring> // strcmp
#include <string>
#include <system_error> // std::errc // NOLINT
#include "compression/types.h"
#include "gemma/configs.h" // ModelConfig, kMaxQKVDim
#include "gemma/tensor_info.h"
#include "gemma/tokenizer.h"
#include "io/blob_store.h"
#include "io/fields.h"
#include "io/io.h" // Path
#include "util/basics.h"
#include "util/threading_context.h"
#include "hwy/base.h"
namespace gcpp {
// Single-file format contains blobs with these names:
static constexpr char kConfigName[] = "config";
static constexpr char kTokenizerName[] = "tokenizer";
static constexpr char kMatPtrsName[] = "toc";
// Pre-2025 format has one metadata blob. 'F' denoted f32.
static constexpr char kDecoratedScalesName[] = "Fscales";
static void WarnIfExtra(const IFields::ReadResult& result, const char* name) {
// No warning if missing_fields > 0: those fields are default-initialized.
if (result.extra_u32) {
HWY_WARN(
"Serialized blob %s has %u extra fields the code is not aware of. "
"Consider updating to the latest code from GitHub.",
name, result.extra_u32);
}
}
// Returns the serialized tokenizer (std::string is required for proto).
// Reads it from a blob or from a separate file if pre-2025.
static std::string ReadTokenizer(BlobReader& reader,
const Path& tokenizer_path) {
std::string tokenizer;
// Check prevents `CallWithSpan` from printing a warning.
if (reader.Find(kTokenizerName)) {
if (!reader.CallWithSpan<char>(
kTokenizerName, [&tokenizer](const hwy::Span<const char> bytes) {
tokenizer.assign(bytes.data(), bytes.size());
})) {
HWY_WARN(
"Reading tokenizer blob failed, please raise an issue. You can "
"instead specify a tokenizer file via --tokenizer.");
}
}
// Read actual tokenizer from blob.
if (!tokenizer.empty() && tokenizer != kMockTokenizer) {
if (!tokenizer_path.Empty()) {
HWY_WARN("--weights has tokenizer but overriding with %s.",
tokenizer_path.path.c_str());
return ReadFileToString(tokenizer_path);
}
return tokenizer;
}
// No blob but user specified path to file: read it or abort.
if (!tokenizer_path.Empty()) {
return ReadFileToString(tokenizer_path);
}
HWY_WARN(
"BlobStore does not contain a tokenizer and no --tokenizer was "
"specified. Tests may continue but inference will fail.");
return kMockTokenizer;
}
using KeyVec = std::vector<std::string>;
class TypePrefix {
public:
static Type TypeFromChar(char c) {
switch (c) {
case 'F':
return Type::kF32;
case 'B':
return Type::kBF16;
case '$':
return Type::kSFP;
case '2':
return Type::kNUQ;
default:
// The other types were not written to pre-2025 files, hence no need to
// encode and check for them here.
return Type::kUnknown;
}
}
TypePrefix(const KeyVec& keys, const BlobReader& reader) {
for (size_t key_idx = 0; key_idx < keys.size(); ++key_idx) {
const std::string& key = keys[key_idx];
const Type type = TypeFromChar(key[0]);
const uint64_t bytes = reader.Range(key_idx).bytes;
bytes_[static_cast<size_t>(type)] += bytes;
blobs_[static_cast<size_t>(type)]++;
total_bytes_ += bytes;
}
}
// Returns true for pre-2025 format, which has type prefixes and thus the
// functions below may be used.
bool HasPrefixes() const {
return bytes_[static_cast<size_t>(Type::kUnknown)] != total_bytes_;
}
// Returns the weight type deduced from the histogram of blobs per type.
// Rationale: We expect a mix of types due to varying precision requirements
// for each tensor. The preferred weight type might not even be the most
// common, because we prioritize higher compression for the *large* tensors.
// Ignore types which only have a few blobs (might be metadata), and assume
// that there would be at least 4 of the large tensors (in particular, global
// attention layers). Hence return the smallest type with >= 4 blobs.
Type DeduceWeightType() const {
size_t min_bits = ~size_t{0};
Type weight_type = Type::kUnknown;
for (size_t i = 0; i < kNumTypes; ++i) {
if (blobs_[i] < 4) continue;
const size_t bits = TypeBits(static_cast<Type>(i));
if (bits < min_bits) {
min_bits = bits;
weight_type = static_cast<Type>(i);
}
}
return weight_type;
}
// Prints statistics on the total size of tensors by type.
void PrintTypeBytes() const {
for (size_t type_idx = 0; type_idx < kNumTypes; ++type_idx) {
const Type type = static_cast<Type>(type_idx);
const uint64_t bytes = bytes_[type_idx];
if (bytes == 0) continue;
const double percent = 100.0 * bytes / total_bytes_;
fprintf(stderr, "%12zu blob bytes (%5.2f%%) of %4s\n",
static_cast<size_t>(bytes), percent, TypeName(type));
}
}
private:
uint64_t total_bytes_ = 0;
std::array<size_t, kNumTypes> bytes_{0};
std::array<size_t, kNumTypes> blobs_{0};
};
// Returns 0 if the blob does not seem to be a per-layer tensor, otherwise the
// layer index.
static size_t LayerIdxFromKey(const std::string& key) {
const auto parse_num = [&key](size_t begin, size_t end) -> int {
HWY_DASSERT(begin <= end);
HWY_DASSERT(end <= key.size());
int val = 0;
auto [ptr, ec] = std::from_chars(key.data() + begin, key.data() + end, val);
return (ec == std::errc()) ? val : -1;
};
const size_t suffix_pos = key.rfind('_');
// If there is no digit after the last underscore, it is not a layer name.
if (suffix_pos == std::string::npos) return 0;
if (suffix_pos == key.size() - 1) return 0;
int layer_idx = parse_num(suffix_pos + 1, key.size());
HWY_ASSERT(layer_idx < 999);
return layer_idx == -1 ? 0 : static_cast<size_t>(layer_idx);
}
// Returns the number of layers based on the largest blob name suffix seen.
// This works with or without type prefixes because it searches for suffixes.
static size_t DeduceNumLayers(const KeyVec& keys) {
// Built-in self-test.
{
HWY_ASSERT(LayerIdxFromKey("gr_conv_w_2") == 2); // common case
HWY_ASSERT(LayerIdxFromKey("prefix_") == 0); // no number
HWY_ASSERT(LayerIdxFromKey("c_embedding") == 0); // per-model
HWY_ASSERT(LayerIdxFromKey("c_final_norm") == 0); // per-model, two _
}
size_t max_layer_idx = 0;
for (const std::string& key : keys) {
max_layer_idx = HWY_MAX(max_layer_idx, LayerIdxFromKey(key));
}
return max_layer_idx + 1;
}
// Looks for known tensor names associated with model families.
// This works with or without type prefixes because it searches for substrings.
static int DeduceLayerTypes(const BlobReader& reader) {
int layer_types = 0;
for (size_t key_idx = 0; key_idx < reader.Keys().size(); ++key_idx) {
const std::string& key = reader.Keys()[key_idx];
if (key.find("gr_conv_w") != std::string::npos) { // NOLINT
return kDeducedGriffin;
}
if (key.find("qkv_ein_w") != std::string::npos) { // NOLINT
layer_types |= kDeducedViT;
}
if (key.find("img_pos_emb") != std::string::npos) { // NOLINT
// About 5.88 elements per pixel; assume at least bf16.
if (reader.Range(key_idx).bytes > 448 * 448 * 5 * sizeof(BF16)) {
layer_types |= kDeduced448;
}
}
}
return layer_types;
}
// `wrapping_override` is forwarded from the command line. For pre-2025 files
// without `ModelConfig`, it is the only way to force PT.
static ModelConfig ReadOrDeduceConfig(BlobReader& reader,
Tristate wrapping_override) {
const TypePrefix type_prefix(reader.Keys(), reader);
Type deduced_weight = Type::kUnknown;
if (type_prefix.HasPrefixes()) {
deduced_weight = type_prefix.DeduceWeightType();
type_prefix.PrintTypeBytes();
}
// Always deduce so we can verify it against the config we read.
const size_t layers = DeduceNumLayers(reader.Keys());
const int layer_types = DeduceLayerTypes(reader);
const Model deduced_model =
DeduceModel(reader.blob_path(), layers, layer_types);
ModelConfig config;
// Check first to prevent `CallWithSpan` from printing a warning.
if (reader.Find(kConfigName)) {
HWY_ASSERT(reader.CallWithSpan<uint32_t>(
kConfigName, [&config](const SerializedSpan serialized) {
const IFields::ReadResult result = config.Read(serialized, 0);
WarnIfExtra(result, kConfigName);
HWY_ASSERT_M(result.pos != 0, "Error deserializing config");
}));
HWY_ASSERT(config.model != Model::UNKNOWN);
HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel);
HWY_ASSERT(config.weight != Type::kUnknown);
for (const LayerConfig& layer_config : config.layer_configs) {
if (static_cast<size_t>(layer_config.qkv_dim) > kMaxQKVDim) {
HWY_ABORT("Increase kMaxQKVDim to at least %u.", layer_config.qkv_dim);
}
}
// We trust the deserialized config, but checking helps to validate the
// deduction, which we rely on below for pre-2025 files.
if (config.model != deduced_model) {
const std::string suffix = WrappingSuffix(config.wrapping);
HWY_WARN("Detected model %s does not match config %s.",
(std::string(ModelPrefix(deduced_model)) + suffix).c_str(),
(std::string(ModelPrefix(config.model)) + suffix).c_str());
}
return config;
}
// Pre-2025 format: no config, rely on deduction plus `wrapping_override`.
return ModelConfig(deduced_model, deduced_weight,
ChooseWrapping(deduced_model, wrapping_override));
}
static std::vector<float> ReadScales(BlobReader& reader,
const ModelConfig& config) {
std::vector<float> scales;
// Check first to prevent `CallWithSpan` from printing a warning. This blob is
// optional even in pre-2025 format; Griffin was the first to include it.
if (reader.Find(kDecoratedScalesName)) {
HWY_ASSERT(reader.CallWithSpan<float>(
kDecoratedScalesName,
[&scales](const hwy::Span<const float> scales_blob) {
scales.assign(scales_blob.cbegin(), scales_blob.cend());
}));
}
return scales;
}
// Single-file format: reads `MatPtr` from the blob; returns false if not found.
bool ModelStore::ReadMatPtrs(BlobReader& reader) {
// Check first to prevent `CallWithSpan` from printing a warning.
if (!reader.Find(kMatPtrsName)) return false;
// For verifying `config_.weight`.
size_t min_bits = ~size_t{0};
Type weight_type = Type::kUnknown;
HWY_ASSERT(reader.CallWithSpan<uint32_t>(
kMatPtrsName, [&, this](SerializedSpan serialized) {
for (size_t pos = 0; pos < serialized.size();) {
MatPtr mat;
const IFields::ReadResult result = mat.Read(serialized, pos);
WarnIfExtra(result, mat.Name());
if (result.pos == 0) {
HWY_ABORT("Deserializing MatPtr %s failed (pos %zu of %zu).",
mat.Name(), pos, serialized.size());
}
pos = result.pos + result.extra_u32;
// Retrieve actual key index because a writer may have written other
// blobs before the tensor data.
const BlobRange* range = reader.Find(mat.Name());
HWY_ASSERT(range);
const size_t key_idx = range->key_idx;
AddMatPtr(key_idx, mat);
const size_t bits = TypeBits(mat.GetType());
if (bits < min_bits) {
min_bits = bits;
weight_type = mat.GetType();
}
}
}));
HWY_ASSERT(weight_type != Type::kUnknown);
HWY_ASSERT(weight_type == config_.weight);
return true;
}
// Pre-2025 format: synthesizes `MatPtr` from the blob names if `!ReadMatPtrs`.
void ModelStore::CreateMatPtrs(BlobReader& reader) {
const TensorInfoRegistry tensors(config_);
const KeyVec& keys = reader.Keys();
mat_ptrs_.reserve(keys.size());
// `key_idx` is the blob index. It is not the same as the index of the
// `MatPtr` in `mat_ptrs_` because not all blobs are tensors.
for (size_t key_idx = 0; key_idx < keys.size(); ++key_idx) {
const Type type = TypePrefix::TypeFromChar(keys[key_idx][0]);
if (type == Type::kUnknown) continue; // likely not a tensor
// Strip type prefix from the key. Still includes layer suffix.
const std::string name = keys[key_idx].substr(1);
const TensorInfo* info = tensors.Find(name);
if (HWY_UNLIKELY(!info)) {
if (name == "scales") continue; // ignore, not a tensor.
HWY_ABORT("Unknown tensor %s.", name.c_str());
}
// Unable to set scale already because they are ordered according to
// `ForEachTensor`, which we do not know here. The initial value is 1.0f
// and we set the correct value in `FindAndUpdateMatPtr`.
AddMatPtr(key_idx, MatPtr(name.c_str(), type, ExtentsFromInfo(info)));
}
HWY_ASSERT(mat_ptrs_.size() <= keys.size());
HWY_ASSERT(mat_ptrs_.size() == key_idx_.size());
}
ModelStore::ModelStore(BlobReader& reader, const Path& tokenizer_path,
Tristate wrapping)
: config_(ReadOrDeduceConfig(reader, wrapping)),
tokenizer_(ReadTokenizer(reader, tokenizer_path)) {
if (!ReadMatPtrs(reader)) { // Pre-2025 format.
CreateMatPtrs(reader);
scales_ = ReadScales(reader, config_);
// ModelConfig serialized a vector of strings. Unpack into a set for more
// efficient lookup.
for (const std::string& name : config_.scale_base_names) {
scale_base_names_.insert(name);
}
// If the model has scales, the config must know about it.
HWY_ASSERT(scales_.empty() || !scale_base_names_.empty());
}
HWY_ASSERT(key_idx_.size() == mat_ptrs_.size());
}
ModelStore::~ModelStore() {
// Sanity check: ensure all scales were consumed.
HWY_ASSERT(scales_consumed_ == scales_.size());
}
const MatPtr* ModelStore::FindMat(const char* name) const {
auto it = mat_idx_for_name_.find(name);
if (it == mat_idx_for_name_.end()) return nullptr;
const size_t mat_idx = it->second;
const MatPtr* file_mat = &mat_ptrs_[mat_idx];
HWY_ASSERT(!strcmp(file_mat->Name(), name));
return file_mat;
}
bool ModelStore::FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const {
const MatPtr* file_mat = FindMat(mat.Name());
if (!file_mat) return false;
if (file_mat->Rows() != mat.Rows() || file_mat->Cols() != mat.Cols()) {
HWY_ABORT("Tensor %s shape %zu %zu mismatches file %zu %zu.", mat.Name(),
mat.Rows(), mat.Cols(), file_mat->Rows(), file_mat->Cols());
}
// `Compress()` output is always packed because it assumes a 1D array.
HWY_ASSERT(mat.IsPacked());
// Update fields. Name already matched, otherwise we would not find it.
// For MatPtr tensors, the type will be `kUnknown`. If it was a `MatPtrT`,
// ensure the type set via code matches the file.
HWY_ASSERT_M(
mat.GetType() == Type::kUnknown || mat.GetType() == file_mat->GetType(),
mat.Name());
mat.SetType(file_mat->GetType());
if (scales_.empty()) {
// `file_mat->Scale()` is either read from file, or we have pre-2025 format
// without the optional scales, and it is default-initialized to 1.0f.
mat.SetScale(file_mat->Scale());
} else { // Pre-2025 with scaling factors: set next if `mat` wants one.
if (scale_base_names_.find(StripLayerSuffix(mat.Name())) !=
scale_base_names_.end()) {
HWY_ASSERT(scales_consumed_ < scales_.size());
mat.SetScale(scales_[scales_consumed_++]);
}
}
key_idx = key_idx_[file_mat - mat_ptrs_.data()];
return true;
}
static void AddBlob(const char* name, const std::vector<uint32_t>& data,
BlobWriter& writer) {
HWY_ASSERT(!data.empty());
writer.Add(name, data.data(), data.size() * sizeof(data[0]));
}
void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer,
const std::vector<uint32_t>& serialized_mat_ptrs,
BlobWriter& writer) {
HWY_ASSERT(config.model != Model::UNKNOWN);
HWY_ASSERT(config.weight != Type::kUnknown);
HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel);
const std::vector<uint32_t> serialized_config = config.Write();
AddBlob(kConfigName, serialized_config, writer);
const std::string serialized_tokenizer = tokenizer.Serialize();
HWY_ASSERT(!serialized_tokenizer.empty());
writer.Add(kTokenizerName, serialized_tokenizer.data(),
serialized_tokenizer.size());
AddBlob(kMatPtrsName, serialized_mat_ptrs, writer);
writer.WriteAll();
}
} // namespace gcpp

111
gemma/model_store.h Normal file
View File

@ -0,0 +1,111 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Reads/writes model metadata (all but the weights) from/to a `BlobStore`.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_MODEL_STORE_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_MODEL_STORE_H_
#include <stddef.h>
#include <stdint.h>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
// IWYU pragma: begin_exports
#include "gemma/configs.h" // ModelConfig
#include "gemma/tokenizer.h"
#include "io/blob_store.h"
#include "io/io.h" // Path
#include "util/basics.h" // Tristate
#include "util/mat.h" // MatPtr
// IWYU pragma: end_exports
#include "util/allocator.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
// Reads and holds the model config, tokenizer and all `MatPtr`: everything
// except the tensor data, which are read/written by `weights.cc`.
//
// As of 2025-04, the `BlobStore` format includes blobs for `ModelConfig`,
// tokenizer, and all `MatPtr` metadata. "Pre-2025" format instead stored the
// tokenizer in a separate file, encoded tensor type in a prefix of the blob
// name, and had a blob for tensor scaling factors. We still support reading
// both, but only write single-file format.
class ModelStore {
public:
// Reads from file(s) or aborts on error. The latter two arguments are only
// used for pre-2025 files.
ModelStore(BlobReader& reader, const Path& tokenizer_path = Path(),
Tristate wrapping = Tristate::kDefault);
~ModelStore();
const ModelConfig& Config() const {
HWY_ASSERT(config_.model != Model::UNKNOWN);
return config_;
}
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
// Returns nullptr if `name` is not available for loading, otherwise the
// metadata of that tensor.
const MatPtr* FindMat(const char* name) const;
// Returns false if `mat` is not available for loading, otherwise updates
// `mat` with metadata from the file and sets `key_idx` for use by
// `BlobReader`. Called via `ReadOrAllocate` in `weights.cc`.
bool FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const;
private:
void AddMatPtr(const size_t key_idx, const MatPtr& mat) {
auto pair_ib = mat_idx_for_name_.insert({mat.Name(), mat_ptrs_.size()});
HWY_ASSERT_M(pair_ib.second, mat.Name()); // Ensure inserted/unique.
mat_ptrs_.push_back(mat);
key_idx_.push_back(key_idx);
}
bool ReadMatPtrs(BlobReader& reader);
void CreateMatPtrs(BlobReader& reader); // Aborts on error.
ModelConfig config_;
GemmaTokenizer tokenizer_;
// All `MatPtr` present in the `BlobStore`, see `ReadMatPtrs`/`CreateMatPtrs`.
std::vector<MatPtr> mat_ptrs_;
// For each of `mat_ptrs_`, the index within `BlobReader::Keys()`. This is
// not necessarily iota because some blobs are not tensors, and callers may
// have added blobs before ours.
std::vector<size_t> key_idx_;
// Index within `mat_ptrs_` and `key_idx_` for each tensor name.
std::unordered_map<std::string, size_t> mat_idx_for_name_;
// Only used if `!ReadMatPtrs` (pre-2025 format):
std::vector<float> scales_;
std::unordered_set<std::string> scale_base_names_;
mutable size_t scales_consumed_ = 0;
};
// Adds metadata blobs to `writer` and writes everything to `path`. This
// produces a single BlobStore file holding everything required for inference.
void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer,
const std::vector<uint32_t>& serialized_mat_ptrs,
BlobWriter& writer);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_MODEL_STORE_H_

View File

@ -15,22 +15,22 @@
// Command line text interface to gemma.
#include <stdio.h>
#include <iostream>
#include <random>
#include <string>
#include <string_view>
#include <vector>
// Placeholder for internal header, do not modify.
#include "compression/shared.h" // PromptWrapping
#include "compression/types.h" // PromptWrapping
#include "evals/benchmark_helper.h"
#include "gemma/common.h"
#include "gemma/gemma.h" // Gemma
#include "ops/matmul.h" // MatMulEnv
#include "gemma/gemma_args.h"
#include "gemma/tokenizer.h" // WrapAndTokenize
#include "ops/matmul.h" // MatMulEnv
#include "paligemma/image.h"
#include "util/app.h"
#include "util/args.h" // HasHelp
#include "util/threading.h"
#include "hwy/base.h"
#include "hwy/highway.h"
#include "hwy/profiler.h"
@ -55,9 +55,8 @@ static constexpr std::string_view kAsciiArtBanner = R""(
|___/ |_| |_|
)"";
std::string GetPrompt(std::istream& input, int verbosity,
std::string_view eot_line) {
PROFILER_ZONE("Gen.input");
std::string GetPromptFromStream(std::istream& input, int verbosity,
std::string_view eot_line) {
if (verbosity >= 1) {
std::cout << "> " << std::flush;
}
@ -77,36 +76,55 @@ std::string GetPrompt(std::istream& input, int verbosity,
return prompt_string;
}
// Get prompt either from interactive input or command line
std::string GetPrompt(const InferenceArgs& inference) {
PROFILER_ZONE("Gen.input");
// If prompt is provided via command line, use that
if (!inference.prompt.empty()) {
return inference.prompt;
}
if (!inference.prompt_file.Empty()) {
return ReadFileToString(inference.prompt_file);
}
return GetPromptFromStream(std::cin, inference.verbosity, inference.eot_line);
}
// The main Read-Eval-Print Loop.
void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
const InferenceArgs& args, const AcceptFunc& accept_token,
std::string& eot_line) {
void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
const Gemma& gemma, KVCache& kv_cache, MatMulEnv& env) {
PROFILER_ZONE("Gen.misc");
size_t abs_pos = 0; // across turns
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
size_t prompt_size = 0;
const ModelConfig& config = gemma.Config();
std::mt19937 gen;
InitGenerator(args, gen);
InitGenerator(inference, gen);
const bool have_image = !args.image_file.path.empty();
const bool have_image = !inference.image_file.path.empty();
Image image;
ImageTokens image_tokens;
const size_t pool_dim = config.vit_config.pool_dim;
ImageTokens image_tokens(
"image_tokens",
have_image ? Extents2D(config.vit_config.seq_len / (pool_dim * pool_dim),
config.model_dim)
: Extents2D(0, 0),
env.ctx.allocator, MatPadding::kOdd);
image_tokens.AllocateAndAttachRowPtrs(env.row_ptrs);
if (have_image) {
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
image_tokens = ImageTokens(Extents2D(
model.GetModelConfig().vit_config.seq_len / (pool_dim * pool_dim),
model.GetModelConfig().model_dim));
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA ||
model.Info().wrapping == PromptWrapping::GEMMA_VLM);
HWY_ASSERT(image.ReadPPM(args.image_file.path));
const size_t image_size = model.GetModelConfig().vit_config.image_size;
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA ||
config.wrapping == PromptWrapping::GEMMA_VLM);
HWY_ASSERT(image.ReadPPM(inference.image_file.path));
const size_t image_size = config.vit_config.image_size;
image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {
.gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin};
RuntimeConfig runtime_config = {.gen = &gen,
.verbosity = inference.verbosity,
.use_spinning = threading.spin};
double image_tokens_start = hwy::platform::Now();
model.GenerateImageTokens(runtime_config, image, image_tokens);
if (app.verbosity >= 1) {
gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image,
image_tokens, env);
if (inference.verbosity >= 1) {
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
fprintf(stderr,
"\n\n[ Timing info ] Image token generation took: %d ms\n",
@ -121,21 +139,21 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
const bool first_response_token = tokens_generated_this_turn == prompt_size;
++tokens_generated_this_turn;
if (in_prompt) {
if (app.verbosity >= 1) {
std::cerr << "." << std::flush;
if (inference.verbosity >= 1) {
std::cout << "." << std::flush;
}
return true;
} else if (model.GetModelConfig().IsEOS(token)) {
if (app.verbosity >= 2) {
} else if (config.IsEOS(token)) {
if (inference.verbosity >= 2) {
std::cout << "\n[ End ]\n";
}
return true;
}
std::string token_text;
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
HWY_ASSERT(gemma.Tokenizer().Decode(std::vector<int>{token}, &token_text));
if (first_response_token) {
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
if (app.verbosity >= 1) {
if (inference.verbosity >= 1) {
std::cout << "\n\n";
}
}
@ -146,72 +164,77 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
while (true) { // Loop until user quits.
tokens_generated_this_turn = 0;
// Read prompt and handle special commands.
std::string prompt_string = GetPrompt(std::cin, app.verbosity, eot_line);
if (!std::cin) return;
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
if (prompt_string.size() >= 2 && prompt_string[0] == '%') {
if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return;
if (prompt_string[1] == 'c' || prompt_string[1] == 'C') {
abs_pos = 0;
const std::string prompt_string = GetPrompt(inference);
const bool is_interactive = inference.IsInteractive();
if (is_interactive) { // handle special commands:
if (!std::cin) return;
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
if (prompt_string.size() >= 2 && prompt_string[0] == '%') {
if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return;
if (prompt_string[1] == 'c' || prompt_string[1] == 'C') {
abs_pos = 0;
continue;
}
}
if (prompt_string.empty()) {
std::cout << "Use '%q' to quit.\n";
continue;
}
}
if (prompt_string.empty()) {
std::cout << "Use '%q' to quit.\n";
continue;
// Set up runtime config.
TimingInfo timing_info = {.verbosity = inference.verbosity};
RuntimeConfig runtime_config = {.gen = &gen,
.verbosity = inference.verbosity,
.stream_token = stream_token,
.use_spinning = threading.spin};
inference.CopyTo(runtime_config);
std::vector<int> prompt;
size_t prefix_end = 0;
if (have_image) {
prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
config.wrapping, abs_pos, prompt_string,
image_tokens.Rows());
runtime_config.image_tokens = &image_tokens;
prompt_size = prompt.size();
if (config.wrapping == PromptWrapping::PALIGEMMA) {
// The end of the prefix for prefix-LM style attention in Paligemma.
// See Figure 2 of https://arxiv.org/abs/2407.07726.
prefix_end = prompt_size;
// We need to look at all the tokens for the prefix.
// NOTE: Online softmax is on the roadmap, after which this requirement
// can be lifted.
runtime_config.prefill_tbatch_size = prompt_size;
}
} else {
prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
config.wrapping, abs_pos, prompt_string);
prompt_size = prompt.size();
}
// Wrap, tokenize and maybe log prompt tokens.
std::vector<int> prompt = WrapAndTokenize(
model.Tokenizer(), model.Info(), abs_pos, prompt_string);
prompt_size = prompt.size();
if constexpr (kVerboseLogTokens) {
for (int i = 0; i < prompt_size; ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
}
}
// Set up runtime config.
TimingInfo timing_info = {.verbosity = app.verbosity};
RuntimeConfig runtime_config = {.gen = &gen,
.verbosity = app.verbosity,
.stream_token = stream_token,
.accept_token = accept_token,
.use_spinning = app.spin};
args.CopyTo(runtime_config);
size_t prefix_end = 0;
if (have_image) {
runtime_config.image_tokens = &image_tokens;
if (model.Info().wrapping == PromptWrapping::PALIGEMMA) {
prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0);
} else if (model.Info().wrapping == PromptWrapping::GEMMA_VLM) {
size_t seq_len = model.GetModelConfig().vit_config.seq_len;
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
prompt =
WrapVLM(model.Tokenizer(), model.Info(), abs_pos, prompt,
image_tokens.BatchSize(), seq_len / (pool_dim * pool_dim));
}
prompt_size = prompt.size();
// The end of the prefix for prefix-LM style attention in Paligemma.
// See Figure 2 of https://arxiv.org/abs/2407.07726.
prefix_end = prompt_size;
// We need to look at all the tokens for the prefix.
runtime_config.prefill_tbatch_size = prompt_size;
}
// Generate until EOS or max_generated_tokens.
if (app.verbosity >= 1) {
if (inference.verbosity >= 1) {
std::cerr << "\n[ Reading prompt ] " << std::flush;
}
model.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache,
gemma.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache, env,
timing_info);
std::cout << "\n\n";
// In non-interactive mode, we only process one prompt/turn.
if (!is_interactive) break;
// Prepare for the next turn. Works only for PaliGemma.
if (!args.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) {
if (!inference.multiturn || config.wrapping == PromptWrapping::PALIGEMMA) {
abs_pos = 0; // Start a new turn at position 0.
InitGenerator(args, gen);
InitGenerator(inference, gen);
} else {
// The last token was either EOS, then it should be ignored because it is
// never part of the dialog, see Table 5 in the Gemma-2 paper:
@ -227,20 +250,17 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
}
}
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference) {
PROFILER_ZONE("Run.misc");
// Note that num_threads is an upper bound; we also limit to the number of
// detected and enabled cores.
const BoundedTopology topology = CreateTopology(app);
NestedPools pools = CreatePools(topology, app);
MatMulEnv env(topology, pools);
if (app.verbosity >= 2) env.print_best = true;
Gemma model = CreateGemma(loader, env);
KVCache kv_cache =
KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size);
ThreadingContext ctx(UpdateArgs(threading, inference));
MatMulEnv env(ctx);
if (inference.verbosity >= 2) env.print_best = true;
const Gemma gemma(loader, inference, ctx);
KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
if (app.verbosity >= 1) {
if (inference.verbosity >= 1) {
std::string instructions =
"*Usage*\n"
" Enter an instruction and press enter (%C resets conversation, "
@ -261,46 +281,37 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
instructions += multiturn;
instructions += examples;
std::cout << "\033[2J\033[1;1H" // clear screen
<< kAsciiArtBanner << "\n\n";
ShowConfig(loader, inference, app, topology, pools);
std::cout << "\n" << instructions << "\n";
// Skip the banner and instructions in non-interactive mode
if (inference.IsInteractive()) {
std::cout << "\033[2J\033[1;1H" // clear screen
<< kAsciiArtBanner << "\n\n";
ShowConfig(loader, threading, inference, gemma.Config(),
gemma.WeightReadMode(), ctx);
std::cout << "\n" << instructions << "\n";
}
}
ReplGemma(model, kv_cache, app, inference, AcceptFunc(), app.eot_line);
ReplGemma(threading, inference, gemma, kv_cache, env);
}
} // namespace gcpp
int main(int argc, char** argv) {
gcpp::InternalInit();
{
PROFILER_ZONE("Startup.misc");
// Placeholder for internal init, do not modify.
gcpp::LoaderArgs loader(argc, argv);
gcpp::ThreadingArgs threading(argc, argv);
gcpp::InferenceArgs inference(argc, argv);
gcpp::AppArgs app(argc, argv);
if (gcpp::HasHelp(argc, argv)) {
std::cerr << gcpp::kAsciiArtBanner;
gcpp::ShowHelp(loader, inference, app);
gcpp::ShowHelp(loader, threading, inference);
return 0;
}
if (const char* error = loader.Validate()) {
std::cerr << gcpp::kAsciiArtBanner;
gcpp::ShowHelp(loader, inference, app);
HWY_ABORT("\nInvalid args: %s", error);
}
if (const char* error = inference.Validate()) {
std::cerr << gcpp::kAsciiArtBanner;
gcpp::ShowHelp(loader, inference, app);
HWY_ABORT("\nInvalid args: %s", error);
}
gcpp::Run(loader, inference, app);
gcpp::Run(loader, threading, inference);
}
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
return 0;

View File

@ -1,607 +0,0 @@
#include "gemma/tensor_index.h"
#include <stddef.h>
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <string>
#include <unordered_map>
#include <vector>
#include "compression/shared.h"
#include "gemma/configs.h"
namespace gcpp {
namespace {
// Returns the non-layer tensors for the model.
std::vector<TensorInfo> ModelTensors(const ModelConfig& config) {
return {
TensorInfo{
.name = "c_embedding",
.source_names = {"embedder/input_embedding"},
.axes = {0, 1},
.shape = {config.vocab_size, config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "c_final_norm",
.source_names = {"final_norm/scale"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "enc_norm_bias",
.source_names = {"img/Transformer/encoder_norm/bias"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "enc_norm_scale",
.source_names = {"img/Transformer/encoder_norm/scale"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "img_emb_bias",
.source_names = {"img/embedding/bias"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "img_emb_kernel",
.source_names = {"img/embedding/kernel"},
.axes = {3, 0, 1, 2},
.shape = {config.vit_config.model_dim, config.vit_config.patch_width,
config.vit_config.patch_width, 3},
.min_size = Type::kBF16,
.cols_take_extra_dims = true,
},
TensorInfo{
.name = "img_head_bias",
.source_names = {"img/head/bias", "embedder/mm_input_projection/b"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "img_head_kernel",
.source_names = {"img/head/kernel", "embedder/mm_input_projection/w"},
.axes = {1, 0},
.shape = {config.model_dim, config.vit_config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "img_pos_emb",
.source_names = {"img/pos_embedding"},
.axes = {0, 1},
.shape = {/*1,*/ config.vit_config.seq_len,
config.vit_config.model_dim},
.min_size = Type::kF32,
},
// RMS norm applied to soft tokens prior to pos embedding.
TensorInfo{
.name = "mm_embed_norm",
.source_names = {"embedder/mm_soft_embedding_norm/scale"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
},
};
}
// Returns the tensors for the given image layer config.
std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
const LayerConfig& layer_config,
const int img_layer_idx) {
return {
// Vit layers.
TensorInfo{
.name = "attn_out_w",
.source_names = {"MultiHeadDotProductAttention_0/out/kernel"},
.axes = {2, 0, 1},
.shape = {config.vit_config.model_dim, layer_config.heads,
layer_config.qkv_dim},
.min_size = Type::kBF16,
.cols_take_extra_dims = true,
},
TensorInfo{
.name = "attn_out_b",
.source_names = {"MultiHeadDotProductAttention_0/out/bias"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "q_ein_w",
.source_names = {"MultiHeadDotProductAttention_0/query/kernel"},
.axes = {1, 2, 0},
.shape = {layer_config.heads, layer_config.qkv_dim,
config.vit_config.model_dim},
.concat_names = {"qkv_ein_w", "k_ein_w", "v_ein_w"},
.concat_axis = 1,
.min_size = Type::kBF16,
},
TensorInfo{
.name = "k_ein_w",
.source_names = {"MultiHeadDotProductAttention_0/key/kernel"},
.axes = {1, 2, 0},
.shape = {layer_config.heads, layer_config.qkv_dim,
config.vit_config.model_dim},
.concat_names = {""},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "v_ein_w",
.source_names = {"MultiHeadDotProductAttention_0/value/kernel"},
.axes = {1, 2, 0},
.shape = {layer_config.heads, layer_config.qkv_dim,
config.vit_config.model_dim},
.concat_names = {""},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "qkv_ein_w",
.source_names = {"MultiHeadDotProductAttention_0/qkv/kernel"},
.axes = {1, 2, 0},
.shape = {layer_config.heads, 3 * layer_config.qkv_dim,
config.vit_config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "q_ein_b",
.source_names = {"MultiHeadDotProductAttention_0/query/bias"},
.axes = {0, 1},
.shape = {layer_config.heads, layer_config.qkv_dim},
.concat_names = {"qkv_ein_b", "k_ein_b", "v_ein_b"},
.concat_axis = 1,
.min_size = Type::kF32,
},
TensorInfo{
.name = "k_ein_b",
.source_names = {"MultiHeadDotProductAttention_0/key/bias"},
.axes = {0, 1},
.shape = {layer_config.kv_heads, layer_config.qkv_dim},
.concat_names = {""},
.min_size = Type::kF32,
},
TensorInfo{
.name = "v_ein_b",
.source_names = {"MultiHeadDotProductAttention_0/value/bias"},
.axes = {0, 1},
.shape = {layer_config.kv_heads, layer_config.qkv_dim},
.concat_names = {""},
.min_size = Type::kF32,
},
TensorInfo{
.name = "qkv_ein_b",
.source_names = {"MultiHeadDotProductAttention_0/qkv/bias"},
.axes = {0, 1},
.shape = {layer_config.heads + layer_config.kv_heads * 2,
layer_config.qkv_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "linear_0_w",
.source_names = {"MlpBlock_0/Dense_0/kernel"},
.axes = {1, 0},
.shape = {layer_config.ff_hidden_dim, config.vit_config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "linear_0_b",
.source_names = {"MlpBlock_0/Dense_0/bias"},
.axes = {0},
.shape = {layer_config.ff_hidden_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "linear_1_w",
.source_names = {"MlpBlock_0/Dense_1/kernel"},
.axes = {1, 0},
.shape = {config.vit_config.model_dim, layer_config.ff_hidden_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "linear_1_b",
.source_names = {"MlpBlock_0/Dense_1/bias"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "ln_0_bias",
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias",
"img/Transformer/encoderblock_" +
std::to_string(img_layer_idx) +
"/LayerNorm_0/bias"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "ln_0_scale",
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale",
"img/Transformer/encoderblock_" +
std::to_string(img_layer_idx) +
"/LayerNorm_0/scale"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "ln_1_bias",
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias",
"img/Transformer/encoderblock_" +
std::to_string(img_layer_idx) +
"/LayerNorm_1/bias"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "ln_1_scale",
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale",
"img/Transformer/encoderblock_" +
std::to_string(img_layer_idx) +
"/LayerNorm_1/scale"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
},
};
}
// Returns the tensors for the given LLM layer config.
std::vector<TensorInfo> LLMLayerTensors(const ModelConfig& config,
const LayerConfig& layer_config,
bool reshape_att) {
std::vector<TensorInfo> tensors = {
TensorInfo{
.name = "key_norm",
.source_names = {"attn/_key_norm/scale"},
.axes = {0},
.shape = {layer_config.qkv_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "query_norm",
.source_names = {"attn/_query_norm/scale"},
.axes = {0},
.shape = {layer_config.qkv_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "qkv1_w",
.source_names = {"attn/q_einsum/w"},
.axes = {0, 2, 1},
.shape = {layer_config.heads * layer_config.qkv_dim,
config.model_dim},
.concat_names = {"qkv_ein", "qkv2_w"},
},
TensorInfo{
.name = "qkv2_w",
.source_names = {"attn/kv_einsum/w"},
.axes = {1, 0, 3, 2},
.shape = {2 * layer_config.kv_heads * layer_config.qkv_dim,
config.model_dim},
.concat_names = {""},
},
TensorInfo{
.name = "q_ein",
.source_names = {"attention_block/proj_q/kernel"},
.axes = {1, 0},
.shape = {layer_config.model_dim, layer_config.model_dim},
.concat_names = {"qkv_ein", "k_ein", "v_ein"},
},
TensorInfo{
.name = "k_ein",
.source_names = {"attention_block/proj_k/kernel"},
.axes = {1, 0},
.shape = {layer_config.qkv_dim, layer_config.model_dim},
.concat_names = {""},
},
TensorInfo{
.name = "v_ein",
.source_names = {"attention_block/proj_v/kernel"},
.axes = {1, 0},
.shape = {layer_config.qkv_dim, layer_config.model_dim},
.concat_names = {""},
},
TensorInfo{
.name = "qkv_ein",
.source_names = {"attn/qkv_einsum/w"},
.axes = {1, 0, 3, 2},
.shape = {(layer_config.heads + 2 * layer_config.kv_heads) *
layer_config.qkv_dim,
config.model_dim},
},
TensorInfo{
.name = "attn_ob",
.source_names = {"attention_block/proj_final/bias"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kF32,
},
// Griffin layers.
TensorInfo{
.name = "gr_lin_x_w",
.source_names = {"recurrent_block/linear_x/kernel"},
.axes = {1, 0},
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
},
TensorInfo{
.name = "gr_lin_x_b",
.source_names = {"recurrent_block/linear_x/bias"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "gr_lin_y_w",
.source_names = {"recurrent_block/linear_y/kernel"},
.axes = {1, 0},
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
},
TensorInfo{
.name = "gr_lin_y_b",
.source_names = {"recurrent_block/linear_y/bias"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "gr_lin_out_w",
.source_names = {"recurrent_block/linear_out/kernel"},
.axes = {1, 0},
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
},
TensorInfo{
.name = "gr_lin_out_b",
.source_names = {"recurrent_block/linear_out/bias"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "gr_conv_w",
.source_names = {"recurrent_block/conv_1d/w"},
.axes = {0, 1},
.shape = {layer_config.conv1d_width, layer_config.griffin_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "gr_conv_b",
.source_names = {"recurrent_block/conv_1d/b"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "gr1_gate_w",
.source_names = {"recurrent_block/rg_lru/input_gate/w"},
.axes = {0, 2, 1},
.shape = {layer_config.heads,
layer_config.griffin_dim / layer_config.heads,
layer_config.griffin_dim / layer_config.heads},
.concat_names = {"gr_gate_w", "gr2_gate_w"},
},
TensorInfo{
.name = "gr2_gate_w",
.source_names = {"recurrent_block/rg_lru/a_gate/w"},
.axes = {0, 2, 1},
.shape = {layer_config.heads,
layer_config.griffin_dim / layer_config.heads,
layer_config.griffin_dim / layer_config.heads},
.concat_names = {""},
},
TensorInfo{
.name = "gr_gate_w",
.source_names = {"recurrent_block/rg_lru/gate/w"},
.axes = {0, 2, 1},
.shape = {2 * layer_config.heads,
layer_config.griffin_dim / layer_config.heads,
layer_config.griffin_dim / layer_config.heads},
},
TensorInfo{
.name = "gr1_gate_b",
.source_names = {"recurrent_block/rg_lru/input_gate/b"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.concat_names = {"gr_gate_b", "gr2_gate_b"},
.min_size = Type::kF32,
},
TensorInfo{
.name = "gr2_gate_b",
.source_names = {"recurrent_block/rg_lru/a_gate/b"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.concat_names = {""},
.min_size = Type::kF32,
},
TensorInfo{
.name = "gr_gate_b",
.source_names = {"recurrent_block/rg_lru/input_gate/b"},
.axes = {0, 1},
.shape = {2 * layer_config.griffin_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "gr_a",
.source_names = {"recurrent_block/rg_lru/a_param"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
.scaled_softplus = true,
},
TensorInfo{
.name = "gating_ein",
.source_names = {"mlp/gating_einsum/w", "mlp/gating_einsum",
"mlp_block/ffw_up/w"},
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
layer_config.optimized_gating ? 2u : 1u},
.shape = {2, layer_config.ff_hidden_dim, config.model_dim},
},
TensorInfo{
.name = "gating1_w",
.source_names = {"none"},
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
layer_config.optimized_gating ? 2u : 1u},
.shape = {layer_config.ff_hidden_dim, config.model_dim},
},
TensorInfo{
.name = "gating2_w",
.source_names = {"none"},
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
layer_config.optimized_gating ? 2u : 1u},
.shape = {layer_config.ff_hidden_dim, config.model_dim},
},
TensorInfo{
.name = "linear_w",
.source_names = {"mlp/linear/w", "mlp/linear",
"mlp_block/ffw_down/kernel"},
.axes = {1, 0},
.shape = {config.model_dim, layer_config.ff_hidden_dim},
},
TensorInfo{
.name = "pre_att_ns",
.source_names = {"pre_attention_norm/scale",
"temporal_pre_norm/scale"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "pre_ff_ns",
.source_names = {"pre_ffw_norm/scale", "channel_pre_norm/scale"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "post_att_ns",
.source_names = {"post_attention_norm/scale"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "post_ff_ns",
.source_names = {"post_ffw_norm/scale"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "ffw_gat_b",
.source_names = {"mlp_block/ffw_up/b"},
.axes = {0},
.shape = {2 * layer_config.ff_hidden_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "ffw_out_b",
.source_names = {"mlp_block/ffw_down/bias"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kF32,
},
};
if (reshape_att) {
tensors.push_back(TensorInfo{
.name = "att_w",
.source_names = {"attn/attn_vec_einsum/w",
"attention_block/proj_final/kernel"},
.preshape = {layer_config.heads, layer_config.qkv_dim,
config.model_dim},
.axes = {2, 0, 1},
.shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim},
.cols_take_extra_dims = true,
});
tensors.push_back(TensorInfo{
.name = "att_ein",
.shape = {layer_config.heads, config.model_dim, layer_config.qkv_dim},
});
} else {
tensors.push_back(TensorInfo{
.name = "att_ein",
.source_names = {"attn/attn_vec_einsum/w",
"attention_block/proj_final/kernel"},
.preshape = {layer_config.heads, layer_config.qkv_dim,
config.model_dim},
.axes = {0, 2, 1},
.shape = {layer_config.heads, config.model_dim, layer_config.qkv_dim},
});
tensors.push_back(TensorInfo{
.name = "att_w",
.shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim},
.cols_take_extra_dims = true,
});
}
return tensors;
}
} // namespace
TensorIndex::TensorIndex(const ModelConfig& config, int llm_layer_idx,
int img_layer_idx, bool reshape_att)
: config_(config),
llm_layer_idx_(llm_layer_idx),
img_layer_idx_(img_layer_idx) {
int layer_idx = std::max(llm_layer_idx_, img_layer_idx_);
std::string suffix;
if (layer_idx >= 0) {
suffix = "_" + std::to_string(layer_idx);
}
if (llm_layer_idx < 0 && img_layer_idx < 0) {
tensors_ = ModelTensors(config);
} else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx &&
img_layer_idx < config.vit_config.layer_configs.size()) {
const auto& layer_config = config.vit_config.layer_configs[img_layer_idx];
tensors_ = ImageLayerTensors(config, layer_config, img_layer_idx);
} else if (0 <= llm_layer_idx &&
llm_layer_idx < config.layer_configs.size()) {
const auto& layer_config = config.layer_configs[llm_layer_idx];
tensors_ = LLMLayerTensors(config, layer_config, reshape_att);
}
for (size_t i = 0; i < tensors_.size(); ++i) {
std::string key = tensors_[i].name + suffix;
name_map_.insert({key, i});
}
}
TensorInfo TensorIndex::TensorInfoFromSourcePath(
const std::string& path) const {
for (const auto& tensor : tensors_) {
for (const auto& source_name : tensor.source_names) {
auto pos = path.rfind(source_name);
if (pos != std::string::npos && path.size() == pos + source_name.size())
return tensor;
}
}
return TensorInfo();
}
const TensorInfo* TensorIndex::FindName(const std::string& name) const {
std::string name_to_find = name;
if (!std::isdigit(name[name.size() - 1])) {
if (img_layer_idx_ >= 0 && llm_layer_idx_ < 0) {
name_to_find = name + "_" + std::to_string(img_layer_idx_);
} else if (llm_layer_idx_ >= 0) {
name_to_find = name + "_" + std::to_string(llm_layer_idx_);
}
}
auto it = name_map_.find(name_to_find);
if (it == name_map_.end()) {
return nullptr;
}
return &tensors_[it->second];
}
} // namespace gcpp

View File

@ -1,101 +0,0 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_
#include <stddef.h>
#include <string>
#include <unordered_map>
#include <vector>
#include "compression/shared.h"
#include "gemma/configs.h"
namespace gcpp {
// Universal tensor information. Holds enough information to construct a
// tensor in LayerWeightsPtrs/ModelWeightsPtrs, as well as to export the
// tensor from the python model with necessary transpose/reshape info.
struct TensorInfo {
// The name of the tensor in the sbs file
std::string name;
// Strings to match to the end of the name of the tensor in the python model.
std::vector<std::string> source_names;
// Initial reshape shape. Use only as a last resort when input may have
// dimensions combined that need to be split before the transpose, as it
// defeats the post-transpose shape checking. Normally empty.
std::vector<size_t> preshape;
// Transpose axes arg. If the input tensor has more dimensions than axes,
// then leading dimensions are collapsed until the number of axes matches.
std::vector<size_t> axes;
// Expected final shape of the tensor after reshape/transpose.
// Note that this is the shape of the tensor during export,
// not the shape of the tensor in the sbs file, as the sbs file
// is restricted to 2D tensors. With few exceptions, the sbs file
// tensor rows gather all the excess dimensions. See cols_take_extra_dims.
std::vector<size_t> shape;
// List of names to concatenate with, used only if multiple tensors are
// concatenated into one. The first tensor in the concatenation should have
// concat names thus: The first name is the name of the result, and the
// tensors with the remaining names are concatenated after this.
// The remaining tensors to be concatenated should have just a single
// empty string in concat_names to indicate that they have been consumed.
std::vector<std::string> concat_names;
// Axis at which to concatenate.
size_t concat_axis = 0;
// The minimum compression weight type for this tensor. The default is
// kNUQ, which provides maximum compression. Other values such as kBF16
// or kF32 can be used to limit the compression to a specific type.
Type min_size = Type::kNUQ;
// Whether to apply scaled softplus to the data.
bool scaled_softplus = false;
// Whether the columns or the rows take any extra dimensions.
// If false, then [10, 20, 30] -> [10*20, 30] and [30] -> [1, 30].
// If true, then [10, 20, 30] -> [10, 20*30] and [30] -> [1, 30].
bool cols_take_extra_dims = false;
};
// Universal index of tensor information, which can be built for a specific
// layer_idx.
class TensorIndex {
public:
// Builds a list of TensorInfo for the given layer_idx.
// If reshape_att is true, the attn_vec_einsum tensor is reshaped.
TensorIndex(const ModelConfig& config, int llm_layer_idx, int img_layer_idx,
bool reshape_att);
~TensorIndex() = default;
// Returns the TensorInfo whose source_name matches the end of the given path,
// or an empty TensorInfo if not found.
// NOTE: that the returned TensorInfo is a copy, so that the source
// TensorIndex can be destroyed without affecting the returned TensorInfo.
TensorInfo TensorInfoFromSourcePath(const std::string& path) const;
// Returns the TensorInfo whose name matches the given name,
// or an empty TensorInfo if not found.
// NOTE: that the returned TensorInfo is a copy, so that the source
// TensorIndex can be destroyed without affecting the returned TensorInfo.
TensorInfo TensorInfoFromName(const std::string& name) const {
const TensorInfo* info = FindName(name);
if (info == nullptr) return TensorInfo();
return *info;
}
// Returns the TensorInfo for the given tensor name, for concise construction
// of ModelWeightsPtrs/LayerWeightsPtrs.
const TensorInfo* FindName(const std::string& name) const;
private:
// Config that was used to build the tensor index.
const ModelConfig& config_;
// Layer that this tensor index is for - either LLM or image.
int llm_layer_idx_;
int img_layer_idx_;
// List of tensor information for this layer.
std::vector<TensorInfo> tensors_;
// Map from tensor name to index in tensors_.
std::unordered_map<std::string, size_t> name_map_;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_

View File

@ -1,72 +0,0 @@
#include "gemma/tensor_index.h"
#include <cstddef>
#include <cstdio>
#include <cstring>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "compression/compress.h"
#include "compression/shared.h"
#include "gemma/configs.h"
#include "gemma/weights.h"
#include "util/basics.h"
#include "hwy/aligned_allocator.h"
namespace gcpp {
namespace {
// Tests that each tensor in the model can be found by exactly one TensorIndex,
// and that the TensorIndex returns the correct shape and name for the tensor,
// for all models.
TEST(TensorIndexTest, FindName) {
for (Model model : kAllModels) {
fprintf(stderr, "Testing model %d\n", static_cast<int>(model));
ModelConfig config = ConfigFromModel(model);
std::vector<TensorIndex> tensor_indexes;
tensor_indexes.emplace_back(config, /*llm_layer_idx=*/-1,
/*img_layer_idx=*/-1,
/*split_and_reshape=*/false);
for (size_t llm_layer_idx = 0; llm_layer_idx < config.layer_configs.size();
++llm_layer_idx) {
tensor_indexes.emplace_back(config, static_cast<int>(llm_layer_idx),
/*img_layer_idx=*/-1,
/*split_and_reshape=*/false);
}
for (size_t img_layer_idx = 0;
img_layer_idx < config.vit_config.layer_configs.size();
++img_layer_idx) {
tensor_indexes.emplace_back(config, /*llm_layer_idx=*/-1,
static_cast<int>(img_layer_idx),
/*split_and_reshape=*/false);
}
// For each tensor in any model, exactly one TensorIndex should find it.
ModelWeightsPtrs<SfpStream> weights(config);
ModelWeightsPtrs<SfpStream>::ForEachTensor(
{&weights}, ForEachType::kInitNoToc,
[&tensor_indexes](const char* name, hwy::Span<MatPtr*> tensors) {
int num_found = 0;
const MatPtr& tensor = *tensors[0];
for (const auto& tensor_index : tensor_indexes) {
// Skip the type marker prefix, but we want the layer index suffix.
std::string name_to_find(name + 1, strlen(name) - 1);
const TensorInfo* info = tensor_index.FindName(name_to_find);
if (info != nullptr) {
// Test that the MatPtr can be constructed from the TensorInfo,
// and that the dimensions match.
MatPtrT<SfpStream> mat_ptr(tensor.Name(), tensor_index);
EXPECT_STREQ(tensor.Name(), mat_ptr.Name())
<< "on tensor " << name;
EXPECT_EQ(tensor.Rows(), mat_ptr.Rows()) << "on tensor " << name;
EXPECT_EQ(tensor.Cols(), mat_ptr.Cols()) << "on tensor " << name;
++num_found;
}
}
EXPECT_EQ(num_found, 1) << " for tensor " << name;
});
}
}
} // namespace
} // namespace gcpp

593
gemma/tensor_info.cc Normal file
View File

@ -0,0 +1,593 @@
#include "gemma/tensor_info.h"
#include <stddef.h>
#include <stdint.h>
#include <string>
#include "compression/types.h"
#include "gemma/configs.h"
namespace gcpp {
void TensorInfoRegistry::Add(const std::string& suffix,
const TensorInfo& info) {
const size_t idx = tensors_.size();
tensors_.push_back(info);
// Also add suffix to `concat_names`.
for (std::string& name : tensors_.back().concat_names) {
name += suffix;
}
const std::string name = info.base_name + suffix;
// Ensure successful insertion because `suffix` ensures uniqueness for
// per-layer tensors, and per-model should only be inserted once.
HWY_ASSERT_M(idx_from_name_.insert({name, idx}).second, name.c_str());
}
// Non-layer tensors.
void TensorInfoRegistry::AddModelTensors(const ModelConfig& config) {
const std::string no_suffix;
Add(no_suffix, {
.base_name = "c_embedding",
.source_names = {"embedder/input_embedding"},
.axes = {0, 1},
.shape = {config.vocab_size, config.model_dim},
.min_size = Type::kBF16,
});
Add(no_suffix, {
.base_name = "c_final_norm",
.source_names = {"final_norm/scale"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kBF16,
});
Add(no_suffix, {
.base_name = "enc_norm_bias",
.source_names = {"img/Transformer/encoder_norm/bias"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
});
Add(no_suffix, {
.base_name = "enc_norm_scale",
.source_names = {"img/Transformer/encoder_norm/scale"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
});
Add(no_suffix, {
.base_name = "img_emb_bias",
.source_names = {"img/embedding/bias"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kF32,
});
Add(no_suffix,
{
.base_name = "img_emb_kernel",
.source_names = {"img/embedding/kernel"},
.axes = {3, 0, 1, 2},
.shape = {config.vit_config.model_dim, config.vit_config.patch_width,
config.vit_config.patch_width, 3},
.min_size = Type::kBF16,
.cols_take_extra_dims = true,
});
Add(no_suffix,
{
.base_name = "img_head_bias",
.source_names = {"img/head/bias", "embedder/mm_input_projection/b"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kF32,
});
Add(no_suffix,
{
.base_name = "img_head_kernel",
.source_names = {"img/head/kernel", "embedder/mm_input_projection/w"},
.axes = {1, 0},
.shape = {config.model_dim, config.vit_config.model_dim},
.min_size = Type::kBF16,
});
Add(no_suffix, {
.base_name = "img_pos_emb",
.source_names = {"img/pos_embedding"},
.axes = {0, 1},
.shape = {/*1,*/ config.vit_config.seq_len,
config.vit_config.model_dim},
.min_size = Type::kF32,
});
// RMS norm applied to soft tokens prior to pos embedding.
Add(no_suffix, {
.base_name = "mm_embed_norm",
.source_names = {"embedder/mm_soft_embedding_norm/scale"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
});
}
// Returns the tensors for the given image layer config.
void TensorInfoRegistry::AddImageLayerTensors(const ModelConfig& config,
const LayerConfig& layer_config,
const size_t img_layer_idx) {
const std::string suffix = LayerSuffix(img_layer_idx);
// Vit layers.
Add(suffix, {
.base_name = "attn_out_w",
.source_names = {"MultiHeadDotProductAttention_0/out/kernel"},
.axes = {2, 0, 1},
.shape = {config.vit_config.model_dim, layer_config.heads,
layer_config.qkv_dim},
.min_size = Type::kBF16,
.cols_take_extra_dims = true,
});
Add(suffix, {
.base_name = "attn_out_b",
.source_names = {"MultiHeadDotProductAttention_0/out/bias"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kF32,
});
Add(suffix,
{
.base_name = "q_ein_w",
.source_names = {"MultiHeadDotProductAttention_0/query/kernel"},
.axes = {1, 2, 0},
.shape = {layer_config.heads, layer_config.qkv_dim,
config.vit_config.model_dim},
.concat_names = {"qkv_ein_w", "k_ein_w", "v_ein_w"},
.concat_axis = 1,
.min_size = Type::kBF16,
});
Add(suffix, {
.base_name = "k_ein_w",
.source_names = {"MultiHeadDotProductAttention_0/key/kernel"},
.axes = {1, 2, 0},
.shape = {layer_config.heads, layer_config.qkv_dim,
config.vit_config.model_dim},
.concat_names = {""},
.min_size = Type::kBF16,
});
Add(suffix,
{
.base_name = "v_ein_w",
.source_names = {"MultiHeadDotProductAttention_0/value/kernel"},
.axes = {1, 2, 0},
.shape = {layer_config.heads, layer_config.qkv_dim,
config.vit_config.model_dim},
.concat_names = {""},
.min_size = Type::kBF16,
});
Add(suffix, {
.base_name = "qkv_ein_w",
.source_names = {"MultiHeadDotProductAttention_0/qkv/kernel"},
.axes = {1, 2, 0},
.shape = {layer_config.heads, 3 * layer_config.qkv_dim,
config.vit_config.model_dim},
.min_size = Type::kBF16,
});
Add(suffix, {
.base_name = "q_ein_b",
.source_names = {"MultiHeadDotProductAttention_0/query/bias"},
.axes = {0, 1},
.shape = {layer_config.heads, layer_config.qkv_dim},
.concat_names = {"qkv_ein_b", "k_ein_b", "v_ein_b"},
.concat_axis = 1,
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "k_ein_b",
.source_names = {"MultiHeadDotProductAttention_0/key/bias"},
.axes = {0, 1},
.shape = {layer_config.kv_heads, layer_config.qkv_dim},
.concat_names = {""},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "v_ein_b",
.source_names = {"MultiHeadDotProductAttention_0/value/bias"},
.axes = {0, 1},
.shape = {layer_config.kv_heads, layer_config.qkv_dim},
.concat_names = {""},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "qkv_ein_b",
.source_names = {"MultiHeadDotProductAttention_0/qkv/bias"},
.axes = {0, 1},
.shape = {layer_config.heads + layer_config.kv_heads * 2,
layer_config.qkv_dim},
.min_size = Type::kF32,
});
Add(suffix,
{
.base_name = "linear_0_w",
.source_names = {"MlpBlock_0/Dense_0/kernel"},
.axes = {1, 0},
.shape = {layer_config.ff_hidden_dim, config.vit_config.model_dim},
.min_size = Type::kBF16,
});
Add(suffix, {
.base_name = "linear_0_b",
.source_names = {"MlpBlock_0/Dense_0/bias"},
.axes = {0},
.shape = {layer_config.ff_hidden_dim},
.min_size = Type::kF32,
});
Add(suffix,
{
.base_name = "linear_1_w",
.source_names = {"MlpBlock_0/Dense_1/kernel"},
.axes = {1, 0},
.shape = {config.vit_config.model_dim, layer_config.ff_hidden_dim},
.min_size = Type::kBF16,
});
Add(suffix, {
.base_name = "linear_1_b",
.source_names = {"MlpBlock_0/Dense_1/bias"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kF32,
});
Add(suffix,
{
.base_name = "ln_0_bias",
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias",
"img/Transformer/encoderblock_" +
std::to_string(img_layer_idx) +
"/LayerNorm_0/bias"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
});
Add(suffix,
{
.base_name = "ln_0_scale",
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale",
"img/Transformer/encoderblock_" +
std::to_string(img_layer_idx) +
"/LayerNorm_0/scale"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
});
Add(suffix,
{
.base_name = "ln_1_bias",
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias",
"img/Transformer/encoderblock_" +
std::to_string(img_layer_idx) +
"/LayerNorm_1/bias"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
});
Add(suffix,
{
.base_name = "ln_1_scale",
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale",
"img/Transformer/encoderblock_" +
std::to_string(img_layer_idx) +
"/LayerNorm_1/scale"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
});
}
void TensorInfoRegistry::AddGriffinLayerTensors(const LayerConfig& layer_config,
const size_t layer_idx) {
const std::string suffix = LayerSuffix(layer_idx);
Add(suffix, {
.base_name = "gr_lin_x_w",
.source_names = {"recurrent_block/linear_x/kernel"},
.axes = {1, 0},
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
});
Add(suffix, {
.base_name = "gr_lin_x_b",
.source_names = {"recurrent_block/linear_x/bias"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "gr_lin_y_w",
.source_names = {"recurrent_block/linear_y/kernel"},
.axes = {1, 0},
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
});
Add(suffix, {
.base_name = "gr_lin_y_b",
.source_names = {"recurrent_block/linear_y/bias"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "gr_lin_out_w",
.source_names = {"recurrent_block/linear_out/kernel"},
.axes = {1, 0},
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
});
Add(suffix, {
.base_name = "gr_lin_out_b",
.source_names = {"recurrent_block/linear_out/bias"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
});
Add(suffix,
{
.base_name = "gr_conv_w",
.source_names = {"recurrent_block/conv_1d/w"},
.axes = {0, 1},
.shape = {layer_config.conv1d_width, layer_config.griffin_dim},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "gr_conv_b",
.source_names = {"recurrent_block/conv_1d/b"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "gr1_gate_w",
.source_names = {"recurrent_block/rg_lru/input_gate/w"},
.axes = {0, 2, 1},
.shape = {layer_config.heads,
layer_config.griffin_dim / layer_config.heads,
layer_config.griffin_dim / layer_config.heads},
.concat_names = {"gr_gate_w", "gr2_gate_w"},
});
Add(suffix, {
.base_name = "gr2_gate_w",
.source_names = {"recurrent_block/rg_lru/a_gate/w"},
.axes = {0, 2, 1},
.shape = {layer_config.heads,
layer_config.griffin_dim / layer_config.heads,
layer_config.griffin_dim / layer_config.heads},
.concat_names = {""},
});
Add(suffix, {
.base_name = "gr_gate_w",
.source_names = {"recurrent_block/rg_lru/gate/w"},
.axes = {0, 2, 1},
.shape = {2 * layer_config.heads,
layer_config.griffin_dim / layer_config.heads,
layer_config.griffin_dim / layer_config.heads},
});
Add(suffix, {
.base_name = "gr1_gate_b",
.source_names = {"recurrent_block/rg_lru/input_gate/b"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.concat_names = {"gr_gate_b", "gr2_gate_b"},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "gr2_gate_b",
.source_names = {"recurrent_block/rg_lru/a_gate/b"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.concat_names = {""},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "gr_gate_b",
.source_names = {"recurrent_block/rg_lru/input_gate/b"},
.axes = {0, 1},
.shape = {2 * layer_config.griffin_dim},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "gr_a",
.source_names = {"recurrent_block/rg_lru/a_param"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
.scaled_softplus = true,
});
}
void TensorInfoRegistry::AddLayerTensors(const ModelConfig& config,
const LayerConfig& layer_config,
const size_t layer_idx) {
const std::string suffix = LayerSuffix(layer_idx);
Add(suffix, {
.base_name = "key_norm",
.source_names = {"attn/_key_norm/scale"},
.axes = {0},
.shape = {layer_config.qkv_dim},
.min_size = Type::kBF16,
});
Add(suffix, {
.base_name = "query_norm",
.source_names = {"attn/_query_norm/scale"},
.axes = {0},
.shape = {layer_config.qkv_dim},
.min_size = Type::kBF16,
});
Add(suffix, {
.base_name = "qkv1_w",
.source_names = {"attn/q_einsum/w"},
.axes = {0, 2, 1},
.shape = {layer_config.heads * layer_config.qkv_dim,
config.model_dim},
.concat_names = {"qkv_ein", "qkv2_w"},
});
Add(suffix, {
.base_name = "qkv2_w",
.source_names = {"attn/kv_einsum/w"},
.axes = {1, 0, 3, 2},
.shape = {2 * layer_config.kv_heads * layer_config.qkv_dim,
config.model_dim},
.concat_names = {""},
});
Add(suffix, {
.base_name = "q_ein",
.source_names = {"attention_block/proj_q/kernel"},
.axes = {1, 0},
.shape = {layer_config.model_dim, layer_config.model_dim},
.concat_names = {"qkv_ein", "k_ein", "v_ein"},
});
Add(suffix, {
.base_name = "k_ein",
.source_names = {"attention_block/proj_k/kernel"},
.axes = {1, 0},
.shape = {layer_config.qkv_dim, layer_config.model_dim},
.concat_names = {""},
});
Add(suffix, {
.base_name = "v_ein",
.source_names = {"attention_block/proj_v/kernel"},
.axes = {1, 0},
.shape = {layer_config.qkv_dim, layer_config.model_dim},
.concat_names = {""},
});
Add(suffix, {
.base_name = "qkv_ein",
.source_names = {"attn/qkv_einsum/w"},
.axes = {1, 0, 3, 2},
.shape = {(layer_config.heads + 2 * layer_config.kv_heads) *
layer_config.qkv_dim,
config.model_dim},
});
Add(suffix, {
.base_name = "attn_ob",
.source_names = {"attention_block/proj_final/bias"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "gating_ein",
.source_names = {"mlp/gating_einsum/w", "mlp/gating_einsum",
"mlp_block/ffw_up/w"},
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
layer_config.optimized_gating ? 2u : 1u},
.shape = {2, layer_config.ff_hidden_dim, config.model_dim},
});
Add(suffix, {
.base_name = "gating1_w",
.source_names = {"none"},
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
layer_config.optimized_gating ? 2u : 1u},
.shape = {layer_config.ff_hidden_dim, config.model_dim},
});
Add(suffix, {
.base_name = "gating2_w",
.source_names = {"none"},
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
layer_config.optimized_gating ? 2u : 1u},
.shape = {layer_config.ff_hidden_dim, config.model_dim},
});
Add(suffix, {
.base_name = "linear_w",
.source_names = {"mlp/linear/w", "mlp/linear",
"mlp_block/ffw_down/kernel"},
.axes = {1, 0},
.shape = {config.model_dim, layer_config.ff_hidden_dim},
});
Add(suffix, {
.base_name = "pre_att_ns",
.source_names = {"pre_attention_norm/scale",
"temporal_pre_norm/scale"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kBF16,
});
Add(suffix,
{
.base_name = "pre_ff_ns",
.source_names = {"pre_ffw_norm/scale", "channel_pre_norm/scale"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kBF16,
});
Add(suffix, {
.base_name = "post_att_ns",
.source_names = {"post_attention_norm/scale"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kBF16,
});
Add(suffix, {
.base_name = "post_ff_ns",
.source_names = {"post_ffw_norm/scale"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kBF16,
});
Add(suffix, {
.base_name = "ffw_gat_b",
.source_names = {"mlp_block/ffw_up/b"},
.axes = {0},
.shape = {2 * layer_config.ff_hidden_dim},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "ffw_out_b",
.source_names = {"mlp_block/ffw_down/bias"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kF32,
});
Add(suffix,
{
.base_name = "att_ein",
.source_names = {"attn/attn_vec_einsum/w",
"attention_block/proj_final/kernel"},
.preshape = {layer_config.heads, layer_config.qkv_dim,
config.model_dim},
.axes = {0, 2, 1},
.shape = {layer_config.heads, config.model_dim, layer_config.qkv_dim},
});
Add(suffix,
{
.base_name = "att_w",
.shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim},
.cols_take_extra_dims = true,
});
if (config.model == Model::GRIFFIN_2B) {
AddGriffinLayerTensors(layer_config, layer_idx);
}
}
TensorInfoRegistry::TensorInfoRegistry(const ModelConfig& config) {
// Upper bound on the number of `Add()` calls in `Add*Tensors()`. Loose bound
// in case those are changed without updating this. Better to allocate a bit
// more than to 1.5-2x the size if too little.
tensors_.reserve(10 + 32 * config.layer_configs.size() +
24 * config.vit_config.layer_configs.size());
AddModelTensors(config);
for (size_t i = 0; i < config.layer_configs.size(); ++i) {
AddLayerTensors(config, config.layer_configs[i], i);
}
for (size_t i = 0; i < config.vit_config.layer_configs.size(); ++i) {
AddImageLayerTensors(config, config.vit_config.layer_configs[i], i);
}
}
TensorInfo TensorInfoRegistry::TensorInfoFromSourcePath(const std::string& path,
int layer_idx) const {
for (const TensorInfo& tensor : tensors_) {
for (const std::string& source_name : tensor.source_names) {
// path ends with source_name?
const size_t pos = path.rfind(source_name);
if (pos != std::string::npos && path.size() == pos + source_name.size()) {
std::string name = tensor.base_name;
if (layer_idx >= 0) name += LayerSuffix(static_cast<size_t>(layer_idx));
return TensorInfoFromName(name);
}
}
}
return TensorInfo();
}
} // namespace gcpp

141
gemma/tensor_info.h Normal file
View File

@ -0,0 +1,141 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_
#include <stddef.h>
#include <string>
#include <unordered_map>
#include <vector>
#include "compression/types.h" // Type
#include "gemma/configs.h"
#include "util/basics.h" // Extents2D
namespace gcpp {
// Tensor metadata. This is far more than required to construct the `MatPtr` in
// `LayerWeightsPtrs/WeightsPtrs`; they only use `.shape` via `ExtentsFromInfo`.
// This is also bound to Python and filled by the exporter.
struct TensorInfo {
// The base name of the tensor without a layer suffix.
std::string base_name;
// Strings to match to the end of the name of the tensor in the python model.
std::vector<std::string> source_names;
// Initial reshape shape. Use only as a last resort when input may have
// dimensions combined that need to be split before the transpose, as it
// defeats the post-transpose shape checking. Normally empty.
std::vector<size_t> preshape;
// Transpose axes arg. If the input tensor has more dimensions than axes,
// then leading dimensions are collapsed until the number of axes matches.
std::vector<size_t> axes;
// Expected final shape of the tensor after reshape/transpose.
// Note that this is the shape of the tensor during export,
// not the shape of the tensor in the sbs file, as the sbs file
// is restricted to 2D tensors. With few exceptions, the sbs file
// tensor rows gather all the excess dimensions. See cols_take_extra_dims.
std::vector<size_t> shape;
// List of names to concatenate with, used only if multiple tensors are
// concatenated into one. The first tensor in the concatenation should have
// concat names thus: The first name is the name of the result, and the
// tensors with the remaining names are concatenated after this.
// The remaining tensors to be concatenated should have just a single
// empty string in concat_names to indicate that they have been consumed.
std::vector<std::string> concat_names;
// Axis at which to concatenate.
size_t concat_axis = 0;
// The highest permissible compression for this tensor. The default is
// kNUQ, which provides maximum compression. Other values such as kBF16
// or kF32 can be used to limit the compression to a specific type.
Type min_size = Type::kNUQ;
// Whether to apply scaled softplus to the data.
bool scaled_softplus = false;
// Whether the columns or the rows take any extra dimensions.
// If false, then [10, 20, 30] -> [10*20, 30] and [30] -> [1, 30].
// If true, then [10, 20, 30] -> [10, 20*30] and [30] -> [1, 30].
bool cols_take_extra_dims = false;
};
// Collapses/expands the tensor dims into 2D extents, which may be 0, 0 for
// not-present tensors such as ViT in a text-only model. Safely handles nullptr
// returned from `TensorInfoRegistry::Find`, hence not a member function.
static inline Extents2D ExtentsFromInfo(const TensorInfo* tensor) {
if (tensor == nullptr) return Extents2D(0, 0);
size_t cols = tensor->shape.back();
size_t rows = 1;
if (tensor->cols_take_extra_dims) {
rows = tensor->shape[0];
for (size_t i = 1; i < tensor->shape.size() - 1; ++i) {
cols *= tensor->shape[i];
}
} else { // rows take extra dims
for (size_t i = 0; i < tensor->shape.size() - 1; ++i) {
rows *= tensor->shape[i];
}
}
// Sometimes only one of rows or cols is zero; set both for consistency.
if (rows == 0 || cols == 0) rows = cols = 0;
return Extents2D(rows, cols);
}
static inline std::string LayerSuffix(size_t layer_idx) {
return std::string("_") + std::to_string(layer_idx);
}
// Returns tensor base name without any layer suffix.
static inline std::string StripLayerSuffix(const std::string& name) {
return name.substr(0, name.rfind('_'));
}
// Holds all `TensorInfo` for a model and retrieves them by (unique) name.
class TensorInfoRegistry {
public:
explicit TensorInfoRegistry(const ModelConfig& config);
~TensorInfoRegistry() = default;
// Returns nullptr if not found, otherwise the `TensorInfo` for the given
// `name`, which either lacks a suffix, or is per-layer and ends with
// `LayerSuffix(layer_idx)`. Used in `WeightsPtrs/LayerWeightsPtrs`.
const TensorInfo* Find(const std::string& name) const {
auto it = idx_from_name_.find(name);
if (it == idx_from_name_.end()) return nullptr;
return &tensors_[it->second];
}
// Returns a copy of the `TensorInfo` whose name matches the given name, or a
// default-constructed `TensorInfo` if not found. Destroying
// `TensorInfoRegistry` afterward will not invalidate the returned value.
TensorInfo TensorInfoFromName(const std::string& name) const {
const TensorInfo* info = Find(name);
if (info == nullptr) return TensorInfo();
return *info;
}
// Returns a copy of the `TensorInfo` whose source_name matches the end of the
// given path, and whose name ends with the given layer_idx, otherwise a
// default-constructed `TensorInfo`. Destroying `TensorInfoRegistry`
// afterward will not invalidate the returned value.
TensorInfo TensorInfoFromSourcePath(const std::string& path,
int layer_idx) const;
private:
// `suffix` is empty (only) for per-model tensors, otherwise `LayerSuffix`.
void Add(const std::string& suffix, const TensorInfo& info);
void AddModelTensors(const ModelConfig& config);
void AddLayerTensors(const ModelConfig& config,
const LayerConfig& layer_config, size_t layer_idx);
void AddGriffinLayerTensors(const LayerConfig& layer_config,
size_t layer_idx);
void AddImageLayerTensors(const ModelConfig& config,
const LayerConfig& layer_config,
size_t img_layer_idx);
std::vector<TensorInfo> tensors_;
// Includes entries for base name *and* the suffixed name for each layer.
std::unordered_map<std::string, size_t> idx_from_name_;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_

40
gemma/tensor_info_test.cc Normal file
View File

@ -0,0 +1,40 @@
#include "gemma/tensor_info.h"
#include <stdio.h>
#include "gtest/gtest.h"
#include "compression/types.h" // SfpStream
#include "gemma/configs.h"
#include "gemma/weights.h"
#include "util/mat.h"
#include "hwy/base.h" // HWY_ASSERT_M
namespace gcpp {
namespace {
// Tests for all models that each tensor in the model can be found and that the
// TensorInfoRegistry returns the correct shape and name for the tensor.
TEST(TensorInfoRegistryTest, Find) {
ForEachModel([&](Model model) {
const ModelConfig config(model, Type::kSFP, ChooseWrapping(model));
fprintf(stderr, "Testing %s (%s)\n", config.display_name.c_str(),
config.Specifier().c_str());
const TensorInfoRegistry tensors(config);
// Each tensor in the model should be known/found.
WeightsPtrs weights(config);
weights.ForEachTensor(nullptr, nullptr, [&tensors](const TensorArgs& t) {
const TensorInfo* info = tensors.Find(t.mat.Name());
HWY_ASSERT_M(info, t.mat.Name());
// Test that the `MatPtr` can be constructed from the TensorInfo,
// and that the dimensions match.
const MatPtr mat_ptr(t.mat.Name(), Type::kUnknown,
ExtentsFromInfo(tensors.Find(t.mat.Name())));
EXPECT_STREQ(t.mat.Name(), mat_ptr.Name()) << t.mat.Name();
EXPECT_EQ(t.mat.Rows(), mat_ptr.Rows()) << t.mat.Name();
EXPECT_EQ(t.mat.Cols(), mat_ptr.Cols()) << t.mat.Name();
});
});
}
} // namespace
} // namespace gcpp

View File

@ -21,9 +21,7 @@
#include <string>
#include <vector>
#include "compression/io.h" // Path
#include "compression/shared.h" // PromptWrapping
#include "gemma/common.h" // Wrap
#include "gemma/configs.h" // PromptWrapping
#include "hwy/base.h" // HWY_ASSERT
#include "hwy/profiler.h"
// copybara:import_next_line:sentencepiece
@ -37,24 +35,20 @@ constexpr bool kShowTokenization = false;
class GemmaTokenizer::Impl {
public:
Impl() = default;
explicit Impl(const Path& tokenizer_path) {
PROFILER_ZONE("Startup.tokenizer");
spp_ = std::make_unique<sentencepiece::SentencePieceProcessor>();
if (!spp_->Load(tokenizer_path.path).ok()) {
HWY_ABORT("Failed to load the tokenizer file.");
}
}
// Loads the tokenizer from a serialized proto.
explicit Impl(const std::string& tokenizer_proto) {
if (tokenizer_proto == kMockTokenizer) return;
PROFILER_ZONE("Startup.tokenizer");
spp_ = std::make_unique<sentencepiece::SentencePieceProcessor>();
if (!spp_->LoadFromSerializedProto(tokenizer_proto).ok()) {
fprintf(stderr, "serialized proto size=%zu.\n", tokenizer_proto.size());
HWY_ABORT("Failed to load the tokenizer from serialized proto.");
HWY_ABORT("Failed to load tokenizer from %zu byte serialized proto.",
tokenizer_proto.size());
}
}
std::string Serialize() const { return spp_->serialized_model_proto(); }
std::string Serialize() const {
return spp_ ? spp_->serialized_model_proto() : kMockTokenizer;
}
bool Encode(const std::string& input,
std::vector<std::string>* pieces) const {
@ -82,22 +76,18 @@ class GemmaTokenizer::Impl {
std::unique_ptr<sentencepiece::SentencePieceProcessor> spp_;
};
GemmaTokenizer::GemmaTokenizer(const Path& tokenizer_path) {
impl_ = std::make_unique<Impl>(tokenizer_path);
GemmaTokenizer::GemmaTokenizer(const std::string& tokenizer_proto)
: impl_(std::make_unique<Impl>(tokenizer_proto)) {
HWY_ASSERT(impl_);
}
// Default suffices, but they must be defined after GemmaTokenizer::Impl.
GemmaTokenizer::GemmaTokenizer() = default;
GemmaTokenizer::~GemmaTokenizer() = default;
GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default;
GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default;
std::string GemmaTokenizer::Serialize() const { return impl_->Serialize(); }
void GemmaTokenizer::Deserialize(const std::string& tokenizer_proto) {
impl_ = std::make_unique<Impl>(tokenizer_proto);
}
bool GemmaTokenizer::Encode(const std::string& input,
std::vector<std::string>* pieces) const {
return impl_->Encode(input, pieces);
@ -114,57 +104,109 @@ bool GemmaTokenizer::Decode(const std::vector<int>& ids,
return impl_->Decode(ids, detokenized);
}
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const ModelInfo& info, size_t pos,
std::string& prompt) {
Wrap(info, pos, prompt);
GemmaChatTemplate::GemmaChatTemplate(const GemmaTokenizer& tokenizer,
Model model) {
sot_user_.reserve(3);
if (!tokenizer.Encode("<start_of_turn>user\n", &sot_user_)) return;
sot_model_.reserve(3);
HWY_ASSERT(tokenizer.Encode("<start_of_turn>model\n", &sot_model_));
eot_.reserve(2);
HWY_ASSERT(tokenizer.Encode("<end_of_turn>\n", &eot_));
std::vector<int> tokens;
HWY_ASSERT(tokenizer.Encode(prompt, &tokens));
// Both pre-trained and instruction-tuned require BOS as first token.
if (pos == 0) {
tokens.insert(tokens.begin(), BOS_ID);
}
// PaliGemma separator. The SEP token "\n" is always tokenized separately.
if (info.wrapping == PromptWrapping::PALIGEMMA
// || info.wrapping == PromptWrapping::GEMMA_VLM
) {
std::vector<int> sep_tokens;
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end());
}
return tokens;
HWY_ASSERT(tokenizer.Encode("\n", &pali_sep_));
vlm_soi_.reserve(2);
HWY_ASSERT(tokenizer.Encode("\n\n<start_of_image>", &vlm_soi_));
vlm_eoi_.reserve(2);
HWY_ASSERT(tokenizer.Encode("<end_of_image>\n\n", &vlm_eoi_));
}
std::vector<int> WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info,
size_t pos, std::vector<int>& tokens,
size_t image_batch_size, size_t max_image_batch_size) {
HWY_ASSERT(info.wrapping == PromptWrapping::GEMMA_VLM);
size_t num_images = hwy::DivCeil(image_batch_size, max_image_batch_size);
std::vector<int> GemmaChatTemplate::Apply(size_t pos,
const std::vector<int>& ids) const {
HWY_ASSERT_M(!sot_user_.empty() && !sot_model_.empty() && !eot_.empty(),
"GemmaChatTemplate has not been initialized.");
std::vector<int> out;
out.reserve(eot_.size() + sot_user_.size() + ids.size() + eot_.size() +
sot_model_.size());
std::vector<int> sep_tokens;
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
std::string begin_image_prompt = "\n\n<start_of_image>";
std::vector<int> begin_image_tokens =
WrapAndTokenize(tokenizer, info, pos, begin_image_prompt);
std::string end_image_prompt = "<end_of_image>\n\n";
std::vector<int> end_image_tokens =
WrapAndTokenize(tokenizer, info, pos, end_image_prompt);
for (size_t i = 0; i < num_images; ++i) {
tokens.insert(tokens.begin(), begin_image_tokens.begin(),
begin_image_tokens.end());
tokens.insert(tokens.begin() + begin_image_tokens.size(), image_batch_size,
-2);
tokens.insert(tokens.begin() + begin_image_tokens.size() + image_batch_size,
end_image_tokens.begin(), end_image_tokens.end());
// Start with BOS, or prepend end_of_turn if this is a continuation.
if (pos == 0) {
out.push_back(BOS_ID);
} else {
out.insert(out.cend(), eot_.cbegin(), eot_.cend());
}
// Start of user turn, user prompt, end of turn; then start of model turn.
out.insert(out.cend(), sot_user_.cbegin(), sot_user_.cend());
out.insert(out.cend(), ids.cbegin(), ids.cend());
out.insert(out.cend(), eot_.cbegin(), eot_.cend());
out.insert(out.cend(), sot_model_.cbegin(), sot_model_.cend());
return out;
}
return tokens;
std::vector<int> GemmaChatTemplate::WrapPali(const std::vector<int>& text_part,
size_t image_batch_size) const {
HWY_ASSERT_M(!pali_sep_.empty(),
"GemmaChatTemplate has not been initialized.");
std::vector<int> out;
out.reserve(image_batch_size + 1 + text_part.size() + pali_sep_.size());
out.resize(image_batch_size, 0);
out.push_back(BOS_ID);
out.insert(out.cend(), text_part.cbegin(), text_part.cend());
out.insert(out.cend(), pali_sep_.cbegin(), pali_sep_.cend());
return out;
}
std::vector<int> GemmaChatTemplate::WrapVLM(const std::vector<int>& text_part,
size_t image_batch_size) const {
HWY_ASSERT_M(!vlm_soi_.empty() && !vlm_eoi_.empty(),
"GemmaChatTemplate has not been initialized.");
std::vector<int> out;
out.reserve(text_part.size() + vlm_soi_.size() + image_batch_size +
vlm_eoi_.size());
out.insert(out.cend(), text_part.cbegin(), text_part.cend());
out.insert(out.cend(), vlm_soi_.cbegin(), vlm_soi_.cend());
out.insert(out.cend(), image_batch_size, -2);
out.insert(out.cend(), vlm_eoi_.cbegin(), vlm_eoi_.cend());
return out;
}
// Text
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const GemmaChatTemplate& chat_template,
const PromptWrapping wrapping, size_t pos,
const std::string& prompt) {
std::vector<int> tokens;
HWY_ASSERT(tokenizer.Encode(prompt, &tokens));
switch (wrapping) {
case PromptWrapping::GEMMA_IT:
case PromptWrapping::GEMMA_VLM:
return chat_template.Apply(pos, tokens);
default:
if (pos == 0) {
tokens.insert(tokens.cbegin(), BOS_ID);
}
return tokens;
}
}
// Vision
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const GemmaChatTemplate& chat_template,
const PromptWrapping wrapping, size_t pos,
const std::string& prompt,
size_t image_batch_size) {
std::vector<int> text_part;
HWY_ASSERT(tokenizer.Encode(prompt, &text_part));
switch (wrapping) {
case PromptWrapping::PALIGEMMA:
HWY_ASSERT(pos == 0);
return chat_template.WrapPali(text_part, image_batch_size);
case PromptWrapping::GEMMA_VLM:
return chat_template.Apply(
pos, chat_template.WrapVLM(text_part, image_batch_size));
default:
HWY_ASSERT_M(false, "Current variant does not support vision prompt.");
}
}
} // namespace gcpp

Some files were not shown because too many files have changed in this diff Show More